Skip to content

Commit 490dd70

Browse files
committed
fix(ltx2): resolve flash attention block size mismatch and missing config overrides
This commit addresses two issues in the LTX-2 pipeline: 1. Pipeline Config Overrides: Fixed a bug in `ltx2_pipeline.py` where `a2v_attention_kernel` and `v2a_attention_kernel` configurations were ignored. The model previously hardcoded a fallback to "flash" because these values were not mapped from the user config to `ltx2_config`. 2. Flash Attention Padding Mismatch: Fixed a `ValueError` (e.g., `kv_block_size=126 should divide kv_seq_len=128`) in `attention_flax.py` that occurred for specific video frame counts. A previous fix padded sequences to satisfy `shard_map` context dimension requirements, but `_select_flash_block_sizes` was calculating block sizes based on the unpadded length. Moved the block size calculation to occur *after* `_reshape_data_for_flash` so that the dynamic `min()` bounds correctly align with the newly padded sequence lengths, keeping cross-attention optimizations intact and unit tests passing.
1 parent ad6391a commit 490dd70

2 files changed

Lines changed: 3 additions & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,11 @@ def _tpu_flash_attention(
287287
) -> jax.Array:
288288
"""TPU Flash Attention"""
289289

290-
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
291290
num_context_shards = mesh.shape["context"]
292291
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
293292
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
294293
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
294+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel)
295295

296296
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
297297
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
127127
ltx2_config["dtype"] = config.activations_dtype
128128
ltx2_config["weights_dtype"] = config.weights_dtype
129129
ltx2_config["attention_kernel"] = config.attention
130+
ltx2_config["a2v_attention_kernel"] = getattr(config, "a2v_attention_kernel", "flash")
131+
ltx2_config["v2a_attention_kernel"] = getattr(config, "v2a_attention_kernel", "dot_product")
130132
ltx2_config["precision"] = get_precision(config)
131133
ltx2_config["flash_block_sizes"] = get_flash_block_sizes(config)
132134
ltx2_config["flash_min_seq_length"] = getattr(config, "flash_min_seq_length", 4096)

0 commit comments

Comments
 (0)