Skip to content

Add LLM PP>1 support for colocated MIMO training (NMFW-19)#4784

Draft
yashaswikarnati wants to merge 4 commits into
NVIDIA:mainfrom
yashaswikarnati:ykarnati/nmfw-19-pp-rebase
Draft

Add LLM PP>1 support for colocated MIMO training (NMFW-19)#4784
yashaswikarnati wants to merge 4 commits into
NVIDIA:mainfrom
yashaswikarnati:ykarnati/nmfw-19-pp-rebase

Conversation

@yashaswikarnati
Copy link
Copy Markdown
Contributor

Summary

Adds PP>1 support for the language model in colocated MIMO training. Encoder modules stay PP=1 on all ranks, while the LLM runs a 1F1B pipeline over microbatches of precomputed encoder embeddings.

Key pieces:

  • Allows colocated layouts where the destination language grid has PP>1 while the source encoder grid remains PP=1.
  • Adds a three-phase colocated schedule:
    1. encoder forward plus colocated TP/DP communication over the full batch,
    2. LLM 1F1B over detached encoder-embedding microbatches,
    3. encoder backward after broadcasting the LLM-side encoder gradient from PP stage 0.
  • Defers finalize_model_grads_func until after encoder backward so LLM and encoder grads are finalized together.
  • Handles fan-in, fan-out, and equal-DP colocated bridge directions, including fan-out narrowing of LLM-side passthrough tensors.
  • Extends colocated correctness coverage to PP>1 fan-in/fan-out cases and adds validation for non-divisible fan-in encoder batches.

Test Plan

  • python3 -m py_compile megatron/core/models/mimo/colocated_schedule.py megatron/core/models/mimo/comm/colocated_communicator.py megatron/core/models/mimo/config/role.py megatron/core/models/mimo/model/base.py tests/unit_tests/models/test_mimo_colocated_communicator.py tests/unit_tests/models/test_mimo_colocated_correctness.py
  • git diff --check upstream/main...HEAD
  • pre-commit hooks during commit/push: black, pylint, isort
  • 8-GPU distributed test not run in this container (nvidia-smi is unavailable):
uv run python -m torch.distributed.run --nproc_per_node=8 \
  -m pytest tests/unit_tests/models/test_mimo_colocated_correctness.py -v -s

yashaswikarnati and others added 4 commits May 13, 2026 17:37
ColocatedBridgeCommunicator changes that unlock dest PP>1 for the
colocated MIMO bridge:

* Validation: only require src PP=1; dest PP>1 is allowed because
  the three-phase schedule handles LLM pipeline orchestration. cp,
  ep, expt_dp must all be 1 (otherwise the get_rank_enum(['tp'])
  enumeration order entangles them with dp/pp and breaks the
  enum_idx -> (actual_dp, actual_pp) decomposition below).
* _build_rank_mappings: dest_pp>1 makes get_rank_enum(['tp'])
  yield one TP group per (cp, dp, pp, ep, expt_dp) combination —
  i.e. the enumeration index conflates dp with pp. Decompose
  enum_idx = actual_dp * dest_pp_size + actual_pp, store the
  actual DP/TP pair in rank_to_dest_pos and the actual PP stage
  in a new rank_to_dest_pp_idx. Without this, fan-out forward
  narrows to the wrong slot for half the dest ranks (slot uses
  enum_idx as if it were dp_idx) and backward gather groups
  silently miss the dest ranks at PP>=1.
* _build_gather_groups: take optional pp_size + rank_to_pp_idx
  and iterate over dest PP stages so fan-out gather groups are
  built per (src_dp_slot, dest_pp_stage, sibling_tp_shard).
  Without this each dest PP stage's LLM-DP siblings would either
  share a group with another stage's siblings or fall off the
  group list entirely (gather_pg=None on those ranks → backward
  all_gather_into_tensor falls through to group=None → WORLD →
  collective mismatch with the size-2 sibling subgroups → NCCL
  watchdog timeout). Fan-in path still uses the pp_size=1 default
  since src_pp must be 1.

RankRole._colocated: when a grid map is supplied, derive each
module's first/last-stage flags from its pp group rather than
hardcoding both to True. Lets the colocated path advertise true
LLM PP stage info to MimoModel.

Tests: replace the "dest PP must be 1" rejection with
test_dest_pp_gt_one_accepted.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Encoder collectives (PP=1) cannot run inside a 1F1B pipeline because
the LLM PP stages are staggered. The new colocated_schedule.py
separates the two with three phases:

  Phase 1 — Encoder forward + colocated TP/DP transform on the full
            batch. All ranks synchronized.
  Phase 2 — LLM 1F1B pipeline with detached encoder embeddings sliced
            per microbatch. Only LLM P2P comm.
  Phase 3 — Encoder backward on the full batch (gradient broadcast
            from PP stage 0 to 1+ first), then a single deferred
            finalize_model_grads_func reduces LLM and encoder grads
            together.

Phase 3's deferral relies on a context manager that swaps the inner
PP schedule's finalize for a capturing no-op, then forwards the
captured num_tokens and force_all_reduce to the original after
encoder backward — required so calculate_per_token_loss=True configs
see the right global divisor and any caller-requested all-reduce
semantics survive the swap.

DP-direction support:
  * Fan-in (enc_dp > llm_dp): _slice_for_encoder_dp narrows the
    encoder input on the bridge's source side (the data iterator
    yields LLM-DP-sized per-rank batches, the encoder sees a smaller
    slot).
  * Fan-out (enc_dp < llm_dp): the data iterator yields encoder-DP-
    sized batches; the bridge narrows encoder embeddings to each
    LLM-DP rank's slot in encode_and_communicate, and
    _build_lm_microbatches narrows the LLM-side passthrough fields
    (input_ids, labels, loss_mask, position_ids, attention_mask) to
    the same slot via _fan_out_slot so they line up with the bridge
    output for the LLM forward. Without this, LLM forward on fan-out
    ranks would feed a full encoder-DP-sized batch of input_ids
    against a narrow LLM-DP-sized encoder embedding.

Shape contract: encoder inputs are 3D [seq, batch, hidden] (batch dim
= 1); encoder outputs may be 2D [seq*batch, hidden] or 3D after the
bridge collapses leading dims. Other layouts (e.g. [B, C, H, W]
images) are rejected upfront.

MimoModel: thread per-PP-stage first/last-stage flags through the
pre/post-process plumbing so PP>1 LLM stages know when to apply
embeddings vs output_layer; piggybacks on RankRole._colocated.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
End-to-end correctness test for colocated MIMO training with LLM
PP>1, consolidated into the existing
test_mimo_colocated_correctness.py rather than a new file. The
existing oracle infrastructure (deterministic batch broadcast,
ref-to-dist param copy, encoder weight + first-layer-grad +
LLM-input + LLM-logits diff stats) covers the bulk of the new
PP>1 cases; the parametrize matrix is extended to llm_pp ∈ {1, 2}
crossed with fan-in/fan-out, and three new helpers handle the
extra PP-specific glue:

  * _llm_pp_remap_name + _copy_llm_params_pp_aware +
    _assert_llm_weights_match_pp_aware — map dist's local
    decoder.layers.{idx} on PP stage s to ref's PP=1
    decoder.layers.{s*layers_per_stage + idx}, used both for the
    initial param copy and for post-step LLM weight comparison
    when dist_llm_tp == enc_tp.
  * _copy_ref_llm_with_tp_and_pp_remap — two-phase copy for cases
    where dist_llm_tp != enc_tp (fan-out + PP=2 on 8 GPUs forces
    this). Phase 1 gathers the full ref LLM params across
    ref_tp_group; Phase 2 slices into each rank's local PP+TP
    shard. The naive "all-gather inside the dist iteration" loop
    deadlocks because dist's PP layout and ref's TP layout don't
    align — different ranks in the same ref_tp_group iterate
    different dist params and never reach a common collective.
  * _wire_training_hooks gains an optional llm_grid arg; with PP>1
    it broadcasts num_tokens from the last LLM PP rank to earlier
    stages before the DP all-reduce so every rank arrives at the
    same N_global divisor.

_run_forward_backward dispatches between forward_backward_no_pipelining
(llm_pp == 1) and colocated_forward_backward_with_pp (llm_pp > 1).
The test wraps the body in try/except + an explicit failure-print
because pytest's distributed traceback formatter tends to SIGABRT
across asymmetric pass/fail ranks before printing per-rank
tracebacks.

teardown_method tracks MimoModels and calls model.destroy() before
destroy_all_grids() — mirrors the existing PP=1 pattern so
ColocatedBridgeCommunicator subgroups don't leak NCCL process
groups across tests.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@yashaswikarnati yashaswikarnati requested review from a team as code owners May 13, 2026 17:40
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 13, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft May 13, 2026 17:40
@github-actions
Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant