Skip to content

Commit 7d8fdfb

Browse files
ninatumartinarroyo
andcommitted
Add WAN-VACE training functionality
Introduces training support for WAN-VACE models. New files: - train_wan_vace.py: Main training script. - wan_vace_trainer.py: Trainer class for WAN-VACE. - wan_vace_checkpointing_2_1.py: Checkpointing logic for WAN-VACE. Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent 28fbabb commit 7d8fdfb

4 files changed

Lines changed: 522 additions & 9 deletions

File tree

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""Copyright 2025 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
import json
17+
from typing import Optional, Tuple
18+
import jax
19+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
20+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
21+
import numpy as np
22+
import orbax.checkpoint as ocp
23+
from .. import max_logging
24+
from ..pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1
25+
26+
27+
class WanVaceCheckpointer2_1(WanCheckpointer):
28+
29+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
30+
if step is None:
31+
step = self.checkpoint_manager.latest_step()
32+
max_logging.log(f"Latest WAN checkpoint step: {step}")
33+
if step is None:
34+
max_logging.log("No WAN checkpoint found.")
35+
return None, None
36+
max_logging.log(f"Loading WAN checkpoint from step {step}")
37+
38+
cpu_devices = np.array(jax.devices(backend="cpu"))
39+
mesh = Mesh(cpu_devices, axis_names=("data",))
40+
replicated_sharding = NamedSharding(mesh, P())
41+
42+
metadatas = self.checkpoint_manager.item_metadata(step)
43+
state = metadatas.wan_state
44+
45+
def add_sharding_to_struct(leaf_struct, sharding):
46+
return jax.ShapeDtypeStruct(
47+
shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding
48+
)
49+
50+
target_shardings = jax.tree_util.tree_map(
51+
lambda x: replicated_sharding, state
52+
)
53+
54+
with mesh:
55+
abstract_train_state_with_sharding = jax.tree_util.tree_map(
56+
add_sharding_to_struct, state, target_shardings
57+
)
58+
59+
max_logging.log("Restoring WAN checkpoint")
60+
restored_checkpoint = self.checkpoint_manager.restore(
61+
step=step,
62+
args=ocp.args.Composite(
63+
wan_config=ocp.args.JsonRestore(),
64+
wan_state=ocp.args.StandardRestore(
65+
abstract_train_state_with_sharding
66+
),
67+
),
68+
)
69+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
70+
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
71+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
72+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
73+
return restored_checkpoint, step
74+
75+
def load_diffusers_checkpoint(self):
76+
pipeline = VaceWanPipeline2_1.from_pretrained(self.config)
77+
return pipeline
78+
79+
def load_checkpoint(self, step=None) -> Tuple[VaceWanPipeline2_1, Optional[dict], Optional[int]]:
80+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
81+
opt_state = None
82+
if restored_checkpoint:
83+
max_logging.log("Loading WAN pipeline from checkpoint")
84+
pipeline = VaceWanPipeline2_1.from_checkpoint(self.config, restored_checkpoint)
85+
if "opt_state" in restored_checkpoint.wan_state.keys():
86+
opt_state = restored_checkpoint.wan_state["opt_state"]
87+
else:
88+
max_logging.log("No checkpoint found, loading default pipeline.")
89+
pipeline = self.load_diffusers_checkpoint()
90+
91+
return pipeline, opt_state, step
92+
93+
def save_checkpoint(
94+
self, train_step, pipeline: VaceWanPipeline2_1, train_states: dict
95+
):
96+
"""Saves the training state and model configurations."""
97+
98+
def config_to_json(model_or_config):
99+
return json.loads(model_or_config.to_json_string())
100+
101+
max_logging.log(f"Saving checkpoint for step {train_step}")
102+
103+
# Save the checkpoint
104+
self.checkpoint_manager.save(
105+
train_step,
106+
args=ocp.args.Composite(
107+
wan_config=ocp.args.JsonSave(config_to_json(pipeline.transformer)),
108+
wan_state=ocp.args.StandardSave(train_states),
109+
),
110+
)
111+
112+
max_logging.log(f"Checkpoint for step {train_step} is saved.")

src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,14 @@ def load_transformer(
338338
return wan_transformer
339339

340340
@classmethod
341-
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
341+
def _load_and_init(
342+
cls,
343+
config: HyperParameters,
344+
restored_checkpoint=None,
345+
vae_only=False,
346+
load_transformer=True,
347+
load_common_components=True,
348+
):
342349
devices_array = max_utils.create_device_mesh(config)
343350
mesh = Mesh(devices_array, config.mesh_axes)
344351
rng = jax.random.key(config.seed)
@@ -348,20 +355,31 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
348355
scheduler = None
349356
scheduler_state = None
350357
text_encoder = None
358+
wan_vae = None
359+
vae_cache = None
360+
351361
if not vae_only:
352362
if load_transformer:
353363
with mesh:
354364
transformer = cls.load_transformer(
355-
devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer"
365+
devices_array=devices_array,
366+
mesh=mesh,
367+
rngs=rngs,
368+
config=config,
369+
restored_checkpoint=restored_checkpoint,
370+
subfolder="transformer",
356371
)
372+
if load_common_components:
373+
text_encoder = cls.load_text_encoder(config=config)
374+
tokenizer = cls.load_tokenizer(config=config)
357375

358-
text_encoder = cls.load_text_encoder(config=config)
359-
tokenizer = cls.load_tokenizer(config=config)
360-
361-
scheduler, scheduler_state = cls.load_scheduler(config=config)
376+
scheduler, scheduler_state = cls.load_scheduler(config=config)
362377

363-
with mesh:
364-
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
378+
if load_common_components:
379+
with mesh:
380+
wan_vae, vae_cache = cls.load_vae(
381+
devices_array=devices_array, mesh=mesh, rngs=rngs, config=config
382+
)
365383

366384
pipeline = cls(
367385
tokenizer=tokenizer,
@@ -376,7 +394,43 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
376394
config=config,
377395
)
378396

379-
pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, mesh)
397+
return pipeline
398+
399+
@classmethod
400+
def from_pretrained(
401+
cls,
402+
config: HyperParameters,
403+
vae_only=False,
404+
load_transformer=True,
405+
load_common_components=True,
406+
):
407+
pipeline = cls._load_and_init(
408+
config, None, vae_only, load_transformer, load_common_components
409+
)
410+
pipeline.transformer = cls.quantize_transformer(
411+
config, pipeline.transformer, pipeline, pipeline.mesh
412+
)
413+
return pipeline
414+
415+
@classmethod
416+
def from_checkpoint(
417+
cls,
418+
config: HyperParameters,
419+
restored_checkpoint=None,
420+
vae_only=False,
421+
load_transformer=True,
422+
load_common_components=True,
423+
):
424+
pipeline = cls._load_and_init(
425+
config,
426+
restored_checkpoint,
427+
vae_only,
428+
load_transformer,
429+
load_common_components,
430+
)
431+
pipeline.transformer = cls.quantize_transformer(
432+
config, pipeline.transformer, pipeline, pipeline.mesh
433+
)
380434
return pipeline
381435

382436
def check_inputs(

src/maxdiffusion/train_wan_vace.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Sequence
18+
19+
import jax
20+
from absl import app
21+
from maxdiffusion import max_logging, pyconfig
22+
from maxdiffusion.train_utils import validate_train_config
23+
import flax
24+
25+
26+
def train(config):
27+
from maxdiffusion.trainers.wan_vace_trainer import WanVaceTrainer
28+
29+
trainer = WanVaceTrainer(config)
30+
trainer.start_training()
31+
32+
33+
def main(argv: Sequence[str]) -> None:
34+
pyconfig.initialize(argv, validate_training=True)
35+
config = pyconfig.config
36+
validate_train_config(config)
37+
max_logging.log(f"Found {jax.device_count()} devices.")
38+
try:
39+
flax.config.update("flax_always_shard_variable", False)
40+
except LookupError:
41+
pass
42+
train(config)
43+
44+
45+
if __name__ == "__main__":
46+
app.run(main)

0 commit comments

Comments
 (0)