Skip to content

Commit 12749a4

Browse files
HuiyingLiclaudekhazic
authored
feat(nemotron-omni): enable context parallelism for VLM path (#2125)
* feat(nemotron-omni): enable context parallelism for VLM path Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix(nemotron-omni): route prepare_inputs_embeds_for_cp through forward for FSDP2 The vision tower needs FSDP2's forward pre-hooks to all-gather its Linear weights. Calling prepare_inputs_embeds_for_cp directly bypassed those hooks and produced "mixed Tensor and DTensor" errors. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix(capabilities): drill into language_model.config for hybrid VLM wrappers NemotronOmni's hybrid layers_block_type lives on the inner language_model.config, not on the outer wrapper. Without this fix, validate_for_mesh rejects cp_size>1 + sdpa as "requires TE attention", even though hybrid+sdpa is supported. Also: enable activation_checkpointing in both cordv2 CP yamls so ep_size=4 fits in 8x80GB (ep_size=8 frees enough memory without it, but matching configs is required for fair parity). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * test(nemotron-omni): correct CP parity yamls to match prior design Prior CP-omni work (conversation dca92b49) kept ep_size=8 for both baseline and CP=2 runs - only cp_size differed. Reducing ep_size in the test (ep4cp2) reroutes tokens through different experts, producing spurious per-step divergence plus OOM. Match the prior design: - nemotron_omni_v3_cord_v2_ep8cp1.yaml: cp=1, ep=8 (no AC, no ckpt) - nemotron_omni_v3_cord_v2_ep8cp2.yaml: cp=2, ep=8 (no AC, no ckpt) Result: step-0 abs diff 0.0008, overall mean abs diff 0.00086, p95 per step 0.020. Consistent with bf16 + CP attention non-associativity over ~50 layers. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * refactor(nemotron-omni): align VLM CP path with Gemma4 (PR #1914) Split prepare_inputs_embeds_for_cp into: - prepare_model_inputs_for_cp(input_ids, ...) -> dict (worker) - prepare_inputs_embeds_for_cp(...) -> Tensor (thin wrapper) Both take individual tensors instead of a batch dict, matching the gemma4_moe model API in PR #1914 so future VLMs only need to define these two methods to opt into the recipe's CP path. Forward flag rename: return_inputs_embeds_only -> _pre_embed_only. Recipe (recipes/vlm/finetune.py): - Drop _vlm_cp_deferred deferral and the prepare block inside sync_ctx. - Do prepare-then-shard inline at the top of _forward_backward_step, before make_cp_batch_and_ctx (matches PR #1914's flow). - Mirror the same prepare step in _run_validation_epoch. The deferred make_cp_batch_and_ctx pattern was based on a wrong mental model: sync_ctx controls FSDP grad sync, not param materialization. FSDP2 forward pre-hooks fire on any model(...) call regardless of sync_ctx, so the deferral was solving a non-problem. 20-step parity check vs pre-refactor: step-0 bit-identical (cp1) / within 0.5% rel (cp2). No behavior change beyond bf16/MoE-routing noise. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix(vlm/finetune): two latent bugs in validation under cp_size>1 Both surfaced when running val with cp_size=2 (cp_size=1 took the early-return path in make_cp_batch_and_ctx and hid them): 1. AttributeError: 'FSDPNemotronOmniForConditionalGeneration' object has no attribute 'device' at finetune.py:1281. The FSDP2 wrapper does not expose .device. Use self.dist_env.device for the position_ids synthesis target device. 2. KeyError: 'labels' at cp_utils.py:297. Validation popped labels before make_cp_batch_and_ctx, but that function reads batch["labels"] to register it as a CP buffer. Pop after, mirroring the train path. Verified: 20-step ep8cp2 run on medpix now reaches [val] step 19 | epoch 0 | loss 1.0888 cleanly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * test(nemotron-omni-cp): switch parity yamls to TE backend + medpix - backend.attn: sdpa -> te. TE attention engages real ring/p2p CP via DotProductAttention.set_context_parallel_group; the SDPA-CP path used DTensor allgather only (cp_utils.py:355 hardcoded "allgather", with TODO to expose). Verified via temporary debug instrumentation: 6/6 attention blocks set up with cp_comm_type=p2p, 0 skipped. - dataset: cord_v2 -> medpix (mmoukouba/MedPix-VQA). cord_v2 loss saturates below 0.1 within 10 steps, making relative-diff parity metrics noisy. medpix stays in 1.5-2.5 range so the bf16 + CP numerical noise floor is properly bounded. - wandb logging enabled for both runs. Result on medpix: cp1 vs cp2 overall mean abs diff 0.00127 (0.065% relative), all 4 windows pass <0.5% rel. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * refactor(vlm): lift CP multimodal kwarg list to shared util Previously the VLM-CP pre-shard step in recipes/vlm/finetune.py hardcoded a Nemotron-Omni-specific 7-key tuple in two places (train and val paths). This umbrella approach mirrors the existing pattern in components/datasets/vlm/fake_image.py::_VISION_TOKEN_ID_ATTRS where vision-token attribute names from every known VLM family are unioned in one tuple and consumers iterate to pick whichever keys the live model has. Add VLM_INPUT_KEYS to components/utils/model_utils.py — the same module that hosts filter_forward_kwargs (the symmetric runtime kwarg filter). This makes the umbrella accessible from anywhere that already imports from model_utils (recipes, _transformers, models, distributed, datasets, tests) without circular-import risk, since model_utils only depends on shared/import_utils. The umbrella covers: - Nemotron-Omni: pixel_values, image_flags, imgs_sizes, pixel_values_videos, sound_features, sound_attention_mask - Gemma4 (PR #1914): image_position_ids, mm_token_type_ids - Kimi-VL / Qwen-VL / Mistral4: image_grid_hws, image_grid_thw, image_sizes - Phi-4-MM (future): audio_input_values, audio_attention_mask Recipe replaces both hardcoded mm_keys tuples with VLM_INPUT_KEYS and uses it for both the kwarg filter and the post-prepare drop set. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * cleanup(vlm/finetune): collapse nested with, drop dead val code After the refactor that moved the CP pre-shard step out of sync_ctx, three follow-up cleanups in _forward_backward_step / _run_validation_epoch: - Merge `with sync_ctx: with train_ctx():` (the only thing between the two was the pre-shard block which now lives at the top of the function). Use combined `with sync_ctx, train_ctx():`. - Delete the val-side synthetic position_ids block. cp_utils .make_cp_batch_and_ctx already injects a 1D arange when position_ids is missing and cp_mesh.size > 1 (cp_utils.py:288), so the recipe-side fallback was duplicating that work. - Trim two over-explanatory comments down to just the why-non-obvious bit (FSDP2 forward pre-hook all-gathers vision-tower weights). No behavior change. Net -45/+28 lines. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * chore(nemotron-omni-cp): drop ep8cp1 baseline yaml The cp_size=1 yaml was a paired-config baseline for the CP=2 parity test only. For everyday use the existing nemotron_omni_v3_cord_v2.yaml covers the non-CP path. Keep the ep8cp2.yaml as the canonical CP-enabled example. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * test(nemotron-omni-cp): unit tests for VLM CP enablement (49 tests) Covers all behavior changes from the CP work in this branch: utils/test_vlm_input_keys.py (9 tests) - VLM_INPUT_KEYS umbrella shape, no duplicates, importability - per-VLM-family coverage: Nemotron-Omni, Gemma4, Qwen-VL, Kimi-VL, Mistral4, Phi-4-MM - excludes labels/position_ids/attention_mask (NOT multimodal inputs) _transformers/test_capabilities_hybrid_vlm.py (13 tests) - _is_hybrid: outer config markers (layers_block_type, hybrid_override_pattern, is_hybrid_model) - drill into language_model.config when outer lacks markers - empty pattern, missing config, lowercase 'm', None inner config distributed/test_cp_utils_inputs_embeds.py (8 tests) - XOR contract: exactly one of input_ids/inputs_embeds in batch - inputs_embeds becomes primary cp_buffer when present - input_ids path unchanged (backward-compat) - position_ids synthesized from inputs_embeds.shape[1] (seq dim, not hidden) - cp_size<=1 short-circuit applies on inputs_embeds path - padding_mask + 3D mRoPE position_ids on inputs_embeds path models/nemotron_omni/test_nemotron_omni_cp.py (12 tests) - prepare_model_inputs_for_cp returns dict with inputs_embeds - text-only / image / video / sound modality scatter at correct positions - dynamic-res branch takes priority over static when imgs_sizes given - sound branch is no-op when sound_encoder is None - prepare_inputs_embeds_for_cp thin wrapper returns Tensor matching dict path - forward(_pre_embed_only=True) early-returns prepared dict, skips LM - forward(inputs_embeds=...) skips multimodal scatter block recipes/test_finetune_vlm_cp_wiring.py (7 tests) - recipe routes through model.__call__ with _pre_embed_only=True (so FSDP2 forward pre-hook fires) - all VLM_INPUT_KEYS popped after prepare; non-mm keys preserved - missing/None mm keys not forwarded as kwargs - prepare step skipped when model lacks prepare_model_inputs_for_cp - prepare step runs under torch.no_grad - val: do NOT pop labels before make_cp_batch_and_ctx (KeyError fix) - val: position_ids uses self.dist_env.device, not model_parts[0].device (AttributeError fix on FSDP-wrapped model) All 49 tests pass on CPU (no GPU/distributed needed). Tests use stubs and monkeypatching; no real model load. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * test(nemotron-omni-cp): switch ep8cp2 backend to TE linear + deepep dispatcher Align the CP-enabled yaml with the non-CP cordv2 baseline: - linear: torch -> te - enable_deepep: false -> dispatcher: deepep (matches cord_v2.yaml) Now ep8cp2 differs from non-CP cordv2 only in: - attn: sdpa -> te (TE-CP path is the supported one) - cp_size: 1 -> 2 (the actual CP-under-test) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix(nemotron-omni-cp): use HF model ID in ep8cp2 yaml Replace local lustre path with the canonical ``nvidia/Nemotron-3-Nano-Omni-30B-A3B-Reasoning-BF16`` model ID, matching the existing ``nemotron_omni_cord_v2.yaml`` and ``_peft.yaml``. Also resolves a secrets-detector false positive (``Base64 High Entropy String`` triggered on the long lustre path at line 27). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * review: address Claude review feedback on PR #2125 - examples/.../ep8cp2.yaml: fix typo in header comment ("ep4cp2" -> "ep8cp2") and the example run command's filename. - tests/unit_tests/distributed/test_cp_utils_inputs_embeds.py: add 5 tests covering the cp-divisor padding path (cp_utils.py:318-348) that prior tests skipped because seq_len was already cp*2-aligned: * pads all cp_buffers (primary, labels, position_ids) to multiple of 2*cp_size * labels pad with -100 (CE ignore_index); int buffers (input_ids, position_ids) pad with 0; float buffers (inputs_embeds) pad with zeros * loss_mask + padding_mask also padded when present * no-op when seq already aligned (identity-preserved) * input_ids path padding semantics (int=0, labels=-100) 13 tests pass in test_cp_utils_inputs_embeds.py (8 original + 5 new). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * review: address Claude review feedback (round 2) - cp_utils.py: track padding_mask index in cp_buffers and mirror it back into batch after the cp-divisor pad, alongside inputs_embeds/input_ids /labels/position_ids. Avoids a latent shape-mismatch trap for any future model that consumes padding_mask in its forward signature. Added a test (test_padding_mirrors_padding_mask_back_into_batch). - nemotron_omni/model.py: widen forward's return-type annotation to Union[dict, Tuple, CausalLMOutputWithPast] since _pre_embed_only=True returns the prepared-inputs dict directly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * ci: ruff format finetune.py for hidden_states one-liner `ruff format` collapsed the if/else expression onto one line. Resolves the `linting` and `Nemo_Linting_Test` CI failures. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * ci: fix L0_Unit_Tests_CPU regressions from CP work Pre-existing tests broke under the new CP code paths: - test_cp_utils.py (3 tests): bumped seq_len from 3/6 to 4/8 so the cp-divisor padding path (added in 24baff1) no longer fires and pre-existing position_ids/padding_mask assertions still hold. - test_nemotron_omni_dynamic_res.py (2 tests): removed the inputs_embeds=... kwarg from the forward() calls. Caller-supplied inputs_embeds is the CP path which by design skips the multimodal scatter. These tests want the scatter to fire, so they now pass only input_ids + multimodal kwargs and let forward compute the embeds internally. - vlm/finetune.py: replace `self.dist_setup.cp_size > 1` with the device_mesh-derived form already used in the val path. Pre-existing test_finetune_vlm_helpers stubs do not set `dist_setup`, only `device_mesh` + `pp_enabled`, so the train path now matches val and works under the same stub recipe. 174 tests pass across: utils/test_vlm_input_keys.py + _transformers/test_capabilities_hybrid_vlm.py + distributed/test_cp_utils.py + distributed/test_cp_utils_inputs_embeds.py + models/nemotron_omni/ + recipes/test_finetune_vlm_helpers.py + recipes/test_finetune_vlm_cp_wiring.py Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * review: align val CP guard + add torch.no_grad (round 4) Mirror the train-side defensive guard in _run_validation_epoch: - check ``"cp" in device_mesh.mesh_dim_names`` before indexing (avoids KeyError on DDP/non-CP meshes) - wrap the prepare-inputs-embeds call in ``torch.no_grad()`` to match train and avoid retaining vision-tower activations during val Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * review: add not-pp_enabled guard to val _cp_active (round 5) Mirror the train-side ``_cp_active`` guard so val also bails on PP+CP+VLM. PP+CP+VLM is a separate effort (see PR description); without this guard val and train would inconsistently exercise the prepare step. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix(cp_utils): semantic pad sentinels for padding_mask + future buffers The cp-divisor padding pass picked the fill value by tensor identity (``buf is labels`` -> -100, else 0), which silently used the wrong sentinel for ``padding_mask``: bool ``True`` == "this position is pad, ignore" but the dtype-default fill of 0 == ``False`` told the MoE router that the cp-pad slots are real tokens. Concrete impact: on the LLM-CP-SFT path (``default_collater`` auto-emits ``padding_mask``), every step under ``cp_size > 1`` with a non-cp-aligned seq_len would route the cp-pad slots to experts -- wasting expert capacity and skewing load-balance loss. Latent today on the Nemotron-Omni VLM path (cordv2 collator doesn't emit ``padding_mask``) but real for DeepSeek-V3/V4, Qwen3-MoE, Gemma4-MoE, GLM4-MoE, GPT-OSS, etc. once they enable CP=2+. Fix: replace ``is buf labels`` special-case with a per-buffer-key ``PAD_FILL`` table that encodes each tensor's "ignore" sentinel: - labels: -100 (CE ignore_index) - padding_mask: True (bool: True == ignore) - attention_mask: False (HF: 0 == ignore) - default: 0 (input_ids, position_ids, ...) The mapping from cp_buffer index to batch key uses a small ``batch_buffer_keys`` registry replacing the ad-hoc ``padding_mask_idx``. The post-pad batch-mirror loop now iterates the registry, so any future batch-sourced cp_buffer (e.g. cu_seqlens variants) is mirrored back automatically -- no per-key special-case in two places. Two new regression tests: - ``padding_mask`` pad slots are ``True``, not ``False`` - ``attention_mask`` mapping documented in PAD_FILL (path currently unreachable because attention_mask is popped earlier; encoded for when a future PR revisits the strip). 33 tests pass in cp_utils + cp_utils_inputs_embeds. Co-authored-by: khazic <khazzz1c@gmail.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> --------- Signed-off-by: HuiyingLi <willwin.lee@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: khazic <khazzz1c@gmail.com>
1 parent 28db2c1 commit 12749a4

13 files changed

Lines changed: 1654 additions & 49 deletions

File tree

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# NemotronOmni v3 (Reasoning) fine-tuning on CORD-V2 -- ep8cp2 (CP=2 test)
2+
#
3+
# Run:
4+
# cd automodel-omni
5+
# automodel examples/vlm_finetune/nemotron_omni/nemotron_omni_v3_cord_v2_ep8cp2.yaml --nproc-per-node 8
6+
7+
recipe: FinetuneRecipeForVLM
8+
9+
step_scheduler:
10+
global_batch_size: 8
11+
local_batch_size: 1
12+
ckpt_every_steps: 100000
13+
val_every_steps: 100000
14+
max_steps: 100
15+
16+
dist_env:
17+
backend: nccl
18+
timeout_minutes: 30
19+
20+
rng:
21+
_target_: nemo_automodel.components.training.rng.StatefulRNG
22+
seed: 1234
23+
ranked: true
24+
25+
model:
26+
_target_: nemo_automodel.NeMoAutoModelForImageTextToText.from_pretrained
27+
pretrained_model_name_or_path: nvidia/Nemotron-3-Nano-Omni-30B-A3B-Reasoning-BF16
28+
trust_remote_code: true
29+
torch_dtype: torch.bfloat16
30+
backend:
31+
_target_: nemo_automodel.components.models.common.BackendConfig
32+
attn: te
33+
linear: te
34+
rms_norm: torch_fp32
35+
rope_fusion: false
36+
dispatcher: deepep
37+
fake_balanced_gate: false
38+
enable_hf_state_dict_adapter: true
39+
40+
processor:
41+
_target_: transformers.AutoProcessor.from_pretrained
42+
pretrained_model_name_or_path: nvidia/Nemotron-3-Nano-Omni-30B-A3B-Reasoning-BF16
43+
trust_remote_code: true
44+
45+
checkpoint:
46+
enabled: false
47+
48+
distributed:
49+
strategy: fsdp2
50+
tp_size: 1
51+
cp_size: 2
52+
pp_size: 1
53+
ep_size: 8
54+
sequence_parallel: false
55+
56+
freeze_config:
57+
freeze_embeddings: true
58+
freeze_vision_tower: true
59+
freeze_audio_tower: true
60+
freeze_language_model: false
61+
62+
loss_fn:
63+
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
64+
65+
dataset:
66+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
67+
path_or_dataset: mmoukouba/MedPix-VQA
68+
split: train
69+
70+
dataloader:
71+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
72+
num_workers: 0
73+
collate_fn:
74+
_target_: nemo_automodel.components.datasets.vlm.collate_fns.nemotron_omni_collate_fn
75+
max_length: 4096
76+
drop_last: true
77+
78+
validation_dataset:
79+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
80+
path_or_dataset: mmoukouba/MedPix-VQA
81+
split: validation
82+
83+
validation_dataloader:
84+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
85+
num_workers: 1
86+
collate_fn:
87+
_target_: nemo_automodel.components.datasets.vlm.collate_fns.nemotron_omni_collate_fn
88+
max_length: 4096
89+
90+
wandb:
91+
entity: Nemo-automodel
92+
project: huiyingl_workspace
93+
name: nomni_v3_medpix_ep8cp2_te
94+
95+
optimizer:
96+
_target_: torch.optim.AdamW
97+
lr: 1e-4
98+
weight_decay: 0.01
99+
betas: [0.9, 0.95]

nemo_automodel/_transformers/capabilities.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,22 @@ def _is_hybrid(model: "nn.Module") -> bool:
9696
9797
Detected via config attributes used by NemotronH (``layers_block_type``)
9898
and HF hybrid models (``hybrid_override_pattern``, ``is_hybrid_model``).
99+
For VLM wrappers, also inspect the inner ``language_model``'s config.
99100
"""
100-
config = getattr(model, "config", None)
101-
if config is None:
102-
return False
103-
for attr in ("layers_block_type", "hybrid_override_pattern"):
104-
pattern = getattr(config, attr, None)
105-
if pattern and any(str(c).upper() == "M" for c in pattern):
101+
candidates = [getattr(model, "config", None)]
102+
inner = getattr(model, "language_model", None)
103+
if inner is not None:
104+
candidates.append(getattr(inner, "config", None))
105+
for config in candidates:
106+
if config is None:
107+
continue
108+
for attr in ("layers_block_type", "hybrid_override_pattern"):
109+
pattern = getattr(config, attr, None)
110+
if pattern and any(str(c).upper() == "M" for c in pattern):
111+
return True
112+
if getattr(config, "is_hybrid_model", False) is True:
106113
return True
107-
return getattr(config, "is_hybrid_model", False) is True
114+
return False
108115

109116

110117
class ModelSupports:

nemo_automodel/components/distributed/cp_utils.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,23 @@ def _get_mesh_size(mesh):
271271
# so that SDPA handles causal masking internally.
272272
batch.pop("attention_mask", None)
273273

274+
# Determine the primary sequence tensor: inputs_embeds (VLM with CP, where
275+
# multimodal token replacement happened pre-shard) or input_ids (standard LLM).
276+
has_inputs_embeds = "inputs_embeds" in batch
277+
has_input_ids = "input_ids" in batch
278+
assert has_inputs_embeds ^ has_input_ids, (
279+
"make_cp_batch_and_ctx requires exactly one of 'inputs_embeds' or 'input_ids' in batch"
280+
)
281+
if has_inputs_embeds:
282+
primary_seq_tensor = batch["inputs_embeds"]
283+
else:
284+
primary_seq_tensor = batch["input_ids"]
285+
seq_len = primary_seq_tensor.shape[1]
286+
274287
# Skip 1D injection if position_ids already in batch (e.g. mRoPE pre-computed)
275288
if "position_ids" not in batch and (_get_mesh_size(cp_mesh) > 1 or _get_mesh_size(tp_mesh) > 1):
276-
batch["position_ids"] = torch.arange(0, batch["input_ids"].shape[1]).unsqueeze(0).to(batch["input_ids"].device)
289+
batch["position_ids"] = torch.arange(0, seq_len).unsqueeze(0).to(primary_seq_tensor.device)
277290

278-
input_ids = batch["input_ids"]
279291
position_ids = batch["position_ids"]
280292

281293
# Determine correct seq dim for CP sharding
@@ -284,12 +296,19 @@ def _get_mesh_size(mesh):
284296

285297
labels = batch["labels"]
286298

287-
# Collect all available tensors for context parallel
288-
cp_buffers = [input_ids, labels, position_ids]
299+
# Collect all available tensors for context parallel. We track each
300+
# cp_buffer's batch key (when sourced from ``batch``) so the padding pass
301+
# below can pick the semantically-correct fill sentinel and mirror the
302+
# padded tensor back into ``batch``. ``loss_mask`` is passed as an arg
303+
# (not in batch) so it has no key.
304+
primary_key = "inputs_embeds" if has_inputs_embeds else "input_ids"
305+
cp_buffers = [primary_seq_tensor, labels, position_ids]
306+
# inputs_embeds is [B, S, H] → seq_dim=1; input_ids is [B, S] → seq_dim=1
289307
cp_seq_dims = [1, 1, pos_seq_dim]
290-
cp_no_restore_buffers = {input_ids, labels}
308+
cp_no_restore_buffers = {primary_seq_tensor, labels}
309+
batch_buffer_keys: dict[int, str] = {0: primary_key, 1: "labels", 2: "position_ids"}
291310

292-
# Add loss_mask if available
311+
# Add loss_mask if available (passed as arg, not in batch -> no key)
293312
if loss_mask is not None:
294313
cp_buffers.append(loss_mask)
295314
cp_seq_dims.append(1)
@@ -298,10 +317,50 @@ def _get_mesh_size(mesh):
298317
# Add padding_mask if available in batch
299318
if "padding_mask" in batch:
300319
padding_mask = batch["padding_mask"]
320+
batch_buffer_keys[len(cp_buffers)] = "padding_mask"
301321
cp_buffers.append(padding_mask)
302322
cp_seq_dims.append(1)
303323
cp_no_restore_buffers.add(padding_mask)
304324

325+
# Pad sequence length to be divisible by 2 * cp_size (required by
326+
# context_parallel load balancing). The inputs_embeds path can hit
327+
# arbitrary seq lengths from the VLM collator, so we pad here rather
328+
# than relying on dataset-side padding.
329+
#
330+
# Per-buffer pad sentinels: each tensor's "ignore" value is semantic, not
331+
# dtype-derived. ``labels``/``padding_mask``/``attention_mask`` are all
332+
# int/bool but have different ignore conventions. Falling through to 0
333+
# for ``padding_mask`` (== False == "real token") would tell the MoE
334+
# router to route the cp-pad slots to experts -- silently wasting capacity
335+
# and skewing load-balance loss.
336+
PAD_FILL = {
337+
"labels": -100, # CE ignore_index
338+
"padding_mask": True, # bool: True == "this position is pad, ignore"
339+
"attention_mask": False, # HF: 0 == "this position is pad, ignore"
340+
# everything else (input_ids, position_ids, ...) -> 0
341+
}
342+
cp_divisor = cp_mesh.size() * 2
343+
if seq_len % cp_divisor != 0:
344+
pad_len = cp_divisor - (seq_len % cp_divisor)
345+
new_no_restore = set()
346+
for i, (buf, dim) in enumerate(zip(cp_buffers, cp_seq_dims)):
347+
pad_shape = list(buf.shape)
348+
pad_shape[dim] = pad_len
349+
if buf.dtype.is_floating_point:
350+
pad_val = torch.zeros(pad_shape, dtype=buf.dtype, device=buf.device)
351+
else:
352+
fill_val = PAD_FILL.get(batch_buffer_keys.get(i), 0)
353+
pad_val = torch.full(pad_shape, fill_val, dtype=buf.dtype, device=buf.device)
354+
old_buf = buf
355+
cp_buffers[i] = torch.cat([buf, pad_val], dim=dim)
356+
if old_buf in cp_no_restore_buffers:
357+
new_no_restore.add(cp_buffers[i])
358+
cp_no_restore_buffers = new_no_restore
359+
# Mirror every batch-sourced cp_buffer back into ``batch`` so any
360+
# downstream consumer reading from the dict sees the padded shape.
361+
for idx, key in batch_buffer_keys.items():
362+
batch[key] = cp_buffers[idx]
363+
305364
cp_ctx = create_context_parallel_ctx(
306365
cp_mesh=cp_mesh,
307366
cp_buffers=cp_buffers,

0 commit comments

Comments
 (0)