Skip to content

Commit c9de967

Browse files
committed
fix(ltx2): resolve flash attention block size mismatch and missing config overrides
This commit addresses two issues in the LTX-2 pipeline: 1. Pipeline Config Overrides: Fixed a bug in `ltx2_pipeline.py` where `a2v_attention_kernel` and `v2a_attention_kernel` configurations were ignored. The model previously hardcoded a fallback to "flash" because these values were not mapped from the user config to `ltx2_config`. 2. Flash Attention Padding Mismatch: Fixed a `ValueError` (e.g., `kv_block_size=126 should divide kv_seq_len=128`) in `attention_flax.py` that occurred for specific video frame counts (like 121 frames). A previous fix padded sequences to satisfy `shard_map` context dimension requirements, but `_select_flash_block_sizes` was still capping `block_kv` to the unpadded sequence length. Removed the `min()` bounds so block sizes align cleanly with the padded sequence lengths passed to the Splash attention kernel.
1 parent ad6391a commit c9de967

2 files changed

Lines changed: 7 additions & 5 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,13 @@ def _select_flash_block_sizes(
222222
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
223223
return splash_attention_kernel.BlockSizes(
224224
block_q=block_size_q,
225-
block_kv_compute=min(kv_max_block_size, key_seq_len),
226-
block_kv=min(kv_max_block_size, key_seq_len),
225+
block_kv_compute=kv_max_block_size,
226+
block_kv=kv_max_block_size,
227227
block_q_dkv=block_size_q,
228-
block_kv_dkv=min(kv_max_block_size, key_seq_len),
229-
block_kv_dkv_compute=min(kv_max_block_size, query_seq_len),
228+
block_kv_dkv=kv_max_block_size,
229+
block_kv_dkv_compute=kv_max_block_size,
230230
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
231-
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query_seq_len),
231+
block_kv_dq=None if attention_kernel == "tokamax_flash" else kv_max_block_size,
232232
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
233233
)
234234

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
127127
ltx2_config["dtype"] = config.activations_dtype
128128
ltx2_config["weights_dtype"] = config.weights_dtype
129129
ltx2_config["attention_kernel"] = config.attention
130+
ltx2_config["a2v_attention_kernel"] = getattr(config, "a2v_attention_kernel", "flash")
131+
ltx2_config["v2a_attention_kernel"] = getattr(config, "v2a_attention_kernel", "dot_product")
130132
ltx2_config["precision"] = get_precision(config)
131133
ltx2_config["flash_block_sizes"] = get_flash_block_sizes(config)
132134
ltx2_config["flash_min_seq_length"] = getattr(config, "flash_min_seq_length", 4096)

0 commit comments

Comments
 (0)