nemo_rl.models.megatron.data#

Module Contents#

Classes#

ProcessedInputs

Processed microbatch inputs used for model forward pass.

ProcessedMicrobatch

Container for a processed microbatch ready for model forward pass.

Functions#

make_processed_microbatch_iterator

Wrap a raw microbatch iterator to yield processed microbatches.

get_microbatch_iterator

Create a processed microbatch iterator from a batch of data.

process_microbatch

Process a microbatch for Megatron model forward pass.

process_global_batch

Process a global batch and compute normalization factors.

_pack_sequences_for_megatron

Pack sequences for Megatron model processing with optional context parallelism.

_get_pack_sequence_parameters_for_megatron

Get pack sequence parameters for Megatron model processing with optional context parallelism.

_unpack_sequences_from_megatron

Unpack sequences from Megatron output format.

get_and_validate_seqlen

API#

class nemo_rl.models.megatron.data.ProcessedInputs#

Processed microbatch inputs used for model forward pass.

input_ids: torch.Tensor#

None

input_ids_cp_sharded: torch.Tensor#

None

attention_mask: Optional[torch.Tensor]#

None

position_ids: Optional[torch.Tensor]#

None

packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams]#

None

cu_seqlens_padded: Optional[torch.Tensor]#

None

class nemo_rl.models.megatron.data.ProcessedMicrobatch#

Container for a processed microbatch ready for model forward pass.

This dataclass holds both the original data dictionary and the processed tensors needed for the Megatron model forward pass.

.. attribute:: data_dict

The original BatchedDataDict containing raw batch data

.. attribute:: input_ids

Processed input token IDs (may be packed for sequence packing)

.. attribute:: input_ids_cp_sharded

Context-parallel sharded input token IDs

.. attribute:: attention_mask

Attention mask tensor (None for packed sequences)

.. attribute:: position_ids

Position IDs tensor (None for packed sequences)

.. attribute:: packed_seq_params

PackedSeqParams for sequence packing (None if not packing)

.. attribute:: cu_seqlens_padded

Padded cumulative sequence lengths (None if not packing)

data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]#

None

input_ids: torch.Tensor#

None

input_ids_cp_sharded: torch.Tensor#

None

attention_mask: Optional[torch.Tensor]#

None

position_ids: Optional[torch.Tensor]#

None

packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams]#

None

cu_seqlens_padded: Optional[torch.Tensor]#

None

nemo_rl.models.megatron.data.make_processed_microbatch_iterator(
raw_iterator: Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]],
cfg: dict[str, Any],
seq_length_key: Optional[str],
pad_individual_seqs_to_multiple_of: int,
pad_packed_seq_to_multiple_of: int,
straggler_timer: megatron.core.utils.StragglerDetector,
pad_full_seq_to: Optional[int],
) Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch]#

Wrap a raw microbatch iterator to yield processed microbatches.

This function takes a raw iterator that yields BatchedDataDict objects and wraps it to yield ProcessedMicrobatch objects that contain both the original data and the processed tensors ready for model forward pass.

Parameters:
  • raw_iterator – Iterator yielding raw BatchedDataDict microbatches

  • cfg – Configuration dictionary containing sequence_packing settings

  • seq_length_key – Key for sequence length in data dict (required for packing)

  • pad_individual_seqs_to_multiple_of – Padding multiple for individual sequences

  • pad_packed_seq_to_multiple_of – Padding multiple for packed sequences

  • pad_full_seq_to – Target length for full sequence padding (optional)

Yields:

ProcessedMicrobatch objects containing processed tensors ready for model forward

nemo_rl.models.megatron.data.get_microbatch_iterator(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
cfg: dict[str, Any],
mbs: int,
straggler_timer: megatron.core.utils.StragglerDetector,
seq_length_key: Optional[str] = None,
) Tuple[Iterator[nemo_rl.models.megatron.data.ProcessedMicrobatch], int, int, int, int]#

Create a processed microbatch iterator from a batch of data.

This function creates an iterator that yields ProcessedMicrobatch objects, which contain both the original data dictionary and the processed tensors ready for model forward pass.

Parameters:
  • data – The batch data to create microbatches from

  • cfg – Configuration dictionary

  • mbs – Microbatch size

  • seq_length_key – Key for sequence lengths in data dict (auto-detected if None)

Returns:

Tuple containing the iterator and metadata

  • iterator: Iterator yielding ProcessedMicrobatch objects

  • data_iterator_len: Number of microbatches in the iterator

  • micro_batch_size: Size of each microbatch

  • seq_dim_size: Sequence length dimension size

  • padded_seq_length: Padded sequence length for pipeline parallelism (may differ from seq_length)

nemo_rl.models.megatron.data.process_microbatch(
data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
seq_length_key: Optional[str] = None,
pad_individual_seqs_to_multiple_of: int = 1,
pad_packed_seq_to_multiple_of: int = 1,
pad_full_seq_to: Optional[int] = None,
pack_sequences: bool = False,
straggler_timer: megatron.core.utils.StragglerDetector = None,
) tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[megatron.core.packed_seq_params.PackedSeqParams], Optional[torch.Tensor]]#

Process a microbatch for Megatron model forward pass.

nemo_rl.models.megatron.data.process_global_batch(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
dp_group: torch.distributed.ProcessGroup,
*,
batch_idx: int,
batch_size: int,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Process a global batch and compute normalization factors.

Parameters:
  • data – Full dataset

  • batch_idx – Index of batch to extract

  • batch_size – Size of batch to extract

  • loss_fn – Loss function (used to check loss type)

  • dp_mesh – Data parallel mesh

Returns:

  • batch: The extracted batch

  • global_valid_seqs: Number of valid sequences across all ranks

  • global_valid_toks: Number of valid tokens across all ranks

Return type:

Dictionary containing

nemo_rl.models.megatron.data._pack_sequences_for_megatron(
input_ids: torch.Tensor,
seq_lengths: torch.Tensor,
pad_individual_seqs_to_multiple_of: int = 1,
pad_packed_seq_to_multiple_of: int = 1,
pad_packed_seq_to: Optional[int] = None,
cp_rank: int = 0,
cp_size: int = 1,
) tuple[torch.Tensor, megatron.core.packed_seq_params.PackedSeqParams, torch.Tensor, Optional[torch.Tensor]]#

Pack sequences for Megatron model processing with optional context parallelism.

Parameters:
  • input_ids – Input token IDs [batch_size, seq_length]

  • seq_lengths – Actual sequence lengths for each sample [batch_size]

  • pad_individual_seqs_to_multiple_of – Pad individual sequences to a multiple of this value

  • pad_packed_seq_to_multiple_of – Pad packed sequences to a multiple of this value

  • pad_packed_seq_to

    Pad packed sequences to this value (before CP)

    • The three parameters above can be calculated using _get_pack_sequence_parameters_for_megatron, we do not recommend users to set these parameters manually.

  • cp_size – Context parallelism size

Returns:

  • packed_input_ids: Packed input tensor [1, T]

  • input_ids_cp_sharded: Sharded input tensor [cp_size, T // cp_size]

  • packed_seq_params: PackedSeqParams object

  • cu_seqlens: Cumulative sequence lengths

  • cu_seqlens_padded: Padded cumulative sequence lengths

Return type:

Tuple of

nemo_rl.models.megatron.data._get_pack_sequence_parameters_for_megatron(
megatron_cfg: dict,
max_seq_len_in_batch: int,
)#

Get pack sequence parameters for Megatron model processing with optional context parallelism.

Parameters:
  • megatron_cfg – Megatron configuration

  • max_seq_len_in_batch – Maximum sequence length in batch

Returns:

  • pad_individual_seqs_to_multiple_of: Pad individual sequences to a multiple of this value

  • pad_packed_seq_to_multiple_of: Pad packed sequences to a multiple of this value

  • pad_packed_seq_to: Pad packed sequences to this value (before CP)

Return type:

Tuple of

nemo_rl.models.megatron.data._unpack_sequences_from_megatron(
output_tensor: torch.Tensor,
seq_lengths: torch.Tensor,
cu_seqlens: torch.Tensor,
cu_seqlens_padded: Optional[torch.Tensor],
original_batch_size: int,
original_seq_length: int,
) torch.Tensor#

Unpack sequences from Megatron output format.

Parameters:
  • output_tensor – Packed output tensor [1, T, vocab_size]

  • seq_lengths – Actual sequence lengths for each sample

  • cu_seqlens – Cumulative sequence lengths

  • cu_seqlens_padded – Padded cumulative sequence lengths (if CP was used)

  • original_batch_size – Original batch size

  • original_seq_length – Original maximum sequence length

Returns:

Unpacked output tensor [batch_size, seq_length, vocab_size]

nemo_rl.models.megatron.data.get_and_validate_seqlen(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
)#