Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ data_sharding: ['data', 'fsdp', 'context', 'tensor']

dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1

flash_block_sizes: {
block_q: 2048,
block_kv: 2048,
block_kv_compute: 1024,
block_q_dkv: 2048,
block_kv_dkv: 2048,
block_kv_dkv_compute: 2048,
use_fused_bwd_kernel: True,
}
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
Expand Down
15 changes: 6 additions & 9 deletions src/maxdiffusion/models/ltx2/attention_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Array = common_types.Array
Mesh = common_types.Mesh
DType = common_types.DType
BlockSizes = common_types.BlockSizes


def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
Expand Down Expand Up @@ -193,9 +194,7 @@ def prepare_video_coords(
# pixel_coords[:, 0, ...] selects Frame dimension.
# pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
frame_coords = pixel_coords[:, 0, ...]
frame_coords = jnp.clip(
frame_coords + self.causal_offset - self.scale_factors[0], min=0
)
frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], min=0)
pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps)

return pixel_coords
Expand All @@ -212,16 +211,12 @@ def prepare_audio_coords(
# 2. Start timestamps
audio_scale_factor = self.scale_factors[0]
grid_start_mel = grid_f * audio_scale_factor
grid_start_mel = jnp.clip(
grid_start_mel + self.causal_offset - audio_scale_factor, min=0
)
grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, min=0)
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate

# 3. End timestamps
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
grid_end_mel = jnp.clip(
grid_end_mel + self.causal_offset - audio_scale_factor, min=0
)
grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, min=0)
grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate

# Stack [num_patches, 2]
Expand Down Expand Up @@ -351,6 +346,7 @@ def __init__(
dtype: DType = jnp.float32,
attention_kernel: str = "flash",
rope_type: str = "interleaved",
flash_block_sizes: BlockSizes = None,
):
self.heads = heads
self.rope_type = rope_type
Expand Down Expand Up @@ -437,6 +433,7 @@ def __init__(
dtype=dtype,
axis_names_q=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV),
axis_names_kv=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_KV_LENGTH, common_types.D_KV),
flash_block_sizes=flash_block_sizes,
)

def __call__(
Expand Down
11 changes: 11 additions & 0 deletions src/maxdiffusion/models/ltx2/transformer_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from maxdiffusion.models.embeddings_flax import NNXPixArtAlphaCombinedTimestepSizeEmbeddings, NNXPixArtAlphaTextProjection
from maxdiffusion.models.gradient_checkpoint import GradientCheckpointType
from maxdiffusion.configuration_utils import ConfigMixin, register_to_config
from maxdiffusion.common_types import BlockSizes


class LTX2AdaLayerNormSingle(nnx.Module):
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
names_which_can_be_saved: list = [],
names_which_can_be_offloaded: list = [],
attention_kernel: str = "flash",
flash_block_sizes: BlockSizes = None,
):
self.dim = dim
self.norm_eps = norm_eps
Expand Down Expand Up @@ -134,6 +136,7 @@ def __init__(
mesh=mesh,
attention_kernel=self.attention_kernel,
rope_type=rope_type,
flash_block_sizes=flash_block_sizes,
)

self.audio_norm1 = nnx.RMSNorm(
Expand All @@ -158,6 +161,7 @@ def __init__(
mesh=mesh,
attention_kernel=self.attention_kernel,
rope_type=rope_type,
flash_block_sizes=flash_block_sizes,
)

# 2. Prompt Cross-Attention
Expand All @@ -184,6 +188,7 @@ def __init__(
mesh=mesh,
attention_kernel=self.attention_kernel,
rope_type=rope_type,
flash_block_sizes=flash_block_sizes,
)

self.audio_norm2 = nnx.RMSNorm(
Expand All @@ -209,6 +214,7 @@ def __init__(
mesh=mesh,
attention_kernel=self.attention_kernel,
rope_type=rope_type,
flash_block_sizes=flash_block_sizes,
)

# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
Expand All @@ -235,6 +241,7 @@ def __init__(
mesh=mesh,
attention_kernel=self.attention_kernel,
rope_type=rope_type,
flash_block_sizes=flash_block_sizes,
)

self.video_to_audio_norm = nnx.RMSNorm(
Expand All @@ -260,6 +267,7 @@ def __init__(
mesh=mesh,
attention_kernel=self.attention_kernel,
rope_type=rope_type,
flash_block_sizes=flash_block_sizes,
)

# 4. Feed Forward
Expand Down Expand Up @@ -553,6 +561,7 @@ def __init__(
scan_layers: bool = True,
attention_kernel: str = "flash",
qk_norm: str = "rms_norm_across_heads",
flash_block_sizes: BlockSizes = None,
**kwargs,
):
self.in_channels = in_channels
Expand Down Expand Up @@ -791,6 +800,7 @@ def init_block(rngs):
names_which_can_be_saved=self.names_which_can_be_saved,
names_which_can_be_offloaded=self.names_which_can_be_offloaded,
attention_kernel=self.attention_kernel,
flash_block_sizes=flash_block_sizes,
)

if self.scan_layers:
Expand Down Expand Up @@ -822,6 +832,7 @@ def init_block(rngs):
names_which_can_be_saved=self.names_which_can_be_saved,
names_which_can_be_offloaded=self.names_which_can_be_offloaded,
attention_kernel=self.attention_kernel,
flash_block_sizes=flash_block_sizes,
)
blocks.append(block)
self.transformer_blocks = nnx.List(blocks)
Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ...pyconfig import HyperParameters
from ... import max_logging
from ... import max_utils
from ...max_utils import get_precision, device_put_replicated
from ...max_utils import get_precision, device_put_replicated, get_flash_block_sizes
from maxdiffusion.maxdiffusion_utils import get_dummy_ltx2_inputs


Expand Down Expand Up @@ -124,6 +124,7 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
ltx2_config["weights_dtype"] = config.weights_dtype
ltx2_config["attention_kernel"] = config.attention
ltx2_config["precision"] = get_precision(config)
ltx2_config["flash_block_sizes"] = get_flash_block_sizes(config)
ltx2_config["remat_policy"] = config.remat_policy
ltx2_config["names_which_can_be_saved"] = config.names_which_can_be_saved
ltx2_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded
Expand Down
Loading