Skip to content

Commit 8c82c74

Browse files
authored
Merge branch 'main' into ninatu/wan_training
2 parents f08effb + 1d5d773 commit 8c82c74

52 files changed

Lines changed: 11121 additions & 367 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
[![Unit Tests](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2026/03/25`**: Wan2.1 and Wan2.2 Magcache inference is now supported
21+
- **`2026/03/25`**: LTX-2 Video Inference is now supported
2022
- **`2026/01/29`**: Wan LoRA for inference is now supported
2123
- **`2026/01/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported
2224
- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported
@@ -49,6 +51,7 @@ MaxDiffusion supports
4951
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
5052
* Dreambooth training support for Stable Diffusion 1.x,2.x.
5153
* LTX-Video text2vid, img2vid (inference).
54+
* LTX-2 Video text2vid (inference).
5255
* Wan2.1 text2vid (training and inference).
5356
* Wan2.2 text2vid (inference).
5457

@@ -73,6 +76,7 @@ MaxDiffusion supports
7376
- [Inference](#inference)
7477
- [Wan](#wan-models)
7578
- [LTX-Video](#ltx-video)
79+
- [LTX-2 Video](#ltx-2-video)
7680
- [Flux](#flux)
7781
- [Fused Attention for GPU](#fused-attention-for-gpu)
7882
- [SDXL](#stable-diffusion-xl)
@@ -497,6 +501,33 @@ To generate images, run the following command:
497501

498502
Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above.
499503

504+
## LTX-2 Video
505+
506+
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
507+
508+
The following command will run LTX-2 T2V:
509+
510+
```bash
511+
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
512+
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
513+
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
514+
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
515+
--xla_tpu_overlap_compute_collective_tc=true \
516+
--xla_enable_async_all_reduce=true" \
517+
HF_HUB_ENABLE_HF_TRANSFER=1 \
518+
python src/maxdiffusion/generate_ltx2.py \
519+
src/maxdiffusion/configs/ltx2_video.yml \
520+
attention="flash" \
521+
num_inference_steps=40 \
522+
num_frames=121 \
523+
width=768 \
524+
height=512 \
525+
per_device_batch_size=.125 \
526+
ici_data_parallelism=2 \
527+
ici_context_parallelism=4 \
528+
run_name=ltx2-inference
529+
```
530+
500531
## Wan Models
501532

502533
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ opencv-python-headless==4.10.0.84
3131
orbax-checkpoint
3232
tokenizers==0.21.0
3333
huggingface_hub>=0.30.2
34-
transformers==4.48.1
34+
transformers==4.51.0
3535
einops==0.8.0
3636
sentencepiece
3737
aqtp

requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ opencv-python-headless==4.10.0.84
3030
orbax-checkpoint
3131
tokenizers==0.21.0
3232
huggingface_hub>=0.30.2
33-
transformers==4.48.1
33+
transformers==4.51.0
3434
tokamax
3535
einops==0.8.0
3636
sentencepiece
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import json
18+
import jax
19+
import numpy as np
20+
from typing import Optional, Tuple
21+
from maxdiffusion.pipelines.ltx2.ltx2_pipeline import LTX2Pipeline
22+
from maxdiffusion import max_logging
23+
from maxdiffusion.checkpointing.checkpointing_utils import create_orbax_checkpoint_manager
24+
import orbax.checkpoint as ocp
25+
from etils import epath
26+
27+
LTX2_CHECKPOINT = "LTX2_CHECKPOINT"
28+
29+
30+
class LTX2Checkpointer:
31+
32+
def __init__(self, config, checkpoint_type: str = LTX2_CHECKPOINT):
33+
self.config = config
34+
self.checkpoint_type = checkpoint_type
35+
self.opt_state = None
36+
37+
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
38+
getattr(self.config, "checkpoint_dir", ""),
39+
enable_checkpointing=True,
40+
save_interval_steps=1,
41+
checkpoint_type=checkpoint_type,
42+
dataset_type=getattr(config, "dataset_type", None),
43+
)
44+
45+
def load_ltx2_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
46+
if self.checkpoint_manager is None:
47+
max_logging.log("No checkpoint manager configured, skipping Orbax load.")
48+
return None, None
49+
50+
if step is None:
51+
step = self.checkpoint_manager.latest_step()
52+
max_logging.log(f"Latest LTX2 checkpoint step: {step}")
53+
if step is None:
54+
max_logging.log("No LTX2 checkpoint found.")
55+
return None, None
56+
max_logging.log(f"Loading LTX2 checkpoint from step {step}")
57+
metadatas = self.checkpoint_manager.item_metadata(step)
58+
transformer_metadata = metadatas.ltx2_state
59+
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
60+
params_restore = ocp.args.PyTreeRestore(
61+
restore_args=jax.tree.map(
62+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
63+
abstract_tree_structure_params,
64+
)
65+
)
66+
67+
max_logging.log("Restoring LTX2 checkpoint")
68+
restored_checkpoint = self.checkpoint_manager.restore(
69+
directory=epath.Path(self.config.checkpoint_dir),
70+
step=step,
71+
args=ocp.args.Composite(
72+
ltx2_state=params_restore,
73+
ltx2_config=ocp.args.JsonRestore(),
74+
),
75+
)
76+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
77+
max_logging.log(f"restored checkpoint ltx2_state {restored_checkpoint.ltx2_state.keys()}")
78+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.ltx2_state.keys()}")
79+
return restored_checkpoint, step
80+
81+
def load_checkpoint(
82+
self, step=None, vae_only=False, load_transformer=True
83+
) -> Tuple[LTX2Pipeline, Optional[dict], Optional[int]]:
84+
restored_checkpoint, step = self.load_ltx2_configs_from_orbax(step)
85+
opt_state = None
86+
87+
if restored_checkpoint:
88+
max_logging.log("Loading LTX2 pipeline from checkpoint")
89+
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer)
90+
if "opt_state" in restored_checkpoint.ltx2_state.keys():
91+
opt_state = restored_checkpoint.ltx2_state["opt_state"]
92+
else:
93+
max_logging.log("No checkpoint found, loading pipeline from pretrained hub")
94+
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer)
95+
96+
return pipeline, opt_state, step
97+
98+
def save_checkpoint(self, train_step, pipeline: LTX2Pipeline, train_states: dict):
99+
"""Saves the training state and model configurations."""
100+
101+
def config_to_json(model_or_config):
102+
return json.loads(model_or_config.to_json_string())
103+
104+
max_logging.log(f"Saving checkpoint for step {train_step}")
105+
items = {
106+
"ltx2_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
107+
}
108+
109+
items["ltx2_state"] = ocp.args.PyTreeSave(train_states)
110+
111+
# Save the checkpoint
112+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
113+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
WAN2_1 = "wan2.1"
5353
WAN2_2 = "wan2.2"
54+
LTX2_VIDEO = "ltx2_video"
5455

5556
WAN_MODEL = WAN2_1
5657

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,16 @@ num_frames: 81
325325
guidance_scale: 5.0
326326
flow_shift: 3.0
327327

328+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
329+
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
330+
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
331+
use_cfg_cache: False
332+
use_magcache: False
333+
magcache_thresh: 0.12
334+
magcache_K: 2
335+
retention_ratio: 0.2
336+
mag_ratios_base: [1.0, 1.0, 1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189]
337+
328338
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
329339
guidance_rescale: 0.0
330340
num_inference_steps: 30

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,9 @@ num_frames: 81
281281
guidance_scale: 5.0
282282
flow_shift: 3.0
283283

284+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
285+
use_cfg_cache: False
286+
284287
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
285288
guidance_rescale: 0.0
286289
num_inference_steps: 30

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,12 @@ guidance_scale_high: 4.0
302302
# timestep to switch between low noise and high noise transformer
303303
boundary_ratio: 0.875
304304

305+
# Diffusion CFG cache (FasterCache-style)
306+
use_cfg_cache: False
307+
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
308+
# when predicted output change (based on accumulated latent/timestep drift) is small
309+
use_sen_cache: False
310+
305311
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
306312
guidance_rescale: 0.0
307313
num_inference_steps: 30

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,15 @@ num_frames: 81
286286
guidance_scale: 5.0
287287
flow_shift: 5.0
288288

289+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
290+
use_cfg_cache: False
291+
use_magcache: False
292+
magcache_thresh: 0.12
293+
magcache_K: 2
294+
retention_ratio: 0.2
295+
mag_ratios_base_720p: [1.0, 1.0, 0.99428, 0.99498, 0.98588, 0.98621, 0.98273, 0.98281, 0.99018, 0.99023, 0.98911, 0.98917, 0.98646, 0.98652, 0.99454, 0.99456, 0.9891, 0.98909, 0.99124, 0.99127, 0.99102, 0.99103, 0.99215, 0.99212, 0.99515, 0.99515, 0.99576, 0.99572, 0.99068, 0.99072, 0.99097, 0.99097, 0.99166, 0.99169, 0.99041, 0.99042, 0.99201, 0.99198, 0.99101, 0.99101, 0.98599, 0.98603, 0.98845, 0.98844, 0.98848, 0.98851, 0.98862, 0.98857, 0.98718, 0.98719, 0.98497, 0.98497, 0.98264, 0.98263, 0.98389, 0.98393, 0.97938, 0.9794, 0.97535, 0.97536, 0.97498, 0.97499, 0.973, 0.97301, 0.96827, 0.96828, 0.96261, 0.96263, 0.95335, 0.9534, 0.94649, 0.94655, 0.93397, 0.93414, 0.91636, 0.9165, 0.89088, 0.89109, 0.8679, 0.86768]
296+
mag_ratios_base_480p: [1.0, 1.0, 0.98783, 0.98993, 0.97559, 0.97593, 0.98311, 0.98319, 0.98202, 0.98225, 0.9888, 0.98878, 0.98762, 0.98759, 0.98957, 0.98971, 0.99052, 0.99043, 0.99383, 0.99384, 0.98857, 0.9886, 0.99065, 0.99068, 0.98845, 0.98847, 0.99057, 0.99057, 0.98957, 0.98961, 0.98601, 0.9861, 0.98823, 0.98823, 0.98756, 0.98759, 0.98808, 0.98814, 0.98721, 0.98724, 0.98571, 0.98572, 0.98543, 0.98544, 0.98157, 0.98165, 0.98411, 0.98413, 0.97952, 0.97953, 0.98149, 0.9815, 0.9774, 0.97742, 0.97825, 0.97826, 0.97355, 0.97361, 0.97085, 0.97087, 0.97056, 0.97055, 0.96588, 0.96587, 0.96113, 0.96124, 0.9567, 0.95681, 0.94961, 0.94969, 0.93973, 0.93988, 0.93217, 0.93224, 0.91878, 0.91896, 0.90955, 0.90954, 0.92617, 0.92616]
297+
289298
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
290299
guidance_rescale: 0.0
291300
num_inference_steps: 50

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ guidance_scale_high: 4.0
298298
# timestep to switch between low noise and high noise transformer
299299
boundary_ratio: 0.875
300300

301+
# Diffusion CFG cache (FasterCache-style)
302+
use_cfg_cache: False
303+
301304
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
302305
guidance_rescale: 0.0
303306
num_inference_steps: 50

0 commit comments

Comments
 (0)