Skip to content

Commit 28fbabb

Browse files
ninatumartinarroyo
andcommitted
Abstract common WAN training components into BaseWanTrainer
The following key functionalities have been moved from WanTrainer to the new `BaseWanTrainer` ABC: - Initialization and config handling - Scheduler creation - TFLOPs calculation - Core training and evaluation loops (`start_training`, `training_loop`, `eval`) - Abstract methods for checkpointer, data loading, sharding, and step functions. Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent f30daac commit 28fbabb

2 files changed

Lines changed: 391 additions & 325 deletions

File tree

Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import abc
18+
from concurrent.futures import ThreadPoolExecutor
19+
from contextlib import nullcontext
20+
import datetime
21+
import os
22+
import pprint
23+
import threading
24+
from flax import nnx
25+
from flax.linen import partitioning as nn_partitioning
26+
from flax.training import train_state
27+
import jax
28+
from jax.experimental import multihost_utils
29+
import jax.numpy as jnp
30+
from maxdiffusion import max_logging, max_utils, train_utils
31+
from maxdiffusion.generate_wan import inference_generate_video
32+
from maxdiffusion.generate_wan import run as generate_wan
33+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
34+
from maxdiffusion.schedulers import FlaxFlowMatchScheduler
35+
from maxdiffusion.train_utils import ( _metrics_queue,_tensorboard_writer_worker, load_next_batch)
36+
from maxdiffusion.utils import load_video
37+
from maxdiffusion.video_processor import VideoProcessor
38+
import numpy as np
39+
from skimage.metrics import structural_similarity as ssim
40+
41+
42+
class TrainState(train_state.TrainState):
43+
graphdef: nnx.GraphDef
44+
rest_of_state: nnx.State
45+
46+
47+
def _to_array(x):
48+
if not isinstance(x, jax.Array):
49+
x = jnp.asarray(x)
50+
return x
51+
52+
53+
def generate_sample(config, pipeline, filename_prefix):
54+
"""
55+
Generates a video to validate training did not corrupt the model
56+
"""
57+
if not hasattr(pipeline, "vae"):
58+
wan_vae, vae_cache = WanPipeline.load_vae(
59+
pipeline.mesh.devices, pipeline.mesh, nnx.Rngs(jax.random.key(config.seed)), config
60+
)
61+
pipeline.vae = wan_vae
62+
pipeline.vae_cache = vae_cache
63+
return generate_wan(config, pipeline, filename_prefix)
64+
65+
66+
def print_ssim(pretrained_video_path, posttrained_video_path):
67+
video_processor = VideoProcessor()
68+
pretrained_video = load_video(pretrained_video_path[0])
69+
pretrained_video = video_processor.preprocess_video(pretrained_video)
70+
pretrained_video = np.array(pretrained_video)
71+
pretrained_video = np.transpose(pretrained_video, (0, 2, 3, 4, 1))
72+
pretrained_video = np.uint8((pretrained_video + 1) * 255 / 2)
73+
74+
posttrained_video = load_video(posttrained_video_path[0])
75+
posttrained_video = video_processor.preprocess_video(posttrained_video)
76+
posttrained_video = np.array(posttrained_video)
77+
posttrained_video = np.transpose(posttrained_video, (0, 2, 3, 4, 1))
78+
posttrained_video = np.uint8((posttrained_video + 1) * 255 / 2)
79+
80+
ssim_compare = ssim(pretrained_video[0], posttrained_video[0], multichannel=True, channel_axis=-1, data_range=255)
81+
82+
max_logging.log(f"SSIM score after training is {ssim_compare}")
83+
84+
85+
class BaseWanTrainer(abc.ABC):
86+
87+
def __init__(self, config):
88+
if config.train_text_encoder:
89+
raise ValueError("this script currently doesn't support training text_encoders")
90+
self.config = config
91+
self.checkpointer = self._get_checkpointer()
92+
93+
@abc.abstractmethod
94+
def _get_checkpointer(self):
95+
"""Returns the checkpointer for the trainer."""
96+
97+
def post_training_steps(self, pipeline, params, train_states, msg=""):
98+
pass
99+
100+
def create_scheduler(self):
101+
"""Creates and initializes the Flow Match scheduler for training."""
102+
noise_scheduler = FlaxFlowMatchScheduler(dtype=jnp.float32)
103+
noise_scheduler_state = noise_scheduler.create_state()
104+
noise_scheduler_state = noise_scheduler.set_timesteps(noise_scheduler_state, num_inference_steps=1000, training=True)
105+
return noise_scheduler, noise_scheduler_state
106+
107+
@staticmethod
108+
def calculate_tflops(pipeline):
109+
maxdiffusion_config = pipeline.config
110+
# Model configuration
111+
height = pipeline.config.height
112+
width = pipeline.config.width
113+
num_frames = pipeline.config.num_frames
114+
115+
# Transformer dimensions
116+
transformer_config = pipeline.transformer.config
117+
num_layers = transformer_config.num_layers
118+
heads = pipeline.transformer.config.num_attention_heads
119+
head_dim = pipeline.transformer.config.attention_head_dim
120+
ffn_dim = transformer_config.ffn_dim
121+
seq_len = int(((height / 8) * (width / 8) * ((num_frames - 1) // pipeline.vae_scale_factor_temporal + 1)) / 4)
122+
text_encoder_dim = 512
123+
# Attention FLOPS
124+
# Self
125+
self_attn_qkv_proj_flops = 3 * (2 * seq_len * (heads * head_dim) ** 2)
126+
self_attn_qk_v_flops = 2 * (2 * seq_len**2 * (heads * head_dim))
127+
# Cross
128+
cross_attn_kv_proj_flops = 3 * (2 * text_encoder_dim * (heads * head_dim) ** 2)
129+
cross_attn_q_proj_flops = 1 * (2 * seq_len * (heads * head_dim) ** 2)
130+
cross_attention_qk_v_flops = 2 * (2 * seq_len * text_encoder_dim * (heads * head_dim))
131+
132+
# Output_projection from attention
133+
attn_output_proj_flops = 2 * (2 * seq_len * (heads * head_dim) ** 2)
134+
135+
total_attn_flops = (
136+
self_attn_qkv_proj_flops
137+
+ self_attn_qk_v_flops
138+
+ cross_attn_kv_proj_flops
139+
+ cross_attn_q_proj_flops
140+
+ cross_attention_qk_v_flops
141+
+ attn_output_proj_flops
142+
)
143+
144+
# FFN
145+
ffn_flops = 2 * (2 * seq_len * (heads * head_dim) * ffn_dim)
146+
147+
flops_per_block = total_attn_flops + ffn_flops
148+
149+
total_transformer_flops = flops_per_block * num_layers
150+
151+
tflops = maxdiffusion_config.per_device_batch_size * total_transformer_flops / 1e12
152+
train_tflops = 3 * tflops
153+
154+
max_logging.log(f"Calculated TFLOPs per pass: {train_tflops:.4f}")
155+
return train_tflops, total_attn_flops, seq_len
156+
157+
@abc.abstractmethod
158+
def get_data_shardings(self, mesh):
159+
"""Returns data shardings for training."""
160+
161+
@abc.abstractmethod
162+
def get_eval_data_shardings(self, mesh):
163+
"""Returns data shardings for evaluation."""
164+
165+
@abc.abstractmethod
166+
def load_dataset(self, mesh, pipeline=None, is_training=True):
167+
"""Loads the dataset."""
168+
169+
@abc.abstractmethod
170+
def get_train_step(self, pipeline, mesh, state_shardings, data_shardings):
171+
"""Returns the training step function."""
172+
173+
@abc.abstractmethod
174+
def get_eval_step(self, pipeline, mesh, state_shardings, eval_data_shardings):
175+
"""Returns the evaluation step function."""
176+
177+
def start_training(self):
178+
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
179+
pipeline, opt_state, step = self.checkpointer.load_checkpoint()
180+
restore_args = {}
181+
if opt_state and step:
182+
restore_args = {"opt_state": opt_state, "step": step}
183+
del opt_state
184+
if self.config.enable_ssim:
185+
# Generate a sample before training to compare against generated sample after training.
186+
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
187+
188+
if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval):
189+
# save some memory.
190+
del pipeline.vae
191+
del pipeline.vae_cache
192+
193+
mesh = pipeline.mesh
194+
train_data_iterator = self.load_dataset(mesh, pipeline=pipeline, is_training=True)
195+
196+
# Load FlowMatch scheduler
197+
scheduler, scheduler_state = self.create_scheduler()
198+
pipeline.scheduler = scheduler
199+
pipeline.scheduler_state = scheduler_state
200+
optimizer, learning_rate_scheduler = self.checkpointer._create_optimizer(
201+
pipeline.transformer, self.config, self.config.learning_rate
202+
)
203+
# Returns pipeline with trained transformer state
204+
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args)
205+
206+
if self.config.enable_ssim:
207+
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
208+
print_ssim(pretrained_video_path, posttrained_video_path)
209+
210+
def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer):
211+
eval_data_iterator = self.load_dataset(mesh, is_training=False)
212+
eval_rng = eval_rng_key
213+
eval_losses_by_timestep = {}
214+
# Loop indefinitely until the iterator is exhausted
215+
while True:
216+
try:
217+
eval_start_time = datetime.datetime.now()
218+
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
219+
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
220+
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
221+
metrics["scalar"]["learning/eval_loss"].block_until_ready()
222+
losses = metrics["scalar"]["learning/eval_loss"]
223+
timesteps = eval_batch["timesteps"]
224+
gathered_losses = multihost_utils.process_allgather(losses, tiled=True)
225+
gathered_losses = jax.device_get(gathered_losses)
226+
gathered_timesteps = multihost_utils.process_allgather(timesteps, tiled=True)
227+
gathered_timesteps = jax.device_get(gathered_timesteps)
228+
if jax.process_index() == 0:
229+
for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()):
230+
timestep = int(t)
231+
if timestep not in eval_losses_by_timestep:
232+
eval_losses_by_timestep[timestep] = []
233+
eval_losses_by_timestep[timestep].append(l)
234+
eval_end_time = datetime.datetime.now()
235+
eval_duration = eval_end_time - eval_start_time
236+
max_logging.log(f"Eval time: {eval_duration.total_seconds():.2f} seconds.")
237+
except StopIteration:
238+
# This block is executed when the iterator has no more data
239+
break
240+
# Check if any evaluation was actually performed
241+
if eval_losses_by_timestep and jax.process_index() == 0:
242+
mean_per_timestep = []
243+
if jax.process_index() == 0:
244+
max_logging.log(f"Step {step}, calculating mean loss per timestep...")
245+
for timestep, losses in sorted(eval_losses_by_timestep.items()):
246+
losses = jnp.array(losses)
247+
losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))]
248+
mean_loss = jnp.mean(losses)
249+
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}")
250+
mean_per_timestep.append(mean_loss)
251+
final_eval_loss = jnp.mean(jnp.array(mean_per_timestep))
252+
max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}")
253+
if writer:
254+
writer.add_scalar("learning/eval_loss", final_eval_loss, step)
255+
256+
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args: dict = {}):
257+
mesh = pipeline.mesh
258+
graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)
259+
260+
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
261+
state = TrainState.create(
262+
apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state
263+
)
264+
if restore_args:
265+
step = restore_args.get("step", 0)
266+
max_logging.log(f"Restoring optimizer and resuming from step {step}")
267+
state.replace(opt_state=restore_args.get("opt_state"), step=restore_args.get("step", 0))
268+
del restore_args["opt_state"]
269+
del optimizer
270+
state = jax.tree.map(_to_array, state)
271+
state_spec = nnx.get_partition_spec(state)
272+
state = jax.lax.with_sharding_constraint(state, state_spec)
273+
state_shardings = nnx.get_named_sharding(state, mesh)
274+
if jax.process_index() == 0 and restore_args:
275+
max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---")
276+
pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60)
277+
max_logging.log(pretty_string)
278+
max_logging.log("------------------------------------------------")
279+
if self.config.hardware != "gpu":
280+
max_utils.delete_pytree(params)
281+
data_shardings = self.get_data_shardings(mesh)
282+
eval_data_shardings = self.get_eval_data_shardings(mesh)
283+
284+
writer = max_utils.initialize_summary_writer(self.config)
285+
writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True)
286+
writer_thread.start()
287+
288+
num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params)
289+
max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer)
290+
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer)
291+
max_utils.add_config_to_summary_writer(self.config, writer)
292+
293+
if jax.process_index() == 0:
294+
max_logging.log("***** Running training *****")
295+
max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}")
296+
max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.config.global_batch_size_to_train_on}")
297+
max_logging.log(f" Total optimization steps = {self.config.max_train_steps}")
298+
299+
p_train_step = self.get_train_step(
300+
pipeline, mesh, state_shardings, data_shardings
301+
)
302+
p_eval_step = self.get_eval_step(
303+
pipeline, mesh, state_shardings, eval_data_shardings
304+
)
305+
306+
rng = jax.random.key(self.config.seed)
307+
rng, eval_rng_key = jax.random.split(rng)
308+
start_step = 0
309+
last_step_completion = datetime.datetime.now()
310+
local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None
311+
running_gcs_metrics = [] if self.config.gcs_metrics else None
312+
first_profiling_step = self.config.skip_first_n_steps_for_profiler
313+
if self.config.enable_profiler and first_profiling_step >= self.config.max_train_steps:
314+
raise ValueError("Profiling requested but initial profiling step set past training final step")
315+
last_profiling_step = np.clip(
316+
first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1
317+
)
318+
if restore_args.get("step", 0):
319+
max_logging.log(f"Resuming training from step {step}")
320+
start_step = restore_args.get("step", 0)
321+
per_device_tflops, _, _ = BaseWanTrainer.calculate_tflops(pipeline)
322+
scheduler_state = pipeline.scheduler_state
323+
example_batch = load_next_batch(train_data_iterator, None, self.config)
324+
325+
with ThreadPoolExecutor(max_workers=1) as executor:
326+
for step in np.arange(start_step, self.config.max_train_steps):
327+
if self.config.enable_profiler and step == first_profiling_step:
328+
max_utils.activate_profiler(self.config)
329+
start_step_time = datetime.datetime.now()
330+
331+
next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config)
332+
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules(
333+
self.config.logical_axis_rules
334+
):
335+
state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state)
336+
train_metric["scalar"]["learning/loss"].block_until_ready()
337+
last_step_completion = datetime.datetime.now()
338+
339+
if self.config.enable_profiler and step == last_profiling_step:
340+
max_utils.deactivate_profiler(self.config)
341+
342+
train_utils.record_scalar_metrics(
343+
train_metric, last_step_completion - start_step_time, per_device_tflops, learning_rate_scheduler(step)
344+
)
345+
if self.config.write_metrics:
346+
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)
347+
348+
if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0:
349+
if self.config.enable_generate_video_for_eval:
350+
pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state)
351+
inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-")
352+
# Re-create the iterator each time you start evaluation to reset it
353+
# This assumes your data loading logic can be called to get a fresh iterator.
354+
self.eval(mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer)
355+
356+
example_batch = next_batch_future.result()
357+
if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0:
358+
max_logging.log(f"Saving checkpoint for step {step}")
359+
if self.config.save_optimizer:
360+
self.checkpointer.save_checkpoint(step, pipeline, state)
361+
else:
362+
self.checkpointer.save_checkpoint(step, pipeline, state.params)
363+
364+
_metrics_queue.put(None)
365+
writer_thread.join()
366+
if writer:
367+
writer.flush()
368+
if self.config.save_final_checkpoint:
369+
max_logging.log(f"Saving final checkpoint for step {step}")
370+
self.checkpointer.save_checkpoint(self.config.max_train_steps - 1, pipeline, state.params)
371+
self.checkpointer.checkpoint_manager.wait_until_finished()
372+
# load new state for trained transformer
373+
pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state)
374+
return pipeline

0 commit comments

Comments
 (0)