@@ -8239,15 +8239,16 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
82398239 }
82408240 }
82418241
8242- // sinks - skip when writing partials, reduce function will apply once
8243- if (sinks && !write_partials ) {
8242+ // sinks - apply only on the first kv-chunk
8243+ if (sinks && ic_start == 0 ) {
82448244 const float s = ((float *)((char *) sinks->data ))[h];
82458245
82468246 float ms = 1 .0f ;
82478247 float vs = 1 .0f ;
82488248
82498249 if (s > M) {
82508250 ms = expf (M - s);
8251+ M = s;
82518252 ggml_vec_scale_f32 (DV , VKQ32 , ms);
82528253 } else {
82538254 vs = expf (s - M);
@@ -8564,7 +8565,6 @@ static void ggml_flash_attn_ext_reduce_partials(
85648565 const ggml_tensor * q = dst->src [0 ];
85658566 const ggml_tensor * k = dst->src [1 ];
85668567 const ggml_tensor * v = dst->src [2 ];
8567- const ggml_tensor * sinks = dst->src [4 ];
85688568
85698569 const int64_t DK = k->ne [0 ];
85708570 const int64_t DV = v->ne [0 ];
@@ -8616,20 +8616,6 @@ static void ggml_flash_attn_ext_reduce_partials(
86168616 M_final = M_new;
86178617 }
86188618
8619- // Apply sinks once after combining all chunks
8620- if (sinks) {
8621- const float s = ((float *) sinks->data )[q_head];
8622-
8623- if (s > M_final) {
8624- const float ms = expf (M_final - s);
8625- ggml_vec_scale_f32 (DV , VKQ_final, ms);
8626- S_final = S_final * ms + 1 .0f ;
8627- M_final = s;
8628- } else {
8629- S_final = S_final + expf (s - M_final);
8630- }
8631- }
8632-
86338619 // Normalize and write to output
86348620 if (S_final != 0 .0f ) {
86358621 const float S_inv = 1 .0f / S_final;
0 commit comments