Skip to content

Commit 9af3c53

Browse files
committed
nice error
1 parent e1d2e70 commit 9af3c53

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

src/maxtext/utils/model_creation_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,15 @@ def create_sharded_state():
345345
else:
346346
checkpoint = restored["params"]["params"]
347347

348+
loaded_count = len(jax.tree_util.tree_leaves(checkpoint))
349+
expected_count = len(jax.tree_util.tree_leaves(target_for_restore))
350+
if loaded_count < expected_count:
351+
raise ValueError(
352+
f"Checkpoint at '{config.load_parameters_path}' loaded only {loaded_count} of {expected_count} "
353+
"expected parameter arrays. This usually means a scanned (stacked-layers) checkpoint was provided "
354+
"where an unscanned checkpoint is required. Please convert the checkpoint to unscanned format first."
355+
)
356+
348357
if checkpoint:
349358
model_arrays = jax.tree.map(
350359
lambda v: v.value,

0 commit comments

Comments
 (0)