Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2317,7 +2317,11 @@ static struct ggml_cgraph * whisper_build_graph_cross(
struct ggml_tensor * k;
struct ggml_tensor * v;

if (wctx.params.flash_attn) {
// Use non-flash layout for cross-attention KV when DTW is enabled,
// since DTW needs the explicit cross-attention weights (KQ_soft_max)
const bool flash_cross = wctx.params.flash_attn && !wctx.params.dtw_token_timestamps;

if (flash_cross) {
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));

Expand Down Expand Up @@ -2677,7 +2681,11 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
0, 2, 1, 3);

if (wctx.params.flash_attn) {
// Use non-flash path for cross-attention when DTW is enabled,
// since DTW needs the explicit cross-attention weights (KQ_soft_max)
const bool flash_cross = wctx.params.flash_attn && !wctx.params.dtw_token_timestamps;

if (flash_cross) {
struct ggml_tensor * Kcross =
ggml_view_3d(ctx0, wstate.kv_cross.k,
n_state_head, n_audio_ctx_pad, n_head,
Expand Down Expand Up @@ -3706,8 +3714,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
ggml_time_init();

if (params.flash_attn && params.dtw_token_timestamps) {
WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
params.dtw_token_timestamps = false;
WHISPER_LOG_INFO("%s: flash_attn with dtw - disabling flash attention for cross-attention only\n", __func__);
}

WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
Expand Down