You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We provide a unified inspection tool In [`inspect_checkpoint.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/inspect_checkpoint.py) to help you view model structures. This is highly useful for both development and troubleshooting.
248
+
249
+
-**Lightweight & low-overhead:** Operates on either CPU or TPU, with a near-zero memory footprint (RAM/HBM) and negligible compute costs.
250
+
-**Detailed outputs:** Prints tensor keys and shapes. Optionally appends data types with `--check_dtype=True` flag.
251
+
-**Save to file:** Optionally saves the printed layout structure to a text file with `--output_file=<path>` flag.
252
+
253
+
**Tool 1 (HF Inspector)**. To inspect the HuggingFace checkpoint structure, from either `.safetensors` or `.pth` files:
To extend conversion support to a new model architecture, you must define its specific parameter and configuration mappings. The conversion logic is decoupled, so you only need to modify the mapping files.
248
274
249
275
1.**Add parameter mappings**:
250
276
277
+
- As the first step, we inspect the checkpoint structures. To see the HuggingFace checkpoint structure, use **Tool 1 (HF Inspector)**. To see the MaxText model structure, use **Tool 2 (MaxText Inspector)**.
278
+
251
279
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer.
252
280
253
281
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.
@@ -262,11 +290,22 @@ Here is an example [PR to add support for gemma3 multi-modal model](https://gith
262
290
263
291
### Common Errors
264
292
265
-
- "Type ShapeDtypeStruct is not a valid JAX type" or generic **PyTree structure/shape mismatches** (e.g., Orbax reporting `"X/Y paths matched"`, such as `143/145 paths`):
266
-
This is almost always caused by a mismatch in the `scan_layers` configuration between the checkpoint conversion script (e.g., `to_maxtext.py` or `to_huggingface.py`) and the trainer/inference runner (e.g., `train.py`).
293
+
-**Error:** When loading a converted checkpoint, `Type ShapeDtypeStruct is not a valid JAX type`.
294
+
295
+
-**Cause (most common): Structure mismatch** between the converted checkpoint and the MaxText model.
296
+
297
+
-**Solution:** To see the MaxText model structure, use **Tool 2 (MaxText Inspector)**. To inspect the checkpoint, use **Tool 3 (Orbax Inspector)**.
298
+
299
+
-**Error:**`Type ShapeDtypeStruct is not a valid JAX type` or generic **PyTree structure/shape mismatches** (e.g., Orbax reporting `"X/Y paths matched"`, such as `143/145 paths`).
300
+
301
+
-**Cause: Configuration mismatch** (e.g., `scan_layers`) between the checkpoint conversion script (e.g., `to_maxtext.py` or `to_huggingface.py`) and the trainer/inference runner (e.g., `train.py`).
267
302
268
303
-**Solution:** Ensure the `scan_layers` flag is set to the exact same value (`True` or `False`) in both the conversion command and your training/execution command.
269
304
270
-
- If the converted checkpoint loads without errors but produces nonsensical output, likely an error in the Q/K/V weight reshaping logic during conversion.
305
+
-**Error:** The converted checkpoint loads without errors but produces nonsensical output.
306
+
307
+
-**Cause:** There is likely an error in the Q/K/V weight reshaping logic during conversion.
308
+
309
+
-**Error:** The model generates repetitive text sequences.
271
310
272
-
- If the model generates repetitive text sequences, check if layer normalization parameters were mapped correctly.
311
+
-**Cause:** Layer normalization parameters were likely not mapped correctly.
Copy file name to clipboardExpand all lines: docs/guides/model_bringup.md
+21-4Lines changed: 21 additions & 4 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -42,7 +42,7 @@ This step can be bypassed if the current MaxText codebase already supports all c
42
42
43
43
## 3. Checkpoint Conversion
44
44
45
-
While most open-source models are distributed in Safetensors or PyTorch formats, MaxText requires conversion to the [Orbax](https://orbax.readthedocs.io/en/latest/) format.
45
+
While most open-source models are distributed in Safetensors or PyTorch formats, MaxText requires conversion to the [Orbax](https://orbax.readthedocs.io/en/latest) format.
46
46
47
47
There are [two primary formats](checkpoints) for Orbax checkpoints within MaxText, and while both are technically compatible with training and inference, we recommend following these performance-optimized guidelines:
48
48
@@ -51,10 +51,27 @@ There are [two primary formats](checkpoints) for Orbax checkpoints within MaxTex
51
51
52
52
### 3.1 Create Mapping
53
53
54
-
Success starts with a clear map. You must align the parameter names from your source checkpoints (Safetensors/PyTorch) with the corresponding MaxText internal names.
54
+
To successfully convert a model, you must define the exact mapping between the parameter names in your source checkpoints (Safetensors/PyTorch) and the corresponding MaxText internal names.
55
55
56
-
- You can print out the keys and shapes of your original `.safetensors` or `.pth` files.
57
-
- To see the target structure, you can initiate a pre-training run to save a randomly initialized checkpoint for inspection.
56
+
We provide a unified utility [`inspect_checkpoint.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/inspect_checkpoint.py) to help you view and compare these structures.
57
+
58
+
(1) **HF Inspector**. To see the HuggingFace checkpoint structure, you can print out the keys and shapes of your original `.safetensors` or `.pth` files.
0 commit comments