Skip to content

Commit 19b875b

Browse files
ninatumartinarroyo
andcommitted
Wan training: Add checks for 'shape' and 'dtype' attributes during checkpoint loading.
Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent dec1690 commit 19b875b

3 files changed

Lines changed: 18 additions & 9 deletions

File tree

src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,12 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
4545
state = metadatas.wan_state
4646

4747
def add_sharding_to_struct(leaf_struct, sharding):
48-
return jax.ShapeDtypeStruct(
49-
shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding
50-
)
48+
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49+
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50+
return jax.ShapeDtypeStruct(
51+
shape=struct.shape, dtype=struct.dtype, sharding=sharding
52+
)
53+
return struct
5154

5255
target_shardings = jax.tree_util.tree_map(
5356
lambda x: replicated_sharding, state

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,12 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
4545
state = metadatas.wan_state
4646

4747
def add_sharding_to_struct(leaf_struct, sharding):
48-
return jax.ShapeDtypeStruct(
49-
shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding
50-
)
48+
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49+
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50+
return jax.ShapeDtypeStruct(
51+
shape=struct.shape, dtype=struct.dtype, sharding=sharding
52+
)
53+
return struct
5154

5255
target_shardings = jax.tree_util.tree_map(
5356
lambda x: replicated_sharding, state

src/maxdiffusion/checkpointing/wan_vace_checkpointer_2_1.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,12 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
4343
state = metadatas.wan_state
4444

4545
def add_sharding_to_struct(leaf_struct, sharding):
46-
return jax.ShapeDtypeStruct(
47-
shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding
48-
)
46+
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
47+
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
48+
return jax.ShapeDtypeStruct(
49+
shape=struct.shape, dtype=struct.dtype, sharding=sharding
50+
)
51+
return struct
4952

5053
target_shardings = jax.tree_util.tree_map(
5154
lambda x: replicated_sharding, state

0 commit comments

Comments
 (0)