@@ -127,6 +127,17 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> "Gemma4ModelProv
127127 self ._is_dense = False
128128 return self ._build_moe_provider (hf_config )
129129
130+ def _text_config (self ) -> Any | None :
131+ """Return the text config used to dispatch dense vs MoE behavior."""
132+ return getattr (self , "hf_config" , None )
133+
134+ def _is_dense_config (self ) -> bool :
135+ """Return whether the current HF config describes a dense Gemma 4 model."""
136+ if getattr (self , "_is_dense" , False ):
137+ return True
138+ text_config = self ._text_config ()
139+ return text_config is not None and not getattr (text_config , "enable_moe_block" , False )
140+
130141 def _build_dense_provider (self , hf_config ) -> Gemma4DenseProvider :
131142 """Build a Gemma4DenseProvider from HF config."""
132143 rope_params = getattr (hf_config , "rope_parameters" , {}) or {}
@@ -269,13 +280,24 @@ def maybe_modify_loaded_hf_weight(
269280
270281 if k_name not in hf_state_dict and v_name not in hf_state_dict :
271282 q_weight = hf_state_dict [q_name ]
272- num_q_heads = getattr (self , "_dense_num_attention_heads" , 8 )
273- kv_head_dim = q_weight .shape [0 ] // num_q_heads
274- num_kv_heads = getattr (
275- self ,
276- "_dense_num_global_query_groups" ,
277- getattr (self , "_dense_num_query_groups" , 2 ),
283+ text_config = self ._text_config ()
284+ num_q_heads = getattr (
285+ text_config , "num_attention_heads" , getattr (self , "_dense_num_attention_heads" , 8 )
278286 )
287+ kv_head_dim = q_weight .shape [0 ] // num_q_heads
288+ num_kv_heads = getattr (text_config , "num_key_value_heads" , getattr (self , "_dense_num_query_groups" , 2 ))
289+ layer_match = re .search (r"layers\.(\d+)\." , q_name )
290+ layer_types = getattr (text_config , "layer_types" , None )
291+ if layer_match and layer_types :
292+ layer_idx = int (layer_match .group (1 ))
293+ if layer_idx < len (layer_types ) and layer_types [layer_idx ] == "full_attention" :
294+ num_kv_heads = getattr (
295+ text_config ,
296+ "num_global_key_value_heads" ,
297+ getattr (self , "_dense_num_global_query_groups" , num_kv_heads ),
298+ )
299+ elif hasattr (self , "_dense_num_global_query_groups" ):
300+ num_kv_heads = self ._dense_num_global_query_groups
279301 kv_shape = (num_kv_heads * kv_head_dim , q_weight .shape [1 ])
280302 k_zero = torch .zeros (kv_shape , dtype = q_weight .dtype , device = q_weight .device )
281303 return {"q" : q_weight , "k" : k_zero , "v" : torch .zeros_like (k_zero )}
@@ -340,7 +362,7 @@ def _fuse_shared_expert_prenorm(
340362 return hf_weights
341363
342364 def mapping_registry (self ) -> MegatronMappingRegistry :
343- if getattr ( self , "_is_dense" , False ):
365+ if self . _is_dense_config ( ):
344366 return self ._dense_mapping_registry ()
345367 return self ._moe_mapping_registry ()
346368
0 commit comments