Commit 490dd70
committed
fix(ltx2): resolve flash attention block size mismatch and missing config overrides
This commit addresses two issues in the LTX-2 pipeline:
1. Pipeline Config Overrides:
Fixed a bug in `ltx2_pipeline.py` where `a2v_attention_kernel` and `v2a_attention_kernel` configurations were ignored. The model previously hardcoded a fallback to "flash" because these values were not mapped from the user config to `ltx2_config`.
2. Flash Attention Padding Mismatch:
Fixed a `ValueError` (e.g., `kv_block_size=126 should divide kv_seq_len=128`) in `attention_flax.py` that occurred for specific video frame counts. A previous fix padded sequences to satisfy `shard_map` context dimension requirements, but `_select_flash_block_sizes` was calculating block sizes based on the unpadded length. Moved the block size calculation to occur *after* `_reshape_data_for_flash` so that the dynamic `min()` bounds correctly align with the newly padded sequence lengths, keeping cross-attention optimizations intact and unit tests passing.1 parent ad6391a commit 490dd70
2 files changed
Lines changed: 3 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
287 | 287 | | |
288 | 288 | | |
289 | 289 | | |
290 | | - | |
291 | 290 | | |
292 | 291 | | |
293 | 292 | | |
294 | 293 | | |
| 294 | + | |
295 | 295 | | |
296 | 296 | | |
297 | 297 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
127 | 127 | | |
128 | 128 | | |
129 | 129 | | |
| 130 | + | |
| 131 | + | |
130 | 132 | | |
131 | 133 | | |
132 | 134 | | |
| |||
0 commit comments