Skip to content

Commit a817a99

Browse files
author
root
committed
[bridge] refactor: Address review feedback for ERNIE 4.5 MoE and VL bridges
- Remove Ernie45ModelProvider class, fold config into provider_bridge() - Fix NaN issue: enable moe_grouped_gemm + moe_permute_fusion (aligned with DeepSeek V3) - Remove gradient_accumulation_fusion override, use auto-detected default - Move pool_offset logic from base param_mapping.py to _DualPoolExpertMixin in ernie45_vl_bridge.py - Move VL modeling files into modeling_ernie45_vl/ subpackage - Move verification scripts from tests/functional_tests/ to examples/models/vlm/ernie_vl/ - Improve task=None comment in model_bridge.py to explain MTP layers scenario - Clean up is_expert docstring in param_mapping.py (remove ERNIE-specific references)
1 parent a0e4541 commit a817a99

22 files changed

Lines changed: 188 additions & 141 deletions

tests/functional_tests/test_groups/models/ernie_vl/ernie45_vl_fwd_bwd.py renamed to examples/models/vlm/ernie_vl/ernie45_vl_fwd_bwd.py

File renamed without changes.

tests/functional_tests/test_groups/models/ernie_vl/ernie45_vl_logit_compare.py renamed to examples/models/vlm/ernie_vl/ernie45_vl_logit_compare.py

File renamed without changes.

tests/functional_tests/test_groups/models/ernie_vl/ernie45_vl_vit_compare.py renamed to examples/models/vlm/ernie_vl/ernie45_vl_vit_compare.py

File renamed without changes.

tests/functional_tests/test_groups/models/ernie_vl/ernie45_vl_vit_debug.py renamed to examples/models/vlm/ernie_vl/ernie45_vl_vit_debug.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ def main():
184184
mg_emb = torch.cat((mg_rotary.reshape(total_patches, 1, 1, -1),
185185
mg_rotary.reshape(total_patches, 1, 1, -1)), dim=-1)
186186
mg_cos = mg_emb.cos().flatten()
187-
mg_sin = mg_emb.sin().flatten()
188187

189188
cos_sim_rope = F.cosine_similarity(
190189
hf_rotary.float().flatten().unsqueeze(0),
@@ -308,13 +307,6 @@ def main():
308307
mg_qkv_out, _ = mg_linear_qkv(mg_patch_out[:, None]) # [N, 1, 3*hidden]
309308
mg_qkv_flat = mg_qkv_out.squeeze(1)
310309

311-
cos_sim_ln = F.cosine_similarity(
312-
hf_normed.float().flatten().unsqueeze(0),
313-
# We can't easily extract just the LN output from the fused module
314-
# so just compare QKV output
315-
torch.zeros(1).unsqueeze(0), # placeholder
316-
).item()
317-
318310
cos_sim_qkv = F.cosine_similarity(
319311
hf_qkv.float().flatten().unsqueeze(0),
320312
mg_qkv_flat.float().flatten().unsqueeze(0)

tests/functional_tests/test_groups/models/ernie_vl/hf_loss_check.py renamed to examples/models/vlm/ernie_vl/hf_loss_check.py

File renamed without changes.

src/megatron/bridge/models/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
)
3535
from megatron.bridge.models.ernie import (
3636
Ernie45Bridge,
37-
Ernie45ModelProvider,
3837
)
3938
from megatron.bridge.models.ernie_vl import (
4039
Ernie45VLBridge,
@@ -175,7 +174,6 @@
175174
"DeepSeekV3Bridge",
176175
# ERNIE Text-Only Models
177176
"Ernie45Bridge",
178-
"Ernie45ModelProvider",
179177
# ERNIE VL Models
180178
"Ernie45VLBridge",
181179
"Ernie45VLModel",

src/megatron/bridge/models/conversion/model_bridge.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,13 @@ def load_weights_hf_to_megatron(
925925

926926
_hf_import_cache: Dict[str, torch.Tensor] = {}
927927
for task in self._with_progress_tracking(hf_to_megatron_tasks, description):
928-
# None task means no mapping exists for this param (e.g. MTP layers without bridge mappings)
928+
# A None task means the Megatron model has a parameter for which no
929+
# HF↔Megatron mapping was registered. This is expected when the HF
930+
# config declares optional layers (e.g. num_nextn_predict_layers for
931+
# MTP) but the HF checkpoint ships no weights for them; the bridge
932+
# intentionally omits mappings so these layers keep their default
933+
# (random-init) weights. Skipping here is safe — it is NOT a
934+
# missing-mapping bug.
929935
if task is None:
930936
continue
931937
# None means megatron module not on current rank, skip if this task is not going to happen

src/megatron/bridge/models/conversion/param_mapping.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ def is_expert(self) -> bool:
173173
174174
Matches both TEGroupedMLP (.experts.linear_fc) and
175175
SequentialMLP (.experts.local_experts.*.linear_fc) patterns.
176-
Also matches dual-pool MoE patterns where an intermediate module name
177-
appears between .mlp. and .experts. (e.g. .mlp.text_moe_layer.experts.).
176+
Uses ``.experts.`` rather than ``.mlp.experts.`` so models with an
177+
intermediate sub-module (e.g. ``.mlp.<pool>.experts.``) are matched too.
178178
"""
179179
return ".experts.linear_fc" in self.megatron_param or ".experts.local_experts." in self.megatron_param
180180

@@ -664,9 +664,6 @@ def gather_from_ep_ranks(
664664
Rank 0: [0, 1, 2, 3], Rank 1: [4, 5, 6, 7].
665665
If the local index L = 0 (derived from the param name), this returns:
666666
{"...experts.0.weight": tensor_from_rank0, "...experts.4.weight": tensor_from_rank1}
667-
- Dual-pool MoE with pool offset P (e.g., P=64 for vision pool):
668-
Vision expert L=0 has HF index P+0=64. With S=2, E/S=32:
669-
{"...experts.64.weight": tensor_from_rank0, "...experts.96.weight": tensor_from_rank1}
670667
671668
Args:
672669
megatron_weights (Optional[torch.Tensor]): The local expert weight tensor
@@ -697,22 +694,11 @@ def gather_from_ep_ranks(
697694
global_expert_number = extract_expert_number_from_param(self.megatron_param)
698695
local_expert_number = global_expert_number % num_experts_per_rank
699696

700-
# Compute pool offset from HF param name. For dual-pool MoE (e.g., ERNIE VL),
701-
# vision expert 3 maps to HF expert 67 (offset=64). The HF param name already
702-
# contains the correct offset-shifted index from _OffsetMapping.resolve().
703-
# For standard single-pool MoE, pool_offset is always 0.
704-
hf_expert_match = re.search(r"experts\.(\d+)", str(hf_param_name))
705-
if hf_expert_match:
706-
hf_expert_number = int(hf_expert_match.group(1))
707-
pool_offset = hf_expert_number - local_expert_number
708-
else:
709-
pool_offset = 0
710-
711697
# Compute global expert numbers for all EP ranks
712698
# use regex to replace the local expert number with the global expert number
713699
gathered_expert_param_names = [
714700
re.sub(
715-
r"experts\.(\d+)", f"experts.{pool_offset + int(local_expert_number) + num_experts_per_rank * i}", str(hf_param_name)
701+
r"experts\.(\d+)", f"experts.{int(local_expert_number) + num_experts_per_rank * i}", str(hf_param_name)
716702
)
717703
for i in range(self.ep_size)
718704
]

src/megatron/bridge/models/ernie/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# limitations under the License.
1414

1515
from megatron.bridge.models.ernie.ernie_45_bridge import Ernie45Bridge
16-
from megatron.bridge.models.ernie.ernie_45_provider import Ernie45ModelProvider
1716

1817

1918
__all__ = [
2019
"Ernie45Bridge",
21-
"Ernie45ModelProvider",
2220
]

src/megatron/bridge/models/ernie/ernie_45_bridge.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
shared experts, expert bias for aux-free load balancing).
2121
"""
2222

23+
import torch.nn.functional as F
2324
from megatron.core.models.gpt.gpt_model import GPTModel
2425

2526
from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
@@ -30,7 +31,30 @@
3031
QKVMapping,
3132
ReplicatedMapping,
3233
)
33-
from megatron.bridge.models.ernie.ernie_45_provider import Ernie45ModelProvider
34+
from megatron.bridge.models.gpt_provider import GPTModelProvider
35+
36+
37+
def _ernie45_decoder_block_spec(config: "GPTModelProvider", vp_stage: int | None = None):
38+
"""Create a decoder block spec that respects ``moe_layer_freq``.
39+
40+
The default ``GPTModelProvider.transformer_layer_spec`` calls
41+
``get_gpt_layer_with_transformer_engine_spec`` which returns a single
42+
MoE layer spec applied uniformly to ALL layers, ignoring
43+
``moe_layer_freq``.
44+
45+
ERNIE 4.5 has mixed dense/MoE layers (layer 0 is dense, layers 1-N
46+
are MoE). This function uses ``get_gpt_decoder_block_spec`` which
47+
calls ``get_gpt_decoder_layer_specs`` — the code path that parses
48+
``config.moe_layer_freq`` and creates per-layer specs (dense for
49+
pattern=0, MoE for pattern=1).
50+
"""
51+
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
52+
53+
return get_gpt_decoder_block_spec(
54+
config=config,
55+
use_transformer_engine=True,
56+
vp_stage=vp_stage,
57+
)
3458

3559

3660
# HF class name string; avoids requiring the HF modeling module at import time.
@@ -109,7 +133,7 @@ def megatron_to_hf(self, megatron_weights, megatron_module):
109133
@MegatronModelBridge.register_bridge(
110134
source=_ERNIE45_MOE_HF_CLASS_NAME,
111135
target=GPTModel,
112-
provider=Ernie45ModelProvider,
136+
provider=GPTModelProvider,
113137
model_type="ernie4_5_moe",
114138
)
115139
class Ernie45Bridge(MegatronModelBridge):
@@ -146,16 +170,31 @@ def _get_num_experts(hf_config) -> int:
146170
return int(raw)
147171

148172
def provider_bridge(self, hf_pretrained):
149-
"""Convert HuggingFace ERNIE 4.5 MoE config to Ernie45ModelProvider.
173+
"""Convert HuggingFace ERNIE 4.5 MoE config to GPTModelProvider.
150174
151175
Uses super().provider_bridge() for standard CONFIG_MAPPING fields
152176
(hidden_size, num_layers, rope_theta, tie_word_embeddings, etc.)
153-
and then overrides ERNIE-specific MoE settings that use non-standard
154-
HF config field names (moe_num_experts, moe_k, moe_intermediate_size).
177+
and then overrides ERNIE-specific settings.
155178
"""
156179
provider = super().provider_bridge(hf_pretrained)
157180
hf_config = hf_pretrained.config
158181

182+
# --- Architecture overrides ---
183+
provider.normalization = "RMSNorm"
184+
provider.activation_func = F.silu
185+
provider.gated_linear_unit = True
186+
provider.add_bias_linear = False
187+
provider.add_qkv_bias = False
188+
provider.hidden_dropout = 0.0
189+
provider.position_embedding_type = "rope"
190+
provider.rotary_base = 500000.0
191+
provider.rotary_interleaved = True
192+
provider.moe_router_load_balancing_type = "aux_loss"
193+
# Mixed dense/MoE layers (layer 0 dense, rest MoE): use decoder
194+
# block spec that parses moe_layer_freq per-layer instead of the
195+
# default spec which applies MoE uniformly to all layers.
196+
provider.transformer_layer_spec = _ernie45_decoder_block_spec
197+
159198
# --- MoE settings (ERNIE uses non-standard HF config field names) ---
160199
num_experts = self._get_num_experts(hf_config)
161200
provider.num_moe_experts = num_experts
@@ -179,17 +218,19 @@ def provider_bridge(self, hf_pretrained):
179218
# Router settings
180219
provider.moe_aux_loss_coeff = getattr(hf_config, "router_aux_loss_coef", 0.001)
181220

182-
# MoE runtime settings
183-
# NOTE: moe_grouped_gemm=False uses SequentialMLP (per-expert forward);
184-
# True uses TEGroupedMLP which can produce NaN with certain TE versions.
185-
provider.moe_grouped_gemm = False
221+
# MoE runtime settings — same as DeepSeek V3 (sigmoid routing + expert bias)
222+
provider.moe_grouped_gemm = True
186223
provider.moe_router_pre_softmax = False
187224
provider.moe_router_score_function = "sigmoid"
188225
provider.moe_router_enable_expert_bias = True
189226
provider.moe_router_dtype = "fp32"
190227
provider.moe_token_dispatcher_type = "alltoall"
191-
provider.moe_permute_fusion = False
192-
provider.gradient_accumulation_fusion = False
228+
provider.moe_permute_fusion = True
229+
# gradient_accumulation_fusion: use the auto-detected default from
230+
# GPTModelProvider (checks for APEX or TE availability) rather than
231+
# overriding it here. For conversion jobs (no backward pass) the
232+
# flag is irrelevant; for training it will be enabled whenever
233+
# the required extensions are present.
193234

194235
# Disable MTP (Multi-Token Prediction) for inference -- the ERNIE HF
195236
# model stores num_nextn_predict_layers in config but does not ship

0 commit comments

Comments
 (0)