Skip to content

Commit b224fd2

Browse files
Merge pull request #385 from AI-Hypercomputer:prisha/ltx2_opt
PiperOrigin-RevId: 903293227
2 parents 0b6410b + 77973e3 commit b224fd2

7 files changed

Lines changed: 308 additions & 82 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@
22
hardware: 'tpu'
33
skip_jax_distributed_system: False
44
attention: 'flash'
5-
a2v_attention_kernel: 'flash'
5+
a2v_attention_kernel: 'dot_product'
66
v2a_attention_kernel: 'dot_product'
77
attention_sharding_uniform: True
88
precision: 'bf16'
9+
10+
# For scanning transformer layers
911
scan_layers: True
12+
13+
# For scanning diffusion loop
14+
scan_diffusion_loop: True
15+
1016
names_which_can_be_saved: []
1117
names_which_can_be_offloaded: []
1218
remat_policy: "NONE"

src/maxdiffusion/models/attention_flax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ def __init__(
939939
dtype=dtype,
940940
param_dtype=weights_dtype,
941941
precision=precision,
942-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")),
942+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "mlp")),
943943
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
944944
)
945945
self.act = get_activation(activation_fn)
@@ -951,8 +951,8 @@ def __init__(
951951
dtype=dtype,
952952
param_dtype=weights_dtype,
953953
precision=precision,
954-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", "embed")),
955-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
954+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", None)),
955+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)),
956956
)
957957

958958
def __call__(self, hidden_states: Array) -> Array:

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import jax.numpy as jnp
2121
from ... import common_types
2222
from ..attention_flax import NNXAttentionOp
23+
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
2324

2425
Array = common_types.Array
2526
Mesh = common_types.Mesh
@@ -349,23 +350,40 @@ def __init__(
349350
rope_type: str = "interleaved",
350351
flash_block_sizes: BlockSizes = None,
351352
flash_min_seq_length: int = 4096,
353+
qkv_sharding_spec: Optional[tuple] = None,
354+
out_sharding_spec: Optional[tuple] = None,
355+
out_bias_sharding_spec: Optional[tuple] = None,
352356
):
353357
self.heads = heads
354358
self.rope_type = rope_type
355359
self.dim_head = dim_head
356360
self.inner_dim = dim_head * heads
357361
self.dropout_rate = dropout
358362

363+
# Auto-detect hardware for sharding specs if not overridden
364+
tpu_type = get_tpu_type()
365+
is_ironwood = tpu_type == TpuType.TPU_7X
366+
367+
# Hardware-aware sharding: Ironwood (v7x) uses 1D sharding along the heads dimension (leaving the embedding dimension replicated)
368+
# to minimize cross-device communication, while other hardware defaults to 2D sharding along both heads and embed dimensions.
369+
# This has currently only been tested on Trillium (v6e) and Ironwood (v7x).
370+
if qkv_sharding_spec is None:
371+
qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads")
372+
if out_sharding_spec is None:
373+
out_sharding_spec = ("heads", None) if is_ironwood else ("heads", "embed")
374+
if out_bias_sharding_spec is None:
375+
out_bias_sharding_spec = (None,) if is_ironwood else ("embed",)
376+
359377
# 1. Define Partitioned Initializers (Logical Axes)
360378
# Q, K, V kernels: [in_features (embed), out_features (heads)]
361-
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads"))
379+
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), qkv_sharding_spec)
362380
# Q, K, V biases: [out_features (heads)]
363381
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
364382

365383
# Out kernel: [in_features (heads), out_features (embed)]
366-
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed"))
384+
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), out_sharding_spec)
367385
# Out bias: [out_features (embed)]
368-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
386+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), out_bias_sharding_spec)
369387

370388
# Norm scales
371389
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))

src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,12 @@ def __init__(self, in_channels: int, mid_channels: int = 1024, scale: float = 2.
165165
in_channels, (num**2) * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), rngs=rngs
166166
)
167167
self.pixel_shuffle = PixelShuffleND(dims=2, upscale_factors=(num, num))
168-
self.blur = BlurDownsample(dims=2, stride=den)
168+
self.blur_down = BlurDownsample(dims=2, stride=den)
169169

170170
def __call__(self, x: jax.Array) -> jax.Array:
171171
x = self.conv(x)
172172
x = self.pixel_shuffle(x)
173-
x = self.blur(x)
173+
x = self.blur_down(x)
174174
return x
175175

176176

0 commit comments

Comments
 (0)