@@ -374,6 +374,106 @@ def test_default_layer_spec_default_case(self, mock_te_full_spec, mock_te_spec):
374374 mock_te_spec .assert_called_once_with (provider )
375375 assert result == "te_spec"
376376
377+ def test_mtp_block_spec_returns_none_when_mtp_disabled (self ):
378+ """mtp_block_spec returns None when mtp_num_layers is unset."""
379+ from megatron .bridge .models .gpt_provider import mtp_block_spec
380+
381+ provider = GPTModelProvider (
382+ num_layers = 2 ,
383+ hidden_size = 128 ,
384+ num_attention_heads = 4 ,
385+ )
386+
387+ assert mtp_block_spec (provider ) is None
388+
389+ @patch ("megatron.core.models.gpt.gpt_layer_specs.get_gpt_mtp_block_spec" )
390+ def test_mtp_block_spec_uses_callable_spec_directly_when_layer_specs_nonempty (self , mock_get_mtp ):
391+ """When the callable spec returns a non-empty block spec, use it as-is."""
392+ from megatron .bridge .models .gpt_provider import mtp_block_spec
393+
394+ provider = GPTModelProvider (
395+ num_layers = 2 ,
396+ hidden_size = 128 ,
397+ num_attention_heads = 4 ,
398+ mtp_num_layers = 1 ,
399+ )
400+
401+ block_spec = Mock ()
402+ block_spec .layer_specs = ["layer_a" , "layer_b" ]
403+ provider .transformer_layer_spec = lambda config : block_spec
404+
405+ mock_get_mtp .return_value = "mtp_spec"
406+
407+ result = mtp_block_spec (provider , vp_stage = None )
408+
409+ mock_get_mtp .assert_called_once_with (provider , block_spec , use_transformer_engine = True , vp_stage = None )
410+ assert result == "mtp_spec"
411+
412+ @patch ("megatron.core.models.gpt.gpt_layer_specs.get_gpt_decoder_layer_specs" )
413+ @patch ("megatron.core.models.gpt.gpt_layer_specs.get_gpt_mtp_block_spec" )
414+ def test_mtp_block_spec_re_derives_last_decoder_spec_when_layer_specs_empty (
415+ self , mock_get_mtp , mock_get_decoder_specs
416+ ):
417+ """When the last-stage spec has empty layer_specs (MoE block spec on the last PP stage),
418+ re-derive all decoder layer specs and pass the last one to get_gpt_mtp_block_spec."""
419+ from megatron .bridge .models .gpt_provider import mtp_block_spec
420+
421+ provider = GPTModelProvider (
422+ num_layers = 2 ,
423+ hidden_size = 128 ,
424+ num_attention_heads = 4 ,
425+ mtp_num_layers = 1 ,
426+ )
427+
428+ empty_block_spec = Mock ()
429+ empty_block_spec .layer_specs = []
430+ provider .transformer_layer_spec = lambda config : empty_block_spec
431+
432+ dense_layer_spec = Mock (name = "dense_layer_spec" )
433+ moe_layer_spec = Mock (name = "moe_layer_spec" )
434+ mock_get_decoder_specs .return_value = [dense_layer_spec , moe_layer_spec ]
435+ mock_get_mtp .return_value = "mtp_spec"
436+
437+ result = mtp_block_spec (provider , vp_stage = 2 )
438+
439+ mock_get_decoder_specs .assert_called_once_with (
440+ provider ,
441+ use_transformer_engine = True ,
442+ normalization = provider .normalization ,
443+ qk_l2_norm = provider .qk_l2_norm ,
444+ )
445+ mock_get_mtp .assert_called_once_with (provider , moe_layer_spec , use_transformer_engine = True , vp_stage = 2 )
446+ assert result == "mtp_spec"
447+
448+ @patch ("megatron.core.models.gpt.gpt_layer_specs.get_gpt_mtp_block_spec" )
449+ def test_mtp_block_spec_passes_vp_stage_to_callable_spec (self , mock_get_mtp ):
450+ """When the transformer_layer_spec callable accepts vp_stage, it is forwarded."""
451+ from megatron .bridge .models .gpt_provider import mtp_block_spec
452+
453+ provider = GPTModelProvider (
454+ num_layers = 2 ,
455+ hidden_size = 128 ,
456+ num_attention_heads = 4 ,
457+ mtp_num_layers = 1 ,
458+ )
459+
460+ block_spec = Mock ()
461+ block_spec .layer_specs = ["layer_a" ]
462+ received_vp_stage = {}
463+
464+ def spec_fn (config , vp_stage = None ):
465+ received_vp_stage ["vp_stage" ] = vp_stage
466+ return block_spec
467+
468+ provider .transformer_layer_spec = spec_fn
469+ mock_get_mtp .return_value = "mtp_spec"
470+
471+ result = mtp_block_spec (provider , vp_stage = 3 )
472+
473+ assert received_vp_stage ["vp_stage" ] == 3
474+ mock_get_mtp .assert_called_once_with (provider , block_spec , use_transformer_engine = True , vp_stage = 3 )
475+ assert result == "mtp_spec"
476+
377477 def test_dense_grouped_gemm_defaults_to_false (self ):
378478 """GPTModelProvider.dense_grouped_gemm defaults to False."""
379479 provider = GPTModelProvider (
0 commit comments