Skip to content

Commit 5f5697b

Browse files
Merge pull request #369 from AI-Hypercomputer:prisha/ltx2_fbs
PiperOrigin-RevId: 892728927
2 parents 9236e88 + e6adff6 commit 5f5697b

4 files changed

Lines changed: 26 additions & 1 deletion

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,16 @@ data_sharding: ['data', 'fsdp', 'context', 'tensor']
5858

5959
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
6060
dcn_fsdp_parallelism: -1
61+
62+
flash_block_sizes: {
63+
block_q: 2048,
64+
block_kv: 2048,
65+
block_kv_compute: 1024,
66+
block_q_dkv: 2048,
67+
block_kv_dkv: 2048,
68+
block_kv_dkv_compute: 2048,
69+
use_fused_bwd_kernel: True,
70+
}
6171
dcn_context_parallelism: 1
6272
dcn_tensor_parallelism: 1
6373
ici_data_parallelism: 1

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Array = common_types.Array
2424
Mesh = common_types.Mesh
2525
DType = common_types.DType
26+
BlockSizes = common_types.BlockSizes
2627

2728

2829
def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
@@ -345,6 +346,7 @@ def __init__(
345346
dtype: DType = jnp.float32,
346347
attention_kernel: str = "flash",
347348
rope_type: str = "interleaved",
349+
flash_block_sizes: BlockSizes = None,
348350
):
349351
self.heads = heads
350352
self.rope_type = rope_type
@@ -431,6 +433,7 @@ def __init__(
431433
dtype=dtype,
432434
axis_names_q=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV),
433435
axis_names_kv=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_KV_LENGTH, common_types.D_KV),
436+
flash_block_sizes=flash_block_sizes,
434437
)
435438

436439
def __call__(

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from maxdiffusion.models.embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings, NNXPixArtAlphaTextProjection
2525
from maxdiffusion.models.gradient_checkpoint import GradientCheckpointType
2626
from maxdiffusion.configuration_utils import ConfigMixin, register_to_config
27+
from maxdiffusion.common_types import BlockSizes
2728

2829

2930
class LTX2AdaLayerNormSingle(nnx.Module):
@@ -105,6 +106,7 @@ def __init__(
105106
names_which_can_be_saved: list = [],
106107
names_which_can_be_offloaded: list = [],
107108
attention_kernel: str = "flash",
109+
flash_block_sizes: BlockSizes = None,
108110
):
109111
self.dim = dim
110112
self.norm_eps = norm_eps
@@ -134,6 +136,7 @@ def __init__(
134136
mesh=mesh,
135137
attention_kernel=self.attention_kernel,
136138
rope_type=rope_type,
139+
flash_block_sizes=flash_block_sizes,
137140
)
138141

139142
self.audio_norm1 = nnx.RMSNorm(
@@ -158,6 +161,7 @@ def __init__(
158161
mesh=mesh,
159162
attention_kernel=self.attention_kernel,
160163
rope_type=rope_type,
164+
flash_block_sizes=flash_block_sizes,
161165
)
162166

163167
# 2. Prompt Cross-Attention
@@ -184,6 +188,7 @@ def __init__(
184188
mesh=mesh,
185189
attention_kernel=self.attention_kernel,
186190
rope_type=rope_type,
191+
flash_block_sizes=flash_block_sizes,
187192
)
188193

189194
self.audio_norm2 = nnx.RMSNorm(
@@ -209,6 +214,7 @@ def __init__(
209214
mesh=mesh,
210215
attention_kernel=self.attention_kernel,
211216
rope_type=rope_type,
217+
flash_block_sizes=flash_block_sizes,
212218
)
213219

214220
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -235,6 +241,7 @@ def __init__(
235241
mesh=mesh,
236242
attention_kernel=self.attention_kernel,
237243
rope_type=rope_type,
244+
flash_block_sizes=flash_block_sizes,
238245
)
239246

240247
self.video_to_audio_norm = nnx.RMSNorm(
@@ -260,6 +267,7 @@ def __init__(
260267
mesh=mesh,
261268
attention_kernel=self.attention_kernel,
262269
rope_type=rope_type,
270+
flash_block_sizes=flash_block_sizes,
263271
)
264272

265273
# 4. Feed Forward
@@ -553,6 +561,7 @@ def __init__(
553561
scan_layers: bool = True,
554562
attention_kernel: str = "flash",
555563
qk_norm: str = "rms_norm_across_heads",
564+
flash_block_sizes: BlockSizes = None,
556565
**kwargs,
557566
):
558567
self.in_channels = in_channels
@@ -791,6 +800,7 @@ def init_block(rngs):
791800
names_which_can_be_saved=self.names_which_can_be_saved,
792801
names_which_can_be_offloaded=self.names_which_can_be_offloaded,
793802
attention_kernel=self.attention_kernel,
803+
flash_block_sizes=flash_block_sizes,
794804
)
795805

796806
if self.scan_layers:
@@ -822,6 +832,7 @@ def init_block(rngs):
822832
names_which_can_be_saved=self.names_which_can_be_saved,
823833
names_which_can_be_offloaded=self.names_which_can_be_offloaded,
824834
attention_kernel=self.attention_kernel,
835+
flash_block_sizes=flash_block_sizes,
825836
)
826837
blocks.append(block)
827838
self.transformer_blocks = nnx.List(blocks)

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from ...pyconfig import HyperParameters
4848
from ... import max_logging
4949
from ... import max_utils
50-
from ...max_utils import get_precision, device_put_replicated
50+
from ...max_utils import get_precision, device_put_replicated, get_flash_block_sizes
5151
from maxdiffusion.maxdiffusion_utils import get_dummy_ltx2_inputs
5252

5353

@@ -124,6 +124,7 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
124124
ltx2_config["weights_dtype"] = config.weights_dtype
125125
ltx2_config["attention_kernel"] = config.attention
126126
ltx2_config["precision"] = get_precision(config)
127+
ltx2_config["flash_block_sizes"] = get_flash_block_sizes(config)
127128
ltx2_config["remat_policy"] = config.remat_policy
128129
ltx2_config["names_which_can_be_saved"] = config.names_which_can_be_saved
129130
ltx2_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded

0 commit comments

Comments
 (0)