Skip to content

Commit 8c19a42

Browse files
committed
simplify sinks application
1 parent 353c85f commit 8c19a42

1 file changed

Lines changed: 3 additions & 17 deletions

File tree

ggml/src/ggml-cpu/ops.cpp

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8239,15 +8239,16 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
82398239
}
82408240
}
82418241

8242-
// sinks - skip when writing partials, reduce function will apply once
8243-
if (sinks && !write_partials) {
8242+
// sinks - apply only on the first kv-chunk
8243+
if (sinks && ic_start == 0) {
82448244
const float s = ((float *)((char *) sinks->data))[h];
82458245

82468246
float ms = 1.0f;
82478247
float vs = 1.0f;
82488248

82498249
if (s > M) {
82508250
ms = expf(M - s);
8251+
M = s;
82518252
ggml_vec_scale_f32(DV, VKQ32, ms);
82528253
} else {
82538254
vs = expf(s - M);
@@ -8564,7 +8565,6 @@ static void ggml_flash_attn_ext_reduce_partials(
85648565
const ggml_tensor * q = dst->src[0];
85658566
const ggml_tensor * k = dst->src[1];
85668567
const ggml_tensor * v = dst->src[2];
8567-
const ggml_tensor * sinks = dst->src[4];
85688568

85698569
const int64_t DK = k->ne[0];
85708570
const int64_t DV = v->ne[0];
@@ -8616,20 +8616,6 @@ static void ggml_flash_attn_ext_reduce_partials(
86168616
M_final = M_new;
86178617
}
86188618

8619-
// Apply sinks once after combining all chunks
8620-
if (sinks) {
8621-
const float s = ((float *) sinks->data)[q_head];
8622-
8623-
if (s > M_final) {
8624-
const float ms = expf(M_final - s);
8625-
ggml_vec_scale_f32(DV, VKQ_final, ms);
8626-
S_final = S_final * ms + 1.0f;
8627-
M_final = s;
8628-
} else {
8629-
S_final = S_final + expf(s - M_final);
8630-
}
8631-
}
8632-
86338619
// Normalize and write to output
86348620
if (S_final != 0.0f) {
86358621
const float S_inv = 1.0f / S_final;

0 commit comments

Comments
 (0)