nemo_rl.utils.automodel_checkpoint#

Automodel checkpoint utilities for DTensor policy workers.

This module provides a wrapper class around the nemo_automodel Checkpointer for saving and loading model checkpoints in DTensor-based policy workers.

Module Contents#

Classes#

AutomodelCheckpointManager

Manages checkpointing for DTensor-based models using nemo_automodel’s Checkpointer.

Functions#

detect_checkpoint_format

Detect model save format and PEFT status from checkpoint directory.

_infer_checkpoint_root

Infer checkpoint root directory from weights path.

API#

class nemo_rl.utils.automodel_checkpoint.AutomodelCheckpointManager(
dp_mesh: torch.distributed.device_mesh.DeviceMesh,
tp_mesh: torch.distributed.device_mesh.DeviceMesh,
model_state_dict_keys: Optional[list[str]] = None,
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
)#

Manages checkpointing for DTensor-based models using nemo_automodel’s Checkpointer.

This class provides a clean interface for saving and loading model checkpoints, wrapping the nemo_automodel Checkpointer with configuration management.

.. attribute:: checkpointer

The underlying nemo_automodel Checkpointer instance.

.. attribute:: checkpoint_config

The current checkpoint configuration.

.. attribute:: model_state_dict_keys

List of model state dict keys for checkpoint validation.

Initialization

Initialize the AutomodelCheckpointManager.

Parameters:
  • dp_mesh – The data parallel device mesh.

  • tp_mesh – The tensor parallel device mesh.

  • model_state_dict_keys – Optional list of model state dict keys.

  • moe_mesh – Optional MoE device mesh.

_get_dp_rank() int#

Get the data parallel rank.

_get_tp_rank() int#

Get the tensor parallel rank.

init_checkpointer(
config_updates: Optional[dict[str, Any]] = None,
checkpoint_root: Optional[str] = None,
) None#

Initialize the Automodel Checkpointer if not already created.

This method creates a new Checkpointer instance with the provided configuration. If a checkpointer already exists, this method does nothing.

Parameters:
  • config_updates – Dict of CheckpointingConfig fields to set during initialization.

  • checkpoint_root – Optional root directory for checkpoints.

update_checkpointer_config(
config_updates: Optional[dict[str, Any]] = None,
checkpoint_root: Optional[str] = None,
) None#

Update the configuration of an existing Checkpointer.

This method updates the mutable config fields on the existing Checkpointer instance. If no checkpointer exists, this method does nothing.

Note: Some config changes (like model_save_format) require rebuilding the checkpointer’s internal addons list. This method handles that automatically.

Parameters:
  • config_updates – Dict of CheckpointingConfig fields to update.

  • checkpoint_root – Optional root directory for checkpoints.

_rebuild_checkpointer_addons() None#

Rebuild the checkpointer’s _addons list based on current config.

The Checkpointer’s _addons list is populated during init based on config. When config changes (e.g., model_save_format or is_peft), we need to rebuild the addons list to match the new config.

set_model_state_dict_keys(keys: list[str]) None#

Set the model state dict keys for checkpoint validation.

Parameters:

keys – List of model state dict keys.

load_base_model(
model: torch.nn.Module,
model_name: str,
hf_cache_dir: Optional[str] = None,
dequantize_base_checkpoint: bool = False,
peft_init_method: Optional[str] = None,
) None#

Load base model weights using the Automodel Checkpointer.

This method loads the initial HuggingFace model weights into the parallelized model.

Parameters:
  • model – The model to load weights into.

  • model_name – Name or path of the model.

  • hf_cache_dir – Optional HuggingFace cache directory.

  • dequantize_base_checkpoint – Whether to dequantize the base checkpoint.

Raises:

AssertionError – If checkpointer has not been initialized.

save_checkpoint(
model: torch.nn.Module,
weights_path: str,
optimizer: Optional[torch.optim.Optimizer] = None,
optimizer_path: Optional[str] = None,
scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
tokenizer: Optional[transformers.AutoTokenizer] = None,
tokenizer_path: Optional[str] = None,
checkpointing_cfg: Optional[nemo_rl.utils.checkpoint.CheckpointingConfig] = None,
lora_enabled: bool = False,
peft_config: Optional[nemo_automodel.components._peft.lora.PeftConfig] = None,
) None#

Save a checkpoint of the model.

The optimizer states are saved only if optimizer and optimizer_path are provided.

Parameters:
  • model – The model to save.

  • weights_path – Path to save model weights.

  • optimizer – Optional optimizer to save.

  • optimizer_path – Optional path to save optimizer state.

  • scheduler – Optional learning rate scheduler.

  • tokenizer – Optional tokenizer to save with the checkpoint.

  • tokenizer_path – Optional path to save tokenizer separately.

  • checkpointing_cfg – Checkpointing configuration.

  • lora_enabled – Whether LoRA is enabled.

  • peft_config – Optional PEFT configuration.

load_checkpoint(
model: torch.nn.Module,
weights_path: str,
optimizer: Optional[torch.optim.Optimizer] = None,
optimizer_path: Optional[str] = None,
scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
) None#

Load a checkpoint into the model using Automodel Checkpointer.

Parameters:
  • model – The model to load weights into.

  • weights_path – Path to the checkpoint weights.

  • optimizer – Optional optimizer to load state into.

  • optimizer_path – Optional path to optimizer checkpoint.

  • scheduler – Optional learning rate scheduler.

nemo_rl.utils.automodel_checkpoint.detect_checkpoint_format(weights_path: str) tuple[str, bool]#

Detect model save format and PEFT status from checkpoint directory.

Parameters:

weights_path – Path to the checkpoint directory (e.g., weights/model)

Returns:

(model_save_format, is_peft) where: model_save_format is “torch_save” for DCP or “safetensors” for safetensors is_peft is True if PEFT/adapter patterns are detected

Return type:

tuple

nemo_rl.utils.automodel_checkpoint._infer_checkpoint_root(weights_path: str) str#

Infer checkpoint root directory from weights path.

When weights_path ends with “…/weights/model”, we need the parent of the weights directory (the checkpoint root), not the weights directory itself.

Parameters:

weights_path – Path to model weights (e.g., “/path/to/policy/weights/model”)

Returns:

Checkpoint root directory (e.g., “/path/to/policy”)

Return type:

str