Skip to content

Commit 1d5d773

Browse files
Merge pull request #365 from AI-Hypercomputer:prisha/magcache
PiperOrigin-RevId: 891764377
2 parents b4f9573 + aba5f6a commit 1d5d773

10 files changed

Lines changed: 648 additions & 92 deletions

File tree

README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
[![Unit Tests](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2026/03/25`**: Wan2.1 and Wan2.2 Magcache inference is now supported
21+
- **`2026/03/25`**: LTX-2 Video Inference is now supported
2022
- **`2026/01/29`**: Wan LoRA for inference is now supported
2123
- **`2026/01/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported
2224
- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported
@@ -49,6 +51,7 @@ MaxDiffusion supports
4951
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
5052
* Dreambooth training support for Stable Diffusion 1.x,2.x.
5153
* LTX-Video text2vid, img2vid (inference).
54+
* LTX-2 Video text2vid (inference).
5255
* Wan2.1 text2vid (training and inference).
5356
* Wan2.2 text2vid (inference).
5457

@@ -73,6 +76,7 @@ MaxDiffusion supports
7376
- [Inference](#inference)
7477
- [Wan](#wan-models)
7578
- [LTX-Video](#ltx-video)
79+
- [LTX-2 Video](#ltx-2-video)
7680
- [Flux](#flux)
7781
- [Fused Attention for GPU](#fused-attention-for-gpu)
7882
- [SDXL](#stable-diffusion-xl)
@@ -497,6 +501,33 @@ To generate images, run the following command:
497501

498502
Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above.
499503

504+
## LTX-2 Video
505+
506+
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
507+
508+
The following command will run LTX-2 T2V:
509+
510+
```bash
511+
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
512+
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
513+
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
514+
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
515+
--xla_tpu_overlap_compute_collective_tc=true \
516+
--xla_enable_async_all_reduce=true" \
517+
HF_HUB_ENABLE_HF_TRANSFER=1 \
518+
python src/maxdiffusion/generate_ltx2.py \
519+
src/maxdiffusion/configs/ltx2_video.yml \
520+
attention="flash" \
521+
num_inference_steps=40 \
522+
num_frames=121 \
523+
width=768 \
524+
height=512 \
525+
per_device_batch_size=.125 \
526+
ici_data_parallelism=2 \
527+
ici_context_parallelism=4 \
528+
run_name=ltx2-inference
529+
```
530+
500531
## Wan Models
501532

502533
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,11 @@ flow_shift: 3.0
328328
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
329329
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
330330
use_cfg_cache: False
331+
use_magcache: False
332+
magcache_thresh: 0.12
333+
magcache_K: 2
334+
retention_ratio: 0.2
335+
mag_ratios_base: [1.0, 1.0, 1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189]
331336

332337
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
333338
guidance_rescale: 0.0

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,12 @@ flow_shift: 5.0
288288

289289
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
290290
use_cfg_cache: False
291+
use_magcache: False
292+
magcache_thresh: 0.12
293+
magcache_K: 2
294+
retention_ratio: 0.2
295+
mag_ratios_base_720p: [1.0, 1.0, 0.99428, 0.99498, 0.98588, 0.98621, 0.98273, 0.98281, 0.99018, 0.99023, 0.98911, 0.98917, 0.98646, 0.98652, 0.99454, 0.99456, 0.9891, 0.98909, 0.99124, 0.99127, 0.99102, 0.99103, 0.99215, 0.99212, 0.99515, 0.99515, 0.99576, 0.99572, 0.99068, 0.99072, 0.99097, 0.99097, 0.99166, 0.99169, 0.99041, 0.99042, 0.99201, 0.99198, 0.99101, 0.99101, 0.98599, 0.98603, 0.98845, 0.98844, 0.98848, 0.98851, 0.98862, 0.98857, 0.98718, 0.98719, 0.98497, 0.98497, 0.98264, 0.98263, 0.98389, 0.98393, 0.97938, 0.9794, 0.97535, 0.97536, 0.97498, 0.97499, 0.973, 0.97301, 0.96827, 0.96828, 0.96261, 0.96263, 0.95335, 0.9534, 0.94649, 0.94655, 0.93397, 0.93414, 0.91636, 0.9165, 0.89088, 0.89109, 0.8679, 0.86768]
296+
mag_ratios_base_480p: [1.0, 1.0, 0.98783, 0.98993, 0.97559, 0.97593, 0.98311, 0.98319, 0.98202, 0.98225, 0.9888, 0.98878, 0.98762, 0.98759, 0.98957, 0.98971, 0.99052, 0.99043, 0.99383, 0.99384, 0.98857, 0.9886, 0.99065, 0.99068, 0.98845, 0.98847, 0.99057, 0.99057, 0.98957, 0.98961, 0.98601, 0.9861, 0.98823, 0.98823, 0.98756, 0.98759, 0.98808, 0.98814, 0.98721, 0.98724, 0.98571, 0.98572, 0.98543, 0.98544, 0.98157, 0.98165, 0.98411, 0.98413, 0.97952, 0.97953, 0.98149, 0.9815, 0.9774, 0.97742, 0.97825, 0.97826, 0.97355, 0.97361, 0.97085, 0.97087, 0.97056, 0.97055, 0.96588, 0.96587, 0.96113, 0.96124, 0.9567, 0.95681, 0.94961, 0.94969, 0.93973, 0.93988, 0.93217, 0.93224, 0.91878, 0.91896, 0.90955, 0.90954, 0.92617, 0.92616]
291297

292298
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
293299
guidance_rescale: 0.0

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ dataset_name: ''
8080
train_split: 'train'
8181
dataset_type: 'tfrecord'
8282
cache_latents_text_encoder_outputs: True
83-
per_device_batch_size: 0.125
83+
per_device_batch_size: 1.0
8484
compile_topology_num_slices: -1
8585
quantization_local_shard_count: -1
8686
use_qwix_quantization: False

src/maxdiffusion/generate_wan.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
100100
num_frames=config.num_frames,
101101
num_inference_steps=config.num_inference_steps,
102102
guidance_scale=config.guidance_scale,
103+
use_magcache=config.use_magcache,
104+
magcache_thresh=config.magcache_thresh,
105+
magcache_K=config.magcache_K,
106+
retention_ratio=config.retention_ratio,
103107
)
104108
elif model_key == WAN2_2:
105109
return pipeline(
@@ -127,6 +131,10 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
127131
num_inference_steps=config.num_inference_steps,
128132
guidance_scale=config.guidance_scale,
129133
use_cfg_cache=config.use_cfg_cache,
134+
use_magcache=config.use_magcache,
135+
magcache_thresh=config.magcache_thresh,
136+
magcache_K=config.magcache_K,
137+
retention_ratio=config.retention_ratio,
130138
)
131139
elif model_key == WAN2_2:
132140
return pipeline(

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 63 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -593,8 +593,11 @@ def __call__(
593593
return_dict: bool = True,
594594
attention_kwargs: Optional[Dict[str, Any]] = None,
595595
deterministic: bool = True,
596-
rngs: nnx.Rngs = None,
597-
) -> Union[jax.Array, Dict[str, jax.Array]]:
596+
rngs: Optional[nnx.Rngs] = None,
597+
skip_blocks: Optional[jax.Array] = None,
598+
cached_residual: Optional[jax.Array] = None,
599+
return_residual: bool = False,
600+
) -> Union[jax.Array, Tuple[jax.Array, jax.Array], Dict[str, jax.Array]]:
598601
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
599602
batch_size, _, num_frames, height, width = hidden_states.shape
600603
p_t, p_h, p_w = self.config.patch_size
@@ -628,52 +631,69 @@ def __call__(
628631
encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1)
629632
encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype)
630633

631-
if self.scan_layers:
632-
633-
def scan_fn(carry, block):
634-
hidden_states_carry, rngs_carry = carry
635-
hidden_states = block(
636-
hidden_states_carry,
637-
encoder_hidden_states,
638-
timestep_proj,
639-
rotary_emb,
640-
deterministic,
641-
rngs_carry,
642-
encoder_attention_mask,
643-
)
644-
new_carry = (hidden_states, rngs_carry)
645-
return new_carry, None
646-
647-
rematted_block_forward = self.gradient_checkpoint.apply(
648-
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
649-
)
650-
initial_carry = (hidden_states, rngs)
651-
final_carry, _ = nnx.scan(
652-
rematted_block_forward,
653-
length=self.num_layers,
654-
in_axes=(nnx.Carry, 0),
655-
out_axes=(nnx.Carry, 0),
656-
)(initial_carry, self.blocks)
657-
658-
hidden_states, _ = final_carry
659-
else:
660-
for block in self.blocks:
634+
def _run_all_blocks(h):
635+
if self.scan_layers:
661636

662-
def layer_forward(hidden_states):
663-
return block(
664-
hidden_states,
637+
def scan_fn(carry, block):
638+
hidden_states_carry, rngs_carry = carry
639+
hidden_states = block(
640+
hidden_states_carry,
665641
encoder_hidden_states,
666642
timestep_proj,
667643
rotary_emb,
668644
deterministic,
669-
rngs,
670-
encoder_attention_mask=encoder_attention_mask,
645+
rngs_carry,
646+
encoder_attention_mask,
671647
)
648+
new_carry = (hidden_states, rngs_carry)
649+
return new_carry, None
672650

673-
rematted_layer_forward = self.gradient_checkpoint.apply(
674-
layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
651+
rematted_block_forward = self.gradient_checkpoint.apply(
652+
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
675653
)
676-
hidden_states = rematted_layer_forward(hidden_states)
654+
initial_carry = (h, rngs)
655+
final_carry, _ = nnx.scan(
656+
rematted_block_forward,
657+
length=self.num_layers,
658+
in_axes=(nnx.Carry, 0),
659+
out_axes=(nnx.Carry, 0),
660+
)(initial_carry, self.blocks)
661+
662+
h_out, _ = final_carry
663+
else:
664+
h_out = h
665+
for block in self.blocks:
666+
667+
def layer_forward(hidden_states):
668+
return block(
669+
hidden_states,
670+
encoder_hidden_states,
671+
timestep_proj,
672+
rotary_emb,
673+
deterministic,
674+
rngs,
675+
encoder_attention_mask=encoder_attention_mask,
676+
)
677+
678+
rematted_layer_forward = self.gradient_checkpoint.apply(
679+
layer_forward,
680+
self.names_which_can_be_saved,
681+
self.names_which_can_be_offloaded,
682+
prevent_cse=not self.scan_layers,
683+
)
684+
h_out = rematted_layer_forward(h_out)
685+
return h_out
686+
687+
hidden_states_before_blocks = hidden_states
688+
689+
if skip_blocks:
690+
if cached_residual is None:
691+
raise ValueError("cached_residual must be provided when skip_blocks is True")
692+
hidden_states = hidden_states + cached_residual
693+
else:
694+
hidden_states = _run_all_blocks(hidden_states)
695+
696+
residual_x = hidden_states - hidden_states_before_blocks
677697

678698
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
679699
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
@@ -685,4 +705,7 @@ def layer_forward(hidden_states):
685705
)
686706
hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6))
687707
hidden_states = hidden_states.reshape(batch_size, -1, num_frames, height, width)
708+
709+
if return_residual:
710+
return hidden_states, residual_x
688711
return hidden_states

0 commit comments

Comments
 (0)