diff --git a/megatron/core/models/mimo/colocated_schedule.py b/megatron/core/models/mimo/colocated_schedule.py new file mode 100644 index 00000000000..e55a588ebcb --- /dev/null +++ b/megatron/core/models/mimo/colocated_schedule.py @@ -0,0 +1,430 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Three-phase schedule for colocated MIMO training with LLM PP>1. + +Phase 1: Encoder forward + communicate for the full batch (all ranks synchronized). +Phase 2: LLM 1F1B pipeline with detached encoder embeddings sliced per microbatch. +Phase 3: Encoder backward for the full batch (all ranks synchronized). + +Encoder runs on all ranks (PP=1) and its TP/DP collectives require all ranks +to participate simultaneously. The 1F1B pipeline staggers ranks across PP stages, +so encoder collectives cannot run inside the pipeline. The three-phase design +separates encoder (synchronized) from LLM (pipelined) by detaching the autograd +graph at the encoder-LLM boundary. + +Shape contract: encoder input tensors are 3D ``[seq, batch, hidden]`` with +the batch dim at ``dim=1``. Encoder output embeddings are either 3D +``[seq, batch, hidden]`` (batch dim = 1) or 2D ``[seq*batch, hidden]`` +(batch dim = 0); the bridge may collapse the leading two dims. Other +layouts (e.g. ``[B, C, H, W]`` images) are not supported. + +DP-direction contract: fan-in (enc_dp > llm_dp), fan-out (enc_dp < llm_dp), +and equal-DP are all supported. The ColocatedBridgeCommunicator handles +the encoder-side reshape on both forward (fan-in: all-gather, fan-out: +narrow) and backward (fan-in: scatter, fan-out: all-gather). The schedule's +job is to hand each side its correctly-sized slice of the global batch: + + * Fan-in: data iterator yields LLM-DP-sized per-rank batches; the + schedule narrows encoder inputs to the encoder rank's smaller slot + in ``_slice_for_encoder_dp`` before encode_and_communicate. + * Fan-out: data iterator yields encoder-DP-sized per-rank batches; the + bridge narrows encoder embeddings to the LLM-DP rank's slot inside + encode_and_communicate, and ``_build_lm_microbatches`` narrows the + LLM-side passthrough fields (input_ids, labels, loss_mask, + position_ids) to the same slot so they line up with the bridge + output for the LLM forward. +""" + +from contextlib import contextmanager +from functools import partial +from typing import Optional + +import torch +import torch.distributed as dist + +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.pipeline_parallel import schedules + + +def colocated_forward_backward_with_pp( + mimo_model, + data_iterator, + num_microbatches: int, + encoder_grid: Optional[HyperCommGrid] = None, + llm_grid: Optional[HyperCommGrid] = None, + encoder_name: str = "images", + forward_only: bool = False, + **schedule_kwargs, +): + """Three-phase colocated training: encoder batch -> LLM pipeline -> encoder backward. + + Args: + mimo_model: MimoModel with colocated communicators and lm_has_pp=True. + data_iterator: Yields dicts with input_ids, labels, etc. + num_microbatches: Number of microbatches for the LLM pipeline. + encoder_grid: Encoder HyperCommGrid (for DP fan-in slicing). + llm_grid: LLM HyperCommGrid (for PP group). + encoder_name: Modality name for the encoder (e.g., "images"). + forward_only: Skip backward passes if True. + **schedule_kwargs: Passed to forward_backward_pipelining_without_interleaving. + Must include p2p_communicator, pg_collection, seq_length, micro_batch_size. + """ + pp_group = llm_grid.get_pg("pp") if llm_grid and 'pp' in llm_grid.dim_names else None + is_pp_first = pp_group is None or pp_group.rank() == 0 + + # ── Phase 1: Encoder forward on full batch (one pass) ──────────────── + # All ranks participate (encoder is PP=1, communicate is collective). + all_batches = [next(data_iterator) for _ in range(num_microbatches)] + full_encoder_input = _concat_encoder_inputs(all_batches, encoder_name) + _slice_for_encoder_dp(full_encoder_input, encoder_grid, llm_grid) + + enc_out = mimo_model.encode_and_communicate({encoder_name: full_encoder_input}) + + # Detach so Phase 2 runs no encoder collectives; microbatch views accumulate + # .grad into detached_full.grad automatically. + detached_full = {k: v.detach().requires_grad_(True) for k, v in enc_out.items()} + lm_data = _build_lm_microbatches( + detached_full, all_batches, num_microbatches, encoder_grid, llm_grid + ) + + # ── Phase 2: LLM 1F1B pipeline ────────────────────────────────────── + # Only LLM P2P communication (within PP group). No encoder collectives. + cache_iter = iter(lm_data) + + def _lm_forward_step(data_iterator_unused, model, *args): + cached = next(cache_iter) + forward_kwargs = dict( + input_ids=cached['input_ids'], + labels=cached['labels'], + loss_mask=cached['loss_mask'], + position_ids=cached['position_ids'], + encoder_embeddings=cached['encoder_embeddings'], + ) + if cached.get('attention_mask') is not None: + forward_kwargs['attention_mask'] = cached['attention_mask'] + if cached.get('packing_kwargs') is not None: + forward_kwargs['packing_kwargs'] = cached['packing_kwargs'] + output_tensor, loss_mask = model(**forward_kwargs) + return output_tensor, partial(_loss_func, cached['loss_mask']) + + # Swap in a capturing finalize so the inner PP schedule does not run DDP + # grad sync before Phase 3 has produced encoder grads. The capture also + # records ``num_tokens`` and ``force_all_reduce`` that the inner schedule + # would have passed — we forward them to the original finalize after + # Phase 3 so per-token-loss configs see the correct global divisor and + # any caller-requested all-reduce semantics are preserved. + with _deferred_finalize(mimo_model.config) as (original_finalize, capture): + losses = schedules.forward_backward_pipelining_without_interleaving( + forward_step_func=_lm_forward_step, + data_iterator=cache_iter, + model=[mimo_model], + num_microbatches=num_microbatches, + forward_only=forward_only, + **schedule_kwargs, + ) + + # ── Phase 3: Encoder backward (one pass, all ranks sync) ──────────── + # detached_full.grad was populated by Phase 2's per-microbatch LLM backward + # (accumulated across microbatch view slices on PP stage 0). + # Broadcast to PP stage 1+ then run one encoder backward for the full batch. + if not forward_only and enc_out: + _broadcast_encoder_grad(detached_full, enc_out, pp_group, is_pp_first) + for key in enc_out: + grad = detached_full[key].grad + if grad is not None: + torch.autograd.backward(enc_out[key], grad_tensors=grad) + + # Single post-Phase-3 finalize: reduces LLM grads (from Phase 2) and + # encoder grads (from Phase 3) together. Without this call, encoder + # grads remain local to each rank and Adam steps on un-reduced grads, + # causing silent divergence from the equal-DP reference. Forward the + # captured force_all_reduce so callers requesting that semantics + # (e.g. final-microbatch sync with overlap_grad_reduce) get it. + if not forward_only and original_finalize is not None: + original_finalize( + [mimo_model], + capture.num_tokens, + pg_collection=schedule_kwargs.get('pg_collection'), + force_all_reduce=capture.force_all_reduce, + ) + + return losses + + +# ── Helpers ────────────────────────────────────────────────────────────── + + +def _fan_out_slot(encoder_grid, llm_grid): + """Return ``(scale, slot)`` for fan-out LLM-side narrowing. + + For fan-out (``llm_dp > enc_dp``) the data iterator yields encoder-DP- + sized per-rank batches. The bridge narrows encoder embeddings to this + LLM-DP rank's slot inside ``encode_and_communicate``; LLM-side fields + (input_ids, labels, ...) must be narrowed to the SAME slot so they + line up with the bridge output for the LLM forward. Returns + ``(scale, slot)`` where ``slot`` is this rank's index inside the + fan-out sibling group; ``(1, 0)`` for equal-DP and fan-in (where the + LLM-side fields are already correctly sized for the LLM-DP rank). + """ + if encoder_grid is None or llm_grid is None: + return 1, 0 + enc_dp = encoder_grid.get_pg("dp").size() + llm_dp = llm_grid.get_pg("dp").size() + if llm_dp <= enc_dp: + return 1, 0 + scale = llm_dp // enc_dp + slot = llm_grid.get_pg("dp").rank() % scale + return scale, slot + + +def _modality_present(batch, encoder_name): + """Return True iff this batch carries inputs for ``encoder_name``.""" + mod_in = batch.get('modality_inputs') + return bool(mod_in) and encoder_name in mod_in and mod_in[encoder_name] is not None + + +def _concat_encoder_inputs(all_batches, encoder_name): + """Concatenate encoder inputs from all microbatches along batch dim (dim=1). + + All encoder input tensors must be 3D ``[seq, batch, hidden]``. All + microbatches must uniformly have or lack ``modality_inputs[encoder_name]``; + mixed batches are rejected because Phase 2 reuses one detached encoder + output across every LLM microbatch. + """ + first = all_batches[0] + has_first = _modality_present(first, encoder_name) + for idx, b in enumerate(all_batches): + if _modality_present(b, encoder_name) != has_first: + raise ValueError( + f"colocated_forward_backward_with_pp requires uniform " + f"modality_inputs across microbatches for '{encoder_name}'; " + f"microbatch 0 has it = {has_first} but microbatch {idx} differs." + ) + if not has_first: + return {} + result = {} + for enc_name in first['modality_inputs'][encoder_name]: + result[enc_name] = {} + for key in first['modality_inputs'][encoder_name][enc_name]: + vals = [b['modality_inputs'][encoder_name][enc_name][key] for b in all_batches] + tensors = [v for v in vals if isinstance(v, torch.Tensor)] + if tensors: + for v in tensors: + if v.ndim != 3: + raise ValueError( + f"encoder input '{enc_name}.{key}' must be 3D " + f"[seq, batch, hidden], got shape={tuple(v.shape)}" + ) + result[enc_name][key] = torch.cat(tensors, dim=1) + else: + result[enc_name][key] = vals[0] + return result + + +def _slice_for_encoder_dp(full_encoder_input, encoder_grid, llm_grid): + """Slice concatenated encoder input for fan-in (enc_dp > llm_dp). + + Encoder input tensors must be 3D ``[seq, batch, hidden]``. For fan-in + the data iterator yields LLM-DP-sized per-rank batches; this helper + narrows them to the encoder rank's smaller slot before forward. + Equal-DP and fan-out (where the per-rank batch is already encoder-DP- + sized — the bridge narrows on the LLM side) are no-ops. + """ + if encoder_grid is None or llm_grid is None: + return + enc_dp = encoder_grid.get_pg("dp").size() + llm_dp = llm_grid.get_pg("dp").size() + if enc_dp <= llm_dp: + return + scale = enc_dp // llm_dp + slot = encoder_grid.get_pg("dp").rank() % scale + for enc_name in full_encoder_input: + for key, tensor in full_encoder_input[enc_name].items(): + if not isinstance(tensor, torch.Tensor): + continue + if tensor.ndim != 3: + raise ValueError( + f"encoder input '{enc_name}.{key}' must be 3D " + f"[seq, batch, hidden], got shape={tuple(tensor.shape)}" + ) + bs = tensor.shape[1] + if bs % scale != 0: + raise ValueError( + f"Encoder fan-in: tensor batch={bs} not divisible by scale={scale}." + ) + ss = bs // scale + if ss == 0: + raise ValueError( + f"Encoder fan-in produces zero-sized batch: " + f"total_batch={bs}, scale={scale}. Increase micro_batch_size." + ) + full_encoder_input[enc_name][key] = tensor[ + :, slot * ss : (slot + 1) * ss, : + ].contiguous() + + +def _build_lm_microbatches( + detached_full, all_batches, num_microbatches, encoder_grid=None, llm_grid=None +): + """Slice detached encoder output into per-microbatch views for the LLM pipeline. + + Encoder embeddings are either 3D ``[seq, batch, hidden]`` (batch dim = 1) + or 2D ``[seq*batch, hidden]`` (batch dim = 0); the bridge may collapse + the leading two dims. Other layouts are rejected. Pass-through fields + (input_ids, labels, loss_mask, position_ids, attention_mask, packing_kwargs) + are copied per microbatch from the corresponding ``all_batches`` entry. + + For fan-out (``llm_dp > enc_dp``) the per-microbatch passthrough fields + arrive at the encoder-DP-sized batch; this helper narrows them to the + LLM-DP rank's slot via :func:`_fan_out_slot` so they line up with the + bridge-narrowed encoder embeddings. Fan-in and equal-DP leave the + fields unchanged (``scale=1, slot=0``). + """ + fan_out_scale, fan_out_slot = _fan_out_slot(encoder_grid, llm_grid) + + def _maybe_narrow(tensor): + """Narrow a batch-dim-0 tensor to this LLM-DP rank's fan-out slot.""" + if fan_out_scale == 1 or tensor is None or not isinstance(tensor, torch.Tensor): + return tensor + bs = tensor.shape[0] + if bs % fan_out_scale != 0: + raise ValueError( + f"Fan-out narrowing: tensor batch={bs} not divisible by " f"scale={fan_out_scale}." + ) + ss = bs // fan_out_scale + return tensor[fan_out_slot * ss : (fan_out_slot + 1) * ss].contiguous() + + def _maybe_narrow_attn(tensor, ref_batch): + """Narrow ``attention_mask`` only when its dim-0 matches the input batch. + + Some callers pass attention_mask as ``[b, 1, s, s]`` (batch-first, + narrow the way ``input_ids`` is narrowed); others pass shapes that + broadcast across batch (e.g. ``[1, 1, s, s]`` causal mask). We only + narrow when dim-0 equals the pre-narrowing batch size, leaving + broadcastable masks alone. + """ + if ( + fan_out_scale == 1 + or tensor is None + or not isinstance(tensor, torch.Tensor) + or ref_batch is None + or not isinstance(ref_batch, torch.Tensor) + or tensor.ndim < 1 + or tensor.shape[0] != ref_batch.shape[0] + ): + return tensor + return _maybe_narrow(tensor) + + def _passthrough(batch_idx): + b = all_batches[batch_idx] + input_ids = b.get('input_ids') + return { + 'input_ids': _maybe_narrow(input_ids), + 'labels': _maybe_narrow(b.get('labels')), + 'loss_mask': _maybe_narrow(b.get('loss_mask')), + 'position_ids': _maybe_narrow(b.get('position_ids')), + 'attention_mask': _maybe_narrow_attn(b.get('attention_mask'), input_ids), + 'packing_kwargs': b.get('packing_kwargs'), + } + + if not detached_full: + # Text-only batch: no encoder embeddings to slice + return [{'encoder_embeddings': {}, **_passthrough(mb)} for mb in range(num_microbatches)] + + sample = next(iter(detached_full.values())) + if sample.ndim not in (2, 3): + raise ValueError( + f"encoder output must be 2D [seq*batch, hidden] or 3D " + f"[seq, batch, hidden], got shape={tuple(sample.shape)}" + ) + batch_dim = 1 if sample.ndim == 3 else 0 + total_batch = sample.shape[batch_dim] + if total_batch % num_microbatches != 0: + raise ValueError( + f"Encoder output batch dim ({total_batch}) must be divisible " + f"by num_microbatches ({num_microbatches})" + ) + mb_size = total_batch // num_microbatches + + lm_data = [] + for mb in range(num_microbatches): + s, e = mb * mb_size, (mb + 1) * mb_size + mb_enc = {k: (v[:, s:e, :] if v.ndim == 3 else v[s:e, :]) for k, v in detached_full.items()} + lm_data.append({'encoder_embeddings': mb_enc, **_passthrough(mb)}) + return lm_data + + +def _broadcast_encoder_grad(detached_full, enc_out, pp_group, is_pp_first): + """Broadcast encoder gradient from PP stage 0 to stage 1+ ranks.""" + if pp_group is None or pp_group.size() <= 1: + return + src = dist.get_global_rank(pp_group, 0) + for key in enc_out: + if is_pp_first: + if detached_full[key].grad is None: + raise RuntimeError( + f"No encoder gradient on PP stage 0 for '{key}'; " + f"Phase 2 LLM backward did not populate detached_full.grad." + ) + dist.broadcast(detached_full[key].grad, src=src, group=pp_group) + else: + grad = torch.empty_like(detached_full[key]) + dist.broadcast(grad, src=src, group=pp_group) + detached_full[key].grad = grad + + +def _loss_func(loss_mask, output_tensor): + """Default loss function for the LLM pipeline. + + Returns the 3-tuple ``(local_sum, local_num_tokens, log_dict)`` contract + expected when ``calculate_per_token_loss=True`` is set on the + TransformerConfig. When it is not set, the schedule divides + ``local_sum`` by ``local_num_tokens`` (clamped to 1), so the 3-tuple + form is also safe for standard per-microbatch-mean configs. + """ + if output_tensor is None: + zero_loss = torch.tensor(0.0, device='cuda', requires_grad=True) + zero_count = torch.tensor(0, device='cuda', dtype=torch.int) + return zero_loss, zero_count, {'loss_reduced': 0.0} + masked = output_tensor.float() * loss_mask.float() + local_sum = masked.sum() + local_num_tokens = loss_mask.float().sum().to(torch.int) + return local_sum, local_num_tokens, {'loss_reduced': local_sum.detach().item()} + + +class _CapturingFinalize: + """Capture finalize args the inner PP schedule would have passed. + + The three-phase schedule defers grad finalization until after Phase 3 + runs encoder backward. Replacing the config's ``finalize_model_grads_func`` + with this object absorbs the inner schedule's invocation and stores + ``num_tokens`` (required for ``calculate_per_token_loss=True`` configs + whose finalize hook divides by the global valid-token count) and + ``force_all_reduce`` (preserves any caller-requested all-reduce + semantics on the final microbatch) so the post-Phase-3 call to the + original finalize can forward both. + """ + + def __init__(self): + self.num_tokens = None + self.force_all_reduce = False + + def __call__(self, model_list, num_tokens, *args, **kwargs): + self.num_tokens = num_tokens + self.force_all_reduce = kwargs.get('force_all_reduce', False) + return None + + +@contextmanager +def _deferred_finalize(config): + """Suppress the PP schedule's end-of-run DDP grad sync; yield the + original finalize and a capture object so callers can invoke the + original (with the captured ``num_tokens``) once after Phase 3. + """ + original = config.finalize_model_grads_func + capture = _CapturingFinalize() + config.finalize_model_grads_func = capture + try: + yield original, capture + finally: + config.finalize_model_grads_func = original diff --git a/megatron/core/models/mimo/comm/colocated_communicator.py b/megatron/core/models/mimo/comm/colocated_communicator.py index 4c43dcdf3cd..93df93381d5 100644 --- a/megatron/core/models/mimo/comm/colocated_communicator.py +++ b/megatron/core/models/mimo/comm/colocated_communicator.py @@ -95,12 +95,7 @@ def __init__( elif self.dest_dp_size > self.src_dp_size: self.direction = BridgeDirection.FAN_OUT self.scale = self.dest_dp_size // self.src_dp_size - self.gather_group_ranks = self._build_gather_groups( - iter_size=self.src_dp_size, - sibling_tp_size=self.dest_tp_size, - scale=self.scale, - rank_to_pos=self.rank_to_dest_pos, - ) + self.gather_group_ranks = self._build_fan_out_gather_groups() self.gather_pg, _ = dist.new_subgroups_by_enumeration( self.gather_group_ranks, backend='nccl' ) @@ -128,8 +123,9 @@ def _validate_grids(self): f"src={self.src_grid.rank_offset}, dest={self.dest_grid.rank_offset}" ) - # Per-grid dim checks: tp/dp required; pp and cp (if present) must be 1. - # CP>1 also corrupts dp_idx when iterating get_rank_enum(['tp']) groups. + # Per-grid dim checks: tp/dp required; cp (if present) must be 1. + # Src PP must be 1; dest PP>1 is allowed. CP>1 corrupts dp_idx when + # iterating get_rank_enum(['tp']) groups. for name, grid in [("src", self.src_grid), ("dest", self.dest_grid)]: for required in ('tp', 'dp'): if required not in grid.dim_names: @@ -137,14 +133,16 @@ def _validate_grids(self): f"{name} grid must have '{required}' dimension, " f"got dim_names={grid.dim_names}" ) - for singleton in ('pp', 'cp'): - if singleton in grid.dim_names: - size = grid.shape[grid.dim_names.index(singleton)] - if size != 1: - raise ValueError( - f"{name} {singleton.upper()} must be 1 for " - f"ColocatedBridgeCommunicator, got {size}" - ) + if 'cp' in grid.dim_names: + cp_size = grid.shape[grid.dim_names.index('cp')] + if cp_size != 1: + raise ValueError( + f"{name} CP must be 1 for ColocatedBridgeCommunicator, got {cp_size}" + ) + if 'pp' in self.src_grid.dim_names: + src_pp = self.src_grid.shape[self.src_grid.dim_names.index('pp')] + if src_pp != 1: + raise ValueError(f"src PP must be 1 for ColocatedBridgeCommunicator, got {src_pp}") src_dp = self.src_grid.shape[self.src_grid.dim_names.index('dp')] dest_dp = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')] @@ -158,20 +156,35 @@ def _extract_parallelism_info(self): self.src_dp_size = self.src_grid.shape[self.src_grid.dim_names.index('dp')] self.dest_tp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('tp')] self.dest_dp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')] + self.dest_pp_size = ( + self.dest_grid.shape[self.dest_grid.dim_names.index('pp')] + if 'pp' in self.dest_grid.dim_names + else 1 + ) def _build_rank_mappings(self): self.rank_to_src_pos: Dict[int, Tuple[int, int]] = {} self.rank_to_dest_pos: Dict[int, Tuple[int, int]] = {} + self.rank_to_dest_pp_pos: Dict[int, Tuple[int, int, int]] = {} + self.dest_pp_pos_to_rank: Dict[Tuple[int, int, int], int] = {} src_tp_groups = self.src_grid.get_rank_enum(['tp']) for dp_idx, tp_group in enumerate(src_tp_groups): for tp_idx, rank in enumerate(tp_group): self.rank_to_src_pos[rank] = (dp_idx, tp_idx) - dest_tp_groups = self.dest_grid.get_rank_enum(['tp']) - for dp_idx, tp_group in enumerate(dest_tp_groups): - for tp_idx, rank in enumerate(tp_group): + # Include destination PP when enumerating destination ranks so DP + # indices stay true DP coordinates instead of flattened (dp, pp) + # positions. Fan-out gather groups then stay within one PP stage. + dest_group_dims = ['tp', 'pp'] if 'pp' in self.dest_grid.dim_names else ['tp'] + dest_tp_pp_groups = self.dest_grid.get_rank_enum(dest_group_dims) + for dp_idx, tp_pp_group in enumerate(dest_tp_pp_groups): + for local_idx, rank in enumerate(tp_pp_group): + pp_idx = local_idx // self.dest_tp_size if self.dest_pp_size > 1 else 0 + tp_idx = local_idx % self.dest_tp_size self.rank_to_dest_pos[rank] = (dp_idx, tp_idx) + self.rank_to_dest_pp_pos[rank] = (dp_idx, pp_idx, tp_idx) + self.dest_pp_pos_to_rank[(dp_idx, pp_idx, tp_idx)] = rank @staticmethod def _build_gather_groups( @@ -198,6 +211,21 @@ def _build_gather_groups( groups.append(group_ranks) return groups + def _build_fan_out_gather_groups(self) -> List[List[int]]: + """Build dest-side fan-out gather groups, preserving destination PP stage.""" + groups: List[List[int]] = [] + for src_dp_idx in range(self.src_dp_size): + sibling_dp_indices = range(src_dp_idx * self.scale, (src_dp_idx + 1) * self.scale) + for dest_pp_idx in range(self.dest_pp_size): + for dest_tp_idx in range(self.dest_tp_size): + group_ranks = [] + for dest_dp_idx in sibling_dp_indices: + group_ranks.append( + self.dest_pp_pos_to_rank[(dest_dp_idx, dest_pp_idx, dest_tp_idx)] + ) + groups.append(group_ranks) + return groups + def is_fan_in(self) -> bool: """True if src DP > dest DP (forward all-gathers).""" return self.direction is BridgeDirection.FAN_IN diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py index 411791f1e5c..389aa65a8d0 100644 --- a/megatron/core/models/mimo/config/role.py +++ b/megatron/core/models/mimo/config/role.py @@ -79,7 +79,7 @@ def build( Grids differ → NON_COLOCATED with PP-stage info per module. """ if module_to_grid_map is None or cls._all_grids_colocated(module_to_grid_map): - return cls._colocated(modality_module_names) + return cls._colocated(modality_module_names, module_to_grid_map) return cls._from_grid_map(module_to_grid_map) @staticmethod @@ -89,16 +89,30 @@ def _all_grids_colocated(module_to_grid_map: Dict[str, 'HyperCommGrid']) -> bool return all(g.rank_offset == first.rank_offset and g.size == first.size for g in grids[1:]) @classmethod - def _colocated(cls, modality_module_names: List[str]) -> 'RankRole': - """Colocated layout: every module on every rank, PP=1.""" + def _colocated( + cls, + modality_module_names: List[str], + module_to_grid_map: Optional[Dict[str, 'HyperCommGrid']] = None, + ) -> 'RankRole': + """Colocated layout: every module on every rank. + + When a grid map is supplied, per-module stage info is derived from + each grid's pp group (LLM PP>1 is allowed). With no grid map, every + module is both first and last stage. + """ all_module_names = list(modality_module_names) + [MIMO_LANGUAGE_MODULE_KEY] - return cls( - modules={ - name: ModuleStageInfo(is_first_stage=True, is_last_stage=True) - for name in all_module_names - }, - mode=ModuleLayout.COLOCATED, - ) + modules = {} + for name in all_module_names: + grid = module_to_grid_map.get(name) if module_to_grid_map else None + if grid is not None and 'pp' in grid.dim_names: + pp_group = grid.get_pg('pp') + pp_rank, pp_size = pp_group.rank(), pp_group.size() + modules[name] = ModuleStageInfo( + is_first_stage=(pp_rank == 0), is_last_stage=(pp_rank == pp_size - 1) + ) + else: + modules[name] = ModuleStageInfo(is_first_stage=True, is_last_stage=True) + return cls(modules=modules, mode=ModuleLayout.COLOCATED) @classmethod def _from_grid_map(cls, module_to_grid_map: Dict[str, HyperCommGrid]) -> 'RankRole': diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index bdfe4289dd0..e7695c8b4ba 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -67,6 +67,11 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - # in TP/DP within those ranks. self._build_colocated_communicators() + lang_info = self.role.modules.get(MIMO_LANGUAGE_MODULE_KEY) + self.lm_has_pp = lang_info is not None and not ( + lang_info.is_first_stage and lang_info.is_last_stage + ) + # Use special token IDs from the config self.special_token_ids = ( mimo_config.special_token_ids.copy() if mimo_config.special_token_ids else {} @@ -318,6 +323,7 @@ def forward( labels: Optional[torch.Tensor] = None, modality_inputs: Optional[Dict[str, Dict[str, Any]]] = None, packing_kwargs: Optional[dict] = None, + encoder_embeddings: Optional[Dict[str, torch.Tensor]] = None, ): """Forward pass through the multimodal model. @@ -362,6 +368,20 @@ def forward( input_tensors = getattr(self, 'input_tensors', None) if self.role.mode == ModuleLayout.COLOCATED: + if self.lm_has_pp and input_tensors is not None: + # PP>1 non-first stage: hidden states from P2P + lm_result = self._forward_language_module( + input_ids, + position_ids, + attention_mask, + labels, + {MIMO_LANGUAGE_MODULE_KEY: input_tensors}, + ) + # Unwrap dict for P2P (schedule uses plain tensors, not dicts) + if isinstance(lm_result, dict): + lm_result = lm_result[MIMO_LANGUAGE_MODULE_KEY] + return lm_result, loss_mask + return self._forward_all_modules( input_ids, position_ids, @@ -370,6 +390,7 @@ def forward( labels, modality_inputs, packing_kwargs, + encoder_embeddings=encoder_embeddings, ) if self.role.mode == ModuleLayout.NON_COLOCATED: @@ -519,7 +540,12 @@ def _build_colocated_communicators(self): ) def destroy(self) -> None: - """Release process groups owned by this MimoModel.""" + """Release process groups owned by this MimoModel. + + NCCL caps concurrent communicators, so long-lived or + repeatedly-rebuilt models leak subgroups without explicit + destroy. Tests should call this before ``destroy_all_grids()``. + """ for comm in self.colocated_comms.values(): comm.destroy() self.colocated_comms.clear() @@ -535,6 +561,22 @@ def _apply_colocated_comms(self, modality_embeddings): ) return modality_embeddings + def encode_and_communicate(self, modality_inputs): + """Run encoder forward + colocated TP/DP transform (collective).""" + modality_embeddings = {} + for modality_name, submodule in self.modality_submodules.items(): + if ( + modality_inputs + and modality_name in modality_inputs + and modality_inputs[modality_name] is not None + ): + embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) + if embeddings is not None: + modality_embeddings[modality_name] = embeddings + if self.colocated_comms: + modality_embeddings = self._apply_colocated_comms(modality_embeddings) + return modality_embeddings + def _forward_all_modules( self, input_ids: torch.Tensor, @@ -544,6 +586,7 @@ def _forward_all_modules( labels: Optional[torch.Tensor], modality_inputs: Optional[Dict[str, Dict[str, Any]]], packing_kwargs: Optional[dict] = None, + encoder_embeddings: Optional[Dict[str, torch.Tensor]] = None, ): """Forward pass when all modules are on all ranks (no multi-module PP). @@ -560,26 +603,12 @@ def _forward_all_modules( packed_seq_params.qkv_format = 'thd' logger.debug(f"Packed sequence parameters: {packed_seq_params}") - # 1. Process each modality to get embeddings - modality_embeddings = {} - - for modality_name, submodule in self.modality_submodules.items(): - if ( - modality_inputs - and modality_name in modality_inputs - and modality_inputs[modality_name] is not None - ): - logger.debug(f"Processing {modality_name} modality") - embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) - if embeddings is not None: - modality_embeddings[modality_name] = embeddings - logger.debug( - f"Generated embeddings for {modality_name} with shape {embeddings.shape}" - ) - - # Apply colocated communication if configured (no-op when colocated_comms is empty) - if self.colocated_comms: - modality_embeddings = self._apply_colocated_comms(modality_embeddings) + if encoder_embeddings is not None: + # PP>1 path: encoder forward + communicate already ran in Phase 1; + # reuse the precomputed embeddings for every LLM microbatch. + modality_embeddings = encoder_embeddings + else: + modality_embeddings = self.encode_and_communicate(modality_inputs) # Get text embeddings text_embeddings = self.get_text_embeddings(input_ids, position_ids, self.special_token_ids) diff --git a/tests/unit_tests/models/test_mimo_colocated_communicator.py b/tests/unit_tests/models/test_mimo_colocated_communicator.py index 67cee551a0f..5b253ee1d2e 100644 --- a/tests/unit_tests/models/test_mimo_colocated_communicator.py +++ b/tests/unit_tests/models/test_mimo_colocated_communicator.py @@ -234,8 +234,8 @@ def test_rank_offset_mismatch(self): "side,dim,expected", [ ("src", "pp", "src PP must be 1"), - ("dest", "pp", "dest PP must be 1"), ("src", "cp", "CP must be 1"), + ("dest", "cp", "CP must be 1"), ], ) def test_pp_or_cp_gt_one_rejected(self, side, dim, expected): @@ -250,6 +250,13 @@ def test_pp_or_cp_gt_one_rejected(self, side, dim, expected): with pytest.raises(ValueError, match=expected): make_comm(src_grid, dest_grid) + def test_dest_pp_gt_one_accepted(self): + # Dest PP>1 is valid: the three-phase colocated schedule handles + # the LLM pipeline orchestration. The bridge only needs src PP=1. + src_grid = create_hypercomm_grid(tp=4, dp=2) + dest_grid = create_hypercomm_grid(tp=2, pp=2, dp=2) + make_comm(src_grid, dest_grid) + def test_dp_not_divisible(self): # 6-rank grids with DP sizes (3 vs 2) that neither divides the other. # Fits inside an 8-rank world (HyperCommGrid enforces size <= world - offset). diff --git a/tests/unit_tests/models/test_mimo_colocated_correctness.py b/tests/unit_tests/models/test_mimo_colocated_correctness.py index e2d91bdf83e..1432a9c839f 100644 --- a/tests/unit_tests/models/test_mimo_colocated_correctness.py +++ b/tests/unit_tests/models/test_mimo_colocated_correctness.py @@ -51,6 +51,7 @@ """ import os +import re from functools import partial import pytest @@ -61,8 +62,10 @@ import megatron.core.pipeline_parallel.schedules as schedule from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.models.mimo.colocated_schedule import colocated_forward_backward_with_pp from megatron.core.models.mimo.optimizer import get_mimo_optimizer from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.transformer.enums import ModelType from megatron.core.utils import unwrap_model from tests.unit_tests.models.test_mimo_1f1b_schedule import ( @@ -164,7 +167,7 @@ def _set_deterministic_env(): os.environ.pop('NVTE_UNFUSED_ATTN', None) -def _wire_training_hooks(mimo_model, language_pg, vision_pg): +def _wire_training_hooks(mimo_model, language_pg, vision_pg, llm_grid=None): """Attach no_sync / finalize_grads / grad_scale hooks to a MimoModel. The finalize hook implements the heterogeneous-DP grad-scaling story @@ -185,6 +188,12 @@ def _wire_training_hooks(mimo_model, language_pg, vision_pg): 3. Calls ``scale_gradients(1/N_global)`` on each side — lands the true global per-token mean uniformly on encoder and LLM grads. + ``llm_grid`` is required for LLM PP>1 callers: with PP>1 the inner + schedule only populates ``num_tokens`` on the last LLM PP stage; this + hook broadcasts it from the last PP rank to earlier stages before the + DP all-reduce so every rank arrives at the same ``N_global``. + Pass ``None`` (default) for PP=1, where the broadcast is a no-op. + Note: encoder has no loss_func (so nothing emits a per-encoder-DP ``num_tokens`` to feed ``finalize_model_grads``' internal all-reduce). Doing the all-reduce once ourselves and calling ``scale_gradients`` @@ -193,6 +202,7 @@ def _wire_training_hooks(mimo_model, language_pg, vision_pg): """ no_sync_func = build_no_sync_func(mimo_model) + pp_group = llm_grid.get_pg("pp") if llm_grid is not None else None def finalize_grads_func(model_list, num_tokens, force_all_reduce=False, **kwargs): # Schedule passes the per-rank sum-across-microbatches of what the @@ -203,6 +213,13 @@ def finalize_grads_func(model_list, num_tokens, force_all_reduce=False, **kwargs "TransformerConfig so the schedule forwards total_num_tokens; got None." ) + # PP>1: only the last LLM PP stage emits a non-zero num_tokens + # from the loss_func. Broadcast to earlier stages so every rank + # holds the same value before the DP all-reduce below. + if pp_group is not None and pp_group.size() > 1: + last_rank = dist.get_global_rank(pp_group, pp_group.size() - 1) + dist.broadcast(num_tokens, src=last_rank, group=pp_group) + # Phase 1: lift the all-reduce. After this, every rank (including # encoder-only replicas) has N_global = total non-padded tokens in # the global batch. @@ -828,6 +845,173 @@ def _assert_encoder_weights_match(ref_module, dist_module, rtol=1e-3, atol=1e-3) ) +_LLM_LAYER_RX = re.compile(r'^(.*decoder\.layers\.)(\d+)(\..*)$') + + +def _llm_pp_remap_name(name, pp_rank, layers_per_stage): + """Remap a dist LLM param name (local layer idx) to its ref PP=1 name (global idx). + + Dist's ``decoder.layers.{local_idx}`` on PP stage ``s`` corresponds to + ref's global layer ``s * layers_per_stage + local_idx``. Non-layer + params (embedding, final_layernorm, output_layer) are present only on + stages that own them and their names match exactly between ref and dist. + """ + m = _LLM_LAYER_RX.match(name) + if not m: + return name + prefix, local_idx_s, suffix = m.groups() + return f"{prefix}{pp_rank * layers_per_stage + int(local_idx_s)}{suffix}" + + +def _copy_llm_params_pp_aware(ref_module, dist_module, pp_rank, pp_size, num_layers): + """Copy LLM params ref (PP=1) → dist (PP>=1) with layer-index remapping. + + Assumes ``dist_llm_tp == ref_llm_tp`` so shards line up 1:1; callers + must verify (the consolidated correctness test only enables LLM PP-aware + copy/oracle when this holds). + """ + assert num_layers % pp_size == 0, ( + f"num_layers={num_layers} not divisible by pp_size={pp_size}; " + f"oracle requires even PP split." + ) + layers_per_stage = num_layers // pp_size + ref_params = dict(ref_module.named_parameters()) + + with torch.no_grad(): + for name, dist_param in dist_module.named_parameters(): + ref_name = _llm_pp_remap_name(name, pp_rank, layers_per_stage) + assert ref_name in ref_params, ( + f"LLM param '{name}' on PP stage {pp_rank} maps to ref name " + f"'{ref_name}' which does not exist in ref (ref has llm_pp=1)." + ) + ref_param = ref_params[ref_name] + assert ref_param.shape == dist_param.shape, ( + f"LLM param '{name}': ref.shape={tuple(ref_param.shape)} != " + f"dist.shape={tuple(dist_param.shape)} — oracle requires " + f"dist_llm_tp == ref_llm_tp." + ) + dist_param.data.copy_(ref_param.data.to(dist_param.dtype)) + + +def _copy_ref_llm_with_tp_and_pp_remap( + ref_module, dist_module, ref_tp_group, dist_tp_group, pp_rank, pp_size, num_layers +): + """Copy ref LLM (PP=1, ``ref_llm_tp``) → dist LLM (PP>=1, ``dist_llm_tp``). + + Combines the PP-aware layer-index remap (from + :func:`_llm_pp_remap_name`) with the TP reshard + (all-gather-across-ref-TP + slice-by-dist-TP). Needed when fan-out + PP>1 forces ``dist_llm_tp != enc_tp`` on a fixed rank count (e.g. + fan-out PP=2 on 8 GPUs). + + Two-phase to avoid cross-PP-stage collectives: + + * Phase 1 — gather full ref params across ``ref_tp_group``. Iterates + ``ref_module.named_parameters()``, which is identical on every + rank in the same ``ref_tp_group``, so the all-gather collective is + lockstep regardless of how dist's PP layout splits those ranks. + * Phase 2 — copy from ``full_ref`` into this rank's local dist params + (PP-staged) using the layer-index remap. No collectives. + + The naive "iterate dist_module and all-gather inside the loop" + approach hangs whenever dist's PP split spreads across a ref TP + group: ranks on different dist PP stages iterate different params + and never reach the same all-gather call together. + """ + assert num_layers % pp_size == 0, f"num_layers={num_layers} not divisible by pp_size={pp_size}." + layers_per_stage = num_layers // pp_size + ref_tp_size = dist.get_world_size(ref_tp_group) + dist_tp_rank = dist.get_rank(dist_tp_group) + dist_tp_size = dist.get_world_size(dist_tp_group) + + # Phase 1: gather full ref params across ref_tp_group. Safe because + # every rank in ref_tp_group iterates ref_module.named_parameters() + # in the same order. + full_ref = {} + with torch.no_grad(): + for name, ref_param in ref_module.named_parameters(): + partition_dim = getattr(ref_param, 'partition_dim', -1) + if ref_tp_size <= 1 or partition_dim < 0: + full_ref[name] = ref_param.data.detach().clone() + continue + shards = [torch.empty_like(ref_param.data) for _ in range(ref_tp_size)] + dist.all_gather(shards, ref_param.data.contiguous(), group=ref_tp_group) + full_ref[name] = torch.cat(shards, dim=partition_dim) + + # Phase 2: per-rank local copy into dist's PP-staged params, with + # PP-aware layer-index remap and dist-TP slicing. No collectives. + with torch.no_grad(): + for name, dist_param in dist_module.named_parameters(): + ref_name = _llm_pp_remap_name(name, pp_rank, layers_per_stage) + assert ref_name in full_ref, ( + f"LLM param '{name}' on PP stage {pp_rank} maps to ref name " + f"'{ref_name}' which does not exist in ref (ref has llm_pp=1)." + ) + full_weight = full_ref[ref_name] + partition_dim = getattr(dist_param, 'partition_dim', -1) + + if dist_tp_size <= 1 or partition_dim < 0: + # Replicated on dist (or no TP): full ref weight should + # match dist's local shape. + assert full_weight.shape == dist_param.shape, ( + f"Param '{name}' (ref '{ref_name}'): full_ref.shape=" + f"{tuple(full_weight.shape)} != dist.shape=" + f"{tuple(dist_param.shape)} (dist_tp={dist_tp_size}, " + f"partition_dim={partition_dim})" + ) + dist_param.data.copy_(full_weight.to(dist_param.dtype)) + continue + + dist_slice = torch.tensor_split(full_weight, dist_tp_size, dim=partition_dim)[ + dist_tp_rank + ] + assert dist_slice.shape == dist_param.shape, ( + f"Param '{name}' (ref '{ref_name}'): sliced.shape=" + f"{tuple(dist_slice.shape)} != dist.shape=" + f"{tuple(dist_param.shape)} (ref_tp={ref_tp_size}, " + f"dist_tp={dist_tp_size}, partition_dim={partition_dim})" + ) + dist_param.data.copy_(dist_slice.to(dist_param.dtype)) + + +def _assert_llm_weights_match_pp_aware( + ref_module, dist_module, pp_rank, pp_size, num_layers, rtol=1e-2, atol=1e-2 +): + """Assert dist LLM shards match ref (PP=1) via the PP-aware layer-index remap. + + Counterpart to :func:`_copy_llm_params_pp_aware`. Non-layer params + (embedding, final_layernorm, output_layer) only exist on stages that + own them and their names are unchanged between ref and dist. + """ + layers_per_stage = num_layers // pp_size + ref_params = dict(ref_module.named_parameters()) + + mismatches = [] + for name, dist_param in dist_module.named_parameters(): + ref_name = _llm_pp_remap_name(name, pp_rank, layers_per_stage) + assert ref_name in ref_params, ( + f"LLM param '{name}' maps to ref '{ref_name}' which does not exist " + f"(ref has llm_pp=1)." + ) + ref_param = ref_params[ref_name] + assert ref_param.shape == dist_param.shape, ( + f"LLM param '{name}': ref.shape={tuple(ref_param.shape)} != " + f"dist.shape={tuple(dist_param.shape)}." + ) + try: + torch.testing.assert_close(dist_param.data, ref_param.data, rtol=rtol, atol=atol) + except AssertionError as e: + mismatches.append((name, ref_name, str(e))) + + if mismatches: + rank = dist.get_rank() + details = "\n".join(f" {n} -> {rn}: {msg}" for n, rn, msg in mismatches) + raise AssertionError( + f"Rank {rank}: {len(mismatches)} LLM param(s) diverged between " + f"PP>1 dist model and PP=1 reference:\n{details}" + ) + + class _BatchIterator: """Minimal iterator over a pre-generated list of batches.""" @@ -857,17 +1041,39 @@ def _run_forward_backward( seq_length, num_microbatches, ): - """One forward/backward pass through the mimo schedule.""" - return schedule.forward_backward_no_pipelining( - forward_step_func=partial( - forward_step, encoder_grid=enc_grid, llm_grid=llm_grid, encoder_name=encoder_name - ), + """Dispatch to no-pipelining (LLM PP=1) or three-phase (LLM PP>1) schedule. + + PP=1 path uses :func:`forward_step` so per-rank slicing for fan-in/ + fan-out happens at forward-time inside the no-pipelining schedule. + PP>1 path uses :func:`colocated_forward_backward_with_pp`, which + applies the same fan-in/fan-out narrowing internally (encoder side + in ``_slice_for_encoder_dp``, LLM side in ``_build_lm_microbatches``). + """ + pp_size = llm_grid.get_pg("pp").size() if 'pp' in llm_grid.dim_names else 1 + if pp_size <= 1: + return schedule.forward_backward_no_pipelining( + forward_step_func=partial( + forward_step, encoder_grid=enc_grid, llm_grid=llm_grid, encoder_name=encoder_name + ), + data_iterator=_BatchIterator(batches), + model=[mimo_model], + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + pg_collection=language_pg, + ) + + return colocated_forward_backward_with_pp( + mimo_model=mimo_model, data_iterator=_BatchIterator(batches), - model=[mimo_model], num_microbatches=num_microbatches, + encoder_grid=enc_grid, + llm_grid=llm_grid, + encoder_name=encoder_name, seq_length=seq_length, micro_batch_size=micro_batch_size, - forward_only=False, + p2p_communicator=P2PCommunicator(pp_group=llm_grid.get_pg("pp"), config=mimo_model.config), pg_collection=language_pg, ) @@ -919,43 +1125,93 @@ def teardown_method(self): version.parse(torch.__version__) < version.parse("2.3.0"), reason="Requires PyTorch 2.3+" ) @pytest.mark.parametrize( - "enc_tp,enc_dp,llm_tp,llm_dp", [(2, 4, 4, 2), (4, 2, 2, 4)], ids=["fan_in", "fan_out"] + "enc_tp,enc_dp,llm_tp,llm_pp,llm_dp", + [ + (2, 4, 4, 1, 2), # fan-in, PP=1 + (4, 2, 2, 1, 4), # fan-out, PP=1 + (2, 4, 2, 2, 2), # fan-in, PP=2 (dist_llm_tp == enc_tp → LLM weight oracle on) + (4, 2, 1, 2, 4), # fan-out, PP=2 (dist_llm_tp != enc_tp → LLM weight oracle off) + ], + ids=["fan_in_pp1", "fan_out_pp1", "fan_in_pp2", "fan_out_pp2"], ) @pytest.mark.parametrize( "mask_pattern", ["uniform", "asymmetric"], ids=["uniform", "asymmetric"] ) @pytest.mark.parametrize("num_microbatches", [1, 4], ids=["mbs1", "mbs4"]) def test_dist_matches_dp1_reference_post_step_weights( - self, enc_tp, enc_dp, llm_tp, llm_dp, mask_pattern, num_microbatches + self, enc_tp, enc_dp, llm_tp, llm_pp, llm_dp, mask_pattern, num_microbatches ): - """Heterogeneous-DP dist post-step encoder weights match equal-DP reference. + """Heterogeneous-(TP/DP/PP) dist post-step weights match equal-DP PP=1 reference. Builds two MimoModels on every rank: - * Dist: the heterogeneous TP/DP config under test, with + * Dist: the heterogeneous TP/DP/PP config under test, with ``calculate_per_token_loss=True`` + custom finalize hook that pure-SUMs DDP and externally divides by ``N_global``. * Ref: equal-DP uniform with ``enc_tp=dist_enc_tp``, - ``enc_dp=dist_enc_dp``, ``llm_tp=dist_enc_tp``, - ``llm_dp=dist_enc_dp`` — bridge is - ``BridgeDirection.EQUAL`` (identity passthrough), and the - encoder TP sharding matches dist's exactly so shards line up - 1:1 for comparison. - - Both models run the same finalize wiring; both DDPs pure-SUM - across their own DP group, then divide uniformly by ``N_global``. - LLM TP differs between the two models, which introduces fp32 TP - accumulation-order drift in the gradient flowing back to the - encoder but does not change the per-token-mean invariant that the - post-step encoder oracle checks. + ``enc_dp=dist_enc_dp``, ``llm_tp=dist_enc_tp``, ``llm_pp=1``, + ``llm_dp=dist_enc_dp`` — bridge is ``BridgeDirection.EQUAL`` + (identity passthrough), and the encoder TP sharding matches + dist's exactly so shards line up 1:1 for comparison. + + For ``llm_pp == 1`` the dist side runs the no-pipelining schedule + with the existing ``forward_step`` (which narrows for fan-in/ + fan-out at forward time). For ``llm_pp > 1`` the dist side runs + :func:`colocated_forward_backward_with_pp` (three-phase: encoder + forward → LLM 1F1B → encoder backward), which applies the same + narrowing internally. Ref always runs no-pipelining (``llm_pp=1``). Reference weights are copied into the distributed model so both start from identical state. One Adam step later, the dist shards - should match the ref shards within fp32 precision. + should match the ref shards within fp32 precision. Oracles: + + * Always: encoder weights, first-layer encoder grads. + * ``llm_pp == 1``: LLM input + LLM logits (TP+DP-gathered, robust + to different LLM TP layouts). + * ``llm_pp > 1`` AND ``dist_llm_tp == enc_tp``: LLM weights via + PP-aware layer-index remap (shards align 1:1 only if dist and + ref share the same LLM TP). + + Fan-out PP>1 with ``dist_llm_tp == enc_tp`` is impossible on 8 + GPUs (``enc_tp * 2 * llm_dp = 8`` and ``llm_dp > enc_dp = 8/enc_tp`` + contradict), so the LLM weight oracle is skipped there — encoder + weight match alone is the gold-standard end-to-end signal because + the encoder's grads pass through the LLM forward + backward. """ if self.world_size != 8: pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + if num_microbatches < llm_pp: + pytest.skip( + f"PP={llm_pp} requires num_microbatches >= {llm_pp}; " f"got {num_microbatches}" + ) + # Wrap the entire test body so we can catch and PRINT any + # exception before pytest's distributed traceback formatter gets + # jumbled across 8 ranks. Without this, NCCL teardown across + # ranks that fail asymmetrically (some raise, some don't) tends + # to SIGABRT before pytest emits per-rank tracebacks. + rank = dist.get_rank() + try: + self._run_test_body( + rank, enc_tp, enc_dp, llm_tp, llm_pp, llm_dp, mask_pattern, num_microbatches + ) + except Exception: + import traceback as _tb + + print( + f"\n=== rank {rank} TEST EXCEPTION ===\n" + f"config: enc_tp={enc_tp} enc_dp={enc_dp} llm_tp={llm_tp} " + f"llm_pp={llm_pp} llm_dp={llm_dp} mbs={num_microbatches} " + f"mask={mask_pattern}\n" + f"{_tb.format_exc()}\n" + f"=== end rank {rank} exception ===\n", + flush=True, + ) + raise + + def _run_test_body( + self, rank, enc_tp, enc_dp, llm_tp, llm_pp, llm_dp, mask_pattern, num_microbatches + ): _set_deterministic_env() torch.use_deterministic_algorithms(True) torch.backends.cudnn.deterministic = True @@ -963,17 +1219,21 @@ def test_dist_matches_dp1_reference_post_step_weights( encoder_name = "images" hidden_size, seq_length, vocab_size = 256, 64, 1000 + num_layers = 2 + # PP-aware param copy/oracle requires layers divisible by pp_size. + assert num_layers % llm_pp == 0 micro_batch_size = 2 # Global batch spans the larger DP side; dist pre-slices per rank - # before forward_step (which further slices encoder/LLM side). + # via _slice_global_batch_for_dist (LLM-DP-sized for fan-in, + # encoder-DP-sized for fan-out). global_batch_size = micro_batch_size * max(enc_dp, llm_dp) # Grids: dist is heterogeneous; ref is equal-DP uniform matching # dist's encoder so the bridge is identity and encoder shards # align 1:1 for direct comparison. dist_enc_grid = create_hypercomm_grid(offset=0, tp=enc_tp, cp=1, pp=1, dp=enc_dp) - dist_llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=1, pp=1, dp=llm_dp) + dist_llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=1, pp=llm_pp, dp=llm_dp) ref_enc_grid = create_hypercomm_grid(offset=0, tp=enc_tp, cp=1, pp=1, dp=enc_dp) ref_llm_grid = create_hypercomm_grid(offset=0, tp=enc_tp, cp=1, pp=1, dp=enc_dp) create_all_embedding_groups([dist_enc_grid, dist_llm_grid, ref_enc_grid, ref_llm_grid]) @@ -988,14 +1248,14 @@ def test_dist_matches_dp1_reference_post_step_weights( overlap_grad_reduce=True, bucket_size=10000, use_distributed_optimizer=True ) - # Build dist first (heterogeneous TP/DP). + # Build dist first (heterogeneous TP/DP/PP). torch.manual_seed(12345) dist_mimo, _, _, dist_language_pg, dist_vision_pg = get_mimo_model( encoder_name=encoder_name, encoder_grid=dist_enc_grid, llm_grid=dist_llm_grid, hidden_size=hidden_size, - num_layers=2, + num_layers=num_layers, vocab_size=vocab_size, seq_len=seq_length, ddp_config=ddp_config, @@ -1007,14 +1267,14 @@ def test_dist_matches_dp1_reference_post_step_weights( dist_mimo.model_type = ModelType.encoder_or_decoder self._mimo_models.append(dist_mimo) - # Reference with equal-DP uniform (enc_tp == llm_tp, enc_dp == llm_dp). + # Reference with equal-DP uniform (enc_tp == llm_tp, enc_dp == llm_dp, PP=1). torch.manual_seed(12345) ref_mimo, _, _, ref_language_pg, ref_vision_pg = get_mimo_model( encoder_name=encoder_name, encoder_grid=ref_enc_grid, llm_grid=ref_llm_grid, hidden_size=hidden_size, - num_layers=2, + num_layers=num_layers, vocab_size=vocab_size, seq_len=seq_length, ddp_config=ddp_config, @@ -1026,25 +1286,57 @@ def test_dist_matches_dp1_reference_post_step_weights( ref_mimo.model_type = ModelType.encoder_or_decoder self._mimo_models.append(ref_mimo) - # Force identical initial state: encoder shards already match - # (same TP layout), so the helper copies shard-to-shard. LLM - # shards don't match (ref_llm_tp=enc_tp, dist_llm_tp=llm_tp), so - # the helper all-gathers ref's shards across ref's TP group and - # re-slices for dist's TP group. + # Force identical initial state. Encoder shards already match + # (same TP layout), so the helper copies shard-to-shard. _copy_ref_params_to_dist( ref_mimo.modality_submodules[encoder_name].module, dist_mimo.modality_submodules[encoder_name].module, ref_enc_grid.get_pg("tp"), dist_enc_grid.get_pg("tp"), ) - _copy_ref_params_to_dist( - ref_mimo.language_model.module, - dist_mimo.language_model.module, - ref_llm_grid.get_pg("tp"), - dist_llm_grid.get_pg("tp"), - ) + if llm_pp == 1: + # LLM shards may not match (ref_llm_tp=enc_tp, dist_llm_tp=llm_tp); + # the helper all-gathers ref's shards across ref's TP group and + # re-slices for dist's TP group. + _copy_ref_params_to_dist( + ref_mimo.language_model.module, + dist_mimo.language_model.module, + ref_llm_grid.get_pg("tp"), + dist_llm_grid.get_pg("tp"), + ) + elif llm_tp == enc_tp: + # PP>1 with matching TP: dist's local layers map to a slice of + # ref's global layers and shards align 1:1 post-remap. + _copy_llm_params_pp_aware( + ref_mimo.language_model.module, + dist_mimo.language_model.module, + pp_rank=dist_llm_grid.get_pg("pp").rank(), + pp_size=llm_pp, + num_layers=num_layers, + ) + else: + # PP>1 with mismatched TP (e.g. fan-out PP=2 on 8 GPUs): + # combine TP-reshard (all-gather ref's TP shards, slice for + # dist's TP) with PP-aware layer-index remap. + _copy_ref_llm_with_tp_and_pp_remap( + ref_mimo.language_model.module, + dist_mimo.language_model.module, + ref_llm_grid.get_pg("tp"), + dist_llm_grid.get_pg("tp"), + pp_rank=dist_llm_grid.get_pg("pp").rank(), + pp_size=llm_pp, + num_layers=num_layers, + ) - _wire_training_hooks(dist_mimo, dist_language_pg, dist_vision_pg) + # PP>1 dist needs the broadcast-from-last-PP-stage variant of the + # finalize hook so num_tokens lands consistently on every rank. + # Ref is always PP=1 (no broadcast needed). + _wire_training_hooks( + dist_mimo, + dist_language_pg, + dist_vision_pg, + llm_grid=dist_llm_grid if llm_pp > 1 else None, + ) _wire_training_hooks(ref_mimo, ref_language_pg, ref_vision_pg) # Distributed optimizers snapshot current param.data into fp32 master @@ -1061,7 +1353,7 @@ def test_dist_matches_dp1_reference_post_step_weights( dist_optimizer = get_mimo_optimizer(dist_mimo, opt_config) ref_optimizer = get_mimo_optimizer(ref_mimo, opt_config) - # Data: one deterministic global batch, identical on every rank. + # Data: deterministic global batches, identical on every rank. torch.manual_seed(99999) global_batches = _generate_and_broadcast_global_batches( global_mbs=global_batch_size, @@ -1083,16 +1375,23 @@ def test_dist_matches_dp1_reference_post_step_weights( ] ref_per_rank_batch_size = global_batch_size // enc_dp - # Logits capture: hook fires on every microbatch forward. - # Registered before forward/backward, removed right after so the - # hook doesn't leak across the second model's run. - dist_logits, dist_logits_hook = _register_logits_capture(dist_mimo) - ref_logits, ref_logits_hook = _register_logits_capture(ref_mimo) - dist_llm_input, dist_input_hook = _register_llm_input_capture(dist_mimo) - ref_llm_input, ref_input_hook = _register_llm_input_capture(ref_mimo) + # Capture hooks: only meaningful for PP=1 (output_layer / decoder + # captures fire on every microbatch; for PP>1 they fire only on + # specific PP stages of dist, breaking the per-microbatch + # alignment with ref's PP=1 captures). Skip registration for PP>1. + capture_hooks = [] + if llm_pp == 1: + dist_logits, dist_logits_hook = _register_logits_capture(dist_mimo) + ref_logits, ref_logits_hook = _register_logits_capture(ref_mimo) + dist_llm_input, dist_input_hook = _register_llm_input_capture(dist_mimo) + ref_llm_input, ref_input_hook = _register_llm_input_capture(ref_mimo) + capture_hooks = [dist_logits_hook, ref_logits_hook, dist_input_hook, ref_input_hook] + else: + dist_logits = ref_logits = dist_llm_input = ref_llm_input = None try: - # One optimizer step on dist (heterogeneous forward_step slicing). + # One optimizer step on dist (PP=1: no-pipelining + forward_step; + # PP>1: three-phase schedule with internal narrowing). dist_optimizer.zero_grad() _run_forward_backward( mimo_model=dist_mimo, @@ -1115,7 +1414,7 @@ def test_dist_matches_dp1_reference_post_step_weights( "silently zeroed by wrong scaling" ) - # One optimizer step on ref (enc_dp == llm_dp → forward_step skips slicing). + # One optimizer step on ref (always PP=1, equal-DP). ref_optimizer.zero_grad() _run_forward_backward( mimo_model=ref_mimo, @@ -1133,18 +1432,14 @@ def test_dist_matches_dp1_reference_post_step_weights( assert ref_success, "Ref optimizer step failed" assert ref_grad_norm is not None and ref_grad_norm > 0, f"Ref grad_norm={ref_grad_norm}" finally: - dist_logits_hook.remove() - ref_logits_hook.remove() - dist_input_hook.remove() - ref_input_hook.remove() - - # Run all three oracles regardless of individual failures so the - # diff-stats print covers every layer. Order: encoder weights / - # first-layer grads first (tightest — same encoder TP/DP layout - # → shards align 1:1), then LLM logits last (loosest — different - # LLM TP layout drives fp32 accumulation drift). Each oracle - # printed its own min/mean/p95/p99/max before its assertion ran, - # so the user sees the full drift distribution for every test. + for h in capture_hooks: + h.remove() + + # Run all oracles regardless of individual failures so the diff- + # stats print covers every layer. Order: encoder weights / first- + # layer grads first (tightest — same encoder TP/DP layout → shards + # align 1:1), then LLM oracles (looser — different LLM TP layout + # drives fp32 accumulation drift). failures = [] try: @@ -1164,20 +1459,60 @@ def test_dist_matches_dp1_reference_post_step_weights( except AssertionError as e: failures.append(('first_layer_grads', str(e))) - try: - _assert_llm_input_match( - ref_llm_input, dist_llm_input, ref_llm_grid, dist_llm_grid, rtol=1e-3, atol=1e-3 - ) - except AssertionError as e: - failures.append(('llm_input', str(e))) + if llm_pp == 1: + # LLM input + logits oracles use TP+DP all-gather, so they + # work for any LLM TP layout. They expect one capture per + # microbatch, which only PP=1 satisfies. + try: + _assert_llm_input_match( + ref_llm_input, dist_llm_input, ref_llm_grid, dist_llm_grid, rtol=1e-3, atol=1e-3 + ) + except AssertionError as e: + failures.append(('llm_input', str(e))) - try: - _assert_llm_logits_match( - ref_logits, dist_logits, ref_llm_grid, dist_llm_grid, rtol=1e-2, atol=1e-2 - ) - except AssertionError as e: - failures.append(('llm_logits', str(e))) + try: + _assert_llm_logits_match( + ref_logits, dist_logits, ref_llm_grid, dist_llm_grid, rtol=1e-2, atol=1e-2 + ) + except AssertionError as e: + failures.append(('llm_logits', str(e))) + elif llm_tp == enc_tp: + # PP>1 with matching TP: assert LLM weights match ref via + # PP-aware layer-index remap. (LLM forward differs between + # 1F1B and no-pipelining, plus TP shards may accumulate in + # different order; tolerances absorb that drift even in fp32.) + try: + _assert_llm_weights_match_pp_aware( + ref_mimo.language_model.module, + dist_mimo.language_model.module, + pp_rank=dist_llm_grid.get_pg("pp").rank(), + pp_size=llm_pp, + num_layers=num_layers, + rtol=1e-2, + atol=1e-2, + ) + except AssertionError as e: + failures.append(('llm_weights_pp_aware', str(e))) + # else: PP>1 with mismatched TP (fan-out on 8 GPUs). The init copy + # via _copy_ref_llm_with_tp_and_pp_remap aligns starting state, but + # post-step shape comparison would require the same TP-reshard of + # ref's PP=1 weights. Skipped here — encoder weight oracle alone + # is sufficient end-to-end (it requires a working LLM forward + + # backward + bridge for the encoder grads to land correctly). if failures: summary = "\n\n".join(f"== {oracle} ==\n{msg}" for oracle, msg in failures) + # Print before raising so the message lands in stdout even when + # post-test cleanup blows up (NCCL teardown across asymmetric + # pass/fail ranks can SIGABRT before pytest formats the + # traceback). + rank = dist.get_rank() + print( + f"\n=== rank {rank} test_dist_matches failures ===\n" + f"config: enc_tp={enc_tp} enc_dp={enc_dp} llm_tp={llm_tp} " + f"llm_pp={llm_pp} llm_dp={llm_dp} mbs={num_microbatches} mask={mask_pattern}\n" + f"{summary}\n" + f"=== end rank {rank} failures ===\n", + flush=True, + ) raise AssertionError(f"{len(failures)} oracle(s) failed:\n{summary}")