Skip to content

Commit 55b57ff

Browse files
Merge pull request #3094 from AI-Hypercomputer:agagik-to-hf
PiperOrigin-RevId: 868328440
2 parents 86ef13f + 0342fb2 commit 55b57ff

1 file changed

Lines changed: 83 additions & 0 deletions

File tree

src/MaxText/utils/ckpt_conversion/to_huggingface.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
scan_layers: (bool) Whether the MaxText model was trained with scanned layers.
3030
This must match the training configuration of the checkpoint.
3131
32+
Optional Flags:
33+
--override_model_architecture: If set, overrides the HF model configuration
34+
with values from the MaxText configuration
35+
(e.g., num_heads, hidden_size) instead of failing.
36+
3237
Environment Variables:
3338
HF_AUTH_TOKEN: (Required) A HuggingFace authentication token. This is needed
3439
to download the correct tokenizer configuration and to upload
@@ -59,6 +64,7 @@
5964
from transformers import AutoTokenizer, AutoProcessor
6065

6166
from absl import app
67+
from absl import flags
6268

6369
from MaxText import pyconfig
6470
from MaxText.utils.ckpt_conversion.utils.param_mapping import (
@@ -80,6 +86,15 @@
8086
from maxtext.utils import max_logging
8187
from maxtext.utils import max_utils
8288

89+
flags.DEFINE_bool(
90+
"override_model_architecture",
91+
False,
92+
"If True, overrides Hugging Face config architecture parameters (heads, layers, dims) "
93+
"with values from the MaxText config. If False, raises a ValueError on mismatch.",
94+
)
95+
96+
FLAGS = flags.FLAGS
97+
8398

8499
def _get_model_mappings(
85100
model_name: str, scan_layers: bool, hf_config_dict: dict, maxtext_config: pyconfig.HyperParameters
@@ -109,6 +124,71 @@ def _get_model_mappings(
109124
}
110125

111126

127+
def _validate_or_update_architecture(hf_config, max_config, override: bool):
128+
"""Validates consistency between HF and MaxText configs or overrides HF config if requested.
129+
130+
Args:
131+
hf_config: The Hugging Face configuration object.
132+
max_config: The MaxText configuration object (HyperParameters).
133+
override: Boolean, if True, update hf_config with max_config values.
134+
If False, raise error on mismatch.
135+
"""
136+
# Mapping from Hugging Face config attribute -> MaxText config attribute
137+
# Note: We use derived MaxText attributes (e.g. emb_dim) which account for scale factors.
138+
attributes_to_check = [
139+
("num_attention_heads", "num_query_heads"),
140+
("num_key_value_heads", "num_kv_heads"),
141+
("head_dim", "head_dim"),
142+
("hidden_size", "emb_dim"),
143+
("intermediate_size", "mlp_dim"),
144+
("num_hidden_layers", "num_decoder_layers"),
145+
("vocab_size", "vocab_size"),
146+
]
147+
148+
mismatches = []
149+
150+
for hf_attr, mt_attr in attributes_to_check:
151+
# Skip checks if the HF config doesn't have this attribute (e.g. layer_norm_eps vs rms_norm_eps)
152+
if not hasattr(hf_config, hf_attr):
153+
continue
154+
155+
# Skip checks if MaxText config doesn't have the attribute (shouldn't happen for valid configs)
156+
if not hasattr(max_config, mt_attr):
157+
continue
158+
159+
hf_value = getattr(hf_config, hf_attr)
160+
mt_value = getattr(max_config, mt_attr)
161+
162+
# Handle None values
163+
if hf_value is None or mt_value is None:
164+
continue
165+
166+
# Compare values (with tolerance for floats)
167+
is_match = False
168+
if isinstance(hf_value, float) or isinstance(mt_value, float):
169+
try:
170+
is_match = abs(float(hf_value) - float(mt_value)) < 1e-6
171+
except (ValueError, TypeError):
172+
is_match = hf_value == mt_value
173+
else:
174+
is_match = hf_value == mt_value
175+
176+
if not is_match:
177+
if override:
178+
max_logging.log(f"⚠️ Overwriting HF Config '{hf_attr}': {hf_value} -> {mt_value} (from MaxText '{mt_attr}')")
179+
setattr(hf_config, hf_attr, mt_value)
180+
else:
181+
mismatches.append(f"{hf_attr} (HF={hf_value} vs MaxText={mt_value})")
182+
183+
if mismatches:
184+
error_msg = (
185+
"Architecture mismatches detected between standard Hugging Face config and provided MaxText config:\n - "
186+
+ "\n - ".join(mismatches)
187+
+ "\n\nAction Required: Pass the flag `--override_model_architecture` to force the conversion using MaxText values."
188+
)
189+
raise ValueError(error_msg)
190+
191+
112192
def main(argv: Sequence[str]) -> None:
113193
"""Main function to convert a MaxText checkpoint to HuggingFace format.
114194
@@ -151,6 +231,9 @@ def main(argv: Sequence[str]) -> None:
151231
raise ValueError(f"Unsupported model name: {config.model_name}. Supported models are: {list(HF_IDS.keys())}")
152232
hf_config_obj = HF_MODEL_CONFIGS[model_key]
153233

234+
# Validate architecture consistency (raising ValueError on mismatch) or override HF config if specified.
235+
_validate_or_update_architecture(hf_config_obj, config, override=FLAGS.override_model_architecture)
236+
154237
# 2. Load Tokenizer
155238
if model_key not in HF_IDS:
156239
raise ValueError(f"HF Tokenizer ID not found for model key: {model_key}")

0 commit comments

Comments
 (0)