Skip to content

Commit b161ce3

Browse files
authored
Wan inference and dummy training code. (#185)
- Implements inference for Wan 2.1. - Some training code is there, but its very WIP and not fully set up.
1 parent 3d5ef04 commit b161ce3

20 files changed

Lines changed: 3727 additions & 1259 deletions
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from abc import ABC
18+
from flax import nnx
19+
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
20+
from ..pipelines.wan.wan_pipeline import WanPipeline
21+
from .. import max_logging, max_utils
22+
23+
WAN_CHECKPOINT = "WAN_CHECKPOINT"
24+
25+
26+
class WanCheckpointer(ABC):
27+
28+
def __init__(self, config, checkpoint_type):
29+
self.config = config
30+
self.checkpoint_type = checkpoint_type
31+
32+
self.checkpoint_manager = create_orbax_checkpoint_manager(
33+
self.config.checkpoint_dir,
34+
enable_checkpointing=True,
35+
save_interval_steps=1,
36+
checkpoint_type=checkpoint_type,
37+
dataset_type=config.dataset_type,
38+
)
39+
40+
def _create_optimizer(self, model, config, learning_rate):
41+
learning_rate_scheduler = max_utils.create_learning_rate_schedule(
42+
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
43+
)
44+
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
45+
return nnx.Optimizer(model, tx), learning_rate_scheduler
46+
47+
def load_wan_configs_from_orbax(self, step):
48+
max_logging.log("Restoring stable diffusion configs")
49+
if step is None:
50+
step = self.checkpoint_manager.latest_step()
51+
if step is None:
52+
return None
53+
54+
def load_diffusers_checkpoint(self):
55+
pipeline = WanPipeline.from_pretrained(self.config)
56+
return pipeline
57+
58+
def load_checkpoint(self, step=None):
59+
model_configs = self.load_wan_configs_from_orbax(step)
60+
61+
if model_configs:
62+
raise NotImplementedError("model configs should not exist in orbax")
63+
else:
64+
pipeline = self.load_diffusers_checkpoint()
65+
66+
return pipeline
Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,19 @@ run_name: ''
1818
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
21+
22+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
23+
write_timing_metrics: True
24+
2125
gcs_metrics: False
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False
2428
log_period: 100
2529

2630
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
2731

28-
# Flux params
29-
flux_name: "flux-dev"
30-
max_sequence_length: 512
31-
time_shift: True
32-
base_shift: 0.5
33-
max_shift: 1.15
34-
# offloads t5 encoder after text encoding to save memory.
35-
offload_encoders: True
36-
37-
3832
unet_checkpoint: ''
39-
revision: 'refs/pr/95'
33+
revision: ''
4034
# This will convert the weights to this dtype.
4135
# When running inference on TPUv5e, use weights_dtype: 'bfloat16'
4236
weights_dtype: 'bfloat16'
@@ -59,24 +53,9 @@ split_head_dim: True
5953
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
6054

6155
flash_block_sizes: {}
62-
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
63-
# flash_block_sizes: {
64-
# "block_q" : 1536,
65-
# "block_kv_compute" : 1536,
66-
# "block_kv" : 1536,
67-
# "block_q_dkv" : 1536,
68-
# "block_kv_dkv" : 1536,
69-
# "block_kv_dkv_compute" : 1536,
70-
# "block_q_dq" : 1536,
71-
# "block_kv_dq" : 1536
72-
# }
7356
# GroupNorm groups
7457
norm_num_groups: 32
7558

76-
# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch
77-
# else they will be loaded from pretrained_model_name_or_path
78-
train_new_unet: False
79-
8059
# train text_encoder - Currently not supported for SDXL
8160
train_text_encoder: False
8261
text_encoder_learning_rate: 4.25e-6
@@ -133,15 +112,17 @@ mesh_axes: ['data', 'fsdp', 'tensor']
133112
# conv_out : conv.shape[-1] weight
134113
logical_axis_rules: [
135114
['batch', 'data'],
115+
['activation_heads', 'fsdp'],
136116
['activation_batch', ['data','fsdp']],
137-
['activation_heads', 'tensor'],
138117
['activation_kv', 'tensor'],
139118
['mlp','tensor'],
140119
['embed','fsdp'],
141120
['heads', 'tensor'],
121+
['norm', 'fsdp'],
142122
['conv_batch', ['data','fsdp']],
143123
['out_channels', 'tensor'],
144124
['conv_out', 'fsdp'],
125+
['conv_in', 'fsdp']
145126
]
146127
data_sharding: [['data', 'fsdp', 'tensor']]
147128

@@ -152,8 +133,8 @@ data_sharding: [['data', 'fsdp', 'tensor']]
152133
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
153134
dcn_fsdp_parallelism: -1
154135
dcn_tensor_parallelism: 1
155-
ici_data_parallelism: -1
156-
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
136+
ici_data_parallelism: 1
137+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
157138
ici_tensor_parallelism: 1
158139

159140
# Dataset
@@ -192,17 +173,19 @@ checkpoint_every: -1
192173
enable_single_replica_ckpt_restoring: False
193174

194175
# Training loop
195-
learning_rate: 4.e-7
176+
learning_rate: 1.e-5
196177
scale_lr: False
197178
max_train_samples: -1
198179
# max_train_steps takes priority over num_train_epochs.
199-
max_train_steps: 200
180+
max_train_steps: 1500
200181
num_train_epochs: 1
201182
seed: 0
202183
output_dir: 'sdxl-model-finetuned'
203184
per_device_batch_size: 1
185+
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
186+
global_batch_size: 0
204187

205-
warmup_steps_fraction: 0.0
188+
warmup_steps_fraction: 0.1
206189
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
207190

208191
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
@@ -212,7 +195,7 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set
212195
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
213196
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
214197
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
215-
adam_weight_decay: 1.e-2 # AdamW Weight decay
198+
adam_weight_decay: 0 # AdamW Weight decay
216199
max_grad_norm: 1.0
217200

218201
enable_profiler: False
@@ -222,14 +205,25 @@ skip_first_n_steps_for_profiler: 5
222205
profiler_steps: 10
223206

224207
# Generation parameters
225-
prompt: "A magical castle in the middle of a forest, artistic drawing"
226-
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
227-
negative_prompt: "purple, red"
208+
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
209+
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
210+
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
228211
do_classifier_free_guidance: True
229-
guidance_scale: 3.5
212+
height: 480
213+
width: 832
214+
num_frames: 81
215+
guidance_scale: 5.0
216+
flow_shift: 3.0
217+
218+
# skip layer guidance
219+
slg_layers: [9]
220+
slg_start: 0.2
221+
slg_end: 1.0
230222
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
231223
guidance_rescale: 0.0
232-
num_inference_steps: 50
224+
num_inference_steps: 30
225+
fps: 24
226+
save_final_checkpoint: False
233227

234228
# SDXL Lightning parameters
235229
lightning_from_pt: True

src/maxdiffusion/generate_wan.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Sequence
16+
import jax
17+
import time
18+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
19+
from maxdiffusion import pyconfig, max_logging
20+
from absl import app
21+
from maxdiffusion.utils import export_to_video
22+
23+
24+
def run(config):
25+
print("seed: ", config.seed)
26+
pipeline = WanPipeline.from_pretrained(config)
27+
s0 = time.perf_counter()
28+
29+
# Skip layer guidance
30+
slg_layers = config.slg_layers
31+
slg_start = config.slg_start
32+
slg_end = config.slg_end
33+
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
34+
global_batch_size = config.global_batch_size
35+
if global_batch_size != 0:
36+
batch_multiplier = global_batch_size
37+
else:
38+
batch_multiplier = jax.device_count() * config.per_device_batch_size
39+
40+
prompt = [config.prompt] * batch_multiplier
41+
negative_prompt = [config.negative_prompt] * batch_multiplier
42+
43+
max_logging.log(
44+
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
45+
)
46+
47+
videos = pipeline(
48+
prompt=prompt,
49+
negative_prompt=negative_prompt,
50+
height=config.height,
51+
width=config.width,
52+
num_frames=config.num_frames,
53+
num_inference_steps=config.num_inference_steps,
54+
guidance_scale=config.guidance_scale,
55+
slg_layers=slg_layers,
56+
slg_start=slg_start,
57+
slg_end=slg_end,
58+
)
59+
60+
print("compile time: ", (time.perf_counter() - s0))
61+
for i in range(len(videos)):
62+
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
63+
s0 = time.perf_counter()
64+
videos = pipeline(
65+
prompt=prompt,
66+
negative_prompt=negative_prompt,
67+
height=config.height,
68+
width=config.width,
69+
num_frames=config.num_frames,
70+
num_inference_steps=config.num_inference_steps,
71+
guidance_scale=config.guidance_scale,
72+
slg_layers=slg_layers,
73+
slg_start=slg_start,
74+
slg_end=slg_end,
75+
)
76+
print("generation time: ", (time.perf_counter() - s0))
77+
for i in range(len(videos)):
78+
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
79+
80+
s0 = time.perf_counter()
81+
with jax.profiler.trace("/tmp/trace/"):
82+
videos = pipeline(
83+
prompt=prompt,
84+
negative_prompt=negative_prompt,
85+
height=config.height,
86+
width=config.width,
87+
num_frames=config.num_frames,
88+
num_inference_steps=config.num_inference_steps,
89+
guidance_scale=config.guidance_scale,
90+
slg_layers=slg_layers,
91+
slg_start=slg_start,
92+
slg_end=slg_end,
93+
)
94+
print("generation time: ", (time.perf_counter() - s0))
95+
96+
97+
def main(argv: Sequence[str]) -> None:
98+
pyconfig.initialize(argv)
99+
run(pyconfig.config)
100+
101+
102+
if __name__ == "__main__":
103+
app.run(main)

src/maxdiffusion/maxdiffusion_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,34 @@ def get_dummy_flux_inputs(config, pipeline, batch_size):
287287
return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states)
288288

289289

290+
def get_dummy_wan_inputs(config, pipeline, batch_size):
291+
latents = pipeline.prepare_latents(
292+
batch_size,
293+
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
294+
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
295+
height=config.height,
296+
width=config.width,
297+
num_frames=config.num_frames,
298+
num_channels_latents=pipeline.transformer.config.in_channels,
299+
)
300+
bsz = latents.shape[0]
301+
prompt_embeds = jax.random.normal(jax.random.key(config.seed), (batch_size, 512, 4096))
302+
timesteps = jnp.array([0] * bsz, dtype=jnp.int32)
303+
return (latents, prompt_embeds, timesteps)
304+
305+
306+
def calculate_wan_tflops(config, pipeline, batch_size, rngs, train):
307+
"""
308+
Calculates jflux tflops.
309+
batch_size should be per_device_batch_size * jax.local_device_count() or attention's shard_map won't
310+
cache the compilation when flash is enabled.
311+
"""
312+
(latents, prompt_embeds, timesteps) = get_dummy_wan_inputs(config, pipeline, batch_size)
313+
return max_utils.calculate_model_tflops(
314+
pipeline.transformer,
315+
)
316+
317+
290318
def calculate_flux_tflops(config, pipeline, batch_size, rngs, train):
291319
"""
292320
Calculates jflux tflops.

0 commit comments

Comments
 (0)