@@ -118,11 +118,20 @@ def _get_model_mappings(
118118 Raises:
119119 ValueError: If mappings for the specified `model_name` are not found.
120120 """
121- if model_name not in PARAM_MAPPING or model_name not in HF_SHAPE or model_name not in HOOK_FNS :
122- raise ValueError (f"Mappings not found for model: { model_name } . Available PARAM_MAPPING keys: { PARAM_MAPPING .keys ()} " )
121+ target_model_name = model_name
122+ if target_model_name not in PARAM_MAPPING and target_model_name .endswith ("-Instruct" ):
123+ base_name = target_model_name [: - len ("-Instruct" )]
124+ if base_name in HF_IDS and base_name in PARAM_MAPPING :
125+ target_model_name = base_name
126+ max_logging .log (f"Mappings for '{ model_name } ' not found, falling back to base model '{ target_model_name } '." )
127+
128+ if target_model_name not in PARAM_MAPPING or target_model_name not in HF_SHAPE or target_model_name not in HOOK_FNS :
129+ raise ValueError (
130+ f"Mappings not found for model: { model_name } (resolved to { target_model_name } ). Available PARAM_MAPPING keys: { PARAM_MAPPING .keys ()} "
131+ )
123132
124- param_mapping = PARAM_MAPPING [model_name ](hf_config_dict , maxtext_config , scan_layers )
125- hook_fn_mapping = HOOK_FNS [model_name ](hf_config_dict , maxtext_config , scan_layers , saving_to_hf = True )
133+ param_mapping = PARAM_MAPPING [target_model_name ](hf_config_dict , maxtext_config , scan_layers )
134+ hook_fn_mapping = HOOK_FNS [target_model_name ](hf_config_dict , maxtext_config , scan_layers , saving_to_hf = True )
126135
127136 # Promote composite hook keys into param_mapping.
128137 # If HOOK_FNS defines a composite tuple key (e.g., (wi_0, wi_1) for MoE gate_up_proj),
@@ -138,7 +147,7 @@ def _get_model_mappings(
138147
139148 return {
140149 "param_mapping" : param_mapping ,
141- "shape_mapping" : HF_SHAPE [model_name ](hf_config_dict ),
150+ "shape_mapping" : HF_SHAPE [target_model_name ](hf_config_dict ),
142151 "hook_fn_mapping" : hook_fn_mapping ,
143152 }
144153
@@ -228,7 +237,7 @@ def _validate_or_update_architecture(hf_config, max_config, override: bool):
228237
229238 if not is_match :
230239 if override :
231- max_logging .log (f"⚠️ Overwriting HF Config '{ hf_attr } ': { hf_value } -> { mt_value } (from MaxText '{ mt_attr } ')" )
240+ max_logging .log (f"Overwriting HF Config '{ hf_attr } ': { hf_value } -> { mt_value } (from MaxText '{ mt_attr } ')" )
232241 setattr (hf_config , hf_attr , mt_value )
233242 else :
234243 mismatches .append (f"{ hf_attr } (HF={ hf_value } vs MaxText={ mt_value } )" )
@@ -275,9 +284,18 @@ def main(argv: Sequence[str]) -> None:
275284
276285 # 1. Get HuggingFace Model Configuration
277286 model_key = config .model_name
287+
278288 if model_key not in HF_IDS :
279289 raise ValueError (f"Unsupported model name: { config .model_name } . Supported models are: { list (HF_IDS .keys ())} " )
280- hf_config_obj = HF_MODEL_CONFIGS [model_key ]
290+
291+ config_key = model_key
292+ if config_key not in HF_MODEL_CONFIGS and config_key .endswith ("-Instruct" ):
293+ base_key = config_key [: - len ("-Instruct" )]
294+ if base_key in HF_MODEL_CONFIGS :
295+ max_logging .log (f"⚠️ Config for '{ config_key } ' not found, falling back to base '{ base_key } '." )
296+ config_key = base_key
297+
298+ hf_config_obj = HF_MODEL_CONFIGS [config_key ]
281299
282300 # Validate architecture consistency (raising ValueError on mismatch) or override HF config if specified.
283301 _validate_or_update_architecture (hf_config_obj , config , override = FLAGS .override_model_architecture )
0 commit comments