Skip to content

Commit 0fc0e61

Browse files
gdengkGao Denghuvunvidia
authored
[DSV3] Fix the ckpt loading issue when no MoE layer on the mtp rank (#3315)
Signed-off-by: Gao Deng <gdeng@login-lyris02.lyris.clusters.nvidia.com> Signed-off-by: Gao <gdeng@nvidia.com> Co-authored-by: Gao Deng <gdeng@login-lyris02.lyris.clusters.nvidia.com> Co-authored-by: Huy Vu <86480512+huvunvidia@users.noreply.github.com>
1 parent 956b8d4 commit 0fc0e61

2 files changed

Lines changed: 111 additions & 1 deletion

File tree

src/megatron/bridge/models/gpt_provider.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,17 @@ def mtp_block_spec(config: "GPTModelProvider", vp_stage: Optional[int] = None) -
361361
if hasattr(spec, "layer_specs") and len(spec.layer_specs) == 0:
362362
# Get the decoder layer spec explicitly if no decoder layer in the last stage,
363363
# Only happens with block spec (TransformerBlockSubmodules) when using MoE.
364-
spec = default_layer_spec(config)
364+
# Re-derive all decoder layer specs and use the last one to get the correct
365+
# layer type (dense vs MoE) for the MTP transformer layer.
366+
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_layer_specs
367+
368+
decoder_layer_specs = get_gpt_decoder_layer_specs(
369+
config,
370+
use_transformer_engine=True,
371+
normalization=config.normalization,
372+
qk_l2_norm=config.qk_l2_norm,
373+
)
374+
spec = decoder_layer_specs[-1]
365375
return get_gpt_mtp_block_spec(config, spec, use_transformer_engine=True, vp_stage=vp_stage)
366376
else:
367377
return None

tests/unit_tests/models/test_gpt_provider.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,106 @@ def test_default_layer_spec_default_case(self, mock_te_full_spec, mock_te_spec):
374374
mock_te_spec.assert_called_once_with(provider)
375375
assert result == "te_spec"
376376

377+
def test_mtp_block_spec_returns_none_when_mtp_disabled(self):
378+
"""mtp_block_spec returns None when mtp_num_layers is unset."""
379+
from megatron.bridge.models.gpt_provider import mtp_block_spec
380+
381+
provider = GPTModelProvider(
382+
num_layers=2,
383+
hidden_size=128,
384+
num_attention_heads=4,
385+
)
386+
387+
assert mtp_block_spec(provider) is None
388+
389+
@patch("megatron.core.models.gpt.gpt_layer_specs.get_gpt_mtp_block_spec")
390+
def test_mtp_block_spec_uses_callable_spec_directly_when_layer_specs_nonempty(self, mock_get_mtp):
391+
"""When the callable spec returns a non-empty block spec, use it as-is."""
392+
from megatron.bridge.models.gpt_provider import mtp_block_spec
393+
394+
provider = GPTModelProvider(
395+
num_layers=2,
396+
hidden_size=128,
397+
num_attention_heads=4,
398+
mtp_num_layers=1,
399+
)
400+
401+
block_spec = Mock()
402+
block_spec.layer_specs = ["layer_a", "layer_b"]
403+
provider.transformer_layer_spec = lambda config: block_spec
404+
405+
mock_get_mtp.return_value = "mtp_spec"
406+
407+
result = mtp_block_spec(provider, vp_stage=None)
408+
409+
mock_get_mtp.assert_called_once_with(provider, block_spec, use_transformer_engine=True, vp_stage=None)
410+
assert result == "mtp_spec"
411+
412+
@patch("megatron.core.models.gpt.gpt_layer_specs.get_gpt_decoder_layer_specs")
413+
@patch("megatron.core.models.gpt.gpt_layer_specs.get_gpt_mtp_block_spec")
414+
def test_mtp_block_spec_re_derives_last_decoder_spec_when_layer_specs_empty(
415+
self, mock_get_mtp, mock_get_decoder_specs
416+
):
417+
"""When the last-stage spec has empty layer_specs (MoE block spec on the last PP stage),
418+
re-derive all decoder layer specs and pass the last one to get_gpt_mtp_block_spec."""
419+
from megatron.bridge.models.gpt_provider import mtp_block_spec
420+
421+
provider = GPTModelProvider(
422+
num_layers=2,
423+
hidden_size=128,
424+
num_attention_heads=4,
425+
mtp_num_layers=1,
426+
)
427+
428+
empty_block_spec = Mock()
429+
empty_block_spec.layer_specs = []
430+
provider.transformer_layer_spec = lambda config: empty_block_spec
431+
432+
dense_layer_spec = Mock(name="dense_layer_spec")
433+
moe_layer_spec = Mock(name="moe_layer_spec")
434+
mock_get_decoder_specs.return_value = [dense_layer_spec, moe_layer_spec]
435+
mock_get_mtp.return_value = "mtp_spec"
436+
437+
result = mtp_block_spec(provider, vp_stage=2)
438+
439+
mock_get_decoder_specs.assert_called_once_with(
440+
provider,
441+
use_transformer_engine=True,
442+
normalization=provider.normalization,
443+
qk_l2_norm=provider.qk_l2_norm,
444+
)
445+
mock_get_mtp.assert_called_once_with(provider, moe_layer_spec, use_transformer_engine=True, vp_stage=2)
446+
assert result == "mtp_spec"
447+
448+
@patch("megatron.core.models.gpt.gpt_layer_specs.get_gpt_mtp_block_spec")
449+
def test_mtp_block_spec_passes_vp_stage_to_callable_spec(self, mock_get_mtp):
450+
"""When the transformer_layer_spec callable accepts vp_stage, it is forwarded."""
451+
from megatron.bridge.models.gpt_provider import mtp_block_spec
452+
453+
provider = GPTModelProvider(
454+
num_layers=2,
455+
hidden_size=128,
456+
num_attention_heads=4,
457+
mtp_num_layers=1,
458+
)
459+
460+
block_spec = Mock()
461+
block_spec.layer_specs = ["layer_a"]
462+
received_vp_stage = {}
463+
464+
def spec_fn(config, vp_stage=None):
465+
received_vp_stage["vp_stage"] = vp_stage
466+
return block_spec
467+
468+
provider.transformer_layer_spec = spec_fn
469+
mock_get_mtp.return_value = "mtp_spec"
470+
471+
result = mtp_block_spec(provider, vp_stage=3)
472+
473+
assert received_vp_stage["vp_stage"] == 3
474+
mock_get_mtp.assert_called_once_with(provider, block_spec, use_transformer_engine=True, vp_stage=3)
475+
assert result == "mtp_spec"
476+
377477
def test_dense_grouped_gemm_defaults_to_false(self):
378478
"""GPTModelProvider.dense_grouped_gemm defaults to False."""
379479
provider = GPTModelProvider(

0 commit comments

Comments
 (0)