Skip to content

Commit 20430ab

Browse files
committed
Merge remote-tracking branch 'origin/main' into fixbiassharding
2 parents 7a6ab88 + 384d211 commit 20430ab

38 files changed

Lines changed: 2461 additions & 403 deletions

Whitespace-only changes.

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ __pycache__/
66
*$py.class
77
# C extensions
88
*.so
9+
Gemini.md
910

1011
# tests and logs
1112
tests/fixtures/cached_*_text.txt

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
22
absl-py
3+
accelerate
34
aqtp
45
chex
56
datasets
@@ -14,6 +15,7 @@ imageio-ffmpeg
1415
imageio
1516
jax
1617
jaxlib
18+
jaxopt
1719
Jinja2
1820
opencv-python-headless
1921
optax

dependencies/requirements/generated_requirements/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# If you need to modify dependencies, please do so in the host requirements file and run seed-env again.
33

44
absl-py>=2.3.1
5+
accelerate>=1.13.0
56
aiofiles>=25.1.0
67
aiohappyeyeballs>=2.6.1
78
aiohttp>=3.13.3
@@ -80,6 +81,7 @@ isort>=8.0.1
8081
jaraco-functools>=4.4.0
8182
jax>=0.9.0
8283
jaxlib>=0.9.0
84+
jaxopt>=0.8.5
8385
jaxtyping>=0.3.9
8486
jinja2>=3.1.6
8587
keras>=3.13.1

src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""
1616

1717
import json
18-
import jax
19-
import numpy as np
2018
from typing import Optional, Tuple
21-
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
22-
from .. import max_logging
23-
import orbax.checkpoint as ocp
2419
from etils import epath
20+
import jax
21+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2522
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
23+
import numpy as np
24+
import orbax.checkpoint as ocp
25+
from .. import max_logging
26+
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
2627

2728

2829
class WanCheckpointer2_1(WanCheckpointer):
@@ -35,13 +36,29 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3536
max_logging.log("No WAN checkpoint found.")
3637
return None, None
3738
max_logging.log(f"Loading WAN checkpoint from step {step}")
39+
40+
cpu_devices = np.array(jax.devices(backend="cpu"))
41+
mesh = Mesh(cpu_devices, axis_names=("data",))
42+
replicated_sharding = NamedSharding(mesh, P())
43+
3844
metadatas = self.checkpoint_manager.item_metadata(step)
39-
transformer_metadata = metadatas.wan_state
40-
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
45+
state = metadatas.wan_state
46+
47+
def add_sharding_to_struct(leaf_struct, sharding):
48+
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49+
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50+
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
51+
return struct
52+
53+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
54+
55+
with mesh:
56+
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)
57+
4158
params_restore = ocp.args.PyTreeRestore(
4259
restore_args=jax.tree.map(
43-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
44-
abstract_tree_structure_params,
60+
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
61+
abstract_train_state_with_sharding,
4562
)
4663
)
4764

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""
1616

1717
import json
18-
import jax
19-
import numpy as np
2018
from typing import Optional, Tuple
21-
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
22-
from .. import max_logging
23-
import orbax.checkpoint as ocp
2419
from etils import epath
20+
import jax
21+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2522
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
23+
import numpy as np
24+
import orbax.checkpoint as ocp
25+
from .. import max_logging
26+
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
2627

2728

2829
class WanCheckpointerI2V_2_1(WanCheckpointer):
@@ -35,13 +36,29 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3536
max_logging.log("No WAN checkpoint found.")
3637
return None, None
3738
max_logging.log(f"Loading WAN checkpoint from step {step}")
39+
40+
cpu_devices = np.array(jax.devices(backend="cpu"))
41+
mesh = Mesh(cpu_devices, axis_names=("data",))
42+
replicated_sharding = NamedSharding(mesh, P())
43+
3844
metadatas = self.checkpoint_manager.item_metadata(step)
39-
transformer_metadata = metadatas.wan_state
40-
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
45+
state = metadatas.wan_state
46+
47+
def add_sharding_to_struct(leaf_struct, sharding):
48+
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49+
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50+
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
51+
return struct
52+
53+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
54+
55+
with mesh:
56+
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)
57+
4158
params_restore = ocp.args.PyTreeRestore(
4259
restore_args=jax.tree.map(
43-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
44-
abstract_tree_structure_params,
60+
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
61+
abstract_train_state_with_sharding,
4562
)
4663
)
4764

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Copyright 2025 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
import json
17+
from typing import Optional, Tuple
18+
import jax
19+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
20+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
21+
import numpy as np
22+
import orbax.checkpoint as ocp
23+
from .. import max_logging
24+
from ..pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1
25+
26+
27+
class WanVaceCheckpointer2_1(WanCheckpointer):
28+
29+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
30+
if step is None:
31+
step = self.checkpoint_manager.latest_step()
32+
max_logging.log(f"Latest WAN checkpoint step: {step}")
33+
if step is None:
34+
max_logging.log("No WAN checkpoint found.")
35+
return None, None
36+
max_logging.log(f"Loading WAN checkpoint from step {step}")
37+
38+
cpu_devices = np.array(jax.devices(backend="cpu"))
39+
mesh = Mesh(cpu_devices, axis_names=("data",))
40+
replicated_sharding = NamedSharding(mesh, P())
41+
42+
metadatas = self.checkpoint_manager.item_metadata(step)
43+
state = metadatas.wan_state
44+
45+
def add_sharding_to_struct(leaf_struct, sharding):
46+
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
47+
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
48+
return jax.ShapeDtypeStruct(shape=struct.shape, dtype=struct.dtype, sharding=sharding)
49+
return struct
50+
51+
target_shardings = jax.tree_util.tree_map(lambda x: replicated_sharding, state)
52+
53+
with mesh:
54+
abstract_train_state_with_sharding = jax.tree_util.tree_map(add_sharding_to_struct, state, target_shardings)
55+
56+
max_logging.log("Restoring WAN checkpoint")
57+
restored_checkpoint = self.checkpoint_manager.restore(
58+
step=step,
59+
args=ocp.args.Composite(
60+
wan_config=ocp.args.JsonRestore(),
61+
wan_state=ocp.args.StandardRestore(abstract_train_state_with_sharding),
62+
),
63+
)
64+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
65+
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
66+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
67+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
68+
return restored_checkpoint, step
69+
70+
def load_diffusers_checkpoint(self):
71+
pipeline = VaceWanPipeline2_1.from_pretrained(self.config)
72+
return pipeline
73+
74+
def load_checkpoint(self, step=None) -> Tuple[VaceWanPipeline2_1, Optional[dict], Optional[int]]:
75+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
76+
opt_state = None
77+
if restored_checkpoint:
78+
max_logging.log("Loading WAN pipeline from checkpoint")
79+
pipeline = VaceWanPipeline2_1.from_checkpoint(self.config, restored_checkpoint)
80+
if "opt_state" in restored_checkpoint.wan_state.keys():
81+
opt_state = restored_checkpoint.wan_state["opt_state"]
82+
else:
83+
max_logging.log("No checkpoint found, loading default pipeline.")
84+
pipeline = self.load_diffusers_checkpoint()
85+
86+
return pipeline, opt_state, step
87+
88+
def save_checkpoint(self, train_step, pipeline: VaceWanPipeline2_1, train_states: dict):
89+
"""Saves the training state and model configurations."""
90+
91+
def config_to_json(model_or_config):
92+
return json.loads(model_or_config.to_json_string())
93+
94+
max_logging.log(f"Saving checkpoint for step {train_step}")
95+
96+
# Save the checkpoint
97+
self.checkpoint_manager.save(
98+
train_step,
99+
args=ocp.args.Composite(
100+
wan_config=ocp.args.JsonSave(config_to_json(pipeline.transformer)),
101+
wan_state=ocp.args.StandardSave(train_states),
102+
),
103+
)
104+
105+
max_logging.log(f"Checkpoint for step {train_step} is saved.")

src/maxdiffusion/configs/base14.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
206206
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
207207
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
208208
adam_weight_decay: 1.e-2 # AdamW Weight decay
209+
opt_enable_grad_clipping: False
210+
max_grad_value: 1.0
211+
opt_enable_grad_global_norm_clipping: False
209212
max_grad_norm: 1.0
210213

211214
enable_profiler: False

src/maxdiffusion/configs/base21.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
211211
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
212212
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
213213
adam_weight_decay: 1.e-2 # AdamW Weight decay
214+
opt_enable_grad_clipping: False
215+
max_grad_value: 1.0
216+
opt_enable_grad_global_norm_clipping: False
214217
max_grad_norm: 1.0
215218

216219
enable_profiler: False

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
221221
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
222222
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
223223
adam_weight_decay: 1.e-2 # AdamW Weight decay
224+
opt_enable_grad_clipping: False
225+
max_grad_value: 1.0
226+
opt_enable_grad_global_norm_clipping: False
224227
max_grad_norm: 1.0
225228

226229
enable_profiler: False

0 commit comments

Comments
 (0)