diff --git a/src/whisper.cpp b/src/whisper.cpp index 796bccfb45d..41a29e868ff 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -2317,7 +2317,11 @@ static struct ggml_cgraph * whisper_build_graph_cross( struct ggml_tensor * k; struct ggml_tensor * v; - if (wctx.params.flash_attn) { + // Use non-flash layout for cross-attention KV when DTW is enabled, + // since DTW needs the explicit cross-attention weights (KQ_soft_max) + const bool flash_cross = wctx.params.flash_attn && !wctx.params.dtw_token_timestamps; + + if (flash_cross) { k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad)); @@ -2677,7 +2681,11 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), 0, 2, 1, 3); - if (wctx.params.flash_attn) { + // Use non-flash path for cross-attention when DTW is enabled, + // since DTW needs the explicit cross-attention weights (KQ_soft_max) + const bool flash_cross = wctx.params.flash_attn && !wctx.params.dtw_token_timestamps; + + if (flash_cross) { struct ggml_tensor * Kcross = ggml_view_3d(ctx0, wstate.kv_cross.k, n_state_head, n_audio_ctx_pad, n_head, @@ -3706,8 +3714,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ ggml_time_init(); if (params.flash_attn && params.dtw_token_timestamps) { - WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__); - params.dtw_token_timestamps = false; + WHISPER_LOG_INFO("%s: flash_attn with dtw - disabling flash attention for cross-attention only\n", __func__); } WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);