nemo_rl.models.megatron.data#
Module Contents#
Classes#
Processed microbatch inputs used for model forward pass. |
|
Container for a processed microbatch ready for model forward pass. |
Functions#
Wrap a raw microbatch iterator to yield processed microbatches. |
|
Create a processed microbatch iterator from a batch of data. |
|
Process a microbatch for Megatron model forward pass. |
|
Process a global batch and compute normalization factors. |
|
Pack sequences for Megatron model processing with optional context parallelism. |
|
Get pack sequence parameters for Megatron model processing with optional context parallelism. |
|
Unpack sequences from Megatron output format. |
|
API#
- class nemo_rl.models.megatron.data.ProcessedInputs#
Processed microbatch inputs used for model forward pass.
- input_ids: torch.Tensor#
None
- input_ids_cp_sharded: torch.Tensor#
None
- attention_mask: Optional[torch.Tensor]#
None
- position_ids: Optional[torch.Tensor]#
None
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams]#
None
- cu_seqlens_padded: Optional[torch.Tensor]#
None
- class nemo_rl.models.megatron.data.ProcessedMicrobatch#
Container for a processed microbatch ready for model forward pass.
This dataclass holds both the original data dictionary and the processed tensors needed for the Megatron model forward pass.
.. attribute:: data_dict
The original BatchedDataDict containing raw batch data
.. attribute:: input_ids
Processed input token IDs (may be packed for sequence packing)
.. attribute:: input_ids_cp_sharded
Context-parallel sharded input token IDs
.. attribute:: attention_mask
Attention mask tensor (None for packed sequences)
.. attribute:: position_ids
Position IDs tensor (None for packed sequences)
.. attribute:: packed_seq_params
PackedSeqParams for sequence packing (None if not packing)
.. attribute:: cu_seqlens_padded
Padded cumulative sequence lengths (None if not packing)
- data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]#
None
- input_ids: torch.Tensor#
None
- input_ids_cp_sharded: torch.Tensor#
None
- attention_mask: Optional[torch.Tensor]#
None
- position_ids: Optional[torch.Tensor]#
None
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams]#
None
- cu_seqlens_padded: Optional[torch.Tensor]#
None
- nemo_rl.models.megatron.data.make_processed_microbatch_iterator(
- raw_iterator: Iterator[nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any]],
- cfg: dict[str, Any],
- seq_length_key: Optional[str],
- pad_individual_seqs_to_multiple_of: int,
- pad_packed_seq_to_multiple_of: int,
- straggler_timer: megatron.core.utils.StragglerDetector,
- pad_full_seq_to: Optional[int],
Wrap a raw microbatch iterator to yield processed microbatches.
This function takes a raw iterator that yields BatchedDataDict objects and wraps it to yield ProcessedMicrobatch objects that contain both the original data and the processed tensors ready for model forward pass.
- Parameters:
raw_iterator – Iterator yielding raw BatchedDataDict microbatches
cfg – Configuration dictionary containing sequence_packing settings
seq_length_key – Key for sequence length in data dict (required for packing)
pad_individual_seqs_to_multiple_of – Padding multiple for individual sequences
pad_packed_seq_to_multiple_of – Padding multiple for packed sequences
pad_full_seq_to – Target length for full sequence padding (optional)
- Yields:
ProcessedMicrobatch objects containing processed tensors ready for model forward
- nemo_rl.models.megatron.data.get_microbatch_iterator(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- cfg: dict[str, Any],
- mbs: int,
- straggler_timer: megatron.core.utils.StragglerDetector,
- seq_length_key: Optional[str] = None,
Create a processed microbatch iterator from a batch of data.
This function creates an iterator that yields ProcessedMicrobatch objects, which contain both the original data dictionary and the processed tensors ready for model forward pass.
- Parameters:
data – The batch data to create microbatches from
cfg – Configuration dictionary
mbs – Microbatch size
seq_length_key – Key for sequence lengths in data dict (auto-detected if None)
- Returns:
Tuple containing the iterator and metadata
iterator: Iterator yielding ProcessedMicrobatch objects
data_iterator_len: Number of microbatches in the iterator
micro_batch_size: Size of each microbatch
seq_dim_size: Sequence length dimension size
padded_seq_length: Padded sequence length for pipeline parallelism (may differ from seq_length)
- nemo_rl.models.megatron.data.process_microbatch(
- data_dict: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- seq_length_key: Optional[str] = None,
- pad_individual_seqs_to_multiple_of: int = 1,
- pad_packed_seq_to_multiple_of: int = 1,
- pad_full_seq_to: Optional[int] = None,
- pack_sequences: bool = False,
- straggler_timer: megatron.core.utils.StragglerDetector = None,
Process a microbatch for Megatron model forward pass.
- nemo_rl.models.megatron.data.process_global_batch(
- data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
- loss_fn: nemo_rl.algorithms.interfaces.LossFunction,
- dp_group: torch.distributed.ProcessGroup,
- *,
- batch_idx: int,
- batch_size: int,
Process a global batch and compute normalization factors.
- Parameters:
data – Full dataset
batch_idx – Index of batch to extract
batch_size – Size of batch to extract
loss_fn – Loss function (used to check loss type)
dp_mesh – Data parallel mesh
- Returns:
batch: The extracted batch
global_valid_seqs: Number of valid sequences across all ranks
global_valid_toks: Number of valid tokens across all ranks
- Return type:
Dictionary containing
- nemo_rl.models.megatron.data._pack_sequences_for_megatron(
- input_ids: torch.Tensor,
- seq_lengths: torch.Tensor,
- pad_individual_seqs_to_multiple_of: int = 1,
- pad_packed_seq_to_multiple_of: int = 1,
- pad_packed_seq_to: Optional[int] = None,
- cp_rank: int = 0,
- cp_size: int = 1,
Pack sequences for Megatron model processing with optional context parallelism.
- Parameters:
input_ids – Input token IDs [batch_size, seq_length]
seq_lengths – Actual sequence lengths for each sample [batch_size]
pad_individual_seqs_to_multiple_of – Pad individual sequences to a multiple of this value
pad_packed_seq_to_multiple_of – Pad packed sequences to a multiple of this value
pad_packed_seq_to –
Pad packed sequences to this value (before CP)
The three parameters above can be calculated using _get_pack_sequence_parameters_for_megatron, we do not recommend users to set these parameters manually.
cp_size – Context parallelism size
- Returns:
packed_input_ids: Packed input tensor [1, T]
input_ids_cp_sharded: Sharded input tensor [cp_size, T // cp_size]
packed_seq_params: PackedSeqParams object
cu_seqlens: Cumulative sequence lengths
cu_seqlens_padded: Padded cumulative sequence lengths
- Return type:
Tuple of
- nemo_rl.models.megatron.data._get_pack_sequence_parameters_for_megatron(
- megatron_cfg: dict,
- max_seq_len_in_batch: int,
Get pack sequence parameters for Megatron model processing with optional context parallelism.
- Parameters:
megatron_cfg – Megatron configuration
max_seq_len_in_batch – Maximum sequence length in batch
- Returns:
pad_individual_seqs_to_multiple_of: Pad individual sequences to a multiple of this value
pad_packed_seq_to_multiple_of: Pad packed sequences to a multiple of this value
pad_packed_seq_to: Pad packed sequences to this value (before CP)
- Return type:
Tuple of
- nemo_rl.models.megatron.data._unpack_sequences_from_megatron(
- output_tensor: torch.Tensor,
- seq_lengths: torch.Tensor,
- cu_seqlens: torch.Tensor,
- cu_seqlens_padded: Optional[torch.Tensor],
- original_batch_size: int,
- original_seq_length: int,
Unpack sequences from Megatron output format.
- Parameters:
output_tensor – Packed output tensor [1, T, vocab_size]
seq_lengths – Actual sequence lengths for each sample
cu_seqlens – Cumulative sequence lengths
cu_seqlens_padded – Padded cumulative sequence lengths (if CP was used)
original_batch_size – Original batch size
original_seq_length – Original maximum sequence length
- Returns:
Unpacked output tensor [batch_size, seq_length, vocab_size]
- nemo_rl.models.megatron.data.get_and_validate_seqlen( )#