@@ -46,23 +46,25 @@ def __next__(self):
4646
4747
4848class _MockProcessGroup :
49+ def __init__ (self , rank = 0 , size = 1 ):
50+ self ._rank = rank
51+ self ._size = size
52+
4953 def rank (self ):
50- return 0
54+ return self . _rank
5155
5256 def size (self ):
53- return 1
57+ return self . _size
5458
5559
5660class _MockPGCollection :
57- def __init__ (self , cp_size = 1 ):
58- self .pp = _MockProcessGroup ()
61+ def __init__ (self , cp_size = 1 , pp_rank = 0 , pp_size = 1 ):
62+ self .pp = _MockProcessGroup (rank = pp_rank , size = pp_size )
5963 self ._cp_size = cp_size
6064
6165 @property
6266 def cp (self ):
63- pg = _MockProcessGroup ()
64- pg .size = lambda : self ._cp_size
65- return pg
67+ return _MockProcessGroup (size = self ._cp_size )
6668
6769
6870class _NoCudaTensor (torch .Tensor ):
@@ -74,7 +76,14 @@ def _as_nocuda(tensor):
7476 return tensor .as_subclass (_NoCudaTensor )
7577
7678
77- def _make_cfg (* , packed_sequence_specs = None , skip_getting_attention_mask_from_dataset = True ):
79+ def _make_cfg (
80+ * ,
81+ packed_sequence_specs = None ,
82+ skip_getting_attention_mask_from_dataset = True ,
83+ pipeline_model_parallel_layout = None ,
84+ pipeline_model_parallel_size = 1 ,
85+ mtp_num_layers = 0 ,
86+ ):
7887 cfg = type ("Cfg" , (), {})()
7988 cfg .dataset = type (
8089 "D" ,
@@ -84,6 +93,15 @@ def _make_cfg(*, packed_sequence_specs=None, skip_getting_attention_mask_from_da
8493 "skip_getting_attention_mask_from_dataset" : skip_getting_attention_mask_from_dataset ,
8594 },
8695 )()
96+ cfg .model = type (
97+ "M" ,
98+ (),
99+ {
100+ "pipeline_model_parallel_layout" : pipeline_model_parallel_layout ,
101+ "pipeline_model_parallel_size" : pipeline_model_parallel_size ,
102+ "mtp_num_layers" : mtp_num_layers ,
103+ },
104+ )()
87105 return cfg
88106
89107
@@ -92,6 +110,15 @@ def _set_middle_pp_stage(monkeypatch):
92110 monkeypatch .setattr ("megatron.bridge.training.gpt_step.is_pp_last_stage" , lambda pg : False )
93111
94112
113+ def _set_last_pp_stage (monkeypatch ):
114+ monkeypatch .setattr ("megatron.bridge.training.gpt_step.is_pp_first_stage" , lambda pg : False )
115+ monkeypatch .setattr ("megatron.bridge.training.gpt_step.is_pp_last_stage" , lambda pg : True )
116+
117+
118+ def _set_distributed_initialized (monkeypatch ):
119+ monkeypatch .setattr (torch .distributed , "is_initialized" , lambda : True )
120+
121+
95122class _NoopTimer :
96123 def __call__ (self , * args , ** kwargs ):
97124 return self
@@ -187,6 +214,133 @@ def test_middle_pp_stage_keeps_non_packed_fast_path(self, monkeypatch):
187214 assert result == (None , None , None , None , None , None , None , None , None , None )
188215 data_iterator .__next__ .assert_not_called ()
189216
217+ def test_middle_pp_stage_without_mtp_keeps_fast_path_when_mtp_enabled (self , monkeypatch ):
218+ """Global MTP does not force ordinary middle PP stages to load a batch."""
219+ _set_middle_pp_stage (monkeypatch )
220+ _set_distributed_initialized (monkeypatch )
221+ data_iterator = MagicMock ()
222+
223+ result = get_batch (
224+ data_iterator ,
225+ _make_cfg (
226+ pipeline_model_parallel_layout = [["embedding" , "decoder" ], ["decoder" ], ["mtp" ], ["loss" ]],
227+ pipeline_model_parallel_size = 4 ,
228+ mtp_num_layers = 1 ,
229+ ),
230+ use_mtp = True ,
231+ pg_collection = _MockPGCollection (pp_rank = 1 , pp_size = 4 ),
232+ )
233+
234+ assert result == (None , None , None , None , None , None , None , None , None , None )
235+ data_iterator .__next__ .assert_not_called ()
236+
237+ def test_standalone_mtp_middle_pp_stage_loads_tokens_and_position_ids (self , monkeypatch ):
238+ """A middle PP stage that owns MTP receives input ids for MCore MTP."""
239+ _set_middle_pp_stage (monkeypatch )
240+ _set_distributed_initialized (monkeypatch )
241+ monkeypatch .setattr (
242+ "megatron.bridge.training.gpt_step.get_batch_on_this_cp_rank" ,
243+ lambda batch , is_hybrid_cp = False , cp_group = None , hybrid_cp_group_func = None : batch ,
244+ )
245+ monkeypatch .setattr (
246+ "megatron.bridge.training.gpt_step.parallel_state.get_virtual_pipeline_model_parallel_rank" ,
247+ lambda : None ,
248+ )
249+
250+ tokens = _as_nocuda (torch .tensor ([[1 , 2 , 3 , 4 ]]))
251+ labels = _as_nocuda (torch .tensor ([[2 , 3 , 4 , 5 ]]))
252+ loss_mask = _as_nocuda (torch .ones (1 , 4 ))
253+ position_ids = _as_nocuda (torch .arange (4 ).unsqueeze (0 ))
254+ batch = {
255+ "tokens" : tokens ,
256+ "labels" : labels ,
257+ "loss_mask" : loss_mask ,
258+ "attention_mask" : None ,
259+ "position_ids" : position_ids ,
260+ }
261+
262+ (
263+ out_tokens ,
264+ out_labels ,
265+ out_loss_mask ,
266+ out_attention_mask ,
267+ out_position_ids ,
268+ out_cu_seqlens ,
269+ out_cu_seqlens_argmin ,
270+ out_max_seqlen ,
271+ out_cu_seqlens_unpadded ,
272+ out_cu_seqlens_unpadded_argmin ,
273+ ) = get_batch (
274+ _Iterator (batch ),
275+ _make_cfg (
276+ pipeline_model_parallel_layout = [["embedding" , "decoder" ], ["decoder" ], ["mtp" ], ["loss" ]],
277+ pipeline_model_parallel_size = 4 ,
278+ mtp_num_layers = 1 ,
279+ ),
280+ use_mtp = True ,
281+ pg_collection = _MockPGCollection (pp_rank = 2 , pp_size = 4 ),
282+ )
283+
284+ assert torch .equal (out_tokens , tokens )
285+ assert out_labels is None
286+ assert out_loss_mask is None
287+ assert out_attention_mask is None
288+ assert torch .equal (out_position_ids , position_ids )
289+ assert out_cu_seqlens is None
290+ assert out_cu_seqlens_argmin is None
291+ assert out_max_seqlen is None
292+ assert out_cu_seqlens_unpadded is None
293+ assert out_cu_seqlens_unpadded_argmin is None
294+
295+ def test_standalone_mtp_loss_stage_skips_mtp_inputs (self , monkeypatch ):
296+ """The loss-only final PP stage does not load token ids for standalone MTP."""
297+ _set_last_pp_stage (monkeypatch )
298+ _set_distributed_initialized (monkeypatch )
299+ monkeypatch .setattr (
300+ "megatron.bridge.training.gpt_step.get_batch_on_this_cp_rank" ,
301+ lambda batch , is_hybrid_cp = False , cp_group = None , hybrid_cp_group_func = None : batch ,
302+ )
303+ monkeypatch .setattr (
304+ "megatron.bridge.training.gpt_step.parallel_state.get_virtual_pipeline_model_parallel_rank" ,
305+ lambda : None ,
306+ )
307+
308+ tokens = _as_nocuda (torch .tensor ([[1 , 2 , 3 , 4 ]]))
309+ labels = _as_nocuda (torch .tensor ([[2 , 3 , 4 , 5 ]]))
310+ loss_mask = _as_nocuda (torch .ones (1 , 4 ))
311+ position_ids = _as_nocuda (torch .arange (4 ).unsqueeze (0 ))
312+ batch = {
313+ "tokens" : tokens ,
314+ "labels" : labels ,
315+ "loss_mask" : loss_mask ,
316+ "attention_mask" : None ,
317+ "position_ids" : position_ids ,
318+ }
319+
320+ (
321+ out_tokens ,
322+ out_labels ,
323+ out_loss_mask ,
324+ out_attention_mask ,
325+ out_position_ids ,
326+ * _ ,
327+ ) = get_batch (
328+ _Iterator (batch ),
329+ _make_cfg (
330+ pipeline_model_parallel_layout = [["embedding" , "decoder" ], ["decoder" ], ["mtp" ], ["loss" ]],
331+ pipeline_model_parallel_size = 4 ,
332+ mtp_num_layers = 1 ,
333+ ),
334+ use_mtp = True ,
335+ pg_collection = _MockPGCollection (pp_rank = 3 , pp_size = 4 ),
336+ )
337+
338+ assert out_tokens is None
339+ assert torch .equal (out_labels , labels )
340+ assert torch .equal (out_loss_mask , loss_mask )
341+ assert out_attention_mask is None
342+ assert out_position_ids is None
343+
190344 def test_forward_common_passes_packed_seq_params_on_middle_pp_stage (self , monkeypatch ):
191345 """Forward path must pass packed metadata on middle PP stages."""
192346 sentinel_packed_seq_params = object ()
0 commit comments