Add LLM PP>1 support for colocated MIMO training (NMFW-19)#4784
Draft
yashaswikarnati wants to merge 4 commits into
Draft
Add LLM PP>1 support for colocated MIMO training (NMFW-19)#4784yashaswikarnati wants to merge 4 commits into
yashaswikarnati wants to merge 4 commits into
Conversation
ColocatedBridgeCommunicator changes that unlock dest PP>1 for the colocated MIMO bridge: * Validation: only require src PP=1; dest PP>1 is allowed because the three-phase schedule handles LLM pipeline orchestration. cp, ep, expt_dp must all be 1 (otherwise the get_rank_enum(['tp']) enumeration order entangles them with dp/pp and breaks the enum_idx -> (actual_dp, actual_pp) decomposition below). * _build_rank_mappings: dest_pp>1 makes get_rank_enum(['tp']) yield one TP group per (cp, dp, pp, ep, expt_dp) combination — i.e. the enumeration index conflates dp with pp. Decompose enum_idx = actual_dp * dest_pp_size + actual_pp, store the actual DP/TP pair in rank_to_dest_pos and the actual PP stage in a new rank_to_dest_pp_idx. Without this, fan-out forward narrows to the wrong slot for half the dest ranks (slot uses enum_idx as if it were dp_idx) and backward gather groups silently miss the dest ranks at PP>=1. * _build_gather_groups: take optional pp_size + rank_to_pp_idx and iterate over dest PP stages so fan-out gather groups are built per (src_dp_slot, dest_pp_stage, sibling_tp_shard). Without this each dest PP stage's LLM-DP siblings would either share a group with another stage's siblings or fall off the group list entirely (gather_pg=None on those ranks → backward all_gather_into_tensor falls through to group=None → WORLD → collective mismatch with the size-2 sibling subgroups → NCCL watchdog timeout). Fan-in path still uses the pp_size=1 default since src_pp must be 1. RankRole._colocated: when a grid map is supplied, derive each module's first/last-stage flags from its pp group rather than hardcoding both to True. Lets the colocated path advertise true LLM PP stage info to MimoModel. Tests: replace the "dest PP must be 1" rejection with test_dest_pp_gt_one_accepted. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Encoder collectives (PP=1) cannot run inside a 1F1B pipeline because
the LLM PP stages are staggered. The new colocated_schedule.py
separates the two with three phases:
Phase 1 — Encoder forward + colocated TP/DP transform on the full
batch. All ranks synchronized.
Phase 2 — LLM 1F1B pipeline with detached encoder embeddings sliced
per microbatch. Only LLM P2P comm.
Phase 3 — Encoder backward on the full batch (gradient broadcast
from PP stage 0 to 1+ first), then a single deferred
finalize_model_grads_func reduces LLM and encoder grads
together.
Phase 3's deferral relies on a context manager that swaps the inner
PP schedule's finalize for a capturing no-op, then forwards the
captured num_tokens and force_all_reduce to the original after
encoder backward — required so calculate_per_token_loss=True configs
see the right global divisor and any caller-requested all-reduce
semantics survive the swap.
DP-direction support:
* Fan-in (enc_dp > llm_dp): _slice_for_encoder_dp narrows the
encoder input on the bridge's source side (the data iterator
yields LLM-DP-sized per-rank batches, the encoder sees a smaller
slot).
* Fan-out (enc_dp < llm_dp): the data iterator yields encoder-DP-
sized batches; the bridge narrows encoder embeddings to each
LLM-DP rank's slot in encode_and_communicate, and
_build_lm_microbatches narrows the LLM-side passthrough fields
(input_ids, labels, loss_mask, position_ids, attention_mask) to
the same slot via _fan_out_slot so they line up with the bridge
output for the LLM forward. Without this, LLM forward on fan-out
ranks would feed a full encoder-DP-sized batch of input_ids
against a narrow LLM-DP-sized encoder embedding.
Shape contract: encoder inputs are 3D [seq, batch, hidden] (batch dim
= 1); encoder outputs may be 2D [seq*batch, hidden] or 3D after the
bridge collapses leading dims. Other layouts (e.g. [B, C, H, W]
images) are rejected upfront.
MimoModel: thread per-PP-stage first/last-stage flags through the
pre/post-process plumbing so PP>1 LLM stages know when to apply
embeddings vs output_layer; piggybacks on RankRole._colocated.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
End-to-end correctness test for colocated MIMO training with LLM
PP>1, consolidated into the existing
test_mimo_colocated_correctness.py rather than a new file. The
existing oracle infrastructure (deterministic batch broadcast,
ref-to-dist param copy, encoder weight + first-layer-grad +
LLM-input + LLM-logits diff stats) covers the bulk of the new
PP>1 cases; the parametrize matrix is extended to llm_pp ∈ {1, 2}
crossed with fan-in/fan-out, and three new helpers handle the
extra PP-specific glue:
* _llm_pp_remap_name + _copy_llm_params_pp_aware +
_assert_llm_weights_match_pp_aware — map dist's local
decoder.layers.{idx} on PP stage s to ref's PP=1
decoder.layers.{s*layers_per_stage + idx}, used both for the
initial param copy and for post-step LLM weight comparison
when dist_llm_tp == enc_tp.
* _copy_ref_llm_with_tp_and_pp_remap — two-phase copy for cases
where dist_llm_tp != enc_tp (fan-out + PP=2 on 8 GPUs forces
this). Phase 1 gathers the full ref LLM params across
ref_tp_group; Phase 2 slices into each rank's local PP+TP
shard. The naive "all-gather inside the dist iteration" loop
deadlocks because dist's PP layout and ref's TP layout don't
align — different ranks in the same ref_tp_group iterate
different dist params and never reach a common collective.
* _wire_training_hooks gains an optional llm_grid arg; with PP>1
it broadcasts num_tokens from the last LLM PP rank to earlier
stages before the DP all-reduce so every rank arrives at the
same N_global divisor.
_run_forward_backward dispatches between forward_backward_no_pipelining
(llm_pp == 1) and colocated_forward_backward_with_pp (llm_pp > 1).
The test wraps the body in try/except + an explicit failure-print
because pytest's distributed traceback formatter tends to SIGABRT
across asymmetric pass/fail ranks before printing per-rank
tracebacks.
teardown_method tracks MimoModels and calls model.destroy() before
destroy_all_grids() — mirrors the existing PP=1 pattern so
ColocatedBridgeCommunicator subgroups don't leak NCCL process
groups across tests.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Contributor
|
This PR has been automatically converted to draft because all PRs must start as drafts. When you are ready for review, click Ready for Review to begin the review process. This will:
See the contribution guide for more details. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds PP>1 support for the language model in colocated MIMO training. Encoder modules stay PP=1 on all ranks, while the LLM runs a 1F1B pipeline over microbatches of precomputed encoder embeddings.
Key pieces:
finalize_model_grads_funcuntil after encoder backward so LLM and encoder grads are finalized together.Test Plan
python3 -m py_compile megatron/core/models/mimo/colocated_schedule.py megatron/core/models/mimo/comm/colocated_communicator.py megatron/core/models/mimo/config/role.py megatron/core/models/mimo/model/base.py tests/unit_tests/models/test_mimo_colocated_communicator.py tests/unit_tests/models/test_mimo_colocated_correctness.pygit diff --check upstream/main...HEADblack,pylint,isortnvidia-smiis unavailable):