Skip to content

Commit 46e86b9

Browse files
committed
[bridge] fix: Fix dual-pool MoE EP export and simplify offset mappings
Fix Expert Parallelism (EP>1) conversion bug for ERNIE 4.5 VL dual-pool MoE vision experts. The previous _DualPoolExpertMixin had a wrong offset formula, and _DualPoolAutoMapping was dead code due to AutoMapping's internal delegation bypassing the mixin override. Simplify from 5 classes to 2 classes + 2 helper functions: - _OffsetGatedMLPMapping: for vision gate/up_proj (GatedMLP) - _OffsetRowParallelMapping: for vision down_proj (replaces AutoMapping to avoid delegation bypass) - _offset_gather_from_ep_ranks(): shared EP gather with pool offset - _resolve_with_offset(): shared wildcard resolution with HF-side shift Also fix test tokenizer fallback to use publicly available Qwen/Qwen3-0.6B.
1 parent 7f1c6b1 commit 46e86b9

2 files changed

Lines changed: 137 additions & 149 deletions

File tree

src/megatron/bridge/models/ernie_vl/ernie45_vl_bridge.py

Lines changed: 136 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
GatedMLPMapping,
9898
QKVMapping,
9999
ReplicatedMapping,
100+
RowParallelMapping,
100101
)
101102
from megatron.bridge.models.ernie_vl.ernie45_vl_provider import Ernie45VLModelProvider
102103
from megatron.bridge.models.ernie_vl.modeling_ernie45_vl.model import Ernie45VLModel
@@ -112,186 +113,173 @@
112113

113114

114115
# ---------------------------------------------------------------------------
115-
# Dual-pool expert EP export support
116+
# Vision pool expert offset mappings
116117
# ---------------------------------------------------------------------------
117118
# In ERNIE VL's dual-pool MoE, vision expert j maps to HF flat expert
118-
# (j + num_text_experts). The _Offset*Mapping classes handle the HF→Megatron
119-
# direction by shifting expert indices during resolve(). For the Megatron→HF
120-
# export direction with EP > 1, gather_from_ep_ranks must reconstruct the
121-
# pool-offset HF expert indices. The mixin below overrides that method so
122-
# the offset logic stays in this file instead of the base param_mapping.py.
119+
# (j + num_text_experts). Two offset-aware mapping classes handle both
120+
# directions:
121+
# - resolve(): shifts the expert index wildcard for the HF side only
122+
# - gather_from_ep_ranks(): reconstructs offset HF indices during EP export
123123
# ---------------------------------------------------------------------------
124124

125125

126-
class _DualPoolExpertMixin:
127-
"""Mixin that adds pool-offset awareness to ``gather_from_ep_ranks``.
126+
def _offset_gather_from_ep_ranks(
127+
mapping,
128+
megatron_weights: Optional[torch.Tensor],
129+
megatron_module,
130+
hf_param_name: Optional[str] = None,
131+
) -> Dict[str, torch.Tensor]:
132+
"""EP all-gather with pool offset for dual-pool MoE vision experts.
128133
129-
For dual-pool MoE, the resolved HF param name already carries the offset
130-
(e.g. ``experts.67.weight`` for vision expert 3 with 64 text experts).
131-
The mixin computes ``pool_offset = hf_index - megatron_local_index`` and
132-
uses it when constructing the HF names for every EP rank.
134+
Per EP rank *i* the HF expert index is:
135+
expert_offset + local_expert_number + num_experts_per_rank * i
133136
"""
134-
135-
def gather_from_ep_ranks(
136-
self,
137-
megatron_weights: Optional[torch.Tensor],
138-
megatron_module,
139-
hf_param_name: Optional[str] = None,
140-
) -> Dict[str, torch.Tensor]:
141-
if self.ep_size == 1:
142-
return {str(hf_param_name): megatron_weights}
143-
144-
if megatron_module is None:
145-
num_experts_per_rank = self.broadcast_obj_from_pp_rank(None, "num_experts_per_rank")
146-
else:
147-
model_config = self._get_config(megatron_module)
148-
num_experts = model_config.num_moe_experts
149-
num_experts_per_rank = num_experts // self.ep_size
150-
num_experts_per_rank = self.broadcast_obj_from_pp_rank(num_experts_per_rank, "num_experts_per_rank")
151-
152-
global_expert_number = extract_expert_number_from_param(self.megatron_param)
153-
local_expert_number = global_expert_number % num_experts_per_rank
154-
155-
# Compute pool offset from the resolved HF param name.
156-
hf_expert_match = re.search(r"experts\.(\d+)", str(hf_param_name))
157-
if hf_expert_match:
158-
hf_expert_number = int(hf_expert_match.group(1))
159-
pool_offset = hf_expert_number - local_expert_number
137+
if mapping.ep_size == 1:
138+
return {str(hf_param_name): megatron_weights}
139+
140+
if megatron_module is None:
141+
num_experts_per_rank = mapping.broadcast_obj_from_pp_rank(None, "num_experts_per_rank")
142+
else:
143+
model_config = mapping._get_config(megatron_module)
144+
num_experts = model_config.num_moe_experts
145+
num_experts_per_rank = num_experts // mapping.ep_size
146+
num_experts_per_rank = mapping.broadcast_obj_from_pp_rank(num_experts_per_rank, "num_experts_per_rank")
147+
148+
global_expert_number = extract_expert_number_from_param(mapping.megatron_param)
149+
local_expert_number = global_expert_number % num_experts_per_rank
150+
151+
gathered_expert_param_names = [
152+
re.sub(
153+
r"experts\.(\d+)",
154+
f"experts.{mapping._expert_offset + local_expert_number + num_experts_per_rank * i}",
155+
str(hf_param_name),
156+
)
157+
for i in range(mapping.ep_size)
158+
]
159+
assert str(hf_param_name) in gathered_expert_param_names, (
160+
f"hf_param_name {hf_param_name} not in gathered_expert_param_names {gathered_expert_param_names}"
161+
)
162+
163+
gathered_weights = [torch.empty_like(megatron_weights) for _ in range(mapping.ep_size)]
164+
torch.distributed.all_gather(gathered_weights, megatron_weights, group=mapping.ep_group)
165+
166+
weights_dict: Dict[str, torch.Tensor] = {}
167+
for i, param_name in enumerate(gathered_expert_param_names):
168+
if param_name in weights_dict:
169+
weights_dict[param_name] = torch.cat([weights_dict[param_name], gathered_weights[i].unsqueeze(0)], dim=0)
160170
else:
161-
pool_offset = 0
162-
163-
gathered_expert_param_names = [
164-
re.sub(
165-
r"experts\.(\d+)",
166-
f"experts.{pool_offset + int(local_expert_number) + num_experts_per_rank * i}",
167-
str(hf_param_name),
168-
)
169-
for i in range(self.ep_size)
170-
]
171-
assert str(hf_param_name) in gathered_expert_param_names
172-
173-
gathered_weights = [torch.empty_like(megatron_weights) for _ in range(self.ep_size)]
174-
torch.distributed.all_gather(gathered_weights, megatron_weights, group=self.ep_group)
175-
176-
weights_dict: Dict[str, torch.Tensor] = {}
177-
for i, param_name in enumerate(gathered_expert_param_names):
178-
if param_name in weights_dict:
179-
weights_dict[param_name] = torch.cat(
180-
[weights_dict[param_name], gathered_weights[i].unsqueeze(0)], dim=0
181-
)
182-
else:
183-
weights_dict[param_name] = gathered_weights[i].unsqueeze(0)
184-
for param_name in weights_dict:
185-
weights_dict[param_name] = weights_dict[param_name].squeeze()
186-
return weights_dict
187-
188-
189-
class _DualPoolGatedMLPMapping(_DualPoolExpertMixin, GatedMLPMapping):
190-
"""GatedMLPMapping with pool-offset-aware EP export."""
191-
192-
pass
193-
194-
195-
class _DualPoolAutoMapping(_DualPoolExpertMixin, AutoMapping):
196-
"""AutoMapping with pool-offset-aware EP export."""
171+
weights_dict[param_name] = gathered_weights[i].unsqueeze(0)
172+
for param_name in weights_dict:
173+
weights_dict[param_name] = weights_dict[param_name].squeeze()
174+
return weights_dict
175+
176+
177+
def _resolve_with_offset(
178+
megatron_pattern: str,
179+
hf_pattern,
180+
captures: Tuple[str, ...],
181+
expert_offset: int,
182+
) -> Tuple[str, ...]:
183+
"""Resolve wildcard captures, shifting the 2nd capture (expert index) for HF side."""
184+
if expert_offset and len(captures) >= 2:
185+
shifted_expert = str(int(captures[1]) + expert_offset)
186+
hf_captures = (captures[0], shifted_expert) + captures[2:]
187+
else:
188+
hf_captures = captures
189+
190+
resolved_megatron = megatron_pattern
191+
idx = 0
192+
while "**" in resolved_megatron and idx < len(captures):
193+
resolved_megatron = resolved_megatron.replace("**", captures[idx], 1)
194+
idx += 1
195+
while "*" in resolved_megatron and idx < len(captures):
196+
resolved_megatron = resolved_megatron.replace("*", captures[idx], 1)
197+
idx += 1
198+
199+
if isinstance(hf_pattern, dict):
200+
resolved_hf: dict | str = {}
201+
for k, v in hf_pattern.items():
202+
resolved_v = v
203+
idx = 0
204+
while "**" in resolved_v and idx < len(hf_captures):
205+
resolved_v = resolved_v.replace("**", hf_captures[idx], 1)
206+
idx += 1
207+
while "*" in resolved_v and idx < len(hf_captures):
208+
resolved_v = resolved_v.replace("*", hf_captures[idx], 1)
209+
idx += 1
210+
resolved_hf[k] = resolved_v
211+
else:
212+
resolved_hf = hf_pattern
213+
idx = 0
214+
while "**" in resolved_hf and idx < len(hf_captures):
215+
resolved_hf = resolved_hf.replace("**", hf_captures[idx], 1)
216+
idx += 1
217+
while "*" in resolved_hf and idx < len(hf_captures):
218+
resolved_hf = resolved_hf.replace("*", hf_captures[idx], 1)
219+
idx += 1
197220

198-
pass
221+
return resolved_megatron, resolved_hf
199222

200223

201224
class _OffsetGatedMLPMapping(GatedMLPMapping):
202-
"""GatedMLPMapping that adds a fixed offset to the expert index wildcard.
225+
"""GatedMLPMapping with expert index offset for vision pool.
203226
204-
Used for vision MoE experts where Megatron expert index j maps to
205-
HF flat expert index (j + offset). The offset is the number of text
206-
experts, so vision expert 0 maps to HF expert N, etc.
207-
208-
The Megatron pattern has two ``*`` wildcards: layer index and expert index.
209-
The HF patterns also have two ``*`` wildcards. During ``resolve()``, the
210-
second capture (expert index) is shifted by ``expert_offset`` **only** for
211-
HF patterns. The Megatron side keeps the original (unshifted) expert index.
227+
Handles both directions:
228+
- resolve(): shifts expert index for HF side only
229+
- gather_from_ep_ranks(): reconstructs offset HF indices during EP export
212230
"""
213231

214232
def __init__(self, megatron_param: str, gate: str, up: str, expert_offset: int = 0):
215233
super().__init__(megatron_param=megatron_param, gate=gate, up=up)
216234
self._expert_offset = expert_offset
217235

218236
def resolve(self, captures: Tuple[str, ...]):
219-
"""Override resolve to apply the expert index offset only to HF side."""
220-
if self._expert_offset and len(captures) >= 2:
221-
# Use original captures for Megatron, shifted for HF
222-
shifted_expert = str(int(captures[1]) + self._expert_offset)
223-
hf_captures = (captures[0], shifted_expert) + captures[2:]
224-
else:
225-
hf_captures = captures
226-
# Resolve Megatron side with original captures
227-
resolved_megatron_param = self.megatron_param
228-
idx = 0
229-
while "**" in resolved_megatron_param and idx < len(captures):
230-
resolved_megatron_param = resolved_megatron_param.replace("**", captures[idx], 1)
231-
idx += 1
232-
while "*" in resolved_megatron_param and idx < len(captures):
233-
resolved_megatron_param = resolved_megatron_param.replace("*", captures[idx], 1)
234-
idx += 1
235-
# Resolve HF side with shifted captures
236-
resolved_hf_param = {}
237-
for k, v in self.hf_param.items():
238-
resolved_v = v
239-
idx = 0
240-
while "**" in resolved_v and idx < len(hf_captures):
241-
resolved_v = resolved_v.replace("**", hf_captures[idx], 1)
242-
idx += 1
243-
while "*" in resolved_v and idx < len(hf_captures):
244-
resolved_v = resolved_v.replace("*", hf_captures[idx], 1)
245-
idx += 1
246-
resolved_hf_param[k] = resolved_v
247-
return _DualPoolGatedMLPMapping(
248-
megatron_param=resolved_megatron_param,
249-
gate=resolved_hf_param["gate"],
250-
up=resolved_hf_param["up"],
237+
resolved_megatron, resolved_hf = _resolve_with_offset(
238+
self.megatron_param,
239+
self.hf_param,
240+
captures,
241+
self._expert_offset,
242+
)
243+
return _OffsetGatedMLPMapping(
244+
megatron_param=resolved_megatron,
245+
gate=resolved_hf["gate"],
246+
up=resolved_hf["up"],
247+
expert_offset=self._expert_offset,
251248
)
252249

250+
def gather_from_ep_ranks(self, megatron_weights, megatron_module, hf_param_name=None):
251+
return _offset_gather_from_ep_ranks(self, megatron_weights, megatron_module, hf_param_name)
253252

254-
class _OffsetAutoMapping(AutoMapping):
255-
"""AutoMapping that adds a fixed offset to the expert index wildcard.
256253

257-
Same concept as _OffsetGatedMLPMapping but for simple 1:1 mappings
258-
(e.g. down_proj). The offset is applied only to HF side, not Megatron.
254+
class _OffsetRowParallelMapping(RowParallelMapping):
255+
"""RowParallelMapping with expert index offset for vision pool.
256+
257+
Used for vision expert down_proj (linear_fc2), which is always
258+
row-parallel in SequentialMLP. Using explicit RowParallelMapping
259+
avoids the AutoMapping delegation issue where the delegate's
260+
gather_from_ep_ranks bypasses offset logic.
259261
"""
260262

261263
def __init__(self, megatron_param: str, hf_param: str, expert_offset: int = 0):
262264
super().__init__(megatron_param=megatron_param, hf_param=hf_param)
263265
self._expert_offset = expert_offset
264266

265267
def resolve(self, captures: Tuple[str, ...]):
266-
"""Override resolve to apply the expert index offset only to HF side."""
267-
if self._expert_offset and len(captures) >= 2:
268-
shifted_expert = str(int(captures[1]) + self._expert_offset)
269-
hf_captures = (captures[0], shifted_expert) + captures[2:]
270-
else:
271-
hf_captures = captures
272-
# Resolve Megatron side with original captures
273-
resolved_megatron_param = self.megatron_param
274-
idx = 0
275-
while "**" in resolved_megatron_param and idx < len(captures):
276-
resolved_megatron_param = resolved_megatron_param.replace("**", captures[idx], 1)
277-
idx += 1
278-
while "*" in resolved_megatron_param and idx < len(captures):
279-
resolved_megatron_param = resolved_megatron_param.replace("*", captures[idx], 1)
280-
idx += 1
281-
# Resolve HF side with shifted captures
282-
resolved_hf_param = self.hf_param
283-
idx = 0
284-
while "**" in resolved_hf_param and idx < len(hf_captures):
285-
resolved_hf_param = resolved_hf_param.replace("**", hf_captures[idx], 1)
286-
idx += 1
287-
while "*" in resolved_hf_param and idx < len(hf_captures):
288-
resolved_hf_param = resolved_hf_param.replace("*", hf_captures[idx], 1)
289-
idx += 1
290-
return _DualPoolAutoMapping(
291-
megatron_param=resolved_megatron_param,
292-
hf_param=resolved_hf_param,
268+
resolved_megatron, resolved_hf = _resolve_with_offset(
269+
self.megatron_param,
270+
self.hf_param,
271+
captures,
272+
self._expert_offset,
273+
)
274+
return _OffsetRowParallelMapping(
275+
megatron_param=resolved_megatron,
276+
hf_param=resolved_hf,
277+
expert_offset=self._expert_offset,
293278
)
294279

280+
def gather_from_ep_ranks(self, megatron_weights, megatron_module, hf_param_name=None):
281+
return _offset_gather_from_ep_ranks(self, megatron_weights, megatron_module, hf_param_name)
282+
295283

296284
class _ConcatBiasMapping(AutoMapping):
297285
"""Mapping for the concatenated text+vision expert bias tensor.
@@ -881,7 +869,7 @@ def mapping_registry(self) -> MegatronMappingRegistry:
881869
up="model.layers.*.mlp.experts.*.up_proj.weight",
882870
expert_offset=num_experts,
883871
),
884-
_OffsetAutoMapping(
872+
_OffsetRowParallelMapping(
885873
megatron_param=(
886874
"language_model.decoder.layers.*.mlp.vision_moe_layer"
887875
".experts.local_experts.*.linear_fc2.weight"

tests/functional_tests/test_groups/models/ernie_vl/test_ernie45_vl_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def ernie45_vl_toy_model_path(self, tmp_path_factory):
155155
if _local_tokenizer_path.exists():
156156
tokenizer = AutoTokenizer.from_pretrained(str(_local_tokenizer_path))
157157
else:
158-
tokenizer = AutoTokenizer.from_pretrained("baidu/ERNIE-4.5-VL-28B-A3B-Instruct")
158+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
159159
tokenizer.save_pretrained(model_dir)
160160
except (OSError, ValueError):
161161
# Create a functional dummy tokenizer from a readily available model

0 commit comments

Comments
 (0)