You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/guides/checkpointing_solutions/convert_checkpoint.md
+5-2Lines changed: 5 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -70,7 +70,7 @@ You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/i
70
70
### Key Parameters
71
71
72
72
-`model_name`: The specific model identifier. It must match a supported entry in the MaxText [globals.py](https://github.com/AI-Hypercomputer/maxtext/blob/16b684840db9b96b19e24e84ac49f06af7204ae3/src/maxtext/utils/globals.py#L46C1-L46C7).
73
-
-`scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [here](../../reference/core_concepts/checkpoints.md) for more information.
73
+
-`scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [here](../../reference/core_concepts/checkpoints.md) for more information.**IMPORTANT:** This setting *must* match the `scan_layers` value used during model training or loading. A mismatch will cause PyTree loading errors (though MaxText will intercept these and raise a descriptive `ValueError` explaining the mismatch).
74
74
-`use_multimodal`: Indicates if multimodality is used, important for Gemma3.
75
75
-`base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Google Cloud Storage (GCS) or local.
76
76
-`hardware=cpu`: The conversion script runs on a CPU machine.
@@ -239,7 +239,10 @@ Here is an example [PR to add support for gemma3 multi-modal model](https://gith
239
239
240
240
### Common Errors
241
241
242
-
- "Type ShapeDtypeStruct is not a valid JAX type": Usually caused by a mismatch in the `scan_layers` flag.
242
+
- "Type ShapeDtypeStruct is not a valid JAX type" or generic **PyTree structure/shape mismatches** (e.g., Orbax reporting `"X/Y paths matched"`, such as `143/145 paths`):
243
+
This is almost always caused by a mismatch in the `scan_layers` configuration between the checkpoint conversion script (e.g., `to_maxtext.py` or `to_huggingface.py`) and the trainer/inference runner (e.g., `train.py`).
244
+
245
+
***Solution:** Ensure the `scan_layers` flag is set to the exact same value (`True` or `False`) in both the conversion command and your training/execution command.
243
246
244
247
- If the converted checkpoint loads without errors but produces nonsensical output, likely an error in the Q/K/V weight reshaping logic during conversion.
Copy file name to clipboardExpand all lines: docs/reference/core_concepts/checkpoints.md
+7Lines changed: 7 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -66,6 +66,13 @@ Their difference can also be represented in the following pytree structure:
66
66
67
67
The stacked format is highly efficient but has one key requirement: all layers within the `scan` operation must have identical configurations. For models with heterogeneous layers (where layer configurations differ), stacking is not possible, and only unstacked checkpoints can be used.
68
68
69
+
In MaxText, the **`scan_layers`** configuration parameter is used to control this setting:
70
+
-`scan_layers=true` tells MaxText to stack layer parameters (recommended for training).
71
+
-`scan_layers=false` tells MaxText to keep layer parameters unstacked (often required for inference and certain model architectures).
72
+
73
+
> [!IMPORTANT]
74
+
> **PyTree Structure Compatibility:** Because JAX expects the loaded PyTree structure to exactly match the model's instantiated structure, the value of the `scan_layers` flag during execution (training, SFT, RL, DPO, or decoding) **must** match the format of the checkpoint being loaded. A mismatch will cause PyTree loading or shape/path mismatch errors (which MaxText will intercept to raise a descriptive `ValueError` pointing to the scan_layers setting).
**Note:** Because the model hasn't been properly trained, the output text will be random. To generate meaningful output, you need to load a trained checkpoint using the `load_parameters_path` argument.
67
67
68
+
> [!NOTE]
69
+
> **Checkpoints & `scan_layers` compatibility:** When loading an external or converted checkpoint via `load_parameters_path`, the `scan_layers` setting in your command **must** match the setting used to save the checkpoint. If the checkpoint was saved/converted with `scan_layers=False` (common for Hugging Face conversions and inference runs), you must specify `scan_layers=False` in your command. Otherwise, JAX/Orbax will raise PyTree structure mismatch errors.
70
+
68
71
### Running models using provided configs
69
72
70
73
MaxText provides many OSS model configs that you can use directly to run training jobs on those model-specific architectures. These model-specific YAML files are located in `src/maxtext/configs/models` for TPU-oriented defaults, and `src/maxtext/configs/models/gpu` for GPU-oriented defaults.
> The `scan_layers` setting during your fine-tuning run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
109
+
> * If the checkpoint was converted or saved with `scan_layers=False` (which is common for Hugging Face conversions and inference-ready models), you **must also provide `scan_layers=False` in the MaxText command.**
110
+
> * If `scan_layers` does not match, MaxText will raise a `ValueError`.
111
+
> See the [Checkpoints concept guide](../../reference/core_concepts/checkpoints.md) for more details.
112
+
106
113
## Running DPO Training
107
114
108
115
You can run the DPO training using the specialized post-training script:
> The `scan_layers` setting during your RL training run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
154
+
> * If the checkpoint was converted or saved with `scan_layers=False` (which is common for Hugging Face conversions and inference-ready models), you **must also provide `scan_layers=False` in the MaxText command.**
155
+
> * If `scan_layers` does not match, MaxText will raise a `ValueError`.
156
+
> See the [Checkpoints concept guide](../../reference/core_concepts/checkpoints.md) for more details.
157
+
151
158
## Submit your RL workload via Pathways
152
159
153
160
See the **Troubleshooting** section for concise instructions on how to retry or
> The `scan_layers` setting during your fine-tuning run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
94
+
> * If the checkpoint was converted or saved with `scan_layers=False` (which is common for Hugging Face conversions and inference-ready models), you **must also provide `scan_layers=False` in the MaxText command.**
95
+
> * If `scan_layers` does not match, MaxText will raise a `ValueError`.
96
+
> See the [Checkpoints concept guide](../../reference/core_concepts/checkpoints.md) for more details.
97
+
91
98
## Run SFT on Hugging Face Dataset
92
99
93
100
Now you are ready to run SFT using the following command:
> The `scan_layers` setting during your fine-tuning run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
145
+
> * If the checkpoint was converted or saved with `scan_layers=False` (which is common for Hugging Face conversions and inference-ready models), you **must also provide `scan_layers=False` in the MaxText command.**
146
+
> * If `scan_layers` does not match, MaxText will raise a `ValueError`.
147
+
> See the [Checkpoints concept guide](../../reference/core_concepts/checkpoints.md) for more details.
148
+
142
149
## Submit workload on GKE cluster
143
150
144
151
This section provides the command to run SFT on a GKE cluster.
0 commit comments