Skip to content

Commit 48c31b8

Browse files
committed
Support -Instruct models on to_huggingface
1 parent f44b423 commit 48c31b8

1 file changed

Lines changed: 25 additions & 7 deletions

File tree

src/maxtext/checkpoint_conversion/to_huggingface.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,20 @@ def _get_model_mappings(
118118
Raises:
119119
ValueError: If mappings for the specified `model_name` are not found.
120120
"""
121-
if model_name not in PARAM_MAPPING or model_name not in HF_SHAPE or model_name not in HOOK_FNS:
122-
raise ValueError(f"Mappings not found for model: {model_name}. Available PARAM_MAPPING keys: {PARAM_MAPPING.keys()}")
121+
target_model_name = model_name
122+
if target_model_name not in PARAM_MAPPING and target_model_name.endswith("-Instruct"):
123+
base_name = target_model_name[: -len("-Instruct")]
124+
if base_name in HF_IDS and base_name in PARAM_MAPPING:
125+
target_model_name = base_name
126+
max_logging.log(f"Mappings for '{model_name}' not found, falling back to base model '{target_model_name}'.")
127+
128+
if target_model_name not in PARAM_MAPPING or target_model_name not in HF_SHAPE or target_model_name not in HOOK_FNS:
129+
raise ValueError(
130+
f"Mappings not found for model: {model_name} (resolved to {target_model_name}). Available PARAM_MAPPING keys: {PARAM_MAPPING.keys()}"
131+
)
123132

124-
param_mapping = PARAM_MAPPING[model_name](hf_config_dict, maxtext_config, scan_layers)
125-
hook_fn_mapping = HOOK_FNS[model_name](hf_config_dict, maxtext_config, scan_layers, saving_to_hf=True)
133+
param_mapping = PARAM_MAPPING[target_model_name](hf_config_dict, maxtext_config, scan_layers)
134+
hook_fn_mapping = HOOK_FNS[target_model_name](hf_config_dict, maxtext_config, scan_layers, saving_to_hf=True)
126135

127136
# Promote composite hook keys into param_mapping.
128137
# If HOOK_FNS defines a composite tuple key (e.g., (wi_0, wi_1) for MoE gate_up_proj),
@@ -138,7 +147,7 @@ def _get_model_mappings(
138147

139148
return {
140149
"param_mapping": param_mapping,
141-
"shape_mapping": HF_SHAPE[model_name](hf_config_dict),
150+
"shape_mapping": HF_SHAPE[target_model_name](hf_config_dict),
142151
"hook_fn_mapping": hook_fn_mapping,
143152
}
144153

@@ -228,7 +237,7 @@ def _validate_or_update_architecture(hf_config, max_config, override: bool):
228237

229238
if not is_match:
230239
if override:
231-
max_logging.log(f"⚠️ Overwriting HF Config '{hf_attr}': {hf_value} -> {mt_value} (from MaxText '{mt_attr}')")
240+
max_logging.log(f"Overwriting HF Config '{hf_attr}': {hf_value} -> {mt_value} (from MaxText '{mt_attr}')")
232241
setattr(hf_config, hf_attr, mt_value)
233242
else:
234243
mismatches.append(f"{hf_attr} (HF={hf_value} vs MaxText={mt_value})")
@@ -275,9 +284,18 @@ def main(argv: Sequence[str]) -> None:
275284

276285
# 1. Get HuggingFace Model Configuration
277286
model_key = config.model_name
287+
278288
if model_key not in HF_IDS:
279289
raise ValueError(f"Unsupported model name: {config.model_name}. Supported models are: {list(HF_IDS.keys())}")
280-
hf_config_obj = HF_MODEL_CONFIGS[model_key]
290+
291+
config_key = model_key
292+
if config_key not in HF_MODEL_CONFIGS and config_key.endswith("-Instruct"):
293+
base_key = config_key[: -len("-Instruct")]
294+
if base_key in HF_MODEL_CONFIGS:
295+
max_logging.log(f"⚠️ Config for '{config_key}' not found, falling back to base '{base_key}'.")
296+
config_key = base_key
297+
298+
hf_config_obj = HF_MODEL_CONFIGS[config_key]
281299

282300
# Validate architecture consistency (raising ValueError on mismatch) or override HF config if specified.
283301
_validate_or_update_architecture(hf_config_obj, config, override=FLAGS.override_model_architecture)

0 commit comments

Comments
 (0)