We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e1d2e70 commit 9af3c53Copy full SHA for 9af3c53
1 file changed
src/maxtext/utils/model_creation_utils.py
@@ -345,6 +345,15 @@ def create_sharded_state():
345
else:
346
checkpoint = restored["params"]["params"]
347
348
+ loaded_count = len(jax.tree_util.tree_leaves(checkpoint))
349
+ expected_count = len(jax.tree_util.tree_leaves(target_for_restore))
350
+ if loaded_count < expected_count:
351
+ raise ValueError(
352
+ f"Checkpoint at '{config.load_parameters_path}' loaded only {loaded_count} of {expected_count} "
353
+ "expected parameter arrays. This usually means a scanned (stacked-layers) checkpoint was provided "
354
+ "where an unscanned checkpoint is required. Please convert the checkpoint to unscanned format first."
355
+ )
356
+
357
if checkpoint:
358
model_arrays = jax.tree.map(
359
lambda v: v.value,
0 commit comments