Skip to content

fix: allow flash attention in encoder when DTW is enabled#3704

Open
Acelogic wants to merge 1 commit intoggml-org:masterfrom
Acelogic:fix/dtw-flash-attn-encoder
Open

fix: allow flash attention in encoder when DTW is enabled#3704
Acelogic wants to merge 1 commit intoggml-org:masterfrom
Acelogic:fix/dtw-flash-attn-encoder

Conversation

@Acelogic
Copy link
Copy Markdown

Summary

  • DTW token timestamps no longer silently disabled when flash attention is enabled
  • Flash attention remains active for encoder self-attention and decoder self-attention
  • Only cross-attention falls back to the non-flash path (needed to extract KQ_soft_max weights for DTW)

Details

Previously, whisper_init_with_params_no_state disabled DTW entirely when flash_attn was set:

if (params.flash_attn && params.dtw_token_timestamps) {
    params.dtw_token_timestamps = false; // DTW silently disabled
}

DTW only needs the explicit cross-attention weights (KQ_soft_max) from the decoder, which the flash attention path doesn't produce (it fuses QKV into one operation). The encoder self-attention and decoder self-attention don't interact with DTW at all.

The fix introduces a flash_cross flag (flash_attn && !dtw_token_timestamps) used at two matching locations:

  1. Encoder (cross-attention KV storage layout) — determines how K/V are stored in kv_cross
  2. Decoder (cross-attention computation) — determines whether to use the fused flash path or the explicit K*Q → softmax → KQV path

Both locations must use the same condition to keep the KV cache layout consistent.

Test plan

  • Builds cleanly on macOS with Metal
  • Functional testing requires a whisper model with DTW-compatible alignment heads

Fixes #3662

Previously, enabling DTW token timestamps with flash attention
caused DTW to be silently disabled entirely. DTW only needs the
explicit cross-attention weights (KQ_soft_max) from the decoder,
so flash attention can remain enabled for:
- encoder self-attention
- decoder self-attention

Only the cross-attention path in both the encoder (KV storage) and
decoder (KQ computation) needs to fall back to the non-flash path
when DTW is active, since flash attention fuses the entire attention
operation and never materializes KQ_soft_max.

This allows DTW timestamps to work alongside flash attention with
no encoder performance penalty.

Fixes ggml-org#3662
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Only disable flash attention in the decoder when dtw flag is enabled

1 participant