nemo_rl.models.megatron.common#

Module Contents#

Functions#

_round_up_to_multiple

forward_step_arbitrary_loss

Forward training step with support for packed sequences and context parallelism.

broadcast_tensor

Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata.

get_moe_metrics

Returns Mixture of Experts (MoE) auxiliary-loss metrics.

API#

nemo_rl.models.megatron.common._round_up_to_multiple(value: int, multiple: int) int#
nemo_rl.models.megatron.common.forward_step_arbitrary_loss(
state: megatron.bridge.training.state.GlobalState,
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor,
data_iterator: Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]],
model: megatron.core.models.gpt.GPTModel,
loss_fn: nemo_rl.algorithms.loss_functions.LossFunction,
pack_sequences: bool = False,
defer_fp32_logits: Optional[bool] = None,
cp_normalize: bool = True,
policy_cfg: Optional[dict] = None,
)#

Forward training step with support for packed sequences and context parallelism.

Parameters:
  • state (GlobalState) – Global state for the run

  • global_valid_seqs – Global count of valid sequences

  • global_valid_toks – Global count of valid tokens

  • data_iterator – Input data iterator

  • model (GPTModel) – The GPT Model

  • loss_fn (LossFunction) – Loss function to apply

  • pack_sequences (bool) – Whether to pack sequences for efficiency

  • defer_fp32_logits (Optional[bool]) – Whether to skip the conversion of logits to fp32

  • cp_normalize (bool) – Whether to normalize the loss by the cp_size

  • policy_cfg (Optional[dict]) – Policy configuration containing generation parameters

Notes on packed sequences with context parallelism (CP): - When CP > 1, each sequence is padded to a multiple of (cp_size * 2) - The factor of 2 ensures load balancing for causal attention - cu_seqlens tracks actual sequence boundaries - cu_seqlens_padded tracks padded sequence boundaries for CP - Requires TransformerEngine >= 1.10 for CP support

nemo_rl.models.megatron.common.broadcast_tensor(
tensor: torch.Tensor | None,
src_rank: int,
group: torch.distributed.ProcessGroup,
) torch.Tensor#

Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata.

Handles the case where the input tensor might be None on non-source ranks. If the input tensor is provided on non-source ranks, it must have the correct shape and dtype matching the tensor on the source rank.

Parameters:
  • tensor – The tensor to broadcast on the source rank. Can be None on non-source ranks (will be created with correct shape/dtype). If not None on non-source ranks, it’s used as the buffer for the broadcast and must match the source tensor’s metadata.

  • src_rank (int) – The global rank of the source process.

  • group – The process group for communication.

Returns:

The broadcasted tensor. On non-source ranks, this will be the tensor received from the source.

Return type:

torch.Tensor

Raises:
  • ValueError – If the tensor is None on the source rank, or if a tensor provided on a non-source rank has mismatched shape/dtype/device.

  • TypeError – If broadcasting metadata fails (e.g., due to pickling issues).

nemo_rl.models.megatron.common.get_moe_metrics(
loss_scale: float,
total_loss_dict: Optional[dict] = None,
per_layer_logging: bool = False,
) dict[str, Any]#

Returns Mixture of Experts (MoE) auxiliary-loss metrics.

This function reduces MoE auxiliary losses across ranks, aggregates them, and returns a dictionary of metrics.

Parameters:
  • loss_scale – Scale factor to apply to each auxiliary loss (e.g., 1/num_microbatches).

  • total_loss_dict – If provided, accumulate means into this dict (by name).

  • per_layer_logging – If True, include per-layer values in the returned dict.

Returns:

A flat dict of aggregated metrics. For each aux loss name, the mean value is returned under the same key (e.g., “load_balancing_loss”). If per_layer_logging is True, per-layer values are returned under keys of the form “moe/{name}layer{i}”.

Return type:

dict[str, Any]