Skip to content

feat(nemotron-omni): enable context parallelism for VLM path#2125

Merged
HuiyingLi merged 20 commits intomainfrom
huiyingl/nemotron-omni-cp
May 7, 2026
Merged

feat(nemotron-omni): enable context parallelism for VLM path#2125
HuiyingLi merged 20 commits intomainfrom
huiyingl/nemotron-omni-cp

Conversation

@HuiyingLi
Copy link
Copy Markdown
Contributor

@HuiyingLi HuiyingLi commented May 4, 2026

Summary

Enables Context Parallelism (CP=2) for the Nemotron-Omni VLM training path on the recipes/vlm/finetune.py + components/models/nemotron_omni/ stack. Aligned with the Gemma4 CP design from PR #1914 so future VLM-CP enablement is a model-only change.

What changed

  • Model (components/models/nemotron_omni/model.py): added prepare_model_inputs_for_cp(input_ids, ...) returning dict (worker) + prepare_inputs_embeds_for_cp(...) thin wrapper returning Tensor. forward(_pre_embed_only=True) short-circuits to the prepare step so FSDP2's forward pre-hook fires (and unshards the vision tower's weights) before the multimodal scatter. forward(inputs_embeds=...) skips the multimodal-replacement block (the embeds are already correct after CP sharding).
  • Recipe (recipes/vlm/finetune.py): inline prepare-then-shard at the top of _forward_backward_step and mirrored in _run_validation_epoch. Drops the _vlm_cp_deferred deferral and the sync_ctx-wrapped prepare block — those were based on a wrong mental model (sync_ctx controls FSDP grad sync, not param materialization). Routes through model.__call__ so FSDP2 hooks fire.
  • CP utilities (components/distributed/cp_utils.py): make_cp_batch_and_ctx accepts inputs_embeds as the primary CP buffer when present (XOR contract: exactly one of input_ids/inputs_embeds). position_ids synthesis uses primary_seq_tensor.shape[1] so the embeds path gets the right seq dim.
  • Capabilities (_transformers/capabilities.py): _is_hybrid drills into language_model.config for VLM wrappers. Without this, the gate rejected cp_size>1 + attn=sdpa for Nemotron-Omni even though hybrid+sdpa is supported.
  • Shared umbrella (components/utils/model_utils.py): added VLM_INPUT_KEYS covering Nemotron-Omni, Gemma4, Qwen-VL, Kimi-VL, Mistral4, Phi-4-MM modality kwargs. Mirrors the discovery pattern of fake_image._VISION_TOKEN_ID_ATTRS. Recipe iterates this instead of hardcoding per-model lists in two places.
  • Validation bug fixes (latent under cp_size=1, surfaced under cp_size=2):
    1. 'FSDPNemotronOmniForConditionalGeneration' has no attribute 'device' at finetune.py:1281 → use self.dist_env.device
    2. KeyError: 'labels' in make_cp_batch_and_ctx because val popped labels first → pop after, mirror train path
  • Test config (examples/vlm_finetune/nemotron_omni/nemotron_omni_v3_cord_v2_ep8cp2.yaml): single CP-enabled yaml with attn: te, linear: te, dispatcher: deepep, cp_size: 2, ep_size: 8 on medpix.

Convergence parity (medpix VQA, 100 steps each, 8 H100 80GB)

CP=1 (existing cordv2 baseline) vs CP=2 (this PR) — overall mean abs diff 0.00127 (0.065% relative), all 4 windows pass <0.5% rel. The small non-zero step-0 diff confirms the CP attention path is actually exercised (a bit-exact match would mean CP=2 was a no-op).
image

run wandb
ep8cp2 (final, this PR's config) https://wandb.ai/Nemo-automodel/huiyingl_workspace/runs/kq6oibgx
ep8cp2 prior runs (sdpa, then te) hfkqtoss / nisw5dqr / r3w8to6u
ep8cp1 baseline (cp=1) irtbjayz / uqwiv7tq

Also tested Qwen3 30b LLM EP8CP2 on main branch v.s. on this PR, no regression observed:
main: https://wandb.ai/Nemo-automodel/huiyingl_workspace/runs/jxhl7w2i
PR 2125 https://wandb.ai/Nemo-automodel/huiyingl_workspace/runs/cflbo4fd

Tests

49 new unit tests (CPU-only, no GPU/distributed required), all pass:

  • tests/unit_tests/utils/test_vlm_input_keys.py (9) — umbrella shape + per-VLM-family coverage
  • tests/unit_tests/_transformers/test_capabilities_hybrid_vlm.py (13) — _is_hybrid drill-down semantics
  • tests/unit_tests/distributed/test_cp_utils_inputs_embeds.py (8) — make_cp_batch_and_ctx inputs_embeds path
  • tests/unit_tests/models/nemotron_omni/test_nemotron_omni_cp.py (12) — prepare_*_for_cp and forward(_pre_embed_only=True) semantics
  • tests/unit_tests/recipes/test_finetune_vlm_cp_wiring.py (7) — recipe wiring + val-bug regression tests

Test plan

  • pytest tests/unit_tests/utils/test_vlm_input_keys.py tests/unit_tests/_transformers/test_capabilities_hybrid_vlm.py tests/unit_tests/distributed/test_cp_utils_inputs_embeds.py tests/unit_tests/models/nemotron_omni/test_nemotron_omni_cp.py tests/unit_tests/recipes/test_finetune_vlm_cp_wiring.py (all 49 pass)
  • torchrun --nproc-per-node=8 examples/vlm_finetune/finetune.py -c examples/vlm_finetune/nemotron_omni/nemotron_omni_v3_cord_v2_ep8cp2.yaml (100 steps + val, no errors)
  • Compare loss curves vs the existing cordv2 cp=1 baseline (run with the same yaml + --distributed.cp_size=1); window-mean diff should stay <0.5% relative on medpix-scale losses.
  • Verify TE-CP ring is engaged via temporary [CP-DEBUG] log in moe/parallelizer.py::apply_cp if running on a new model — should report cp_comm_type=p2p and 0 skipped attn blocks.

Notes

  • Existing PP path is gated not self.pp_enabled for the new prepare step — PP+CP+VLM is a separate effort (PP scheduler chunking + CP would need new design; matches PR feat: add Context Parallelism support for Gemma4 dense and MoE VLM #1914's gate).
  • TE-CP at cp_size=2 uses ring/p2p attention (engaged automatically when attn: te). At cp=2, ring vs allgather memory savings are ~0% (ring needs 2/cp_size of K/V); benefits show at cp≥4.

🤖 Generated with Claude Code

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

HuiyingLi and others added 13 commits May 4, 2026 16:26
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…d for FSDP2

The vision tower needs FSDP2's forward pre-hooks to all-gather its
Linear weights. Calling prepare_inputs_embeds_for_cp directly bypassed
those hooks and produced "mixed Tensor and DTensor" errors.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…appers

NemotronOmni's hybrid layers_block_type lives on the inner
language_model.config, not on the outer wrapper. Without this fix,
validate_for_mesh rejects cp_size>1 + sdpa as "requires TE attention",
even though hybrid+sdpa is supported.

Also: enable activation_checkpointing in both cordv2 CP yamls so
ep_size=4 fits in 8x80GB (ep_size=8 frees enough memory without it,
but matching configs is required for fair parity).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Prior CP-omni work (conversation dca92b49) kept ep_size=8 for both
baseline and CP=2 runs - only cp_size differed. Reducing ep_size in the
test (ep4cp2) reroutes tokens through different experts, producing
spurious per-step divergence plus OOM. Match the prior design:

- nemotron_omni_v3_cord_v2_ep8cp1.yaml: cp=1, ep=8 (no AC, no ckpt)
- nemotron_omni_v3_cord_v2_ep8cp2.yaml: cp=2, ep=8 (no AC, no ckpt)

Result: step-0 abs diff 0.0008, overall mean abs diff 0.00086, p95 per
step 0.020. Consistent with bf16 + CP attention non-associativity over
~50 layers.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Split prepare_inputs_embeds_for_cp into:
  - prepare_model_inputs_for_cp(input_ids, ...) -> dict  (worker)
  - prepare_inputs_embeds_for_cp(...)           -> Tensor (thin wrapper)

Both take individual tensors instead of a batch dict, matching the
gemma4_moe model API in PR #1914 so future VLMs only need to define
these two methods to opt into the recipe's CP path.

Forward flag rename: return_inputs_embeds_only -> _pre_embed_only.

Recipe (recipes/vlm/finetune.py):
  - Drop _vlm_cp_deferred deferral and the prepare block inside sync_ctx.
  - Do prepare-then-shard inline at the top of _forward_backward_step,
    before make_cp_batch_and_ctx (matches PR #1914's flow).
  - Mirror the same prepare step in _run_validation_epoch.

The deferred make_cp_batch_and_ctx pattern was based on a wrong mental
model: sync_ctx controls FSDP grad sync, not param materialization.
FSDP2 forward pre-hooks fire on any model(...) call regardless of
sync_ctx, so the deferral was solving a non-problem.

20-step parity check vs pre-refactor: step-0 bit-identical (cp1) /
within 0.5% rel (cp2). No behavior change beyond bf16/MoE-routing noise.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Both surfaced when running val with cp_size=2 (cp_size=1 took the
early-return path in make_cp_batch_and_ctx and hid them):

1. AttributeError: 'FSDPNemotronOmniForConditionalGeneration' object has
   no attribute 'device' at finetune.py:1281. The FSDP2 wrapper does not
   expose .device. Use self.dist_env.device for the position_ids
   synthesis target device.

2. KeyError: 'labels' at cp_utils.py:297. Validation popped labels
   before make_cp_batch_and_ctx, but that function reads batch["labels"]
   to register it as a CP buffer. Pop after, mirroring the train path.

Verified: 20-step ep8cp2 run on medpix now reaches
[val] step 19 | epoch 0 | loss 1.0888 cleanly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
- backend.attn: sdpa -> te.  TE attention engages real ring/p2p CP
  via DotProductAttention.set_context_parallel_group; the SDPA-CP
  path used DTensor allgather only (cp_utils.py:355 hardcoded
  "allgather", with TODO to expose).  Verified via temporary debug
  instrumentation: 6/6 attention blocks set up with cp_comm_type=p2p,
  0 skipped.

- dataset: cord_v2 -> medpix (mmoukouba/MedPix-VQA).  cord_v2 loss
  saturates below 0.1 within 10 steps, making relative-diff parity
  metrics noisy.  medpix stays in 1.5-2.5 range so the bf16 + CP
  numerical noise floor is properly bounded.

- wandb logging enabled for both runs.

Result on medpix: cp1 vs cp2 overall mean abs diff 0.00127 (0.065%
relative), all 4 windows pass <0.5% rel.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Previously the VLM-CP pre-shard step in recipes/vlm/finetune.py
hardcoded a Nemotron-Omni-specific 7-key tuple in two places (train
and val paths).  This umbrella approach mirrors the existing pattern
in components/datasets/vlm/fake_image.py::_VISION_TOKEN_ID_ATTRS
where vision-token attribute names from every known VLM family are
unioned in one tuple and consumers iterate to pick whichever keys
the live model has.

Add VLM_INPUT_KEYS to components/utils/model_utils.py — the same
module that hosts filter_forward_kwargs (the symmetric runtime kwarg
filter).  This makes the umbrella accessible from anywhere that
already imports from model_utils (recipes, _transformers, models,
distributed, datasets, tests) without circular-import risk, since
model_utils only depends on shared/import_utils.

The umbrella covers:
  - Nemotron-Omni:  pixel_values, image_flags, imgs_sizes,
                    pixel_values_videos, sound_features,
                    sound_attention_mask
  - Gemma4 (PR #1914): image_position_ids, mm_token_type_ids
  - Kimi-VL / Qwen-VL / Mistral4: image_grid_hws, image_grid_thw,
                    image_sizes
  - Phi-4-MM (future): audio_input_values, audio_attention_mask

Recipe replaces both hardcoded mm_keys tuples with VLM_INPUT_KEYS
and uses it for both the kwarg filter and the post-prepare drop set.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
After the refactor that moved the CP pre-shard step out of sync_ctx,
three follow-up cleanups in _forward_backward_step / _run_validation_epoch:

  - Merge `with sync_ctx: with train_ctx():` (the only thing between
    the two was the pre-shard block which now lives at the top of
    the function).  Use combined `with sync_ctx, train_ctx():`.
  - Delete the val-side synthetic position_ids block.  cp_utils
    .make_cp_batch_and_ctx already injects a 1D arange when
    position_ids is missing and cp_mesh.size > 1 (cp_utils.py:288),
    so the recipe-side fallback was duplicating that work.
  - Trim two over-explanatory comments down to just the why-non-obvious
    bit (FSDP2 forward pre-hook all-gathers vision-tower weights).

No behavior change.  Net -45/+28 lines.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
The cp_size=1 yaml was a paired-config baseline for the CP=2 parity
test only.  For everyday use the existing nemotron_omni_v3_cord_v2.yaml
covers the non-CP path.  Keep the ep8cp2.yaml as the canonical
CP-enabled example.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Covers all behavior changes from the CP work in this branch:

  utils/test_vlm_input_keys.py (9 tests)
    - VLM_INPUT_KEYS umbrella shape, no duplicates, importability
    - per-VLM-family coverage: Nemotron-Omni, Gemma4, Qwen-VL, Kimi-VL,
      Mistral4, Phi-4-MM
    - excludes labels/position_ids/attention_mask (NOT multimodal inputs)

  _transformers/test_capabilities_hybrid_vlm.py (13 tests)
    - _is_hybrid: outer config markers (layers_block_type,
      hybrid_override_pattern, is_hybrid_model)
    - drill into language_model.config when outer lacks markers
    - empty pattern, missing config, lowercase 'm', None inner config

  distributed/test_cp_utils_inputs_embeds.py (8 tests)
    - XOR contract: exactly one of input_ids/inputs_embeds in batch
    - inputs_embeds becomes primary cp_buffer when present
    - input_ids path unchanged (backward-compat)
    - position_ids synthesized from inputs_embeds.shape[1] (seq dim, not hidden)
    - cp_size<=1 short-circuit applies on inputs_embeds path
    - padding_mask + 3D mRoPE position_ids on inputs_embeds path

  models/nemotron_omni/test_nemotron_omni_cp.py (12 tests)
    - prepare_model_inputs_for_cp returns dict with inputs_embeds
    - text-only / image / video / sound modality scatter at correct positions
    - dynamic-res branch takes priority over static when imgs_sizes given
    - sound branch is no-op when sound_encoder is None
    - prepare_inputs_embeds_for_cp thin wrapper returns Tensor matching dict path
    - forward(_pre_embed_only=True) early-returns prepared dict, skips LM
    - forward(inputs_embeds=...) skips multimodal scatter block

  recipes/test_finetune_vlm_cp_wiring.py (7 tests)
    - recipe routes through model.__call__ with _pre_embed_only=True
      (so FSDP2 forward pre-hook fires)
    - all VLM_INPUT_KEYS popped after prepare; non-mm keys preserved
    - missing/None mm keys not forwarded as kwargs
    - prepare step skipped when model lacks prepare_model_inputs_for_cp
    - prepare step runs under torch.no_grad
    - val: do NOT pop labels before make_cp_batch_and_ctx (KeyError fix)
    - val: position_ids uses self.dist_env.device, not model_parts[0].device
      (AttributeError fix on FSDP-wrapped model)

All 49 tests pass on CPU (no GPU/distributed needed).  Tests use stubs
and monkeypatching; no real model load.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…ispatcher

Align the CP-enabled yaml with the non-CP cordv2 baseline:
  - linear: torch -> te
  - enable_deepep: false -> dispatcher: deepep (matches cord_v2.yaml)

Now ep8cp2 differs from non-CP cordv2 only in:
  - attn: sdpa -> te  (TE-CP path is the supported one)
  - cp_size: 1 -> 2  (the actual CP-under-test)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Replace local lustre path with the canonical
``nvidia/Nemotron-3-Nano-Omni-30B-A3B-Reasoning-BF16`` model ID,
matching the existing ``nemotron_omni_cord_v2.yaml`` and ``_peft.yaml``.

Also resolves a secrets-detector false positive (``Base64 High Entropy
String`` triggered on the long lustre path at line 27).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi HuiyingLi force-pushed the huiyingl/nemotron-omni-cp branch from 0d0ebc8 to d98bf05 Compare May 4, 2026 23:26
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

Comment on lines +1 to +5
# NemotronOmni v3 (Reasoning) fine-tuning on CORD-V2 -- ep4cp2 (CP=2 test)
#
# Run:
# cd automodel-omni
# automodel examples/vlm_finetune/nemotron_omni/nemotron_omni_v3_cord_v2_ep4cp2.yaml --nproc-per-node 8
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: the comment says "ep4cp2" but the file is ep8cp2 (and ep_size: 8). The run command also references the wrong filename.

Suggested change
# NemotronOmni v3 (Reasoning) fine-tuning on CORD-V2 -- ep4cp2 (CP=2 test)
#
# Run:
# cd automodel-omni
# automodel examples/vlm_finetune/nemotron_omni/nemotron_omni_v3_cord_v2_ep4cp2.yaml --nproc-per-node 8
# NemotronOmni v3 (Reasoning) fine-tuning on CORD-V2 -- ep8cp2 (CP=2 test)
#
# Run:
# cd automodel-omni
# automodel examples/vlm_finetune/nemotron_omni/nemotron_omni_v3_cord_v2_ep8cp2.yaml --nproc-per-node 8

Comment on lines +318 to +348
# Pad sequence length to be divisible by 2 * cp_size (required by
# context_parallel load balancing). The inputs_embeds path can hit
# arbitrary seq lengths from the VLM collator, so we pad here rather
# than relying on dataset-side padding.
cp_divisor = cp_mesh.size() * 2
if seq_len % cp_divisor != 0:
pad_len = cp_divisor - (seq_len % cp_divisor)
new_no_restore = set()
for i, (buf, dim) in enumerate(zip(cp_buffers, cp_seq_dims)):
pad_shape = list(buf.shape)
pad_shape[dim] = pad_len
if buf.dtype in (torch.float16, torch.bfloat16, torch.float32):
pad_val = torch.zeros(pad_shape, dtype=buf.dtype, device=buf.device)
else:
# labels use -100 (ignore_index); position_ids and other ints get 0.
fill_val = -100 if (buf is labels) else 0
pad_val = torch.full(pad_shape, fill_val, dtype=buf.dtype, device=buf.device)
old_buf = buf
cp_buffers[i] = torch.cat([buf, pad_val], dim=dim)
if old_buf in cp_no_restore_buffers:
new_no_restore.add(cp_buffers[i])
cp_no_restore_buffers = new_no_restore
# Mirror the padded primary tensor back into the batch so the model
# forward sees the padded shape (we already padded labels/position_ids
# via cp_buffers, but those are still the original objects in `batch`).
if has_inputs_embeds:
batch["inputs_embeds"] = cp_buffers[0]
else:
batch["input_ids"] = cp_buffers[0]
batch["labels"] = cp_buffers[1]
batch["position_ids"] = cp_buffers[2]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new CP-divisor padding block (lines 318–348) handles multiple dtypes, rebuilds cp_no_restore_buffers, and mirrors padded tensors back into batch. None of the new unit tests in test_cp_utils_inputs_embeds.py exercise this path. Consider adding a test where seq_len % (cp_size * 2) != 0 to verify:

  • all buffers (including loss_mask / padding_mask if present) are padded to the correct length
  • labels padding uses -100 (not 0)
  • batch["inputs_embeds"] / batch["input_ids"], batch["labels"], and batch["position_ids"] are updated to the padded versions

  - examples/.../ep8cp2.yaml: fix typo in header comment ("ep4cp2" -> "ep8cp2")
    and the example run command's filename.

  - tests/unit_tests/distributed/test_cp_utils_inputs_embeds.py: add 5 tests
    covering the cp-divisor padding path (cp_utils.py:318-348) that prior
    tests skipped because seq_len was already cp*2-aligned:
      * pads all cp_buffers (primary, labels, position_ids) to multiple of 2*cp_size
      * labels pad with -100 (CE ignore_index); int buffers (input_ids,
        position_ids) pad with 0; float buffers (inputs_embeds) pad with zeros
      * loss_mask + padding_mask also padded when present
      * no-op when seq already aligned (identity-preserved)
      * input_ids path padding semantics (int=0, labels=-100)

13 tests pass in test_cp_utils_inputs_embeds.py (8 original + 5 new).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

Comment on lines 873 to 875
_pre_embed_only: bool = False,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug (type annotation): When _pre_embed_only=True, this method returns a plain dict (from prepare_model_inputs_for_cp), not Union[Tuple, CausalLMOutputWithPast]. This will mislead type checkers and callers inspecting the signature.

Suggested change
_pre_embed_only: bool = False,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
_pre_embed_only: bool = False,
**kwargs,
) -> Union[dict, Tuple, CausalLMOutputWithPast]:

Comment on lines +340 to +348
# Mirror the padded primary tensor back into the batch so the model
# forward sees the padded shape (we already padded labels/position_ids
# via cp_buffers, but those are still the original objects in `batch`).
if has_inputs_embeds:
batch["inputs_embeds"] = cp_buffers[0]
else:
batch["input_ids"] = cp_buffers[0]
batch["labels"] = cp_buffers[1]
batch["position_ids"] = cp_buffers[2]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The batch-mirroring after padding updates inputs_embeds/input_ids, labels, and position_ids, but not padding_mask (if present at cp_buffers index 4+). If a future model accepts padding_mask in its forward signature, it would receive the un-padded tensor while other tensors are padded — a shape mismatch.

Not a problem today (Nemotron-Omni doesn't use padding_mask), but worth a comment or mirroring all batch-sourced buffers back to avoid a latent trap. For example, you could add:

if "padding_mask" in batch and len(cp_buffers) > 4:
    batch["padding_mask"] = cp_buffers[-1]  # or track index explicitly

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Light Review — feat(nemotron-omni): enable context parallelism for VLM path

Well-structured PR. The CP integration follows the established Gemma4 pattern, the bug fixes for validation are clearly motivated, and the test coverage (49 new unit tests) is thorough — including regression tests for both surfaced bugs.

Two inline comments posted:

  1. Return type annotation mismatch on forward() — when _pre_embed_only=True, the method returns a dict, but the annotation claims Union[Tuple, CausalLMOutputWithPast]. Minor fix.

  2. Latent padding_mask shape mismatch — the new CP-divisor padding block in cp_utils.py mirrors padded tensors back into the batch dict for inputs_embeds/input_ids, labels, and position_ids, but not for padding_mask (if present). Not a problem for Nemotron-Omni today, but a future trap worth addressing or at least commenting.

Everything else LGTM — the XOR contract in make_cp_batch_and_ctx, the _is_hybrid drill-down for VLM wrappers, the VLM_INPUT_KEYS umbrella, and the recipe wiring all look correct.

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test 63e291b

claude[bot]
claude Bot previously approved these changes May 5, 2026
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Well-structured CP enablement for Nemotron-Omni VLM. The key design choices are sound:

  • Routing through model.__call__ (not the bound method) so FSDP2 forward pre-hooks fire for vision tower unshard is correct.
  • The XOR contract in make_cp_batch_and_ctx cleanly separates the inputs_embeds (VLM-CP) and input_ids (LLM) paths.
  • The validation bug fixes (labels pop ordering, .device on FSDP wrapper) address real latent issues that would surface under CP>1.
  • CP-divisor padding logic correctly uses -100 for labels and 0 for other int buffers, and the num_label_tokens computation before padding ensures correct loss normalization.
  • 49 new unit tests provide thorough coverage of all new paths.

@hijkzzz
Copy link
Copy Markdown

hijkzzz commented May 5, 2026

please merge it

The cp-divisor padding pass picked the fill value by tensor identity
(``buf is labels`` -> -100, else 0), which silently used the wrong
sentinel for ``padding_mask``: bool ``True`` == "this position is pad,
ignore" but the dtype-default fill of 0 == ``False`` told the MoE
router that the cp-pad slots are real tokens.

Concrete impact: on the LLM-CP-SFT path (``default_collater`` auto-emits
``padding_mask``), every step under ``cp_size > 1`` with a non-cp-aligned
seq_len would route the cp-pad slots to experts -- wasting expert
capacity and skewing load-balance loss.  Latent today on the Nemotron-Omni
VLM path (cordv2 collator doesn't emit ``padding_mask``) but real for
DeepSeek-V3/V4, Qwen3-MoE, Gemma4-MoE, GLM4-MoE, GPT-OSS, etc. once
they enable CP=2+.

Fix: replace ``is buf labels`` special-case with a per-buffer-key
``PAD_FILL`` table that encodes each tensor's "ignore" sentinel:
  - labels:         -100   (CE ignore_index)
  - padding_mask:   True   (bool: True == ignore)
  - attention_mask: False  (HF: 0 == ignore)
  - default:        0      (input_ids, position_ids, ...)

The mapping from cp_buffer index to batch key uses a small
``batch_buffer_keys`` registry replacing the ad-hoc ``padding_mask_idx``.
The post-pad batch-mirror loop now iterates the registry, so any
future batch-sourced cp_buffer (e.g. cu_seqlens variants) is mirrored
back automatically -- no per-key special-case in two places.

Two new regression tests:
  - ``padding_mask`` pad slots are ``True``, not ``False``
  - ``attention_mask`` mapping documented in PAD_FILL (path currently
    unreachable because attention_mask is popped earlier; encoded for
    when a future PR revisits the strip).

33 tests pass in cp_utils + cp_utils_inputs_embeds.

Co-authored-by: khazic <khazzz1c@gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test 9dd50ec

@HuiyingLi HuiyingLi merged commit 12749a4 into main May 7, 2026
74 checks passed
@HuiyingLi HuiyingLi deleted the huiyingl/nemotron-omni-cp branch May 7, 2026 00:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants