@@ -100,11 +100,7 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider
100100 return provider
101101
102102 def build_conversion_tasks (self , hf_pretrained , megatron_model ):
103- """Override to store config before mapping_registry is called."""
104- from transformers import PretrainedConfig
105-
106- # Store config on instance for use in mapping_registry
107- self ._hf_config = hf_pretrained if isinstance (hf_pretrained , PretrainedConfig ) else hf_pretrained .config
103+ """Override to store HF state source before mapping_registry is called."""
108104 has_state = hasattr (hf_pretrained , "state" ) and hasattr (hf_pretrained .state , "source" )
109105 self ._hf_state_source = hf_pretrained .state .source if has_state else None
110106 self ._hf_keys = list (self ._hf_state_source .get_all_keys ()) if self ._hf_state_source else None
@@ -208,10 +204,10 @@ def mapping_registry(self) -> MegatronMappingRegistry:
208204 ]
209205 )
210206 # optionally add MTP mappings
211- if not hasattr ( self , "_hf_config" ) :
207+ if self . hf_config is None :
212208 logger .warning ("No HF config found, skipping MTP mappings." )
213209 return MegatronMappingRegistry (* mapping_list )
214- hf_config = self ._hf_config
210+ hf_config = self .hf_config
215211 num_mtp_layers = getattr (hf_config , "num_nextn_predict_layers" , 0 )
216212 num_transformer_layers = hf_config .num_hidden_layers
217213 for mtp_layer in range (num_mtp_layers ):
0 commit comments