Skip to content

Commit 52fee2b

Browse files
akoumpaclaude
andcommitted
Merge branch 'main' into codex/refactor-component-builders
Resolve conflicts in loss/mtp.py and recipes/llm/train_ft.py by keeping both sides: Alexandros' typed MTPLossConfig (scaling_factor/ignore_index) alongside main's PP/CP seq_idx + cu_seqlens plumbing from #2316. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
2 parents 934fc04 + 987ba24 commit 52fee2b

25 files changed

Lines changed: 2692 additions & 270 deletions

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,16 @@ def save_model(
352352
model_state = ModelState(model, self.config.is_peft)
353353
state_dict = model_state.state_dict()
354354

355-
# Convert to HF format if using custom model implementations
355+
# Convert to HF format if using custom model implementations.
356356
state_dict = _maybe_adapt_state_dict_to_hf(
357357
model_state.model[0],
358358
state_dict,
359359
quantization=False,
360360
device_mesh=self.moe_mesh,
361361
v4_compatible=self.config.v4_compatible,
362362
)
363+
# MoE adapters return non-contiguous views; safetensors.save rejects those.
364+
_materialize_to_hf_views_for_save(state_dict)
363365
# Build the consolidated model.safetensors.index.json if needed
364366
fqn_to_file_index_mapping = self._maybe_build_consolidated_index(model_state, state_dict)
365367
fqn_to_dtype_mapping = self._maybe_build_original_dtype_mapping(model_state, state_dict)
@@ -576,6 +578,8 @@ def load_model(
576578
reader_key_mapping = None if has_state_dict_adapter else key_mapping
577579
storage_reader = self._get_storage_reader(model_path, reader_key_mapping, is_init_step=is_init_step)
578580

581+
# MoE adapters return views into model storage; DCP writes safetensors
582+
# data straight through them and from_hf skips the rebuild.
579583
state_dict = _maybe_adapt_state_dict_to_hf(
580584
model_state.model[0],
581585
state_dict,
@@ -1758,6 +1762,27 @@ def _maybe_adapt_state_dict_to_hf(
17581762
return state_dict
17591763

17601764

1765+
def _materialize_to_hf_views_for_save(state_dict: dict[str, torch.Tensor]) -> None:
1766+
"""Replace non-contiguous tensor values in ``state_dict`` with contiguous copies in place.
1767+
1768+
MoE adapters return non-contiguous strided views into the model's grouped
1769+
expert storage for the optimized load path; ``safetensors.torch.save``
1770+
(which the DCP HF storage writer calls) rejects non-contiguous tensors,
1771+
so we materialize one tensor at a time here with ``empty_cache`` between
1772+
iterations. Per-tensor transient is bounded to a single expert weight
1773+
instead of allocating the full grouped set up front.
1774+
"""
1775+
if not state_dict:
1776+
return
1777+
cuda_available = torch.cuda.is_available()
1778+
for key, value in list(state_dict.items()):
1779+
if isinstance(value, torch.Tensor) and not value.is_contiguous():
1780+
state_dict[key] = value.contiguous()
1781+
del value
1782+
if cuda_available:
1783+
torch.cuda.empty_cache()
1784+
1785+
17611786
def _equally_divide_layers(num_shards: int, keys: list[str]) -> dict[str, int]:
17621787
"""
17631788
Equally divide the state dict keys into num_shards shards.

nemo_automodel/components/distributed/thd_utils.py

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,21 @@ def process_input_for_thd(
5454
[total_tokens, hidden_dim] for 3D embeddings
5555
- 'labels': Reshaped labels tensor of shape [total_tokens]
5656
- 'position_ids': Reshaped tensor of shape [total_tokens]
57-
- 'cu_seqlens': Cumulative padded sequence lengths tensor of shape [num_sequences + 1] (int32)
57+
- 'cu_seqlens': Cumulative REAL sequence lengths tensor of shape [num_sequences + 1] (int32)
5858
where num_sequences is the total count of non-padded sequences across the batch.
59-
NOTE: This contains cumulative lengths from seq_lens_padded (not seq_lens) since
60-
CP doesn't support padding between sequences (resulting in NaNs). The labels or loss mask
61-
will ensure that loss is computed correctly.
59+
Built from seq_lens (the unpadded real lengths). When the trailing pack-pad is
60+
purely at the end (cp_size == 1), the last entry is grown to total_tokens to absorb
61+
that pad and avoid TE's ``pad_between_seqs=True`` path; see the absorption block in
62+
the function body for the gate.
63+
- 'cu_seqlens_padded': (optional) Cumulative PADDED sequence lengths tensor of the same
64+
shape as ``cu_seqlens``. Only emitted when it differs from ``cu_seqlens`` after
65+
absorption (i.e., when padding lives between sub-sequences, which is the CP case).
66+
Forwarded to TE as ``cu_seqlens_q_padded`` / ``cu_seqlens_kv_padded`` with
67+
``pad_between_seqs=True`` so the kernel reads memory offsets from the padded
68+
variant while attending only over the real-length slots.
69+
- 'max_seqlen': Scalar int32 tensor equal to ``max(cu_seqlens[i+1] - cu_seqlens[i])``
70+
after any absorption. Honors TE's contract that
71+
``max_seqlen_q >= max(cu_seqlens_q[i+1] - cu_seqlens_q[i])``.
6272
- 'padding_mask': Boolean tensor of shape [total_tokens] indicating padding positions
6373
- Non-tensor keys from input batch are preserved (e.g., 'qkv_format')
6474
@@ -77,8 +87,11 @@ def process_input_for_thd(
7787
>>> # result['input_ids'].shape: [12] (2D input collapsed to 1D)
7888
>>> # result['labels'].shape: [12]
7989
>>> # result['position_ids'].shape: [12]
80-
>>> # result['cu_seqlens']: tensor([0, 4, 6, 12], dtype=torch.int32)
90+
>>> # result['cu_seqlens']: tensor([0, 3, 5, 11], dtype=torch.int32)
91+
>>> # Breakdown: [0] + cumsum([3, 2, 6]) = [0, 3, 5, 11] (from seq_lens — real lengths)
92+
>>> # result['cu_seqlens_padded']: tensor([0, 4, 6, 12], dtype=torch.int32)
8193
>>> # Breakdown: [0] + cumsum([4, 2, 6]) = [0, 4, 6, 12] (from seq_lens_padded)
94+
>>> # result['max_seqlen']: tensor(6, dtype=torch.int32) # max slot width in cu_seqlens
8295
>>> # result['padding_mask'].shape: [12]
8396
"""
8497
input_ids = batch["input_ids"]
@@ -96,13 +109,13 @@ def process_input_for_thd(
96109
input_ids_thd = input_ids.reshape(total_tokens, -1).squeeze(-1)
97110
labels_thd = labels.reshape(total_tokens, -1).squeeze(-1)
98111

112+
cu_seqlens = None
113+
cu_seqlens_padded = None
114+
max_seqlen = None
99115
if seq_lens is not None:
100-
# Filter out padding values and flatten
101-
# seq_lens shape: [batch_size, num_packs] -> flatten and remove padding values
102116
seq_lens_flat = seq_lens.reshape(-1)
103117
valid_seq_lens = seq_lens_flat[seq_lens_flat != seq_lens_padding_value]
104118

105-
# Compute cumulative sequence lengths for attention
106119
cu_seqlens = torch.cat(
107120
[
108121
torch.tensor([0], dtype=valid_seq_lens.dtype, device=valid_seq_lens.device),
@@ -112,7 +125,6 @@ def process_input_for_thd(
112125
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(device=valid_seq_lens.device)
113126

114127
if seq_lens_padded is not None:
115-
# Same processing for padded sequence lengths
116128
seq_lens_padded_flat = seq_lens_padded.reshape(-1)
117129
valid_seq_lens_padded = seq_lens_padded_flat[seq_lens_padded_flat != seq_lens_padding_value]
118130

@@ -121,16 +133,46 @@ def process_input_for_thd(
121133
)
122134
cu_seqlens_padded = cu_seqlens_padded.to(dtype=torch.int32).to(device=valid_seq_lens_padded.device)
123135

136+
# Trailing-only pack-pad (cp_size==1): absorb into cu_seqlens[-1] so
137+
# the emit gate below drops cu_seqlens_padded and TE skips its
138+
# pad_between_seqs=True path. CP>1 differs in multiple entries and
139+
# falls through; both arrays are emitted and TE handles padding.
140+
if (
141+
cu_seqlens is not None
142+
and cu_seqlens_padded is not None
143+
and cu_seqlens.numel() == cu_seqlens_padded.numel()
144+
and cu_seqlens.numel() > 1
145+
and torch.equal(cu_seqlens[:-1], cu_seqlens_padded[:-1])
146+
):
147+
_total = int(total_tokens)
148+
_real_total = int(cu_seqlens[-1].item())
149+
if _real_total < _total:
150+
_extended = cu_seqlens.clone()
151+
_extended[-1] = _total
152+
cu_seqlens = _extended
153+
cu_seqlens_padded = cu_seqlens.clone()
154+
155+
# Compute max_seqlen from the FINAL cu_seqlens to honor TE's contract
156+
# (``max_seqlen_q >= max(cu_seqlens[i+1] - cu_seqlens[i])``, see TE's
157+
# cpp_extensions/fused_attn.py:152-159).
158+
if cu_seqlens is not None and cu_seqlens.numel() > 1:
159+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().to(dtype=torch.int32)
160+
124161
result = {
125162
"input_ids": input_ids_thd,
126163
"position_ids": position_ids_thd,
127-
# Pass cu_seqlens_padded here since CP doesn't support padding between sequences correctly, the labels or loss mask will ensure that loss is computed correctly.
128-
"cu_seqlens": cu_seqlens_padded,
164+
"cu_seqlens": cu_seqlens,
129165
"labels": labels_thd,
130166
"padding_mask": (input_ids_thd == padding_token_id),
131167
}
168+
# Emit cu_seqlens_padded only when it differs from cu_seqlens — its
169+
# presence is what flips TE's pad_between_seqs=True path in
170+
# attention/utils.py.
171+
if cu_seqlens_padded is not None and not torch.equal(cu_seqlens_padded, cu_seqlens):
172+
result["cu_seqlens_padded"] = cu_seqlens_padded
173+
if max_seqlen is not None:
174+
result["max_seqlen"] = max_seqlen
132175

133-
# Preserve qkv_format and other non-tensor keys from the original batch
134176
for key, value in batch.items():
135177
if key not in result and not isinstance(value, torch.Tensor):
136178
result[key] = value
@@ -175,8 +217,14 @@ def split_batch_into_thd_chunks(
175217
- 'input_ids': [num_chunks, tokens_per_chunk] or [num_chunks, tokens_per_chunk, hidden_dim]
176218
- 'labels': [num_chunks, tokens_per_chunk]
177219
- 'position_ids': [num_chunks, tokens_per_chunk]
178-
- 'cu_seqlens': [num_chunks, max_sequences_per_chunk + 1] (padded with seq_lens_padding_value).
179-
Contains cumulative lengths from seq_lens_padded for CP compatibility.
220+
- 'cu_seqlens': [num_chunks, max_sequences_per_chunk + 1] (right-padded with
221+
seq_lens_padding_value across chunks for rectangularity). Built from seq_lens
222+
(real lengths) per chunk; see ``process_input_for_thd`` for the absorption
223+
semantics applied per chunk.
224+
- 'cu_seqlens_padded': (optional) Same shape, emitted whenever ANY chunk emits it.
225+
For chunks that absorbed (no separate padded variant), this row equals the
226+
chunk's ``cu_seqlens``.
227+
- 'max_seqlen': [num_chunks] per-chunk scalar tensor.
180228
- 'padding_mask': [num_chunks, tokens_per_chunk]
181229
- Non-tensor keys from input batch are preserved
182230
- When num_chunks <= 1:
@@ -230,12 +278,21 @@ def pad_and_stack(tensor_list, padding_value):
230278
for i in range(num_chunks)
231279
]
232280

233-
# Stack results
234-
return {
281+
stacked: dict = {
235282
"input_ids": torch.stack([c["input_ids"] for c in chunk_results]),
236283
"labels": torch.stack([c["labels"] for c in chunk_results]),
237284
"position_ids": torch.stack([c["position_ids"] for c in chunk_results]),
238285
"cu_seqlens": pad_and_stack([c["cu_seqlens"] for c in chunk_results], seq_lens_padding_value),
239286
"padding_mask": torch.stack([c["padding_mask"] for c in chunk_results]),
240-
**{k: v for k, v in chunk_results[0].items() if not isinstance(v, torch.Tensor)},
241287
}
288+
# Emit cu_seqlens_padded whenever any chunk emits it; absorbed chunks
289+
# fall back to their cu_seqlens (semantically equal) for rectangularity.
290+
if any("cu_seqlens_padded" in c for c in chunk_results):
291+
stacked["cu_seqlens_padded"] = pad_and_stack(
292+
[c.get("cu_seqlens_padded", c["cu_seqlens"]) for c in chunk_results],
293+
seq_lens_padding_value,
294+
)
295+
if all("max_seqlen" in c for c in chunk_results):
296+
stacked["max_seqlen"] = torch.stack([c["max_seqlen"] for c in chunk_results])
297+
stacked.update({k: v for k, v in chunk_results[0].items() if not isinstance(v, torch.Tensor)})
298+
return stacked

nemo_automodel/components/loss/mtp.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def calculate_mtp_loss(
3333
scaling_factor: float = 0.1,
3434
num_label_tokens: Optional[int] = None,
3535
ignore_index: int = -100,
36+
cu_seqlens: Optional[torch.Tensor] = None,
37+
seq_idx: Optional[torch.Tensor] = None,
3638
) -> torch.Tensor:
3739
"""Compute the DeepSeek-V3 Multi-Token Prediction auxiliary loss.
3840
@@ -53,6 +55,14 @@ def calculate_mtp_loss(
5355
base loss for sum-reduction normalization).
5456
ignore_index: Label value masked out of the CE loss for the trailing
5557
``k+1`` rolled positions at depth ``k``.
58+
cu_seqlens: Optional cumulative sequence lengths ``[num_seqs+1]``
59+
(THD-pack layout). When supplied and ``seq_idx`` is not, builds
60+
a per-token sub-sequence index via searchsorted. Without packing
61+
this can be omitted.
62+
seq_idx: Optional per-token sub-sequence index ``[B, S]`` (or ``[S]``).
63+
Equality classes are what matter; absolute values can be any
64+
ints. Takes precedence over ``cu_seqlens``. Used to mask label
65+
rolls whose source position lies in a different sub-sequence.
5666
5767
Returns:
5868
Scalar MTP loss with autograd graph.
@@ -61,15 +71,62 @@ def calculate_mtp_loss(
6171
raise ValueError("Provide exactly one of mtp_per_depth_h or mtp_per_depth_logits")
6272

6373
mtp_outputs = mtp_per_depth_logits if mtp_per_depth_logits is not None else mtp_per_depth_h
74+
75+
# Reconcile per-depth output and label dims for the THD-packed non-PP path:
76+
# the model unsqueezes outputs from ``[T, *]`` back to ``[1, T, *]`` (model.py
77+
# post-MTP-forward), while labels arrive as 1D ``[T]`` from
78+
# ``process_input_for_thd``. ``FusedLinearCrossEntropy`` / ``cut_cross_entropy``
79+
# asserts ``hidden_states.shape[:-1] == labels.shape`` so squeeze the synthetic
80+
# batch axis when labels are flat.
81+
if labels.dim() == 1:
82+
mtp_outputs = [h.squeeze(0) if (h.dim() == 3 and h.shape[0] == 1) else h for h in mtp_outputs]
83+
6484
D = len(mtp_outputs)
6585
cur_labels = labels
6686
total = mtp_outputs[0].new_zeros(())
87+
88+
if seq_idx is None and cu_seqlens is not None:
89+
cs = cu_seqlens
90+
if cs.dim() == 2 and cs.shape[0] == 1:
91+
cs = cs.squeeze(0)
92+
if cs.dim() == 1:
93+
# Span the full (padded) token axis; cu_seqlens[-1] excludes tail pad.
94+
# Matches the model's mamba seq_idx build (nemotron_v3/layers.py).
95+
total_len = labels.shape[-1]
96+
positions = torch.arange(total_len, device=labels.device)
97+
# ``right=True`` so a position equal to a boundary (the first token
98+
# of sub-seq k, position == cu_seqlens[k]) maps to k, not k-1.
99+
seq_idx = torch.searchsorted(cs[1:].contiguous(), positions, right=True)
100+
if labels.dim() == 2:
101+
seq_idx = seq_idx.unsqueeze(0).expand(labels.shape[0], -1)
102+
elif seq_idx is not None:
103+
if seq_idx.dim() == 1 and labels.dim() == 2:
104+
seq_idx = seq_idx.unsqueeze(0).expand(labels.shape[0], -1)
105+
elif seq_idx.dim() == 2 and labels.dim() == 1 and seq_idx.shape[0] == 1:
106+
seq_idx = seq_idx.squeeze(0)
107+
# Under PP the caller must chunk seq_idx to per-microbatch shape; a
108+
# mismatch is a wiring bug, not a runtime condition to swallow.
109+
if seq_idx.shape != labels.shape:
110+
raise ValueError(
111+
f"calculate_mtp_loss: seq_idx.shape={tuple(seq_idx.shape)} does not "
112+
f"match labels.shape={tuple(labels.shape)}; under PP, chunk seq_idx "
113+
f"into per-microbatch pieces before passing it in."
114+
)
115+
67116
for k, mtp_output in enumerate(mtp_outputs):
68117
cur_labels = roll_tensor(cur_labels, shifts=-1, dim=-1)
69118
masked = cur_labels.clone()
70119
n_invalid = min(k + 1, masked.shape[-1])
71120
masked[..., -n_invalid:] = ignore_index
72121

122+
# Mask labels whose rolled source (position t+k+1) lives in a
123+
# different sub-seq than position t — predictions across sub-seq
124+
# boundaries are nonsensical.
125+
if seq_idx is not None:
126+
rolled_seq_idx = roll_tensor(seq_idx, shifts=-(k + 1), dim=-1)
127+
cross_seq = rolled_seq_idx != seq_idx
128+
masked = torch.where(cross_seq, torch.full_like(masked, ignore_index), masked)
129+
73130
if mtp_per_depth_logits is not None:
74131
if isinstance(loss_fn, FusedLinearCrossEntropy):
75132
raise ValueError("MTP logits are incompatible with FusedLinearCrossEntropy")
@@ -105,7 +162,15 @@ def calculate_mtp_loss(
105162

106163

107164
class PipelineCausalLMLoss(nn.Module):
108-
"""Pipeline schedule loss that can add MTP auxiliary CE on the last stage."""
165+
"""Pipeline schedule loss that can add MTP auxiliary CE on the last stage.
166+
167+
Per-microbatch ``seq_idx`` is read from a trailing element of the
168+
last-stage output tuple — the model appends an ``[B, S] int32`` tail
169+
when MTP is enabled. This binds each microbatch's seq_idx to its loss
170+
call via the PP runtime's output→loss contract, so the wiring is
171+
schedule-agnostic. Legacy ``cu_seqlens`` (THD path) is a fallback for
172+
models that don't emit a seq_idx tail.
173+
"""
109174

110175
def __init__(
111176
self,
@@ -119,8 +184,26 @@ def __init__(
119184
self.model = model
120185
self.scaling_factor = scaling_factor
121186
self.ignore_index = ignore_index
187+
# Legacy THD-pack fallback used when the model has no seq_idx tail.
188+
self.cu_seqlens: Optional[torch.Tensor] = None
189+
190+
@staticmethod
191+
def _extract_seq_idx_tail(output) -> tuple[Optional[torch.Tensor], object]:
192+
"""Detect and strip a trailing per-microbatch seq_idx from output.
193+
194+
Convention: with MTP enabled the last-stage output is
195+
``(logits, *mtp_per_depth_h, seq_idx)`` with an ``[B, S] int32``
196+
tail — dtype alone discriminates.
197+
"""
198+
if isinstance(output, tuple) and len(output) > 0:
199+
last = output[-1]
200+
if isinstance(last, torch.Tensor) and last.dtype == torch.int32 and last.dim() == 2:
201+
return last, output[:-1]
202+
return None, output
122203

123204
def forward(self, output, labels: torch.Tensor) -> torch.Tensor:
205+
seq_idx_mb, output = self._extract_seq_idx_tail(output)
206+
124207
if isinstance(output, tuple):
125208
logits = output[0]
126209
hidden_states = None
@@ -156,6 +239,8 @@ def forward(self, output, labels: torch.Tensor) -> torch.Tensor:
156239
model=self.model,
157240
scaling_factor=scaling_factor,
158241
ignore_index=self.ignore_index,
242+
cu_seqlens=self.cu_seqlens,
243+
seq_idx=seq_idx_mb,
159244
)
160245
return loss
161246

0 commit comments

Comments
 (0)