Skip to content

Commit f08effb

Browse files
ninatumartinarroyo
andcommitted
Wan training: Change checkpoint restore_type to jax.Array
- Address review comments: change the restore_type from np.ndarray to jax.Array. This is necessary because JAX sharding has no effect on np.ndarray, using jax.Array ensures that the specified sharding is respected during checkpoint restoration. Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent 19b875b commit f08effb

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def add_sharding_to_struct(leaf_struct, sharding):
6363

6464
params_restore = ocp.args.PyTreeRestore(
6565
restore_args=jax.tree.map(
66-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
66+
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
6767
abstract_train_state_with_sharding,
6868
)
6969
)

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def add_sharding_to_struct(leaf_struct, sharding):
6363

6464
params_restore = ocp.args.PyTreeRestore(
6565
restore_args=jax.tree.map(
66-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
66+
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
6767
abstract_train_state_with_sharding,
6868
)
6969
)

0 commit comments

Comments
 (0)