Skip to content

Commit f9f9506

Browse files
ninatumartinarroyo
andcommitted
Improve Wan model training and add Wan-VACE training support.
Key changes include: 1. Bug fixes: * Resolved training mode bug when dropout > 0 (e.g., ensured rngs parameter is passed to layer_forward for gradient checkpointing with dropout) * Fixed prepare_sample_fn usage for 'tfrecord' dataset type. * Addressed checkpoint loading issues with larger TPU slices and different topologies for Wan 2.1. * Corrected timestep sampling for continuous sampling 2. Config updates: * Ensured adam_weight_decay is a float. * Added tensorboard_dir parameter for logging. * Now uses config.learning_rate instead of a hardcoded value. * Set default dropout to 0.0 in WAN configs (instead of 0.1). 3. Wan-VACE Support: * Refactoring: Common training components (initialization, scheduler, TFLOPs calculation, training/eval loops) have been abstracted into a new BaseWanTrainer ABC to improve code structure and reusability. * Added new scripts (train_wan_vace.py), trainer (wan_vace_trainer.py), and checkpointing logic (wan_vace_checkpointing_2_1.py) to enable training of WAN-VACE models. 4. New Features: * Introduced config.disable_training_weights to optionally disable mid-point loss weighting. * Added logging for max_grad_norm and max_abs_grad. Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent 4e213a1 commit f9f9506

19 files changed

Lines changed: 1147 additions & 402 deletions

src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""
1616

1717
import json
18-
import jax
19-
import numpy as np
2018
from typing import Optional, Tuple
21-
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
22-
from .. import max_logging
23-
import orbax.checkpoint as ocp
2419
from etils import epath
20+
import jax
21+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2522
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
23+
import numpy as np
24+
import orbax.checkpoint as ocp
25+
from .. import max_logging
26+
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
2627

2728

2829
class WanCheckpointer2_1(WanCheckpointer):
@@ -35,13 +36,35 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3536
max_logging.log("No WAN checkpoint found.")
3637
return None, None
3738
max_logging.log(f"Loading WAN checkpoint from step {step}")
39+
40+
cpu_devices = np.array(jax.devices(backend="cpu"))
41+
mesh = Mesh(cpu_devices, axis_names=("data",))
42+
replicated_sharding = NamedSharding(mesh, P())
43+
3844
metadatas = self.checkpoint_manager.item_metadata(step)
39-
transformer_metadata = metadatas.wan_state
40-
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
45+
state = metadatas.wan_state
46+
47+
def add_sharding_to_struct(leaf_struct, sharding):
48+
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49+
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50+
return jax.ShapeDtypeStruct(
51+
shape=struct.shape, dtype=struct.dtype, sharding=sharding
52+
)
53+
return struct
54+
55+
target_shardings = jax.tree_util.tree_map(
56+
lambda x: replicated_sharding, state
57+
)
58+
59+
with mesh:
60+
abstract_train_state_with_sharding = jax.tree_util.tree_map(
61+
add_sharding_to_struct, state, target_shardings
62+
)
63+
4164
params_restore = ocp.args.PyTreeRestore(
4265
restore_args=jax.tree.map(
43-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
44-
abstract_tree_structure_params,
66+
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
67+
abstract_train_state_with_sharding,
4568
)
4669
)
4770

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""
1616

1717
import json
18-
import jax
19-
import numpy as np
2018
from typing import Optional, Tuple
21-
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
22-
from .. import max_logging
23-
import orbax.checkpoint as ocp
2419
from etils import epath
20+
import jax
21+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2522
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
23+
import numpy as np
24+
import orbax.checkpoint as ocp
25+
from .. import max_logging
26+
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
2627

2728

2829
class WanCheckpointerI2V_2_1(WanCheckpointer):
@@ -35,13 +36,35 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3536
max_logging.log("No WAN checkpoint found.")
3637
return None, None
3738
max_logging.log(f"Loading WAN checkpoint from step {step}")
39+
40+
cpu_devices = np.array(jax.devices(backend="cpu"))
41+
mesh = Mesh(cpu_devices, axis_names=("data",))
42+
replicated_sharding = NamedSharding(mesh, P())
43+
3844
metadatas = self.checkpoint_manager.item_metadata(step)
39-
transformer_metadata = metadatas.wan_state
40-
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
45+
state = metadatas.wan_state
46+
47+
def add_sharding_to_struct(leaf_struct, sharding):
48+
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
49+
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
50+
return jax.ShapeDtypeStruct(
51+
shape=struct.shape, dtype=struct.dtype, sharding=sharding
52+
)
53+
return struct
54+
55+
target_shardings = jax.tree_util.tree_map(
56+
lambda x: replicated_sharding, state
57+
)
58+
59+
with mesh:
60+
abstract_train_state_with_sharding = jax.tree_util.tree_map(
61+
add_sharding_to_struct, state, target_shardings
62+
)
63+
4164
params_restore = ocp.args.PyTreeRestore(
4265
restore_args=jax.tree.map(
43-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
44-
abstract_tree_structure_params,
66+
lambda _: ocp.RestoreArgs(restore_type=jax.Array),
67+
abstract_train_state_with_sharding,
4568
)
4669
)
4770

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Copyright 2025 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
import json
17+
from typing import Optional, Tuple
18+
import jax
19+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
20+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
21+
import numpy as np
22+
import orbax.checkpoint as ocp
23+
from .. import max_logging
24+
from ..pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1
25+
26+
27+
class WanVaceCheckpointer2_1(WanCheckpointer):
28+
29+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
30+
if step is None:
31+
step = self.checkpoint_manager.latest_step()
32+
max_logging.log(f"Latest WAN checkpoint step: {step}")
33+
if step is None:
34+
max_logging.log("No WAN checkpoint found.")
35+
return None, None
36+
max_logging.log(f"Loading WAN checkpoint from step {step}")
37+
38+
cpu_devices = np.array(jax.devices(backend="cpu"))
39+
mesh = Mesh(cpu_devices, axis_names=("data",))
40+
replicated_sharding = NamedSharding(mesh, P())
41+
42+
metadatas = self.checkpoint_manager.item_metadata(step)
43+
state = metadatas.wan_state
44+
45+
def add_sharding_to_struct(leaf_struct, sharding):
46+
struct = ocp.utils.to_shape_dtype_struct(leaf_struct)
47+
if hasattr(struct, "shape") and hasattr(struct, "dtype"):
48+
return jax.ShapeDtypeStruct(
49+
shape=struct.shape, dtype=struct.dtype, sharding=sharding
50+
)
51+
return struct
52+
53+
target_shardings = jax.tree_util.tree_map(
54+
lambda x: replicated_sharding, state
55+
)
56+
57+
with mesh:
58+
abstract_train_state_with_sharding = jax.tree_util.tree_map(
59+
add_sharding_to_struct, state, target_shardings
60+
)
61+
62+
max_logging.log("Restoring WAN checkpoint")
63+
restored_checkpoint = self.checkpoint_manager.restore(
64+
step=step,
65+
args=ocp.args.Composite(
66+
wan_config=ocp.args.JsonRestore(),
67+
wan_state=ocp.args.StandardRestore(
68+
abstract_train_state_with_sharding
69+
),
70+
),
71+
)
72+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
73+
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
74+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
75+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
76+
return restored_checkpoint, step
77+
78+
def load_diffusers_checkpoint(self):
79+
pipeline = VaceWanPipeline2_1.from_pretrained(self.config)
80+
return pipeline
81+
82+
def load_checkpoint(self, step=None) -> Tuple[VaceWanPipeline2_1, Optional[dict], Optional[int]]:
83+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
84+
opt_state = None
85+
if restored_checkpoint:
86+
max_logging.log("Loading WAN pipeline from checkpoint")
87+
pipeline = VaceWanPipeline2_1.from_checkpoint(self.config, restored_checkpoint)
88+
if "opt_state" in restored_checkpoint.wan_state.keys():
89+
opt_state = restored_checkpoint.wan_state["opt_state"]
90+
else:
91+
max_logging.log("No checkpoint found, loading default pipeline.")
92+
pipeline = self.load_diffusers_checkpoint()
93+
94+
return pipeline, opt_state, step
95+
96+
def save_checkpoint(
97+
self, train_step, pipeline: VaceWanPipeline2_1, train_states: dict
98+
):
99+
"""Saves the training state and model configurations."""
100+
101+
def config_to_json(model_or_config):
102+
return json.loads(model_or_config.to_json_string())
103+
104+
max_logging.log(f"Saving checkpoint for step {train_step}")
105+
106+
# Save the checkpoint
107+
self.checkpoint_manager.save(
108+
train_step,
109+
args=ocp.args.Composite(
110+
wan_config=ocp.args.JsonSave(config_to_json(pipeline.transformer)),
111+
wan_state=ocp.args.StandardSave(train_states),
112+
),
113+
)
114+
115+
max_logging.log(f"Checkpoint for step {train_step} is saved.")

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ mask_padding_tokens: True
7272
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
7373
# in cross attention q.
7474
attention_sharding_uniform: True
75-
dropout: 0.1
75+
dropout: 0.0
7676

7777
flash_block_sizes: {
7878
"block_q" : 512,
@@ -281,6 +281,7 @@ output_dir: 'sdxl-model-finetuned'
281281
per_device_batch_size: 1.0
282282
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
283283
global_batch_size: 0
284+
disable_training_weights: False # if True, disables the use of mid-point loss weighting
284285

285286
# For creating tfrecords from dataset
286287
tfrecords_dir: ''
@@ -300,7 +301,7 @@ save_optimizer: False
300301
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
301302
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
302303
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
303-
adam_weight_decay: 0 # AdamW Weight decay
304+
adam_weight_decay: 0.0 # AdamW Weight decay
304305
max_grad_norm: 1.0
305306

306307
enable_profiler: False

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ mask_padding_tokens: True
7272
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
7373
# in cross attention q.
7474
attention_sharding_uniform: True
75-
dropout: 0.1
75+
dropout: 0.0
7676

7777
flash_block_sizes: {
7878
"block_q" : 512,
@@ -237,6 +237,7 @@ output_dir: 'sdxl-model-finetuned'
237237
per_device_batch_size: 1.0
238238
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
239239
global_batch_size: 0
240+
disable_training_weights: False # if True, disables the use of mid-point loss weighting
240241

241242
# For creating tfrecords from dataset
242243
tfrecords_dir: ''
@@ -256,7 +257,7 @@ save_optimizer: False
256257
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
257258
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
258259
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
259-
adam_weight_decay: 0 # AdamW Weight decay
260+
adam_weight_decay: 0.0 # AdamW Weight decay
260261
max_grad_norm: 1.0
261262

262263
enable_profiler: False

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ mask_padding_tokens: True
7171
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
7272
# in cross attention q.
7373
attention_sharding_uniform: True
74-
dropout: 0.1
74+
dropout: 0.0
7575

7676
flash_block_sizes: {
7777
"block_q" : 512,
@@ -267,7 +267,7 @@ save_optimizer: False
267267
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
268268
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
269269
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
270-
adam_weight_decay: 0 # AdamW Weight decay
270+
adam_weight_decay: 0.0 # AdamW Weight decay
271271
max_grad_norm: 1.0
272272

273273
enable_profiler: False

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6464
flash_min_seq_length: 4096
65-
dropout: 0.1
65+
dropout: 0.0
6666

6767
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
6868
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
@@ -262,7 +262,7 @@ save_optimizer: False
262262
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
263263
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
264264
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
265-
adam_weight_decay: 0 # AdamW Weight decay
265+
adam_weight_decay: 0.0 # AdamW Weight decay
266266
max_grad_norm: 1.0
267267

268268
enable_profiler: False

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6464
flash_min_seq_length: 4096
65-
dropout: 0.1
65+
dropout: 0.0
6666

6767
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
6868
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
@@ -263,7 +263,7 @@ save_optimizer: False
263263
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
264264
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
265265
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
266-
adam_weight_decay: 0 # AdamW Weight decay
266+
adam_weight_decay: 0.0 # AdamW Weight decay
267267
max_grad_norm: 1.0
268268

269269
enable_profiler: False

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,11 @@ def _make_tfrecord_iterator(
113113
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
114114
}
115115

116-
used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description
116+
used_feature_description = (
117+
feature_description_fn
118+
if (make_cached_tfrecord_iterator or config.dataset_type == "tfrecord")
119+
else feature_description
120+
)
117121

118122
def _parse_tfrecord_fn(example):
119123
return tf.io.parse_single_example(example, used_feature_description)
@@ -141,7 +145,11 @@ def prepare_sample(features):
141145
ds = ds.concatenate(padding_ds)
142146
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")
143147

144-
used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
148+
used_prepare_sample = (
149+
prepare_sample_fn
150+
if (make_cached_tfrecord_iterator or config.dataset_type == "tfrecord")
151+
else prepare_sample
152+
)
145153
ds = (
146154
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
147155
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1263,7 +1263,10 @@ def __call__(
12631263

12641264
with jax.named_scope("proj_attn"):
12651265
hidden_states = self.proj_attn(attn_output)
1266-
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
1266+
if self.drop_out.rate > 0:
1267+
hidden_states = self.drop_out(
1268+
hidden_states, deterministic=deterministic, rngs=rngs
1269+
)
12671270
return hidden_states
12681271

12691272

0 commit comments

Comments
 (0)