Skip to content

Commit 9caffac

Browse files
committed
Add unified tool to inspect checkpoint structures
1 parent 03cabc3 commit 9caffac

4 files changed

Lines changed: 451 additions & 10 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,42 @@ Max KL divergence for a single token in the set: 0.003497
240240

241241
______________________________________________________________________
242242

243-
## Troubleshooting and Development
243+
## Development and Troubleshooting
244+
245+
### Inspection Tools
246+
247+
We provide a unified inspection tool In [`inspect_checkpoint.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/inspect_checkpoint.py) to help you view model structures. This is highly useful for both development and troubleshooting.
248+
249+
- **Lightweight & low-overhead:** Operates on either CPU or TPU, with a near-zero memory footprint (RAM/HBM) and negligible compute costs.
250+
- **Detailed outputs:** Prints tensor keys and shapes. Optionally appends data types with `--check_dtype=True` flag.
251+
- **Save to file:** Optionally saves the printed layout structure to a text file with `--output_file=<path>` flag.
252+
253+
**Tool 1 (HF Inspector)**. To inspect the HuggingFace checkpoint structure, from either `.safetensors` or `.pth` files:
254+
255+
```
256+
python -m maxtext.checkpoint_conversion.inspect_checkpoint hf --path <local_hf_path> --format <safetensors | pth>
257+
```
258+
259+
**Tool 2 (MaxText Inspector)**. To inspect the MaxText model structure:
260+
261+
```
262+
python -m maxtext.checkpoint_conversion.inspect_checkpoint maxtext model_name=<maxtext_model_name> scan_layers=<True | False>
263+
```
264+
265+
**Tool 3 (Orbax Inspector)**. To inspect the Orbax checkpoint:
266+
267+
```
268+
python -m maxtext.checkpoint_conversion.inspect_checkpoint orbax --path <local_orbax_path | gcs_orbax_path>
269+
```
244270

245271
### Adding New Models
246272

247273
To extend conversion support to a new model architecture, you must define its specific parameter and configuration mappings. The conversion logic is decoupled, so you only need to modify the mapping files.
248274

249275
1. **Add parameter mappings**:
250276

277+
- As the first step, we inspect the checkpoint structures. To see the HuggingFace checkpoint structure, use **Tool 1 (HF Inspector)**. To see the MaxText model structure, use **Tool 2 (MaxText Inspector)**.
278+
251279
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer.
252280

253281
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.
@@ -262,11 +290,22 @@ Here is an example [PR to add support for gemma3 multi-modal model](https://gith
262290

263291
### Common Errors
264292

265-
- "Type ShapeDtypeStruct is not a valid JAX type" or generic **PyTree structure/shape mismatches** (e.g., Orbax reporting `"X/Y paths matched"`, such as `143/145 paths`):
266-
This is almost always caused by a mismatch in the `scan_layers` configuration between the checkpoint conversion script (e.g., `to_maxtext.py` or `to_huggingface.py`) and the trainer/inference runner (e.g., `train.py`).
293+
- **Error:** When loading a converted checkpoint, `Type ShapeDtypeStruct is not a valid JAX type`.
294+
295+
- **Cause (most common): Structure mismatch** between the converted checkpoint and the MaxText model.
296+
297+
- **Solution:** To see the MaxText model structure, use **Tool 2 (MaxText Inspector)**. To inspect the checkpoint, use **Tool 3 (Orbax Inspector)**.
298+
299+
- **Error:** `Type ShapeDtypeStruct is not a valid JAX type` or generic **PyTree structure/shape mismatches** (e.g., Orbax reporting `"X/Y paths matched"`, such as `143/145 paths`).
300+
301+
- **Cause: Configuration mismatch** (e.g., `scan_layers`) between the checkpoint conversion script (e.g., `to_maxtext.py` or `to_huggingface.py`) and the trainer/inference runner (e.g., `train.py`).
267302

268303
- **Solution:** Ensure the `scan_layers` flag is set to the exact same value (`True` or `False`) in both the conversion command and your training/execution command.
269304

270-
- If the converted checkpoint loads without errors but produces nonsensical output, likely an error in the Q/K/V weight reshaping logic during conversion.
305+
- **Error:** The converted checkpoint loads without errors but produces nonsensical output.
306+
307+
- **Cause:** There is likely an error in the Q/K/V weight reshaping logic during conversion.
308+
309+
- **Error:** The model generates repetitive text sequences.
271310

272-
- If the model generates repetitive text sequences, check if layer normalization parameters were mapped correctly.
311+
- **Cause:** Layer normalization parameters were likely not mapped correctly.

docs/guides/model_bringup.md

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ This step can be bypassed if the current MaxText codebase already supports all c
4242

4343
## 3. Checkpoint Conversion
4444

45-
While most open-source models are distributed in Safetensors or PyTorch formats, MaxText requires conversion to the [Orbax](https://orbax.readthedocs.io/en/latest/) format.
45+
While most open-source models are distributed in Safetensors or PyTorch formats, MaxText requires conversion to the [Orbax](https://orbax.readthedocs.io/en/latest) format.
4646

4747
There are [two primary formats](checkpoints) for Orbax checkpoints within MaxText, and while both are technically compatible with training and inference, we recommend following these performance-optimized guidelines:
4848

@@ -51,10 +51,27 @@ There are [two primary formats](checkpoints) for Orbax checkpoints within MaxTex
5151

5252
### 3.1 Create Mapping
5353

54-
Success starts with a clear map. You must align the parameter names from your source checkpoints (Safetensors/PyTorch) with the corresponding MaxText internal names.
54+
To successfully convert a model, you must define the exact mapping between the parameter names in your source checkpoints (Safetensors/PyTorch) and the corresponding MaxText internal names.
5555

56-
- You can print out the keys and shapes of your original `.safetensors` or `.pth` files.
57-
- To see the target structure, you can initiate a pre-training run to save a randomly initialized checkpoint for inspection.
56+
We provide a unified utility [`inspect_checkpoint.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/inspect_checkpoint.py) to help you view and compare these structures.
57+
58+
(1) **HF Inspector**. To see the HuggingFace checkpoint structure, you can print out the keys and shapes of your original `.safetensors` or `.pth` files.
59+
60+
```
61+
python -m maxtext.checkpoint_conversion.inspect_checkpoint hf --path <local_hf_path> --format <safetensors | pth>
62+
```
63+
64+
(2) **MaxText Inspector**. View the expected parameter structure of the target MaxText model:
65+
66+
```
67+
python -m maxtext.checkpoint_conversion.inspect_checkpoint maxtext model_name=<maxtext_model_name> scan_layers=<True | False>
68+
```
69+
70+
(3) **Orbax Inspector** (Optional). If you have already saved an Orbax checkpoint during pretraining, you can inspect its structure directly:
71+
72+
```
73+
python -m maxtext.checkpoint_conversion.inspect_checkpoint orbax --path <local_orbax_path | gcs_orbax_path>
74+
```
5875

5976
### 3.2 Write Script
6077

0 commit comments

Comments
 (0)