Skip to content

Commit 394a7cf

Browse files
pavelgeinyaoyu-33
andauthored
[model] fix: use hf_config to check whether model is dense (#4414)
Signed-off-by: Pavel Gein <pavel.gein@gmail.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Co-authored-by: yaoyu-33 <yaoyu.094@gmail.com>
1 parent f4b1dff commit 394a7cf

3 files changed

Lines changed: 73 additions & 40 deletions

File tree

src/megatron/bridge/models/gemma/gemma4_bridge.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,17 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> "Gemma4ModelProv
127127
self._is_dense = False
128128
return self._build_moe_provider(hf_config)
129129

130+
def _text_config(self) -> Any | None:
131+
"""Return the text config used to dispatch dense vs MoE behavior."""
132+
return getattr(self, "hf_config", None)
133+
134+
def _is_dense_config(self) -> bool:
135+
"""Return whether the current HF config describes a dense Gemma 4 model."""
136+
if getattr(self, "_is_dense", False):
137+
return True
138+
text_config = self._text_config()
139+
return text_config is not None and not getattr(text_config, "enable_moe_block", False)
140+
130141
def _build_dense_provider(self, hf_config) -> Gemma4DenseProvider:
131142
"""Build a Gemma4DenseProvider from HF config."""
132143
rope_params = getattr(hf_config, "rope_parameters", {}) or {}
@@ -269,13 +280,24 @@ def maybe_modify_loaded_hf_weight(
269280

270281
if k_name not in hf_state_dict and v_name not in hf_state_dict:
271282
q_weight = hf_state_dict[q_name]
272-
num_q_heads = getattr(self, "_dense_num_attention_heads", 8)
273-
kv_head_dim = q_weight.shape[0] // num_q_heads
274-
num_kv_heads = getattr(
275-
self,
276-
"_dense_num_global_query_groups",
277-
getattr(self, "_dense_num_query_groups", 2),
283+
text_config = self._text_config()
284+
num_q_heads = getattr(
285+
text_config, "num_attention_heads", getattr(self, "_dense_num_attention_heads", 8)
278286
)
287+
kv_head_dim = q_weight.shape[0] // num_q_heads
288+
num_kv_heads = getattr(text_config, "num_key_value_heads", getattr(self, "_dense_num_query_groups", 2))
289+
layer_match = re.search(r"layers\.(\d+)\.", q_name)
290+
layer_types = getattr(text_config, "layer_types", None)
291+
if layer_match and layer_types:
292+
layer_idx = int(layer_match.group(1))
293+
if layer_idx < len(layer_types) and layer_types[layer_idx] == "full_attention":
294+
num_kv_heads = getattr(
295+
text_config,
296+
"num_global_key_value_heads",
297+
getattr(self, "_dense_num_global_query_groups", num_kv_heads),
298+
)
299+
elif hasattr(self, "_dense_num_global_query_groups"):
300+
num_kv_heads = self._dense_num_global_query_groups
279301
kv_shape = (num_kv_heads * kv_head_dim, q_weight.shape[1])
280302
k_zero = torch.zeros(kv_shape, dtype=q_weight.dtype, device=q_weight.device)
281303
return {"q": q_weight, "k": k_zero, "v": torch.zeros_like(k_zero)}
@@ -340,7 +362,7 @@ def _fuse_shared_expert_prenorm(
340362
return hf_weights
341363

342364
def mapping_registry(self) -> MegatronMappingRegistry:
343-
if getattr(self, "_is_dense", False):
365+
if self._is_dense_config():
344366
return self._dense_mapping_registry()
345367
return self._moe_mapping_registry()
346368

src/megatron/bridge/models/gemma_vl/gemma4_vl_bridge.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,6 @@ def _text_config(self):
184184
hf_config = getattr(self, "hf_config", None)
185185
return getattr(hf_config, "text_config", None)
186186

187-
def _is_dense_e4b_config(self) -> bool:
188-
if getattr(self, "_is_dense", False):
189-
return True
190-
text_config = self._text_config()
191-
return text_config is not None and not getattr(text_config, "enable_moe_block", True)
192-
193187
def _hf_layer_prefix(self) -> str:
194188
"""VLM text weights live under ``model.language_model.*``."""
195189
return "model.language_model."
@@ -238,35 +232,9 @@ def _fuse_shared_expert_prenorm(
238232
hf_weights[role] = fused.to(weight.dtype)
239233
return hf_weights
240234

241-
def maybe_modify_loaded_hf_weight(
242-
self, hf_param: str | dict[str, str], hf_state_dict: Mapping[str, torch.Tensor]
243-
) -> torch.Tensor:
244-
"""Handle special weight loading for Gemma 4 VLM."""
245-
if self._is_dense_e4b_config() and isinstance(hf_param, dict) and "v" in hf_param:
246-
k_name = hf_param["k"]
247-
v_name = hf_param["v"]
248-
q_name = hf_param["q"]
249-
if k_name not in hf_state_dict and v_name not in hf_state_dict:
250-
q_weight = hf_state_dict[q_name]
251-
text_config = self._text_config()
252-
num_q_heads = getattr(text_config, "num_attention_heads", 8)
253-
num_kv_heads = getattr(text_config, "num_key_value_heads", 2)
254-
layer_match = re.search(r"layers\.(\d+)\.", q_name)
255-
layer_types = getattr(text_config, "layer_types", None)
256-
if layer_match and layer_types:
257-
layer_idx = int(layer_match.group(1))
258-
if layer_idx < len(layer_types) and layer_types[layer_idx] == "full_attention":
259-
num_kv_heads = getattr(text_config, "num_global_key_value_heads", num_kv_heads)
260-
kv_head_dim = q_weight.shape[0] // num_q_heads
261-
kv_shape = (num_kv_heads * kv_head_dim, q_weight.shape[1])
262-
k_zero = torch.zeros(kv_shape, dtype=q_weight.dtype, device=q_weight.device)
263-
return {"q": q_weight, "k": k_zero, "v": torch.zeros_like(k_zero)}
264-
265-
return super().maybe_modify_loaded_hf_weight(hf_param, hf_state_dict)
266-
267235
def mapping_registry(self) -> MegatronMappingRegistry:
268236
"""Dispatch to Dense or MoE VLM mappings."""
269-
if self._is_dense_e4b_config():
237+
if self._is_dense_config():
270238
if self._conversion_mode() == "text":
271239
return self._dense_mapping_registry(megatron_prefix="")
272240
return self._dense_vl_mapping_registry()

tests/unit_tests/models/gemma/test_gemma4_bridge.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,25 @@ def test_kv_synthesis_uses_dense_provider_head_metadata(self, bridge, mock_pretr
308308
assert result["k"].shape == (4, 8)
309309
assert result["v"].shape == (4, 8)
310310

311+
def test_kv_synthesis_uses_hf_config_without_provider_bridge(self, bridge, mock_hf_config_dense):
312+
bridge.hf_config = mock_hf_config_dense
313+
mock_hf_config_dense.num_attention_heads = 6
314+
mock_hf_config_dense.num_key_value_heads = 5
315+
mock_hf_config_dense.num_global_key_value_heads = 3
316+
mock_hf_config_dense.layer_types = ["full_attention"]
317+
q_weight = torch.randn(24, 8)
318+
sd = {"model.layers.0.self_attn.q_proj.weight": q_weight}
319+
hf_param = {
320+
"q": "model.layers.0.self_attn.q_proj.weight",
321+
"k": "model.layers.0.self_attn.k_proj.weight",
322+
"v": "model.layers.0.self_attn.v_proj.weight",
323+
}
324+
325+
result = bridge.maybe_modify_loaded_hf_weight(hf_param, sd)
326+
327+
assert result["k"].shape == (12, 8)
328+
assert result["v"].shape == (12, 8)
329+
311330
def test_kv_passthrough_when_v_present(self, bridge):
312331
sd = self._make_sd()
313332
sd["model.layers.0.self_attn.v_proj.weight"] = torch.randn(4, 8)
@@ -486,6 +505,30 @@ def test_has_post_moe_layernorm(self, bridge):
486505
names = self._collect_names(bridge.mapping_registry())
487506
assert any("post_moe_layernorm" in n for n in names)
488507

508+
def test_selects_dense_registry_from_hf_config_without_provider_bridge(self, bridge, mock_hf_config_dense):
509+
bridge.hf_config = mock_hf_config_dense
510+
511+
names = self._collect_names(bridge.mapping_registry())
512+
513+
assert "per_layer_embedding.weight" in names
514+
assert "decoder.layers.*.mlp.router.weight" not in names
515+
516+
def test_selects_dense_registry_when_enable_moe_block_missing(self, bridge):
517+
bridge.hf_config = Mock(spec=[])
518+
519+
names = self._collect_names(bridge.mapping_registry())
520+
521+
assert "per_layer_embedding.weight" in names
522+
assert "decoder.layers.*.mlp.router.weight" not in names
523+
524+
def test_selects_moe_registry_from_hf_config_without_provider_bridge(self, bridge, mock_hf_config_moe):
525+
bridge.hf_config = mock_hf_config_moe
526+
527+
names = self._collect_names(bridge.mapping_registry())
528+
529+
assert "decoder.layers.*.mlp.router.weight" in names
530+
assert "per_layer_embedding.weight" not in names
531+
489532
def test_has_layer_scalar_mapping(self, bridge):
490533
names = self._collect_names(bridge.mapping_registry())
491534
assert any("layer_scalar" in n for n in names)

0 commit comments

Comments
 (0)