|
| 1 | +""" |
| 2 | +Copyright 2025 Google LLC |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +import abc |
| 18 | +from concurrent.futures import ThreadPoolExecutor |
| 19 | +from contextlib import nullcontext |
| 20 | +import datetime |
| 21 | +import os |
| 22 | +import pprint |
| 23 | +import threading |
| 24 | +from flax import nnx |
| 25 | +from flax.linen import partitioning as nn_partitioning |
| 26 | +from flax.training import train_state |
| 27 | +import jax |
| 28 | +from jax.experimental import multihost_utils |
| 29 | +import jax.numpy as jnp |
| 30 | +from maxdiffusion import max_logging, max_utils, train_utils |
| 31 | +from maxdiffusion.generate_wan import inference_generate_video |
| 32 | +from maxdiffusion.generate_wan import run as generate_wan |
| 33 | +from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline |
| 34 | +from maxdiffusion.schedulers import FlaxFlowMatchScheduler |
| 35 | +from maxdiffusion.train_utils import ( _metrics_queue,_tensorboard_writer_worker, load_next_batch) |
| 36 | +from maxdiffusion.utils import load_video |
| 37 | +from maxdiffusion.video_processor import VideoProcessor |
| 38 | +import numpy as np |
| 39 | +from skimage.metrics import structural_similarity as ssim |
| 40 | + |
| 41 | + |
| 42 | +class TrainState(train_state.TrainState): |
| 43 | + graphdef: nnx.GraphDef |
| 44 | + rest_of_state: nnx.State |
| 45 | + |
| 46 | + |
| 47 | +def _to_array(x): |
| 48 | + if not isinstance(x, jax.Array): |
| 49 | + x = jnp.asarray(x) |
| 50 | + return x |
| 51 | + |
| 52 | + |
| 53 | +def generate_sample(config, pipeline, filename_prefix): |
| 54 | + """ |
| 55 | + Generates a video to validate training did not corrupt the model |
| 56 | + """ |
| 57 | + if not hasattr(pipeline, "vae"): |
| 58 | + wan_vae, vae_cache = WanPipeline.load_vae( |
| 59 | + pipeline.mesh.devices, pipeline.mesh, nnx.Rngs(jax.random.key(config.seed)), config |
| 60 | + ) |
| 61 | + pipeline.vae = wan_vae |
| 62 | + pipeline.vae_cache = vae_cache |
| 63 | + return generate_wan(config, pipeline, filename_prefix) |
| 64 | + |
| 65 | + |
| 66 | +def print_ssim(pretrained_video_path, posttrained_video_path): |
| 67 | + video_processor = VideoProcessor() |
| 68 | + pretrained_video = load_video(pretrained_video_path[0]) |
| 69 | + pretrained_video = video_processor.preprocess_video(pretrained_video) |
| 70 | + pretrained_video = np.array(pretrained_video) |
| 71 | + pretrained_video = np.transpose(pretrained_video, (0, 2, 3, 4, 1)) |
| 72 | + pretrained_video = np.uint8((pretrained_video + 1) * 255 / 2) |
| 73 | + |
| 74 | + posttrained_video = load_video(posttrained_video_path[0]) |
| 75 | + posttrained_video = video_processor.preprocess_video(posttrained_video) |
| 76 | + posttrained_video = np.array(posttrained_video) |
| 77 | + posttrained_video = np.transpose(posttrained_video, (0, 2, 3, 4, 1)) |
| 78 | + posttrained_video = np.uint8((posttrained_video + 1) * 255 / 2) |
| 79 | + |
| 80 | + ssim_compare = ssim(pretrained_video[0], posttrained_video[0], multichannel=True, channel_axis=-1, data_range=255) |
| 81 | + |
| 82 | + max_logging.log(f"SSIM score after training is {ssim_compare}") |
| 83 | + |
| 84 | + |
| 85 | +class BaseWanTrainer(abc.ABC): |
| 86 | + |
| 87 | + def __init__(self, config): |
| 88 | + if config.train_text_encoder: |
| 89 | + raise ValueError("this script currently doesn't support training text_encoders") |
| 90 | + self.config = config |
| 91 | + self.checkpointer = self._get_checkpointer() |
| 92 | + |
| 93 | + @abc.abstractmethod |
| 94 | + def _get_checkpointer(self): |
| 95 | + """Returns the checkpointer for the trainer.""" |
| 96 | + |
| 97 | + def post_training_steps(self, pipeline, params, train_states, msg=""): |
| 98 | + pass |
| 99 | + |
| 100 | + def create_scheduler(self): |
| 101 | + """Creates and initializes the Flow Match scheduler for training.""" |
| 102 | + noise_scheduler = FlaxFlowMatchScheduler(dtype=jnp.float32) |
| 103 | + noise_scheduler_state = noise_scheduler.create_state() |
| 104 | + noise_scheduler_state = noise_scheduler.set_timesteps(noise_scheduler_state, num_inference_steps=1000, training=True) |
| 105 | + return noise_scheduler, noise_scheduler_state |
| 106 | + |
| 107 | + @staticmethod |
| 108 | + def calculate_tflops(pipeline): |
| 109 | + maxdiffusion_config = pipeline.config |
| 110 | + # Model configuration |
| 111 | + height = pipeline.config.height |
| 112 | + width = pipeline.config.width |
| 113 | + num_frames = pipeline.config.num_frames |
| 114 | + |
| 115 | + # Transformer dimensions |
| 116 | + transformer_config = pipeline.transformer.config |
| 117 | + num_layers = transformer_config.num_layers |
| 118 | + heads = pipeline.transformer.config.num_attention_heads |
| 119 | + head_dim = pipeline.transformer.config.attention_head_dim |
| 120 | + ffn_dim = transformer_config.ffn_dim |
| 121 | + seq_len = int(((height / 8) * (width / 8) * ((num_frames - 1) // pipeline.vae_scale_factor_temporal + 1)) / 4) |
| 122 | + text_encoder_dim = 512 |
| 123 | + # Attention FLOPS |
| 124 | + # Self |
| 125 | + self_attn_qkv_proj_flops = 3 * (2 * seq_len * (heads * head_dim) ** 2) |
| 126 | + self_attn_qk_v_flops = 2 * (2 * seq_len**2 * (heads * head_dim)) |
| 127 | + # Cross |
| 128 | + cross_attn_kv_proj_flops = 3 * (2 * text_encoder_dim * (heads * head_dim) ** 2) |
| 129 | + cross_attn_q_proj_flops = 1 * (2 * seq_len * (heads * head_dim) ** 2) |
| 130 | + cross_attention_qk_v_flops = 2 * (2 * seq_len * text_encoder_dim * (heads * head_dim)) |
| 131 | + |
| 132 | + # Output_projection from attention |
| 133 | + attn_output_proj_flops = 2 * (2 * seq_len * (heads * head_dim) ** 2) |
| 134 | + |
| 135 | + total_attn_flops = ( |
| 136 | + self_attn_qkv_proj_flops |
| 137 | + + self_attn_qk_v_flops |
| 138 | + + cross_attn_kv_proj_flops |
| 139 | + + cross_attn_q_proj_flops |
| 140 | + + cross_attention_qk_v_flops |
| 141 | + + attn_output_proj_flops |
| 142 | + ) |
| 143 | + |
| 144 | + # FFN |
| 145 | + ffn_flops = 2 * (2 * seq_len * (heads * head_dim) * ffn_dim) |
| 146 | + |
| 147 | + flops_per_block = total_attn_flops + ffn_flops |
| 148 | + |
| 149 | + total_transformer_flops = flops_per_block * num_layers |
| 150 | + |
| 151 | + tflops = maxdiffusion_config.per_device_batch_size * total_transformer_flops / 1e12 |
| 152 | + train_tflops = 3 * tflops |
| 153 | + |
| 154 | + max_logging.log(f"Calculated TFLOPs per pass: {train_tflops:.4f}") |
| 155 | + return train_tflops, total_attn_flops, seq_len |
| 156 | + |
| 157 | + @abc.abstractmethod |
| 158 | + def get_data_shardings(self, mesh): |
| 159 | + """Returns data shardings for training.""" |
| 160 | + |
| 161 | + @abc.abstractmethod |
| 162 | + def get_eval_data_shardings(self, mesh): |
| 163 | + """Returns data shardings for evaluation.""" |
| 164 | + |
| 165 | + @abc.abstractmethod |
| 166 | + def load_dataset(self, mesh, pipeline=None, is_training=True): |
| 167 | + """Loads the dataset.""" |
| 168 | + |
| 169 | + @abc.abstractmethod |
| 170 | + def get_train_step(self, pipeline, mesh, state_shardings, data_shardings): |
| 171 | + """Returns the training step function.""" |
| 172 | + |
| 173 | + @abc.abstractmethod |
| 174 | + def get_eval_step(self, pipeline, mesh, state_shardings, eval_data_shardings): |
| 175 | + """Returns the evaluation step function.""" |
| 176 | + |
| 177 | + def start_training(self): |
| 178 | + with nn_partitioning.axis_rules(self.config.logical_axis_rules): |
| 179 | + pipeline, opt_state, step = self.checkpointer.load_checkpoint() |
| 180 | + restore_args = {} |
| 181 | + if opt_state and step: |
| 182 | + restore_args = {"opt_state": opt_state, "step": step} |
| 183 | + del opt_state |
| 184 | + if self.config.enable_ssim: |
| 185 | + # Generate a sample before training to compare against generated sample after training. |
| 186 | + pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-") |
| 187 | + |
| 188 | + if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval): |
| 189 | + # save some memory. |
| 190 | + del pipeline.vae |
| 191 | + del pipeline.vae_cache |
| 192 | + |
| 193 | + mesh = pipeline.mesh |
| 194 | + train_data_iterator = self.load_dataset(mesh, pipeline=pipeline, is_training=True) |
| 195 | + |
| 196 | + # Load FlowMatch scheduler |
| 197 | + scheduler, scheduler_state = self.create_scheduler() |
| 198 | + pipeline.scheduler = scheduler |
| 199 | + pipeline.scheduler_state = scheduler_state |
| 200 | + optimizer, learning_rate_scheduler = self.checkpointer._create_optimizer( |
| 201 | + pipeline.transformer, self.config, self.config.learning_rate |
| 202 | + ) |
| 203 | + # Returns pipeline with trained transformer state |
| 204 | + pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args) |
| 205 | + |
| 206 | + if self.config.enable_ssim: |
| 207 | + posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-") |
| 208 | + print_ssim(pretrained_video_path, posttrained_video_path) |
| 209 | + |
| 210 | + def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer): |
| 211 | + eval_data_iterator = self.load_dataset(mesh, is_training=False) |
| 212 | + eval_rng = eval_rng_key |
| 213 | + eval_losses_by_timestep = {} |
| 214 | + # Loop indefinitely until the iterator is exhausted |
| 215 | + while True: |
| 216 | + try: |
| 217 | + eval_start_time = datetime.datetime.now() |
| 218 | + eval_batch = load_next_batch(eval_data_iterator, None, self.config) |
| 219 | + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): |
| 220 | + metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) |
| 221 | + metrics["scalar"]["learning/eval_loss"].block_until_ready() |
| 222 | + losses = metrics["scalar"]["learning/eval_loss"] |
| 223 | + timesteps = eval_batch["timesteps"] |
| 224 | + gathered_losses = multihost_utils.process_allgather(losses, tiled=True) |
| 225 | + gathered_losses = jax.device_get(gathered_losses) |
| 226 | + gathered_timesteps = multihost_utils.process_allgather(timesteps, tiled=True) |
| 227 | + gathered_timesteps = jax.device_get(gathered_timesteps) |
| 228 | + if jax.process_index() == 0: |
| 229 | + for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()): |
| 230 | + timestep = int(t) |
| 231 | + if timestep not in eval_losses_by_timestep: |
| 232 | + eval_losses_by_timestep[timestep] = [] |
| 233 | + eval_losses_by_timestep[timestep].append(l) |
| 234 | + eval_end_time = datetime.datetime.now() |
| 235 | + eval_duration = eval_end_time - eval_start_time |
| 236 | + max_logging.log(f"Eval time: {eval_duration.total_seconds():.2f} seconds.") |
| 237 | + except StopIteration: |
| 238 | + # This block is executed when the iterator has no more data |
| 239 | + break |
| 240 | + # Check if any evaluation was actually performed |
| 241 | + if eval_losses_by_timestep and jax.process_index() == 0: |
| 242 | + mean_per_timestep = [] |
| 243 | + if jax.process_index() == 0: |
| 244 | + max_logging.log(f"Step {step}, calculating mean loss per timestep...") |
| 245 | + for timestep, losses in sorted(eval_losses_by_timestep.items()): |
| 246 | + losses = jnp.array(losses) |
| 247 | + losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))] |
| 248 | + mean_loss = jnp.mean(losses) |
| 249 | + max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}") |
| 250 | + mean_per_timestep.append(mean_loss) |
| 251 | + final_eval_loss = jnp.mean(jnp.array(mean_per_timestep)) |
| 252 | + max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") |
| 253 | + if writer: |
| 254 | + writer.add_scalar("learning/eval_loss", final_eval_loss, step) |
| 255 | + |
| 256 | + def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args: dict = {}): |
| 257 | + mesh = pipeline.mesh |
| 258 | + graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...) |
| 259 | + |
| 260 | + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): |
| 261 | + state = TrainState.create( |
| 262 | + apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state |
| 263 | + ) |
| 264 | + if restore_args: |
| 265 | + step = restore_args.get("step", 0) |
| 266 | + max_logging.log(f"Restoring optimizer and resuming from step {step}") |
| 267 | + state.replace(opt_state=restore_args.get("opt_state"), step=restore_args.get("step", 0)) |
| 268 | + del restore_args["opt_state"] |
| 269 | + del optimizer |
| 270 | + state = jax.tree.map(_to_array, state) |
| 271 | + state_spec = nnx.get_partition_spec(state) |
| 272 | + state = jax.lax.with_sharding_constraint(state, state_spec) |
| 273 | + state_shardings = nnx.get_named_sharding(state, mesh) |
| 274 | + if jax.process_index() == 0 and restore_args: |
| 275 | + max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---") |
| 276 | + pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60) |
| 277 | + max_logging.log(pretty_string) |
| 278 | + max_logging.log("------------------------------------------------") |
| 279 | + if self.config.hardware != "gpu": |
| 280 | + max_utils.delete_pytree(params) |
| 281 | + data_shardings = self.get_data_shardings(mesh) |
| 282 | + eval_data_shardings = self.get_eval_data_shardings(mesh) |
| 283 | + |
| 284 | + writer = max_utils.initialize_summary_writer(self.config) |
| 285 | + writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) |
| 286 | + writer_thread.start() |
| 287 | + |
| 288 | + num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params) |
| 289 | + max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) |
| 290 | + max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer) |
| 291 | + max_utils.add_config_to_summary_writer(self.config, writer) |
| 292 | + |
| 293 | + if jax.process_index() == 0: |
| 294 | + max_logging.log("***** Running training *****") |
| 295 | + max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}") |
| 296 | + max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.config.global_batch_size_to_train_on}") |
| 297 | + max_logging.log(f" Total optimization steps = {self.config.max_train_steps}") |
| 298 | + |
| 299 | + p_train_step = self.get_train_step( |
| 300 | + pipeline, mesh, state_shardings, data_shardings |
| 301 | + ) |
| 302 | + p_eval_step = self.get_eval_step( |
| 303 | + pipeline, mesh, state_shardings, eval_data_shardings |
| 304 | + ) |
| 305 | + |
| 306 | + rng = jax.random.key(self.config.seed) |
| 307 | + rng, eval_rng_key = jax.random.split(rng) |
| 308 | + start_step = 0 |
| 309 | + last_step_completion = datetime.datetime.now() |
| 310 | + local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None |
| 311 | + running_gcs_metrics = [] if self.config.gcs_metrics else None |
| 312 | + first_profiling_step = self.config.skip_first_n_steps_for_profiler |
| 313 | + if self.config.enable_profiler and first_profiling_step >= self.config.max_train_steps: |
| 314 | + raise ValueError("Profiling requested but initial profiling step set past training final step") |
| 315 | + last_profiling_step = np.clip( |
| 316 | + first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1 |
| 317 | + ) |
| 318 | + if restore_args.get("step", 0): |
| 319 | + max_logging.log(f"Resuming training from step {step}") |
| 320 | + start_step = restore_args.get("step", 0) |
| 321 | + per_device_tflops, _, _ = BaseWanTrainer.calculate_tflops(pipeline) |
| 322 | + scheduler_state = pipeline.scheduler_state |
| 323 | + example_batch = load_next_batch(train_data_iterator, None, self.config) |
| 324 | + |
| 325 | + with ThreadPoolExecutor(max_workers=1) as executor: |
| 326 | + for step in np.arange(start_step, self.config.max_train_steps): |
| 327 | + if self.config.enable_profiler and step == first_profiling_step: |
| 328 | + max_utils.activate_profiler(self.config) |
| 329 | + start_step_time = datetime.datetime.now() |
| 330 | + |
| 331 | + next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config) |
| 332 | + with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( |
| 333 | + self.config.logical_axis_rules |
| 334 | + ): |
| 335 | + state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state) |
| 336 | + train_metric["scalar"]["learning/loss"].block_until_ready() |
| 337 | + last_step_completion = datetime.datetime.now() |
| 338 | + |
| 339 | + if self.config.enable_profiler and step == last_profiling_step: |
| 340 | + max_utils.deactivate_profiler(self.config) |
| 341 | + |
| 342 | + train_utils.record_scalar_metrics( |
| 343 | + train_metric, last_step_completion - start_step_time, per_device_tflops, learning_rate_scheduler(step) |
| 344 | + ) |
| 345 | + if self.config.write_metrics: |
| 346 | + train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) |
| 347 | + |
| 348 | + if self.config.eval_every > 0 and (step + 1) % self.config.eval_every == 0: |
| 349 | + if self.config.enable_generate_video_for_eval: |
| 350 | + pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) |
| 351 | + inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-") |
| 352 | + # Re-create the iterator each time you start evaluation to reset it |
| 353 | + # This assumes your data loading logic can be called to get a fresh iterator. |
| 354 | + self.eval(mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer) |
| 355 | + |
| 356 | + example_batch = next_batch_future.result() |
| 357 | + if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: |
| 358 | + max_logging.log(f"Saving checkpoint for step {step}") |
| 359 | + if self.config.save_optimizer: |
| 360 | + self.checkpointer.save_checkpoint(step, pipeline, state) |
| 361 | + else: |
| 362 | + self.checkpointer.save_checkpoint(step, pipeline, state.params) |
| 363 | + |
| 364 | + _metrics_queue.put(None) |
| 365 | + writer_thread.join() |
| 366 | + if writer: |
| 367 | + writer.flush() |
| 368 | + if self.config.save_final_checkpoint: |
| 369 | + max_logging.log(f"Saving final checkpoint for step {step}") |
| 370 | + self.checkpointer.save_checkpoint(self.config.max_train_steps - 1, pipeline, state.params) |
| 371 | + self.checkpointer.checkpoint_manager.wait_until_finished() |
| 372 | + # load new state for trained transformer |
| 373 | + pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) |
| 374 | + return pipeline |
0 commit comments