2020shared experts, expert bias for aux-free load balancing).
2121"""
2222
23+ import torch .nn .functional as F
2324from megatron .core .models .gpt .gpt_model import GPTModel
2425
2526from megatron .bridge .models .conversion .mapping_registry import MegatronMappingRegistry
3031 QKVMapping ,
3132 ReplicatedMapping ,
3233)
33- from megatron .bridge .models .ernie .ernie_45_provider import Ernie45ModelProvider
34+ from megatron .bridge .models .gpt_provider import GPTModelProvider
35+
36+
37+ def _ernie45_decoder_block_spec (config : "GPTModelProvider" , vp_stage : int | None = None ):
38+ """Create a decoder block spec that respects ``moe_layer_freq``.
39+
40+ The default ``GPTModelProvider.transformer_layer_spec`` calls
41+ ``get_gpt_layer_with_transformer_engine_spec`` which returns a single
42+ MoE layer spec applied uniformly to ALL layers, ignoring
43+ ``moe_layer_freq``.
44+
45+ ERNIE 4.5 has mixed dense/MoE layers (layer 0 is dense, layers 1-N
46+ are MoE). This function uses ``get_gpt_decoder_block_spec`` which
47+ calls ``get_gpt_decoder_layer_specs`` — the code path that parses
48+ ``config.moe_layer_freq`` and creates per-layer specs (dense for
49+ pattern=0, MoE for pattern=1).
50+ """
51+ from megatron .core .models .gpt .gpt_layer_specs import get_gpt_decoder_block_spec
52+
53+ return get_gpt_decoder_block_spec (
54+ config = config ,
55+ use_transformer_engine = True ,
56+ vp_stage = vp_stage ,
57+ )
3458
3559
3660# HF class name string; avoids requiring the HF modeling module at import time.
@@ -109,7 +133,7 @@ def megatron_to_hf(self, megatron_weights, megatron_module):
109133@MegatronModelBridge .register_bridge (
110134 source = _ERNIE45_MOE_HF_CLASS_NAME ,
111135 target = GPTModel ,
112- provider = Ernie45ModelProvider ,
136+ provider = GPTModelProvider ,
113137 model_type = "ernie4_5_moe" ,
114138)
115139class Ernie45Bridge (MegatronModelBridge ):
@@ -146,16 +170,31 @@ def _get_num_experts(hf_config) -> int:
146170 return int (raw )
147171
148172 def provider_bridge (self , hf_pretrained ):
149- """Convert HuggingFace ERNIE 4.5 MoE config to Ernie45ModelProvider .
173+ """Convert HuggingFace ERNIE 4.5 MoE config to GPTModelProvider .
150174
151175 Uses super().provider_bridge() for standard CONFIG_MAPPING fields
152176 (hidden_size, num_layers, rope_theta, tie_word_embeddings, etc.)
153- and then overrides ERNIE-specific MoE settings that use non-standard
154- HF config field names (moe_num_experts, moe_k, moe_intermediate_size).
177+ and then overrides ERNIE-specific settings.
155178 """
156179 provider = super ().provider_bridge (hf_pretrained )
157180 hf_config = hf_pretrained .config
158181
182+ # --- Architecture overrides ---
183+ provider .normalization = "RMSNorm"
184+ provider .activation_func = F .silu
185+ provider .gated_linear_unit = True
186+ provider .add_bias_linear = False
187+ provider .add_qkv_bias = False
188+ provider .hidden_dropout = 0.0
189+ provider .position_embedding_type = "rope"
190+ provider .rotary_base = 500000.0
191+ provider .rotary_interleaved = True
192+ provider .moe_router_load_balancing_type = "aux_loss"
193+ # Mixed dense/MoE layers (layer 0 dense, rest MoE): use decoder
194+ # block spec that parses moe_layer_freq per-layer instead of the
195+ # default spec which applies MoE uniformly to all layers.
196+ provider .transformer_layer_spec = _ernie45_decoder_block_spec
197+
159198 # --- MoE settings (ERNIE uses non-standard HF config field names) ---
160199 num_experts = self ._get_num_experts (hf_config )
161200 provider .num_moe_experts = num_experts
@@ -179,17 +218,19 @@ def provider_bridge(self, hf_pretrained):
179218 # Router settings
180219 provider .moe_aux_loss_coeff = getattr (hf_config , "router_aux_loss_coef" , 0.001 )
181220
182- # MoE runtime settings
183- # NOTE: moe_grouped_gemm=False uses SequentialMLP (per-expert forward);
184- # True uses TEGroupedMLP which can produce NaN with certain TE versions.
185- provider .moe_grouped_gemm = False
221+ # MoE runtime settings — same as DeepSeek V3 (sigmoid routing + expert bias)
222+ provider .moe_grouped_gemm = True
186223 provider .moe_router_pre_softmax = False
187224 provider .moe_router_score_function = "sigmoid"
188225 provider .moe_router_enable_expert_bias = True
189226 provider .moe_router_dtype = "fp32"
190227 provider .moe_token_dispatcher_type = "alltoall"
191- provider .moe_permute_fusion = False
192- provider .gradient_accumulation_fusion = False
228+ provider .moe_permute_fusion = True
229+ # gradient_accumulation_fusion: use the auto-detected default from
230+ # GPTModelProvider (checks for APEX or TE availability) rather than
231+ # overriding it here. For conversion jobs (no backward pass) the
232+ # flag is irrelevant; for training it will be enabled whenever
233+ # the required extensions are present.
193234
194235 # Disable MTP (Multi-Token Prediction) for inference -- the ERNIE HF
195236 # model stores num_nextn_predict_layers in config but does not ship
0 commit comments