Skip to content

Commit 79fd839

Browse files
committed
LTX2 Performance enhancements
1 parent ad6391a commit 79fd839

7 files changed

Lines changed: 314 additions & 82 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@
22
hardware: 'tpu'
33
skip_jax_distributed_system: False
44
attention: 'flash'
5-
a2v_attention_kernel: 'flash'
5+
a2v_attention_kernel: 'dot_product'
66
v2a_attention_kernel: 'dot_product'
77
attention_sharding_uniform: True
88
precision: 'bf16'
9+
10+
# For scanning transformer layers
911
scan_layers: True
12+
13+
# For scanning diffusion loop
14+
scan_diffusion_loop: True
15+
1016
names_which_can_be_saved: []
1117
names_which_can_be_offloaded: []
1218
remat_policy: "NONE"

src/maxdiffusion/models/attention_flax.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,15 @@ def _tpu_flash_attention(
287287
) -> jax.Array:
288288
"""TPU Flash Attention"""
289289

290+
<<<<<<< HEAD
290291
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
292+
=======
293+
>>>>>>> efbbdc84 (LTX2 Performance enhancements)
291294
num_context_shards = mesh.shape["context"]
292295
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
293296
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
294297
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
298+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
295299

296300
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
297301
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
@@ -892,7 +896,7 @@ def __init__(
892896
dtype=dtype,
893897
param_dtype=weights_dtype,
894898
precision=precision,
895-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
899+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "mlp")),
896900
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
897901
)
898902
self.act = get_activation(activation_fn)
@@ -904,8 +908,8 @@ def __init__(
904908
dtype=dtype,
905909
param_dtype=weights_dtype,
906910
precision=precision,
907-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", "embed")),
908-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
911+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", None)),
912+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
909913
)
910914

911915
def __call__(self, hidden_states: Array) -> Array:

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import jax.numpy as jnp
2121
from ... import common_types
2222
from ..attention_flax import NNXAttentionOp
23+
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
2324

2425
Array = common_types.Array
2526
Mesh = common_types.Mesh
@@ -349,23 +350,40 @@ def __init__(
349350
rope_type: str = "interleaved",
350351
flash_block_sizes: BlockSizes = None,
351352
flash_min_seq_length: int = 4096,
353+
qkv_sharding_spec: Optional[tuple] = None,
354+
out_sharding_spec: Optional[tuple] = None,
355+
out_bias_sharding_spec: Optional[tuple] = None,
352356
):
353357
self.heads = heads
354358
self.rope_type = rope_type
355359
self.dim_head = dim_head
356360
self.inner_dim = dim_head * heads
357361
self.dropout_rate = dropout
358362

363+
# Auto-detect hardware for sharding specs if not overridden
364+
tpu_type = get_tpu_type()
365+
is_ironwood = tpu_type == TpuType.TPU_7X
366+
367+
# Hardware-aware sharding: Ironwood (v7x) uses 1D sharding along the heads dimension (leaving the embedding dimension replicated)
368+
# to minimize cross-device communication, while other hardware defaults to 2D sharding along both heads and embed dimensions.
369+
# This has currently only been tested on Trillium (v6e) and Ironwood (v7x).
370+
if qkv_sharding_spec is None:
371+
qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads")
372+
if out_sharding_spec is None:
373+
out_sharding_spec = ("heads", None) if is_ironwood else ("heads", "embed")
374+
if out_bias_sharding_spec is None:
375+
out_bias_sharding_spec = (None,) if is_ironwood else ("embed",)
376+
359377
# 1. Define Partitioned Initializers (Logical Axes)
360378
# Q, K, V kernels: [in_features (embed), out_features (heads)]
361-
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads"))
379+
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), qkv_sharding_spec)
362380
# Q, K, V biases: [out_features (heads)]
363381
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
364382

365383
# Out kernel: [in_features (heads), out_features (embed)]
366-
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed"))
384+
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), out_sharding_spec)
367385
# Out bias: [out_features (embed)]
368-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
386+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), out_bias_sharding_spec)
369387

370388
# Norm scales
371389
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))

src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,12 @@ def __init__(self, in_channels: int, mid_channels: int = 1024, scale: float = 2.
165165
in_channels, (num**2) * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), rngs=rngs
166166
)
167167
self.pixel_shuffle = PixelShuffleND(dims=2, upscale_factors=(num, num))
168-
self.blur = BlurDownsample(dims=2, stride=den)
168+
self.blur_down = BlurDownsample(dims=2, stride=den)
169169

170170
def __call__(self, x: jax.Array) -> jax.Array:
171171
x = self.conv(x)
172172
x = self.pixel_shuffle(x)
173-
x = self.blur(x)
173+
x = self.blur_down(x)
174174
return x
175175

176176

0 commit comments

Comments
 (0)