|
29 | 29 | scan_layers: (bool) Whether the MaxText model was trained with scanned layers. |
30 | 30 | This must match the training configuration of the checkpoint. |
31 | 31 |
|
| 32 | +Optional Flags: |
| 33 | + --override_model_architecture: If set, overrides the HF model configuration |
| 34 | + with values from the MaxText configuration |
| 35 | + (e.g., num_heads, hidden_size) instead of failing. |
| 36 | +
|
32 | 37 | Environment Variables: |
33 | 38 | HF_AUTH_TOKEN: (Required) A HuggingFace authentication token. This is needed |
34 | 39 | to download the correct tokenizer configuration and to upload |
|
59 | 64 | from transformers import AutoTokenizer, AutoProcessor |
60 | 65 |
|
61 | 66 | from absl import app |
| 67 | +from absl import flags |
62 | 68 |
|
63 | 69 | from MaxText import pyconfig |
64 | 70 | from MaxText.utils.ckpt_conversion.utils.param_mapping import ( |
|
80 | 86 | from maxtext.utils import max_logging |
81 | 87 | from maxtext.utils import max_utils |
82 | 88 |
|
| 89 | +flags.DEFINE_bool( |
| 90 | + "override_model_architecture", |
| 91 | + False, |
| 92 | + "If True, overrides Hugging Face config architecture parameters (heads, layers, dims) " |
| 93 | + "with values from the MaxText config. If False, raises a ValueError on mismatch.", |
| 94 | +) |
| 95 | + |
| 96 | +FLAGS = flags.FLAGS |
| 97 | + |
83 | 98 |
|
84 | 99 | def _get_model_mappings( |
85 | 100 | model_name: str, scan_layers: bool, hf_config_dict: dict, maxtext_config: pyconfig.HyperParameters |
@@ -109,6 +124,71 @@ def _get_model_mappings( |
109 | 124 | } |
110 | 125 |
|
111 | 126 |
|
| 127 | +def _validate_or_update_architecture(hf_config, max_config, override: bool): |
| 128 | + """Validates consistency between HF and MaxText configs or overrides HF config if requested. |
| 129 | +
|
| 130 | + Args: |
| 131 | + hf_config: The Hugging Face configuration object. |
| 132 | + max_config: The MaxText configuration object (HyperParameters). |
| 133 | + override: Boolean, if True, update hf_config with max_config values. |
| 134 | + If False, raise error on mismatch. |
| 135 | + """ |
| 136 | + # Mapping from Hugging Face config attribute -> MaxText config attribute |
| 137 | + # Note: We use derived MaxText attributes (e.g. emb_dim) which account for scale factors. |
| 138 | + attributes_to_check = [ |
| 139 | + ("num_attention_heads", "num_query_heads"), |
| 140 | + ("num_key_value_heads", "num_kv_heads"), |
| 141 | + ("head_dim", "head_dim"), |
| 142 | + ("hidden_size", "emb_dim"), |
| 143 | + ("intermediate_size", "mlp_dim"), |
| 144 | + ("num_hidden_layers", "num_decoder_layers"), |
| 145 | + ("vocab_size", "vocab_size"), |
| 146 | + ] |
| 147 | + |
| 148 | + mismatches = [] |
| 149 | + |
| 150 | + for hf_attr, mt_attr in attributes_to_check: |
| 151 | + # Skip checks if the HF config doesn't have this attribute (e.g. layer_norm_eps vs rms_norm_eps) |
| 152 | + if not hasattr(hf_config, hf_attr): |
| 153 | + continue |
| 154 | + |
| 155 | + # Skip checks if MaxText config doesn't have the attribute (shouldn't happen for valid configs) |
| 156 | + if not hasattr(max_config, mt_attr): |
| 157 | + continue |
| 158 | + |
| 159 | + hf_value = getattr(hf_config, hf_attr) |
| 160 | + mt_value = getattr(max_config, mt_attr) |
| 161 | + |
| 162 | + # Handle None values |
| 163 | + if hf_value is None or mt_value is None: |
| 164 | + continue |
| 165 | + |
| 166 | + # Compare values (with tolerance for floats) |
| 167 | + is_match = False |
| 168 | + if isinstance(hf_value, float) or isinstance(mt_value, float): |
| 169 | + try: |
| 170 | + is_match = abs(float(hf_value) - float(mt_value)) < 1e-6 |
| 171 | + except (ValueError, TypeError): |
| 172 | + is_match = hf_value == mt_value |
| 173 | + else: |
| 174 | + is_match = hf_value == mt_value |
| 175 | + |
| 176 | + if not is_match: |
| 177 | + if override: |
| 178 | + max_logging.log(f"⚠️ Overwriting HF Config '{hf_attr}': {hf_value} -> {mt_value} (from MaxText '{mt_attr}')") |
| 179 | + setattr(hf_config, hf_attr, mt_value) |
| 180 | + else: |
| 181 | + mismatches.append(f"{hf_attr} (HF={hf_value} vs MaxText={mt_value})") |
| 182 | + |
| 183 | + if mismatches: |
| 184 | + error_msg = ( |
| 185 | + "Architecture mismatches detected between standard Hugging Face config and provided MaxText config:\n - " |
| 186 | + + "\n - ".join(mismatches) |
| 187 | + + "\n\nAction Required: Pass the flag `--override_model_architecture` to force the conversion using MaxText values." |
| 188 | + ) |
| 189 | + raise ValueError(error_msg) |
| 190 | + |
| 191 | + |
112 | 192 | def main(argv: Sequence[str]) -> None: |
113 | 193 | """Main function to convert a MaxText checkpoint to HuggingFace format. |
114 | 194 |
|
@@ -151,6 +231,9 @@ def main(argv: Sequence[str]) -> None: |
151 | 231 | raise ValueError(f"Unsupported model name: {config.model_name}. Supported models are: {list(HF_IDS.keys())}") |
152 | 232 | hf_config_obj = HF_MODEL_CONFIGS[model_key] |
153 | 233 |
|
| 234 | + # Validate architecture consistency (raising ValueError on mismatch) or override HF config if specified. |
| 235 | + _validate_or_update_architecture(hf_config_obj, config, override=FLAGS.override_model_architecture) |
| 236 | + |
154 | 237 | # 2. Load Tokenizer |
155 | 238 | if model_key not in HF_IDS: |
156 | 239 | raise ValueError(f"HF Tokenizer ID not found for model key: {model_key}") |
|
0 commit comments