Skip to content

fix(mimo): Scale encoder gradients for heterogeneous DP in multimodule finalization#3197

Closed
kamran-nvidia wants to merge 8 commits intomainfrom
kamran/fix_grad_scale_mimo
Closed

fix(mimo): Scale encoder gradients for heterogeneous DP in multimodule finalization#3197
kamran-nvidia wants to merge 8 commits intomainfrom
kamran/fix_grad_scale_mimo

Conversation

@kamran-nvidia
Copy link
Copy Markdown
Contributor

What does this PR do ?

When encoder_dp > llm_dp in heterogeneous parallelism, the LLM's loss normalization divides by all tokens it processes, but after bridge fan-out each encoder DP rank only carries gradients for llm_dp / encoder_dp of the samples. This makes encoder gradients too small. Fix by scaling encoder gradients up by encoder_dp / llm_dp after DDP finalization.
Broadcast num_tokens from the LLM's last pipeline stage to all ranks so that encoder-only ranks (which see num_tokens=0) can correctly scale their gradients by 1/total_num_tokens.
Add helper _get_dp_size_from_grid() that reads DP size from grid shape metadata, working on all ranks including those outside the grid.
Details
In finalize_model_grads_multimodule, after each module's _finalize_model_grads call, we now compare the module's DP size against the LLM's DP size. If they differ (heterogeneous case), we call module.scale_gradients(module_dp / llm_dp) to compensate for the normalization mismatch.

The num_tokens broadcast ensures that when calculate_per_token_loss=True, all ranks — including encoder-only ranks that never accumulate token counts — receive the correct total from the LLM's last pipeline stage.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

aroshanghias-nvd and others added 8 commits March 17, 2026 18:28
…rallelism

Squash of all Phase 4 MiMo work from mimo/phase4-training (47870e4),
rebased onto upstream/main at f1fb06a.

Includes:
- MimoModelProvider with ModuleSpec-based API and heterogeneous LLaVA support
- MiMo training loop (pretrain_mimo, train_mimo, mimo_step)
- Heterogeneous TP/PP/DP parallelism plumbing (mimo_parallel_utils)
- MiMo data loading (collate, dataset, loaders, hf_provider, mock_provider)
- Data loader dispatch routing for MIMO models (loaders.py)
- MiMo DDP wrapping and model builder
- Kamran's loss mask and heterogeneous LLaVA dataset support
- Megatron-LM submodule pinned to PR #3212 head
- Full unit test coverage (provider, config, step, collate, pretrain tests)

Phase 5 (checkpointing/evaluation) is stacked in a separate branch.

Original commit history preserved in backup/mimo-phase4-training-v0 (47870e4).

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Squash of all Phase 5 MiMo checkpointing/evaluation work from
mimo/phase5-checkpointing (989842e), stacked on Phase 4 rebuild.

Includes:
- Checkpoint save/resume wiring for heterogeneous MIMO models
- MiMo evaluation infrastructure (eval.py MIMO extensions)
- Distributed batch slicing for evaluation (dp_utils.slice_batch_for_mimo)
- E2E training tests (test_mimo_training_e2e, test_mimo_training_llava)
- E2E checkpoint resume tests (test_mimo_checkpoint_resume_e2e)
- Parallelism test runner (run_mimo_parallelism_tests.sh)
- Full checkpoint unit test coverage (test_mimo_checkpointing — 1159 lines)

Original commit history preserved in backup/mimo-phase5-checkpointing-v0 (989842e).

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
…iring

The _make_setup_output fixture was missing pg_collections and
checkpointing_context attributes needed by the Phase 5 checkpoint
code path in pretrain_mimo. Also set checkpoint config fields to None
and build_data_iterators_fn return value so the test completes
without hitting unrelated code paths.

Pre-existing test gap at 989842e.

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
…int test

Restores the checkpoint resume test wrapper from mimo/wip-phase4-training.
Runs save→resume round-trip across multiple parallelism configs.

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
…l_grads_multimodule

Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 7, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@kamran-nvidia kamran-nvidia marked this pull request as ready for review April 7, 2026 17:21
@yaoyu-33 yaoyu-33 added bug Something isn't working area:training Training loop, callbacks, and runtime integration needs-review PR is ready for code review and waiting on a reviewer labels Apr 9, 2026
Base automatically changed from mimo/phase5-checkpointing-rebuild to main April 14, 2026 21:23
@pruprakash
Copy link
Copy Markdown

QA RCCA Analysis

1. Fix Reference

2. Root Cause

When encoder_dp > llm_dp in heterogeneous parallelism, the LLM's loss normalization divides by all tokens it processes, but after bridge fan-out each encoder DP rank only carries gradients for llm_dp / encoder_dp of the samples. This makes encoder gradients too small.

3. Status

4. Conclusion

Verdict: UNRESOLVED - FIX NOT MERGED

PR #3197 had extensive test coverage prepared (22 test files) but was never merged. No regression test added - fix must be merged first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:training Training loop, callbacks, and runtime integration bug Something isn't working needs-review PR is ready for code review and waiting on a reviewer qa_rcca_done

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants