nemo_rl.utils.automodel_checkpoint#
Automodel checkpoint utilities for DTensor policy workers.
This module provides a wrapper class around the nemo_automodel Checkpointer for saving and loading model checkpoints in DTensor-based policy workers.
Module Contents#
Classes#
Manages checkpointing for DTensor-based models using nemo_automodel’s Checkpointer. |
Functions#
Detect model save format and PEFT status from checkpoint directory. |
|
Infer checkpoint root directory from weights path. |
API#
- class nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager(
- dp_mesh: torch.distributed.device_mesh.DeviceMesh,
- tp_mesh: torch.distributed.device_mesh.DeviceMesh,
- model_state_dict_keys: Optional[list[str]] = None,
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
Manages checkpointing for DTensor-based models using nemo_automodel’s Checkpointer.
This class provides a clean interface for saving and loading model checkpoints, wrapping the nemo_automodel Checkpointer with configuration management.
.. attribute:: checkpointer
The underlying nemo_automodel Checkpointer instance.
.. attribute:: checkpoint_config
The current checkpoint configuration.
.. attribute:: model_state_dict_keys
List of model state dict keys for checkpoint validation.
Initialization
Initialize the AutomodelCheckpointManager.
- Parameters:
dp_mesh – The data parallel device mesh.
tp_mesh – The tensor parallel device mesh.
model_state_dict_keys – Optional list of model state dict keys.
moe_mesh – Optional MoE device mesh.
- _get_dp_rank() int#
Get the data parallel rank.
- _get_tp_rank() int#
Get the tensor parallel rank.
- init_checkpointer(
- config_updates: Optional[dict[str, Any]] = None,
- checkpoint_root: Optional[str] = None,
Initialize the Automodel Checkpointer if not already created.
This method creates a new Checkpointer instance with the provided configuration. If a checkpointer already exists, this method does nothing.
- Parameters:
config_updates – Dict of CheckpointingConfig fields to set during initialization.
checkpoint_root – Optional root directory for checkpoints.
- update_checkpointer_config(
- config_updates: Optional[dict[str, Any]] = None,
- checkpoint_root: Optional[str] = None,
Update the configuration of an existing Checkpointer.
This method updates the mutable config fields on the existing Checkpointer instance. If no checkpointer exists, this method does nothing.
Note: Some config changes (like model_save_format) require rebuilding the checkpointer’s internal addons list. This method handles that automatically.
- Parameters:
config_updates – Dict of CheckpointingConfig fields to update.
checkpoint_root – Optional root directory for checkpoints.
- _rebuild_checkpointer_addons() None#
Rebuild the checkpointer’s _addons list based on current config.
The Checkpointer’s _addons list is populated during init based on config. When config changes (e.g., model_save_format or is_peft), we need to rebuild the addons list to match the new config.
- set_model_state_dict_keys(keys: list[str]) None#
Set the model state dict keys for checkpoint validation.
- Parameters:
keys – List of model state dict keys.
- load_base_model(
- model: torch.nn.Module,
- model_name: str,
- hf_cache_dir: Optional[str] = None,
- dequantize_base_checkpoint: bool = False,
- peft_init_method: Optional[str] = None,
Load base model weights using the Automodel Checkpointer.
This method loads the initial HuggingFace model weights into the parallelized model.
- Parameters:
model – The model to load weights into.
model_name – Name or path of the model.
hf_cache_dir – Optional HuggingFace cache directory.
dequantize_base_checkpoint – Whether to dequantize the base checkpoint.
- Raises:
AssertionError – If checkpointer has not been initialized.
- save_checkpoint(
- model: torch.nn.Module,
- weights_path: str,
- optimizer: Optional[torch.optim.Optimizer] = None,
- optimizer_path: Optional[str] = None,
- scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
- tokenizer: Optional[transformers.AutoTokenizer] = None,
- tokenizer_path: Optional[str] = None,
- checkpointing_cfg: Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None,
- lora_enabled: bool = False,
- peft_config: Optional[nemo_automodel.components._peft.lora.PeftConfig] = None,
Save a checkpoint of the model.
The optimizer states are saved only if
optimizerandoptimizer_pathare provided.- Parameters:
model – The model to save.
weights_path – Path to save model weights.
optimizer – Optional optimizer to save.
optimizer_path – Optional path to save optimizer state.
scheduler – Optional learning rate scheduler.
tokenizer – Optional tokenizer to save with the checkpoint.
tokenizer_path – Optional path to save tokenizer separately.
checkpointing_cfg – Checkpointing configuration.
lora_enabled – Whether LoRA is enabled.
peft_config – Optional PEFT configuration.
- load_checkpoint(
- model: torch.nn.Module,
- weights_path: str,
- optimizer: Optional[torch.optim.Optimizer] = None,
- optimizer_path: Optional[str] = None,
- scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
Load a checkpoint into the model using Automodel Checkpointer.
- Parameters:
model – The model to load weights into.
weights_path – Path to the checkpoint weights.
optimizer – Optional optimizer to load state into.
optimizer_path – Optional path to optimizer checkpoint.
scheduler – Optional learning rate scheduler.
- nemo_rl.utils.automodel_checkpoint.detect_checkpoint_format(weights_path: str) tuple[str, bool]#
Detect model save format and PEFT status from checkpoint directory.
- Parameters:
weights_path – Path to the checkpoint directory (e.g., weights/model)
- Returns:
(model_save_format, is_peft) where: model_save_format is “torch_save” for DCP or “safetensors” for safetensors is_peft is True if PEFT/adapter patterns are detected
- Return type:
tuple
- nemo_rl.utils.automodel_checkpoint._infer_checkpoint_root(weights_path: str) str#
Infer checkpoint root directory from weights path.
When weights_path ends with “…/weights/model”, we need the parent of the weights directory (the checkpoint root), not the weights directory itself.
- Parameters:
weights_path – Path to model weights (e.g., “/path/to/policy/weights/model”)
- Returns:
Checkpoint root directory (e.g., “/path/to/policy”)
- Return type:
str