Skip to content

Commit c832537

Browse files
committed
docs & feat: clarify and improve scan_layers mismatch error handling in conversion, training and checkpoints
1 parent 2e6cd11 commit c832537

11 files changed

Lines changed: 278 additions & 57 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/i
7070
### Key Parameters
7171

7272
- `model_name`: The specific model identifier. It must match a supported entry in the MaxText [globals.py](https://github.com/AI-Hypercomputer/maxtext/blob/16b684840db9b96b19e24e84ac49f06af7204ae3/src/maxtext/utils/globals.py#L46C1-L46C7).
73-
- `scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [here](../../reference/core_concepts/checkpoints.md) for more information.
73+
- `scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [here](../../reference/core_concepts/checkpoints.md) for more information. **IMPORTANT:** This setting *must* match the `scan_layers` value used during model training or loading. A mismatch will cause PyTree loading errors (though MaxText will intercept these and raise a descriptive `ValueError` explaining the mismatch).
7474
- `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
7575
- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Google Cloud Storage (GCS) or local.
7676
- `hardware=cpu`: The conversion script runs on a CPU machine.
@@ -239,7 +239,10 @@ Here is an example [PR to add support for gemma3 multi-modal model](https://gith
239239

240240
### Common Errors
241241

242-
- "Type ShapeDtypeStruct is not a valid JAX type": Usually caused by a mismatch in the `scan_layers` flag.
242+
- "Type ShapeDtypeStruct is not a valid JAX type" or generic **PyTree structure/shape mismatches** (e.g., Orbax reporting `"X/Y paths matched"`, such as `143/145 paths`):
243+
This is almost always caused by a mismatch in the `scan_layers` configuration between the checkpoint conversion script (e.g., `to_maxtext.py` or `to_huggingface.py`) and the trainer/inference runner (e.g., `train.py`).
244+
245+
* **Solution:** Ensure the `scan_layers` flag is set to the exact same value (`True` or `False`) in both the conversion command and your training/execution command.
243246

244247
- If the converted checkpoint loads without errors but produces nonsensical output, likely an error in the Q/K/V weight reshaping logic during conversion.
245248

docs/reference/core_concepts/checkpoints.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ Their difference can also be represented in the following pytree structure:
6666

6767
The stacked format is highly efficient but has one key requirement: all layers within the `scan` operation must have identical configurations. For models with heterogeneous layers (where layer configurations differ), stacking is not possible, and only unstacked checkpoints can be used.
6868

69+
In MaxText, the **`scan_layers`** configuration parameter is used to control this setting:
70+
- `scan_layers=true` tells MaxText to stack layer parameters (recommended for training).
71+
- `scan_layers=false` tells MaxText to keep layer parameters unstacked (often required for inference and certain model architectures).
72+
73+
> [!IMPORTANT]
74+
> **PyTree Structure Compatibility:** Because JAX expects the loaded PyTree structure to exactly match the model's instantiated structure, the value of the `scan_layers` flag during execution (training, SFT, RL, DPO, or decoding) **must** match the format of the checkpoint being loaded. A mismatch will cause PyTree loading or shape/path mismatch errors (which MaxText will intercept to raise a descriptive `ValueError` pointing to the scan_layers setting).
75+
6976
### Takeaways
7077

7178
To summarize the four checkpoint types:

docs/run_maxtext/run_maxtext_localhost.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ python3 -m maxtext.inference.decode \
6565

6666
**Note:** Because the model hasn't been properly trained, the output text will be random. To generate meaningful output, you need to load a trained checkpoint using the `load_parameters_path` argument.
6767

68+
> [!NOTE]
69+
> **Checkpoints & `scan_layers` compatibility:** When loading an external or converted checkpoint via `load_parameters_path`, the `scan_layers` setting in your command **must** match the setting used to save the checkpoint. If the checkpoint was saved/converted with `scan_layers=False` (common for Hugging Face conversions and inference runs), you must specify `scan_layers=False` in your command. Otherwise, JAX/Orbax will raise PyTree structure mismatch errors.
70+
6871
### Running models using provided configs
6972

7073
MaxText provides many OSS model configs that you can use directly to run training jobs on those model-specific architectures. These model-specific YAML files are located in `src/maxtext/configs/models` for TPU-oriented defaults, and `src/maxtext/configs/models/gpu` for GPU-oriented defaults.

docs/tutorials/posttraining/dpo.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ Refer to the steps in [Hugging Face to MaxText](../../guides/checkpointing_solut
103103
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items
104104
```
105105

106+
> [!IMPORTANT]
107+
> **Matching the `scan_layers` Parameter:**
108+
> The `scan_layers` setting during your fine-tuning run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
109+
> * If the checkpoint was converted or saved with `scan_layers=False` (which is common for Hugging Face conversions and inference-ready models), you **must also provide `scan_layers=False` in the MaxText command.**
110+
> * If `scan_layers` does not match, MaxText will raise a `ValueError`.
111+
> See the [Checkpoints concept guide](../../reference/core_concepts/checkpoints.md) for more details.
112+
106113
## Running DPO Training
107114

108115
You can run the DPO training using the specialized post-training script:

docs/tutorials/posttraining/rl_on_multi_host.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,13 @@ Refer to the steps in [Hugging Face to MaxText](../../guides/checkpointing_solut
148148
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items
149149
```
150150

151+
> [!IMPORTANT]
152+
> **Matching the `scan_layers` Parameter:**
153+
> The `scan_layers` setting during your RL training run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
154+
> * If the checkpoint was converted or saved with `scan_layers=False` (which is common for Hugging Face conversions and inference-ready models), you **must also provide `scan_layers=False` in the MaxText command.**
155+
> * If `scan_layers` does not match, MaxText will raise a `ValueError`.
156+
> See the [Checkpoints concept guide](../../reference/core_concepts/checkpoints.md) for more details.
157+
151158
## Submit your RL workload via Pathways
152159

153160
See the **Troubleshooting** section for concise instructions on how to retry or

docs/tutorials/posttraining/sft.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solution
8888
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items
8989
```
9090

91+
> [!IMPORTANT]
92+
> **Matching the `scan_layers` Parameter:**
93+
> The `scan_layers` setting during your fine-tuning run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
94+
> * If the checkpoint was converted or saved with `scan_layers=False` (which is common for Hugging Face conversions and inference-ready models), you **must also provide `scan_layers=False` in the MaxText command.**
95+
> * If `scan_layers` does not match, MaxText will raise a `ValueError`.
96+
> See the [Checkpoints concept guide](../../reference/core_concepts/checkpoints.md) for more details.
97+
9198
## Run SFT on Hugging Face Dataset
9299

93100
Now you are ready to run SFT using the following command:

docs/tutorials/posttraining/sft_on_multi_host.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solution
139139
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # gs://my-bucket/my-checkpoint-directory/0/items
140140
```
141141

142+
> [!IMPORTANT]
143+
> **Matching the `scan_layers` Parameter:**
144+
> The `scan_layers` setting during your fine-tuning run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
145+
> * If the checkpoint was converted or saved with `scan_layers=False` (which is common for Hugging Face conversions and inference-ready models), you **must also provide `scan_layers=False` in the MaxText command.**
146+
> * If `scan_layers` does not match, MaxText will raise a `ValueError`.
147+
> See the [Checkpoints concept guide](../../reference/core_concepts/checkpoints.md) for more details.
148+
142149
## Submit workload on GKE cluster
143150

144151
This section provides the command to run SFT on a GKE cluster.

src/maxtext/common/checkpointing.py

Lines changed: 129 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Any, Optional
1919

2020
from absl import flags
21+
import contextlib
2122
import datetime
2223
from etils import epath
2324
from flax import nnx
@@ -639,12 +640,14 @@ def _restore_grain_iterator(
639640
if isinstance(data_iterator, RemoteIteratorWrapper):
640641
grain_restore_args = GrainCheckpointRestore(item=data_iterator)
641642
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args))
643+
_assert_no_shaped_dtype_struct(restored_state)
642644
return (restored_state, None)
643645

644646
# ElasticIterator: one shared `process_0.json` regardless of shard count.
645647
if not isinstance(data_iterator, list) and isinstance(data_iterator.local_iterator, ElasticIterator):
646648
grain_restore_args = GrainCheckpointRestore(item=data_iterator.local_iterator)
647649
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args))
650+
_assert_no_shaped_dtype_struct(restored_state)
648651
return (restored_state, None)
649652

650653
directory = checkpoint_manager.directory / str(step) / "iter"
@@ -693,9 +696,68 @@ def _restore_grain_iterator(
693696

694697
# Call restore once with the composed arguments
695698
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args))
699+
_assert_no_shaped_dtype_struct(restored_state)
696700
return (restored_state, None)
697701

698702

703+
def _is_structural_or_shape_mismatch(e: Exception) -> bool:
704+
"""Helper to check if an exception is likely a PyTree structure or shape mismatch."""
705+
if not isinstance(e, (ValueError, TypeError)):
706+
return False
707+
msg = str(e).lower()
708+
mismatch_keywords = [
709+
"mismatch",
710+
"structure",
711+
"shape",
712+
"tree",
713+
"leaf",
714+
"leaves",
715+
"paths matched",
716+
"shapedtypestruct",
717+
"invalid type",
718+
]
719+
return any(kw in msg for kw in mismatch_keywords)
720+
721+
722+
def _assert_no_shaped_dtype_struct(pytree):
723+
"""Asserts that there are no jax.ShapeDtypeStruct leaves in the restored pytree."""
724+
if isinstance(pytree, jax.ShapeDtypeStruct):
725+
raise ValueError(
726+
f"Some parameters in the restored state remained as ShapeDtypeStruct: {pytree}. "
727+
"This indicates a structural mismatch between the checkpoint and the model configuration. "
728+
"Usually this is due to 'scan_layers' configuration mismatch."
729+
)
730+
731+
if hasattr(pytree, "keys") and hasattr(pytree, "__getitem__"):
732+
for k in pytree.keys():
733+
_assert_no_shaped_dtype_struct(pytree[k])
734+
elif isinstance(pytree, (list, tuple)):
735+
for v in pytree:
736+
_assert_no_shaped_dtype_struct(v)
737+
else:
738+
leaves = jax.tree_util.tree_leaves(pytree)
739+
if len(leaves) == 1 and leaves[0] is pytree:
740+
return
741+
for leaf in leaves:
742+
_assert_no_shaped_dtype_struct(leaf)
743+
744+
745+
@contextlib.contextmanager
746+
def _handle_checkpoint_mismatch(context_name: str, path: str):
747+
"""Context manager to intercept PyTree/shape mismatches and raise descriptive errors."""
748+
try:
749+
yield
750+
except Exception as e:
751+
if _is_structural_or_shape_mismatch(e):
752+
raise ValueError(
753+
f"Failed to {context_name} from {path}. "
754+
"This is often caused by a mismatch in the 'scan_layers' configuration "
755+
"(stacked vs unstacked) between your current execution command and "
756+
f"the saved checkpoint. Original error: {e}"
757+
) from e
758+
raise
759+
760+
699761
def load_state_if_possible(
700762
checkpoint_manager: CheckpointManager | None,
701763
data_iterator: MultiHostDataLoadIterator | list[MultiHostDataLoadIterator] | None,
@@ -777,13 +839,15 @@ def map_to_pspec(data):
777839
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
778840
):
779841
checkpoint_path = str(checkpoint_manager.directory / str(step) / "items")
780-
restored_nnx = _load_linen_checkpoint_into_nnx(
781-
checkpoint_path,
782-
abstract_unboxed_pre_state,
783-
checkpoint_storage_concurrent_gb,
784-
use_ocdbt,
785-
use_zarr3,
786-
)
842+
with _handle_checkpoint_mismatch("restore NNX checkpoint", checkpoint_path):
843+
restored_nnx = _load_linen_checkpoint_into_nnx(
844+
checkpoint_path,
845+
abstract_unboxed_pre_state,
846+
checkpoint_storage_concurrent_gb,
847+
use_ocdbt,
848+
use_zarr3,
849+
)
850+
_assert_no_shaped_dtype_struct(restored_nnx)
787851
return ({"items": restored_nnx}, None)
788852

789853
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
@@ -798,64 +862,74 @@ def map_to_pspec(data):
798862
partial_restore=True,
799863
)
800864

801-
match (checkpoint_manager, dataset_type, data_iterator):
802-
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
803-
# or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and
804-
# 'data_iterator' can be any value and aren't used in this pattern.
805-
case (checkpoint_manager, _, _) if isinstance(
806-
checkpoint_manager,
807-
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
808-
):
809-
return (
810-
checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state,
811-
None,
812-
)
813-
# Case 2: Matches if dataset type is "grain" and the data iterator is not a
814-
# PlaceHolderDataIterator and a specific checkpoint file exists for the iterator
815-
case (
816-
checkpoint_manager,
817-
dataset_type,
818-
data_iterator,
819-
) if (
820-
dataset_type == "grain"
821-
and data_iterator
822-
and not isinstance(data_iterator, PlaceHolderDataIterator)
823-
and (checkpoint_manager.directory / str(step) / "iter").exists()
824-
):
825-
return _restore_grain_iterator(
826-
checkpoint_manager, step, data_iterator, checkpoint_args, expansion_factor_real_data
827-
)
828-
# Case 3: Default/Fallback case.
829-
# This case acts as a wildcard ('_') and matches if none of the preceding cases were met.
830-
case _:
831-
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)
865+
checkpoint_path = str(checkpoint_manager.directory / str(step))
866+
with _handle_checkpoint_mismatch("restore checkpoint", checkpoint_path):
867+
match (checkpoint_manager, dataset_type, data_iterator):
868+
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
869+
# or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and
870+
# 'data_iterator' can be any value and aren't used in this pattern.
871+
case (checkpoint_manager, _, _) if isinstance(
872+
checkpoint_manager,
873+
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
874+
):
875+
restored = checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state
876+
_assert_no_shaped_dtype_struct(restored)
877+
return (
878+
restored,
879+
None,
880+
)
881+
# Case 2: Matches if dataset type is "grain" and the data iterator is not a
882+
# PlaceHolderDataIterator and a specific checkpoint file exists for the iterator
883+
case (
884+
checkpoint_manager,
885+
dataset_type,
886+
data_iterator,
887+
) if (
888+
dataset_type == "grain"
889+
and data_iterator
890+
and not isinstance(data_iterator, PlaceHolderDataIterator)
891+
and (checkpoint_manager.directory / str(step) / "iter").exists()
892+
):
893+
return _restore_grain_iterator(
894+
checkpoint_manager, step, data_iterator, checkpoint_args, expansion_factor_real_data
895+
)
896+
# Case 3: Default/Fallback case.
897+
# This case acts as a wildcard ('_') and matches if none of the preceding cases were met.
898+
case _:
899+
restored = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args))
900+
_assert_no_shaped_dtype_struct(restored)
901+
return (restored, None)
832902

833903
if load_parameters_from_path != "":
834904
if isinstance(abstract_unboxed_pre_state, nnx.State):
835905
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
836906
else:
837907
params = abstract_unboxed_pre_state.params
838908

839-
restored_params = load_params_from_path(
840-
load_parameters_from_path,
841-
params,
842-
checkpoint_storage_concurrent_gb,
843-
use_ocdbt=use_ocdbt,
844-
use_zarr3=use_zarr3,
845-
)
909+
with _handle_checkpoint_mismatch("load parameters", load_parameters_from_path):
910+
restored_params = load_params_from_path(
911+
load_parameters_from_path,
912+
params,
913+
checkpoint_storage_concurrent_gb,
914+
use_ocdbt=use_ocdbt,
915+
use_zarr3=use_zarr3,
916+
)
917+
_assert_no_shaped_dtype_struct(restored_params)
846918
return None, restored_params
847919
elif load_full_state_from_path != "":
848920
max_logging.log(f"Loading full state from path: {load_full_state_from_path}")
849-
restored_state = _load_full_state_from_path(
850-
path=load_full_state_from_path,
851-
abstract_unboxed_pre_state=abstract_unboxed_pre_state,
852-
enable_orbax_v1=enable_orbax_v1,
853-
checkpoint_conversion_fn=checkpoint_conversion_fn,
854-
source_checkpoint_layout=source_checkpoint_layout,
855-
checkpoint_storage_concurrent_gb=checkpoint_storage_concurrent_gb,
856-
use_ocdbt=use_ocdbt,
857-
use_zarr3=use_zarr3,
858-
)
921+
with _handle_checkpoint_mismatch("load full state", load_full_state_from_path):
922+
restored_state = _load_full_state_from_path(
923+
path=load_full_state_from_path,
924+
abstract_unboxed_pre_state=abstract_unboxed_pre_state,
925+
enable_orbax_v1=enable_orbax_v1,
926+
checkpoint_conversion_fn=checkpoint_conversion_fn,
927+
source_checkpoint_layout=source_checkpoint_layout,
928+
checkpoint_storage_concurrent_gb=checkpoint_storage_concurrent_gb,
929+
use_ocdbt=use_ocdbt,
930+
use_zarr3=use_zarr3,
931+
)
932+
_assert_no_shaped_dtype_struct(restored_state)
859933
return {"items": restored_state}, None
860934
else:
861935
max_logging.log("No existing checkpoints found, not restoring checkpoint.")

src/maxtext/utils/model_creation_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,14 @@ def _walk_align(ckpt, model_arr, axes):
10601060
)
10611061

10621062
except Exception as e:
1063+
from maxtext.common.checkpointing import _is_structural_or_shape_mismatch
1064+
if _is_structural_or_shape_mismatch(e):
1065+
raise ValueError(
1066+
f"Checkpoint loading failed from '{config.load_parameters_path}'. "
1067+
"This is often caused by a mismatch in the 'scan_layers' configuration "
1068+
"(stacked vs unstacked) between your current execution command and "
1069+
f"the saved checkpoint. Original error: {e}"
1070+
) from e
10631071
raise ValueError(f"Checkpoint loading failed: {e}") from e
10641072

10651073
if wrap_with_tunix_adapter:

0 commit comments

Comments
 (0)