Skip to content

Commit f833914

Browse files
committed
docs & feat: clarify and improve scan_layers mismatch error handling in conversion, training and checkpoints
1 parent 2e6cd11 commit f833914

14 files changed

Lines changed: 301 additions & 74 deletions

File tree

docs/development.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ hidden:
77
---
88
development/update_dependencies.md
99
development/contribute_docs.md
10+
development/hlo_diff_testing.md
1011
```

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/i
7070
### Key Parameters
7171

7272
- `model_name`: The specific model identifier. It must match a supported entry in the MaxText [globals.py](https://github.com/AI-Hypercomputer/maxtext/blob/16b684840db9b96b19e24e84ac49f06af7204ae3/src/maxtext/utils/globals.py#L46C1-L46C7).
73-
- `scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [here](../../reference/core_concepts/checkpoints.md) for more information.
73+
- `scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [here](../../reference/core_concepts/checkpoints.md) for more information. **IMPORTANT:** This setting *must* match the `scan_layers` value used during model training or loading. A mismatch will cause PyTree loading errors (though MaxText will intercept these and raise a descriptive `ValueError` explaining the mismatch).
7474
- `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
7575
- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Google Cloud Storage (GCS) or local.
7676
- `hardware=cpu`: The conversion script runs on a CPU machine.
@@ -239,7 +239,10 @@ Here is an example [PR to add support for gemma3 multi-modal model](https://gith
239239

240240
### Common Errors
241241

242-
- "Type ShapeDtypeStruct is not a valid JAX type": Usually caused by a mismatch in the `scan_layers` flag.
242+
- "Type ShapeDtypeStruct is not a valid JAX type" or generic **PyTree structure/shape mismatches** (e.g., Orbax reporting `"X/Y paths matched"`, such as `143/145 paths`):
243+
This is almost always caused by a mismatch in the `scan_layers` configuration between the checkpoint conversion script (e.g., `to_maxtext.py` or `to_huggingface.py`) and the trainer/inference runner (e.g., `train.py`).
244+
245+
- **Solution:** Ensure the `scan_layers` flag is set to the exact same value (`True` or `False`) in both the conversion command and your training/execution command.
243246

244247
- If the converted checkpoint loads without errors but produces nonsensical output, likely an error in the Q/K/V weight reshaping logic during conversion.
245248

docs/guides/data_input_pipeline.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,6 @@ hidden:
6464
data_input_pipeline/data_input_grain
6565
data_input_pipeline/data_input_hf
6666
data_input_pipeline/data_input_tfds
67+
data_input_pipeline/olmo_grain
6768
data_input_pipeline/data_pipeline_perf.md
6869
```

docs/guides/optimization.md

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,37 +18,39 @@
1818

1919
Explore techniques for maximizing performance, including model customization, sharding strategies, Pallas kernels, and benchmarking.
2020

21-
::::{grid} 1 2 2 2
22-
:gutter: 2
21+
````{grid} 1 2 2 2
22+
---
23+
gutter: 2
24+
---
2325
24-
:::{grid-item-card} 🛠️ Customizing Model Configs
26+
```{grid-item-card} 🛠️ Customizing Model Configs
2527
:link: optimization/custom_model
2628
:link-type: doc
2729
2830
Optimize and customize your LLM model configurations for higher performance (MFU) on TPUs.
29-
:::
31+
```
3032
31-
:::{grid-item-card} 🥞 Sharding Strategies
33+
```{grid-item-card} 🥞 Sharding Strategies
3234
:link: optimization/sharding
3335
:link-type: doc
3436
3537
Choose efficient sharding strategies (FSDP, TP, EP, PP) using Roofline Analysis and understand arithmetic intensity.
36-
:::
38+
```
3739
38-
:::{grid-item-card} ⚡ Pallas Kernels
40+
```{grid-item-card} ⚡ Pallas Kernels
3941
:link: optimization/pallas_kernels_performance
4042
:link-type: doc
4143
4244
Optimize with Pallas kernels for fine-grained control.
43-
:::
45+
```
4446
45-
:::{grid-item-card} 📈 Benchmarking & Tuning
47+
```{grid-item-card} 📈 Benchmarking & Tuning
4648
:link: optimization/benchmark_and_performance
4749
:link-type: doc
4850
4951
Guide to setting up benchmarks, performing performance tuning, and analyzing metrics.
50-
:::
51-
::::
52+
```
53+
````
5254

5355
```{toctree}
5456
---
@@ -57,6 +59,7 @@ maxdepth: 1
5759
---
5860
optimization/custom_model.md
5961
optimization/sharding.md
62+
optimization/custom_mesh_and_rule.md
6063
optimization/pallas_kernels_performance.md
6164
optimization/benchmark_and_performance.md
6265
```

docs/reference/core_concepts/checkpoints.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ Their difference can also be represented in the following pytree structure:
6666

6767
The stacked format is highly efficient but has one key requirement: all layers within the `scan` operation must have identical configurations. For models with heterogeneous layers (where layer configurations differ), stacking is not possible, and only unstacked checkpoints can be used.
6868

69+
In MaxText, the **`scan_layers`** configuration parameter is used to control this setting:
70+
71+
- `scan_layers=true` tells MaxText to stack layer parameters (recommended for training).
72+
- `scan_layers=false` tells MaxText to keep layer parameters unstacked (often required for inference and certain model architectures).
73+
74+
> [!IMPORTANT]
75+
> **PyTree Structure Compatibility:** Because JAX expects the loaded PyTree structure to exactly match the model's instantiated structure, the value of the `scan_layers` flag during execution (training, SFT, RL, DPO, or decoding) **must** match the format of the checkpoint being loaded. A mismatch will cause PyTree loading or shape/path mismatch errors (which MaxText will intercept to raise a descriptive `ValueError` pointing to the scan_layers setting).
76+
6977
### Takeaways
7078

7179
To summarize the four checkpoint types:

docs/run_maxtext/run_maxtext_localhost.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ python3 -m maxtext.inference.decode \
6565

6666
**Note:** Because the model hasn't been properly trained, the output text will be random. To generate meaningful output, you need to load a trained checkpoint using the `load_parameters_path` argument.
6767

68+
> [!NOTE]
69+
> **Checkpoints & `scan_layers` compatibility:** When loading an external or converted checkpoint via `load_parameters_path`, the `scan_layers` setting in your command **must** match the setting used to save the checkpoint. If the checkpoint was saved/converted with `scan_layers=False` (common for Hugging Face conversions and inference runs), you must specify `scan_layers=False` in your command. Otherwise, JAX/Orbax will raise PyTree structure mismatch errors.
70+
6871
### Running models using provided configs
6972

7073
MaxText provides many OSS model configs that you can use directly to run training jobs on those model-specific architectures. These model-specific YAML files are located in `src/maxtext/configs/models` for TPU-oriented defaults, and `src/maxtext/configs/models/gpu` for GPU-oriented defaults.

docs/tutorials/posttraining/dpo.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ Refer to the steps in [Hugging Face to MaxText](../../guides/checkpointing_solut
103103
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items
104104
```
105105

106+
> [!IMPORTANT]
107+
> **Matching the `scan_layers` Parameter:**
108+
> The `scan_layers` setting during your fine-tuning run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
109+
>
110+
> - If the checkpoint was converted or saved with `scan_layers=False` (which is common for Hugging Face conversions and inference-ready models), you **must also provide `scan_layers=False` in the MaxText command.**
111+
> - If `scan_layers` does not match, MaxText will raise a `ValueError`.
112+
> See the [Checkpoints concept guide](../../reference/core_concepts/checkpoints.md) for more details.
113+
106114
## Running DPO Training
107115

108116
You can run the DPO training using the specialized post-training script:

docs/tutorials/posttraining/rl_on_multi_host.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ Refer to the steps in [Hugging Face to MaxText](../../guides/checkpointing_solut
148148
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items
149149
```
150150

151+
> [!IMPORTANT]
152+
> **Matching the `scan_layers` Parameter:**
153+
> The `scan_layers` setting during your RL training run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
154+
>
155+
> - If the checkpoint was converted or saved with `scan_layers=False` (which is common for Hugging Face conversions and inference-ready models), you **must also provide `scan_layers=False` in the MaxText command.**
156+
> - If `scan_layers` does not match, MaxText will raise a `ValueError`.
157+
> See the [Checkpoints concept guide](../../reference/core_concepts/checkpoints.md) for more details.
158+
151159
## Submit your RL workload via Pathways
152160

153161
See the **Troubleshooting** section for concise instructions on how to retry or

docs/tutorials/posttraining/sft.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solution
8888
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items
8989
```
9090

91+
> [!IMPORTANT]
92+
> **Matching the `scan_layers` Parameter:**
93+
> The `scan_layers` setting during your fine-tuning run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
94+
>
95+
> - If the checkpoint was converted or saved with `scan_layers=False` (which is common for Hugging Face conversions and inference-ready models), you **must also provide `scan_layers=False` in the MaxText command.**
96+
> - If `scan_layers` does not match, MaxText will raise a `ValueError`.
97+
> See the [Checkpoints concept guide](../../reference/core_concepts/checkpoints.md) for more details.
98+
9199
## Run SFT on Hugging Face Dataset
92100

93101
Now you are ready to run SFT using the following command:

docs/tutorials/posttraining/sft_on_multi_host.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,14 @@ Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solution
139139
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # gs://my-bucket/my-checkpoint-directory/0/items
140140
```
141141

142+
> [!IMPORTANT]
143+
> **Matching the `scan_layers` Parameter:**
144+
> The `scan_layers` setting during your fine-tuning run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
145+
>
146+
> - If the checkpoint was converted or saved with `scan_layers=False` (which is common for Hugging Face conversions and inference-ready models), you **must also provide `scan_layers=False` in the MaxText command.**
147+
> - If `scan_layers` does not match, MaxText will raise a `ValueError`.
148+
> See the [Checkpoints concept guide](../../reference/core_concepts/checkpoints.md) for more details.
149+
142150
## Submit workload on GKE cluster
143151

144152
This section provides the command to run SFT on a GKE cluster.

0 commit comments

Comments
 (0)