Commit c9de967
committed
fix(ltx2): resolve flash attention block size mismatch and missing config overrides
This commit addresses two issues in the LTX-2 pipeline:
1. Pipeline Config Overrides:
Fixed a bug in `ltx2_pipeline.py` where `a2v_attention_kernel` and `v2a_attention_kernel` configurations were ignored. The model previously hardcoded a fallback to "flash" because these values were not mapped from the user config to `ltx2_config`.
2. Flash Attention Padding Mismatch:
Fixed a `ValueError` (e.g., `kv_block_size=126 should divide kv_seq_len=128`) in `attention_flax.py` that occurred for specific video frame counts (like 121 frames). A previous fix padded sequences to satisfy `shard_map` context dimension requirements, but `_select_flash_block_sizes` was still capping `block_kv` to the unpadded sequence length. Removed the `min()` bounds so block sizes align cleanly with the padded sequence lengths passed to the Splash attention kernel.1 parent ad6391a commit c9de967
2 files changed
Lines changed: 7 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
222 | 222 | | |
223 | 223 | | |
224 | 224 | | |
225 | | - | |
226 | | - | |
| 225 | + | |
| 226 | + | |
227 | 227 | | |
228 | | - | |
229 | | - | |
| 228 | + | |
| 229 | + | |
230 | 230 | | |
231 | | - | |
| 231 | + | |
232 | 232 | | |
233 | 233 | | |
234 | 234 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
127 | 127 | | |
128 | 128 | | |
129 | 129 | | |
| 130 | + | |
| 131 | + | |
130 | 132 | | |
131 | 133 | | |
132 | 134 | | |
| |||
0 commit comments