|
97 | 97 | GatedMLPMapping, |
98 | 98 | QKVMapping, |
99 | 99 | ReplicatedMapping, |
| 100 | + RowParallelMapping, |
100 | 101 | ) |
101 | 102 | from megatron.bridge.models.ernie_vl.ernie45_vl_provider import Ernie45VLModelProvider |
102 | 103 | from megatron.bridge.models.ernie_vl.modeling_ernie45_vl.model import Ernie45VLModel |
|
112 | 113 |
|
113 | 114 |
|
114 | 115 | # --------------------------------------------------------------------------- |
115 | | -# Dual-pool expert EP export support |
| 116 | +# Vision pool expert offset mappings |
116 | 117 | # --------------------------------------------------------------------------- |
117 | 118 | # In ERNIE VL's dual-pool MoE, vision expert j maps to HF flat expert |
118 | | -# (j + num_text_experts). The _Offset*Mapping classes handle the HF→Megatron |
119 | | -# direction by shifting expert indices during resolve(). For the Megatron→HF |
120 | | -# export direction with EP > 1, gather_from_ep_ranks must reconstruct the |
121 | | -# pool-offset HF expert indices. The mixin below overrides that method so |
122 | | -# the offset logic stays in this file instead of the base param_mapping.py. |
| 119 | +# (j + num_text_experts). Two offset-aware mapping classes handle both |
| 120 | +# directions: |
| 121 | +# - resolve(): shifts the expert index wildcard for the HF side only |
| 122 | +# - gather_from_ep_ranks(): reconstructs offset HF indices during EP export |
123 | 123 | # --------------------------------------------------------------------------- |
124 | 124 |
|
125 | 125 |
|
126 | | -class _DualPoolExpertMixin: |
127 | | - """Mixin that adds pool-offset awareness to ``gather_from_ep_ranks``. |
| 126 | +def _offset_gather_from_ep_ranks( |
| 127 | + mapping, |
| 128 | + megatron_weights: Optional[torch.Tensor], |
| 129 | + megatron_module, |
| 130 | + hf_param_name: Optional[str] = None, |
| 131 | +) -> Dict[str, torch.Tensor]: |
| 132 | + """EP all-gather with pool offset for dual-pool MoE vision experts. |
128 | 133 |
|
129 | | - For dual-pool MoE, the resolved HF param name already carries the offset |
130 | | - (e.g. ``experts.67.weight`` for vision expert 3 with 64 text experts). |
131 | | - The mixin computes ``pool_offset = hf_index - megatron_local_index`` and |
132 | | - uses it when constructing the HF names for every EP rank. |
| 134 | + Per EP rank *i* the HF expert index is: |
| 135 | + expert_offset + local_expert_number + num_experts_per_rank * i |
133 | 136 | """ |
134 | | - |
135 | | - def gather_from_ep_ranks( |
136 | | - self, |
137 | | - megatron_weights: Optional[torch.Tensor], |
138 | | - megatron_module, |
139 | | - hf_param_name: Optional[str] = None, |
140 | | - ) -> Dict[str, torch.Tensor]: |
141 | | - if self.ep_size == 1: |
142 | | - return {str(hf_param_name): megatron_weights} |
143 | | - |
144 | | - if megatron_module is None: |
145 | | - num_experts_per_rank = self.broadcast_obj_from_pp_rank(None, "num_experts_per_rank") |
146 | | - else: |
147 | | - model_config = self._get_config(megatron_module) |
148 | | - num_experts = model_config.num_moe_experts |
149 | | - num_experts_per_rank = num_experts // self.ep_size |
150 | | - num_experts_per_rank = self.broadcast_obj_from_pp_rank(num_experts_per_rank, "num_experts_per_rank") |
151 | | - |
152 | | - global_expert_number = extract_expert_number_from_param(self.megatron_param) |
153 | | - local_expert_number = global_expert_number % num_experts_per_rank |
154 | | - |
155 | | - # Compute pool offset from the resolved HF param name. |
156 | | - hf_expert_match = re.search(r"experts\.(\d+)", str(hf_param_name)) |
157 | | - if hf_expert_match: |
158 | | - hf_expert_number = int(hf_expert_match.group(1)) |
159 | | - pool_offset = hf_expert_number - local_expert_number |
| 137 | + if mapping.ep_size == 1: |
| 138 | + return {str(hf_param_name): megatron_weights} |
| 139 | + |
| 140 | + if megatron_module is None: |
| 141 | + num_experts_per_rank = mapping.broadcast_obj_from_pp_rank(None, "num_experts_per_rank") |
| 142 | + else: |
| 143 | + model_config = mapping._get_config(megatron_module) |
| 144 | + num_experts = model_config.num_moe_experts |
| 145 | + num_experts_per_rank = num_experts // mapping.ep_size |
| 146 | + num_experts_per_rank = mapping.broadcast_obj_from_pp_rank(num_experts_per_rank, "num_experts_per_rank") |
| 147 | + |
| 148 | + global_expert_number = extract_expert_number_from_param(mapping.megatron_param) |
| 149 | + local_expert_number = global_expert_number % num_experts_per_rank |
| 150 | + |
| 151 | + gathered_expert_param_names = [ |
| 152 | + re.sub( |
| 153 | + r"experts\.(\d+)", |
| 154 | + f"experts.{mapping._expert_offset + local_expert_number + num_experts_per_rank * i}", |
| 155 | + str(hf_param_name), |
| 156 | + ) |
| 157 | + for i in range(mapping.ep_size) |
| 158 | + ] |
| 159 | + assert str(hf_param_name) in gathered_expert_param_names, ( |
| 160 | + f"hf_param_name {hf_param_name} not in gathered_expert_param_names {gathered_expert_param_names}" |
| 161 | + ) |
| 162 | + |
| 163 | + gathered_weights = [torch.empty_like(megatron_weights) for _ in range(mapping.ep_size)] |
| 164 | + torch.distributed.all_gather(gathered_weights, megatron_weights, group=mapping.ep_group) |
| 165 | + |
| 166 | + weights_dict: Dict[str, torch.Tensor] = {} |
| 167 | + for i, param_name in enumerate(gathered_expert_param_names): |
| 168 | + if param_name in weights_dict: |
| 169 | + weights_dict[param_name] = torch.cat([weights_dict[param_name], gathered_weights[i].unsqueeze(0)], dim=0) |
160 | 170 | else: |
161 | | - pool_offset = 0 |
162 | | - |
163 | | - gathered_expert_param_names = [ |
164 | | - re.sub( |
165 | | - r"experts\.(\d+)", |
166 | | - f"experts.{pool_offset + int(local_expert_number) + num_experts_per_rank * i}", |
167 | | - str(hf_param_name), |
168 | | - ) |
169 | | - for i in range(self.ep_size) |
170 | | - ] |
171 | | - assert str(hf_param_name) in gathered_expert_param_names |
172 | | - |
173 | | - gathered_weights = [torch.empty_like(megatron_weights) for _ in range(self.ep_size)] |
174 | | - torch.distributed.all_gather(gathered_weights, megatron_weights, group=self.ep_group) |
175 | | - |
176 | | - weights_dict: Dict[str, torch.Tensor] = {} |
177 | | - for i, param_name in enumerate(gathered_expert_param_names): |
178 | | - if param_name in weights_dict: |
179 | | - weights_dict[param_name] = torch.cat( |
180 | | - [weights_dict[param_name], gathered_weights[i].unsqueeze(0)], dim=0 |
181 | | - ) |
182 | | - else: |
183 | | - weights_dict[param_name] = gathered_weights[i].unsqueeze(0) |
184 | | - for param_name in weights_dict: |
185 | | - weights_dict[param_name] = weights_dict[param_name].squeeze() |
186 | | - return weights_dict |
187 | | - |
188 | | - |
189 | | -class _DualPoolGatedMLPMapping(_DualPoolExpertMixin, GatedMLPMapping): |
190 | | - """GatedMLPMapping with pool-offset-aware EP export.""" |
191 | | - |
192 | | - pass |
193 | | - |
194 | | - |
195 | | -class _DualPoolAutoMapping(_DualPoolExpertMixin, AutoMapping): |
196 | | - """AutoMapping with pool-offset-aware EP export.""" |
| 171 | + weights_dict[param_name] = gathered_weights[i].unsqueeze(0) |
| 172 | + for param_name in weights_dict: |
| 173 | + weights_dict[param_name] = weights_dict[param_name].squeeze() |
| 174 | + return weights_dict |
| 175 | + |
| 176 | + |
| 177 | +def _resolve_with_offset( |
| 178 | + megatron_pattern: str, |
| 179 | + hf_pattern, |
| 180 | + captures: Tuple[str, ...], |
| 181 | + expert_offset: int, |
| 182 | +) -> Tuple[str, ...]: |
| 183 | + """Resolve wildcard captures, shifting the 2nd capture (expert index) for HF side.""" |
| 184 | + if expert_offset and len(captures) >= 2: |
| 185 | + shifted_expert = str(int(captures[1]) + expert_offset) |
| 186 | + hf_captures = (captures[0], shifted_expert) + captures[2:] |
| 187 | + else: |
| 188 | + hf_captures = captures |
| 189 | + |
| 190 | + resolved_megatron = megatron_pattern |
| 191 | + idx = 0 |
| 192 | + while "**" in resolved_megatron and idx < len(captures): |
| 193 | + resolved_megatron = resolved_megatron.replace("**", captures[idx], 1) |
| 194 | + idx += 1 |
| 195 | + while "*" in resolved_megatron and idx < len(captures): |
| 196 | + resolved_megatron = resolved_megatron.replace("*", captures[idx], 1) |
| 197 | + idx += 1 |
| 198 | + |
| 199 | + if isinstance(hf_pattern, dict): |
| 200 | + resolved_hf: dict | str = {} |
| 201 | + for k, v in hf_pattern.items(): |
| 202 | + resolved_v = v |
| 203 | + idx = 0 |
| 204 | + while "**" in resolved_v and idx < len(hf_captures): |
| 205 | + resolved_v = resolved_v.replace("**", hf_captures[idx], 1) |
| 206 | + idx += 1 |
| 207 | + while "*" in resolved_v and idx < len(hf_captures): |
| 208 | + resolved_v = resolved_v.replace("*", hf_captures[idx], 1) |
| 209 | + idx += 1 |
| 210 | + resolved_hf[k] = resolved_v |
| 211 | + else: |
| 212 | + resolved_hf = hf_pattern |
| 213 | + idx = 0 |
| 214 | + while "**" in resolved_hf and idx < len(hf_captures): |
| 215 | + resolved_hf = resolved_hf.replace("**", hf_captures[idx], 1) |
| 216 | + idx += 1 |
| 217 | + while "*" in resolved_hf and idx < len(hf_captures): |
| 218 | + resolved_hf = resolved_hf.replace("*", hf_captures[idx], 1) |
| 219 | + idx += 1 |
197 | 220 |
|
198 | | - pass |
| 221 | + return resolved_megatron, resolved_hf |
199 | 222 |
|
200 | 223 |
|
201 | 224 | class _OffsetGatedMLPMapping(GatedMLPMapping): |
202 | | - """GatedMLPMapping that adds a fixed offset to the expert index wildcard. |
| 225 | + """GatedMLPMapping with expert index offset for vision pool. |
203 | 226 |
|
204 | | - Used for vision MoE experts where Megatron expert index j maps to |
205 | | - HF flat expert index (j + offset). The offset is the number of text |
206 | | - experts, so vision expert 0 maps to HF expert N, etc. |
207 | | -
|
208 | | - The Megatron pattern has two ``*`` wildcards: layer index and expert index. |
209 | | - The HF patterns also have two ``*`` wildcards. During ``resolve()``, the |
210 | | - second capture (expert index) is shifted by ``expert_offset`` **only** for |
211 | | - HF patterns. The Megatron side keeps the original (unshifted) expert index. |
| 227 | + Handles both directions: |
| 228 | + - resolve(): shifts expert index for HF side only |
| 229 | + - gather_from_ep_ranks(): reconstructs offset HF indices during EP export |
212 | 230 | """ |
213 | 231 |
|
214 | 232 | def __init__(self, megatron_param: str, gate: str, up: str, expert_offset: int = 0): |
215 | 233 | super().__init__(megatron_param=megatron_param, gate=gate, up=up) |
216 | 234 | self._expert_offset = expert_offset |
217 | 235 |
|
218 | 236 | def resolve(self, captures: Tuple[str, ...]): |
219 | | - """Override resolve to apply the expert index offset only to HF side.""" |
220 | | - if self._expert_offset and len(captures) >= 2: |
221 | | - # Use original captures for Megatron, shifted for HF |
222 | | - shifted_expert = str(int(captures[1]) + self._expert_offset) |
223 | | - hf_captures = (captures[0], shifted_expert) + captures[2:] |
224 | | - else: |
225 | | - hf_captures = captures |
226 | | - # Resolve Megatron side with original captures |
227 | | - resolved_megatron_param = self.megatron_param |
228 | | - idx = 0 |
229 | | - while "**" in resolved_megatron_param and idx < len(captures): |
230 | | - resolved_megatron_param = resolved_megatron_param.replace("**", captures[idx], 1) |
231 | | - idx += 1 |
232 | | - while "*" in resolved_megatron_param and idx < len(captures): |
233 | | - resolved_megatron_param = resolved_megatron_param.replace("*", captures[idx], 1) |
234 | | - idx += 1 |
235 | | - # Resolve HF side with shifted captures |
236 | | - resolved_hf_param = {} |
237 | | - for k, v in self.hf_param.items(): |
238 | | - resolved_v = v |
239 | | - idx = 0 |
240 | | - while "**" in resolved_v and idx < len(hf_captures): |
241 | | - resolved_v = resolved_v.replace("**", hf_captures[idx], 1) |
242 | | - idx += 1 |
243 | | - while "*" in resolved_v and idx < len(hf_captures): |
244 | | - resolved_v = resolved_v.replace("*", hf_captures[idx], 1) |
245 | | - idx += 1 |
246 | | - resolved_hf_param[k] = resolved_v |
247 | | - return _DualPoolGatedMLPMapping( |
248 | | - megatron_param=resolved_megatron_param, |
249 | | - gate=resolved_hf_param["gate"], |
250 | | - up=resolved_hf_param["up"], |
| 237 | + resolved_megatron, resolved_hf = _resolve_with_offset( |
| 238 | + self.megatron_param, |
| 239 | + self.hf_param, |
| 240 | + captures, |
| 241 | + self._expert_offset, |
| 242 | + ) |
| 243 | + return _OffsetGatedMLPMapping( |
| 244 | + megatron_param=resolved_megatron, |
| 245 | + gate=resolved_hf["gate"], |
| 246 | + up=resolved_hf["up"], |
| 247 | + expert_offset=self._expert_offset, |
251 | 248 | ) |
252 | 249 |
|
| 250 | + def gather_from_ep_ranks(self, megatron_weights, megatron_module, hf_param_name=None): |
| 251 | + return _offset_gather_from_ep_ranks(self, megatron_weights, megatron_module, hf_param_name) |
253 | 252 |
|
254 | | -class _OffsetAutoMapping(AutoMapping): |
255 | | - """AutoMapping that adds a fixed offset to the expert index wildcard. |
256 | 253 |
|
257 | | - Same concept as _OffsetGatedMLPMapping but for simple 1:1 mappings |
258 | | - (e.g. down_proj). The offset is applied only to HF side, not Megatron. |
| 254 | +class _OffsetRowParallelMapping(RowParallelMapping): |
| 255 | + """RowParallelMapping with expert index offset for vision pool. |
| 256 | +
|
| 257 | + Used for vision expert down_proj (linear_fc2), which is always |
| 258 | + row-parallel in SequentialMLP. Using explicit RowParallelMapping |
| 259 | + avoids the AutoMapping delegation issue where the delegate's |
| 260 | + gather_from_ep_ranks bypasses offset logic. |
259 | 261 | """ |
260 | 262 |
|
261 | 263 | def __init__(self, megatron_param: str, hf_param: str, expert_offset: int = 0): |
262 | 264 | super().__init__(megatron_param=megatron_param, hf_param=hf_param) |
263 | 265 | self._expert_offset = expert_offset |
264 | 266 |
|
265 | 267 | def resolve(self, captures: Tuple[str, ...]): |
266 | | - """Override resolve to apply the expert index offset only to HF side.""" |
267 | | - if self._expert_offset and len(captures) >= 2: |
268 | | - shifted_expert = str(int(captures[1]) + self._expert_offset) |
269 | | - hf_captures = (captures[0], shifted_expert) + captures[2:] |
270 | | - else: |
271 | | - hf_captures = captures |
272 | | - # Resolve Megatron side with original captures |
273 | | - resolved_megatron_param = self.megatron_param |
274 | | - idx = 0 |
275 | | - while "**" in resolved_megatron_param and idx < len(captures): |
276 | | - resolved_megatron_param = resolved_megatron_param.replace("**", captures[idx], 1) |
277 | | - idx += 1 |
278 | | - while "*" in resolved_megatron_param and idx < len(captures): |
279 | | - resolved_megatron_param = resolved_megatron_param.replace("*", captures[idx], 1) |
280 | | - idx += 1 |
281 | | - # Resolve HF side with shifted captures |
282 | | - resolved_hf_param = self.hf_param |
283 | | - idx = 0 |
284 | | - while "**" in resolved_hf_param and idx < len(hf_captures): |
285 | | - resolved_hf_param = resolved_hf_param.replace("**", hf_captures[idx], 1) |
286 | | - idx += 1 |
287 | | - while "*" in resolved_hf_param and idx < len(hf_captures): |
288 | | - resolved_hf_param = resolved_hf_param.replace("*", hf_captures[idx], 1) |
289 | | - idx += 1 |
290 | | - return _DualPoolAutoMapping( |
291 | | - megatron_param=resolved_megatron_param, |
292 | | - hf_param=resolved_hf_param, |
| 268 | + resolved_megatron, resolved_hf = _resolve_with_offset( |
| 269 | + self.megatron_param, |
| 270 | + self.hf_param, |
| 271 | + captures, |
| 272 | + self._expert_offset, |
| 273 | + ) |
| 274 | + return _OffsetRowParallelMapping( |
| 275 | + megatron_param=resolved_megatron, |
| 276 | + hf_param=resolved_hf, |
| 277 | + expert_offset=self._expert_offset, |
293 | 278 | ) |
294 | 279 |
|
| 280 | + def gather_from_ep_ranks(self, megatron_weights, megatron_module, hf_param_name=None): |
| 281 | + return _offset_gather_from_ep_ranks(self, megatron_weights, megatron_module, hf_param_name) |
| 282 | + |
295 | 283 |
|
296 | 284 | class _ConcatBiasMapping(AutoMapping): |
297 | 285 | """Mapping for the concatenated text+vision expert bias tensor. |
@@ -881,7 +869,7 @@ def mapping_registry(self) -> MegatronMappingRegistry: |
881 | 869 | up="model.layers.*.mlp.experts.*.up_proj.weight", |
882 | 870 | expert_offset=num_experts, |
883 | 871 | ), |
884 | | - _OffsetAutoMapping( |
| 872 | + _OffsetRowParallelMapping( |
885 | 873 | megatron_param=( |
886 | 874 | "language_model.decoder.layers.*.mlp.vision_moe_layer" |
887 | 875 | ".experts.local_experts.*.linear_fc2.weight" |
|
0 commit comments