Skip to content

Commit a9eabd3

Browse files
authored
[training] fix: Route batches to standalone MTP stages (#4208)
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
1 parent afeb2e0 commit a9eabd3

2 files changed

Lines changed: 221 additions & 14 deletions

File tree

src/megatron/bridge/training/gpt_step.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@
2121
from megatron.core import parallel_state
2222
from megatron.core.models.gpt import GPTModel
2323
from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage
24+
from megatron.core.transformer.enums import LayerType
25+
from megatron.core.transformer.pipeline_parallel_layer_layout import PipelineParallelLayerLayout
2426
from megatron.core.utils import (
2527
get_batch_on_this_cp_rank,
2628
get_model_config,
29+
get_pg_rank,
30+
get_pg_size,
2731
is_te_min_version,
2832
unwrap_model,
2933
)
@@ -57,6 +61,50 @@ def _middle_pp_stage_needs_batch(cfg: ConfigContainer) -> bool:
5761
return uses_custom_attention_mask or _uses_packed_sequence_metadata(cfg)
5862

5963

64+
def _layout_stage_has_mtp(layout, *, pp_rank: int, pp_size: int, vp_stage: int) -> bool:
65+
"""Return whether a parsed or raw pipeline layout stage owns MTP layers."""
66+
if isinstance(layout, str):
67+
layout = PipelineParallelLayerLayout.from_str(layout, pp_size)
68+
69+
if isinstance(layout, PipelineParallelLayerLayout):
70+
stage_layout = layout.layout[pp_rank][vp_stage]
71+
elif isinstance(layout, list):
72+
stage_layout = layout[vp_stage * pp_size + pp_rank]
73+
else:
74+
return False
75+
76+
return any(
77+
layer == "mtp" or layer == LayerType.mtp or getattr(layer, "name", None) == "mtp" for layer in stage_layout
78+
)
79+
80+
81+
def _current_pp_stage_has_mtp(cfg: ConfigContainer, *, pg_collection) -> bool:
82+
"""Return whether the current PP/VPP stage owns the configured MTP block."""
83+
model_cfg = getattr(cfg, "model", None)
84+
layout = getattr(model_cfg, "pipeline_model_parallel_layout", None)
85+
if layout is None:
86+
return False
87+
88+
pp_group = getattr(pg_collection, "pp", None)
89+
pp_rank = get_pg_rank(pp_group)
90+
pp_size = get_pg_size(pp_group)
91+
vp_stage = parallel_state.get_virtual_pipeline_model_parallel_rank()
92+
if vp_stage is None:
93+
vp_stage = 0
94+
95+
return _layout_stage_has_mtp(layout, pp_rank=pp_rank, pp_size=pp_size, vp_stage=vp_stage)
96+
97+
98+
def _current_pp_stage_needs_mtp_inputs(cfg: ConfigContainer, *, pg_collection, is_last: bool) -> bool:
99+
"""Return whether this stage needs token ids for MTP embedding lookup."""
100+
model_cfg = getattr(cfg, "model", None)
101+
layout = getattr(model_cfg, "pipeline_model_parallel_layout", None)
102+
if layout is None:
103+
return is_last
104+
105+
return _current_pp_stage_has_mtp(cfg, pg_collection=pg_collection)
106+
107+
60108
def _partition_packed_batch_for_cp(batch: dict[str, torch.Tensor], cp_size: int) -> dict[str, torch.Tensor]:
61109
"""Partition THD/packed batches across context-parallel ranks.
62110
@@ -105,7 +153,7 @@ def _partition_packed_batch_for_cp(batch: dict[str, torch.Tensor], cp_size: int)
105153

106154
def get_batch_from_iterator(
107155
data_iterator: Iterable,
108-
use_mtp: bool = False,
156+
include_mtp_inputs: bool = False,
109157
skip_getting_attention_mask_from_dataset: bool = True,
110158
*,
111159
is_first_pp_stage: bool,
@@ -116,7 +164,7 @@ def get_batch_from_iterator(
116164
117165
Args:
118166
data_iterator: The data iterator to get the batch from.
119-
use_mtp: Whether Multi-Token Prediction layers are enabled.
167+
include_mtp_inputs: Whether this PP stage needs Multi-Token Prediction input tensors.
120168
skip_getting_attention_mask_from_dataset: If set, the dataset will pass a None attention mask.
121169
include_full_batch_fields: Whether to include all standard training tensors regardless of PP stage.
122170
@@ -143,7 +191,7 @@ def get_batch_from_iterator(
143191
required_host_keys.add("cu_seqlens_unpadded_argmin")
144192

145193
if not include_full_batch_fields:
146-
if is_first_pp_stage or use_mtp:
194+
if is_first_pp_stage or include_mtp_inputs:
147195
required_device_keys.update(("tokens", "position_ids"))
148196
if is_last_pp_stage:
149197
required_device_keys.update(("labels", "loss_mask"))
@@ -191,13 +239,18 @@ def get_batch(
191239
is_last = is_pp_last_stage(pg_collection.pp)
192240
is_middle = (not is_first) and (not is_last)
193241
include_full_batch_fields = is_middle and _middle_pp_stage_needs_batch(cfg)
194-
if is_middle and not include_full_batch_fields:
242+
include_mtp_inputs = use_mtp and _current_pp_stage_needs_mtp_inputs(
243+
cfg, pg_collection=pg_collection, is_last=is_last
244+
)
245+
if is_middle and not include_full_batch_fields and not include_mtp_inputs:
195246
return None, None, None, None, None, None, None, None, None, None
196247

197248
batch = get_batch_from_iterator(
198249
data_iterator,
199-
use_mtp,
200-
getattr(cfg.dataset, "skip_getting_attention_mask_from_dataset", True),
250+
include_mtp_inputs=include_mtp_inputs,
251+
skip_getting_attention_mask_from_dataset=getattr(
252+
cfg.dataset, "skip_getting_attention_mask_from_dataset", True
253+
),
201254
is_first_pp_stage=is_first,
202255
is_last_pp_stage=is_last,
203256
include_full_batch_fields=include_full_batch_fields,

tests/unit_tests/training/test_gpt_step.py

Lines changed: 162 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,25 @@ def __next__(self):
4646

4747

4848
class _MockProcessGroup:
49+
def __init__(self, rank=0, size=1):
50+
self._rank = rank
51+
self._size = size
52+
4953
def rank(self):
50-
return 0
54+
return self._rank
5155

5256
def size(self):
53-
return 1
57+
return self._size
5458

5559

5660
class _MockPGCollection:
57-
def __init__(self, cp_size=1):
58-
self.pp = _MockProcessGroup()
61+
def __init__(self, cp_size=1, pp_rank=0, pp_size=1):
62+
self.pp = _MockProcessGroup(rank=pp_rank, size=pp_size)
5963
self._cp_size = cp_size
6064

6165
@property
6266
def cp(self):
63-
pg = _MockProcessGroup()
64-
pg.size = lambda: self._cp_size
65-
return pg
67+
return _MockProcessGroup(size=self._cp_size)
6668

6769

6870
class _NoCudaTensor(torch.Tensor):
@@ -74,7 +76,14 @@ def _as_nocuda(tensor):
7476
return tensor.as_subclass(_NoCudaTensor)
7577

7678

77-
def _make_cfg(*, packed_sequence_specs=None, skip_getting_attention_mask_from_dataset=True):
79+
def _make_cfg(
80+
*,
81+
packed_sequence_specs=None,
82+
skip_getting_attention_mask_from_dataset=True,
83+
pipeline_model_parallel_layout=None,
84+
pipeline_model_parallel_size=1,
85+
mtp_num_layers=0,
86+
):
7887
cfg = type("Cfg", (), {})()
7988
cfg.dataset = type(
8089
"D",
@@ -84,6 +93,15 @@ def _make_cfg(*, packed_sequence_specs=None, skip_getting_attention_mask_from_da
8493
"skip_getting_attention_mask_from_dataset": skip_getting_attention_mask_from_dataset,
8594
},
8695
)()
96+
cfg.model = type(
97+
"M",
98+
(),
99+
{
100+
"pipeline_model_parallel_layout": pipeline_model_parallel_layout,
101+
"pipeline_model_parallel_size": pipeline_model_parallel_size,
102+
"mtp_num_layers": mtp_num_layers,
103+
},
104+
)()
87105
return cfg
88106

89107

@@ -92,6 +110,15 @@ def _set_middle_pp_stage(monkeypatch):
92110
monkeypatch.setattr("megatron.bridge.training.gpt_step.is_pp_last_stage", lambda pg: False)
93111

94112

113+
def _set_last_pp_stage(monkeypatch):
114+
monkeypatch.setattr("megatron.bridge.training.gpt_step.is_pp_first_stage", lambda pg: False)
115+
monkeypatch.setattr("megatron.bridge.training.gpt_step.is_pp_last_stage", lambda pg: True)
116+
117+
118+
def _set_distributed_initialized(monkeypatch):
119+
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
120+
121+
95122
class _NoopTimer:
96123
def __call__(self, *args, **kwargs):
97124
return self
@@ -187,6 +214,133 @@ def test_middle_pp_stage_keeps_non_packed_fast_path(self, monkeypatch):
187214
assert result == (None, None, None, None, None, None, None, None, None, None)
188215
data_iterator.__next__.assert_not_called()
189216

217+
def test_middle_pp_stage_without_mtp_keeps_fast_path_when_mtp_enabled(self, monkeypatch):
218+
"""Global MTP does not force ordinary middle PP stages to load a batch."""
219+
_set_middle_pp_stage(monkeypatch)
220+
_set_distributed_initialized(monkeypatch)
221+
data_iterator = MagicMock()
222+
223+
result = get_batch(
224+
data_iterator,
225+
_make_cfg(
226+
pipeline_model_parallel_layout=[["embedding", "decoder"], ["decoder"], ["mtp"], ["loss"]],
227+
pipeline_model_parallel_size=4,
228+
mtp_num_layers=1,
229+
),
230+
use_mtp=True,
231+
pg_collection=_MockPGCollection(pp_rank=1, pp_size=4),
232+
)
233+
234+
assert result == (None, None, None, None, None, None, None, None, None, None)
235+
data_iterator.__next__.assert_not_called()
236+
237+
def test_standalone_mtp_middle_pp_stage_loads_tokens_and_position_ids(self, monkeypatch):
238+
"""A middle PP stage that owns MTP receives input ids for MCore MTP."""
239+
_set_middle_pp_stage(monkeypatch)
240+
_set_distributed_initialized(monkeypatch)
241+
monkeypatch.setattr(
242+
"megatron.bridge.training.gpt_step.get_batch_on_this_cp_rank",
243+
lambda batch, is_hybrid_cp=False, cp_group=None, hybrid_cp_group_func=None: batch,
244+
)
245+
monkeypatch.setattr(
246+
"megatron.bridge.training.gpt_step.parallel_state.get_virtual_pipeline_model_parallel_rank",
247+
lambda: None,
248+
)
249+
250+
tokens = _as_nocuda(torch.tensor([[1, 2, 3, 4]]))
251+
labels = _as_nocuda(torch.tensor([[2, 3, 4, 5]]))
252+
loss_mask = _as_nocuda(torch.ones(1, 4))
253+
position_ids = _as_nocuda(torch.arange(4).unsqueeze(0))
254+
batch = {
255+
"tokens": tokens,
256+
"labels": labels,
257+
"loss_mask": loss_mask,
258+
"attention_mask": None,
259+
"position_ids": position_ids,
260+
}
261+
262+
(
263+
out_tokens,
264+
out_labels,
265+
out_loss_mask,
266+
out_attention_mask,
267+
out_position_ids,
268+
out_cu_seqlens,
269+
out_cu_seqlens_argmin,
270+
out_max_seqlen,
271+
out_cu_seqlens_unpadded,
272+
out_cu_seqlens_unpadded_argmin,
273+
) = get_batch(
274+
_Iterator(batch),
275+
_make_cfg(
276+
pipeline_model_parallel_layout=[["embedding", "decoder"], ["decoder"], ["mtp"], ["loss"]],
277+
pipeline_model_parallel_size=4,
278+
mtp_num_layers=1,
279+
),
280+
use_mtp=True,
281+
pg_collection=_MockPGCollection(pp_rank=2, pp_size=4),
282+
)
283+
284+
assert torch.equal(out_tokens, tokens)
285+
assert out_labels is None
286+
assert out_loss_mask is None
287+
assert out_attention_mask is None
288+
assert torch.equal(out_position_ids, position_ids)
289+
assert out_cu_seqlens is None
290+
assert out_cu_seqlens_argmin is None
291+
assert out_max_seqlen is None
292+
assert out_cu_seqlens_unpadded is None
293+
assert out_cu_seqlens_unpadded_argmin is None
294+
295+
def test_standalone_mtp_loss_stage_skips_mtp_inputs(self, monkeypatch):
296+
"""The loss-only final PP stage does not load token ids for standalone MTP."""
297+
_set_last_pp_stage(monkeypatch)
298+
_set_distributed_initialized(monkeypatch)
299+
monkeypatch.setattr(
300+
"megatron.bridge.training.gpt_step.get_batch_on_this_cp_rank",
301+
lambda batch, is_hybrid_cp=False, cp_group=None, hybrid_cp_group_func=None: batch,
302+
)
303+
monkeypatch.setattr(
304+
"megatron.bridge.training.gpt_step.parallel_state.get_virtual_pipeline_model_parallel_rank",
305+
lambda: None,
306+
)
307+
308+
tokens = _as_nocuda(torch.tensor([[1, 2, 3, 4]]))
309+
labels = _as_nocuda(torch.tensor([[2, 3, 4, 5]]))
310+
loss_mask = _as_nocuda(torch.ones(1, 4))
311+
position_ids = _as_nocuda(torch.arange(4).unsqueeze(0))
312+
batch = {
313+
"tokens": tokens,
314+
"labels": labels,
315+
"loss_mask": loss_mask,
316+
"attention_mask": None,
317+
"position_ids": position_ids,
318+
}
319+
320+
(
321+
out_tokens,
322+
out_labels,
323+
out_loss_mask,
324+
out_attention_mask,
325+
out_position_ids,
326+
*_,
327+
) = get_batch(
328+
_Iterator(batch),
329+
_make_cfg(
330+
pipeline_model_parallel_layout=[["embedding", "decoder"], ["decoder"], ["mtp"], ["loss"]],
331+
pipeline_model_parallel_size=4,
332+
mtp_num_layers=1,
333+
),
334+
use_mtp=True,
335+
pg_collection=_MockPGCollection(pp_rank=3, pp_size=4),
336+
)
337+
338+
assert out_tokens is None
339+
assert torch.equal(out_labels, labels)
340+
assert torch.equal(out_loss_mask, loss_mask)
341+
assert out_attention_mask is None
342+
assert out_position_ids is None
343+
190344
def test_forward_common_passes_packed_seq_params_on_middle_pp_stage(self, monkeypatch):
191345
"""Forward path must pass packed metadata on middle PP stages."""
192346
sentinel_packed_seq_params = object()

0 commit comments

Comments
 (0)