@@ -8566,30 +8566,30 @@ static void ggml_flash_attn_ext_reduce_partials(
85668566 const ggml_tensor * k = dst->src [1 ];
85678567 const ggml_tensor * v = dst->src [2 ];
85688568
8569- const int64_t DK = k->ne [0 ];
8570- const int64_t DV = v->ne [0 ];
8571- const int64_t nek1 = k->ne [1 ];
8569+ const int64_t DK = k->ne [0 ];
8570+ const int64_t DV = v->ne [0 ];
8571+ const int64_t nek1 = k->ne [1 ];
85728572 const int64_t n_q_heads = q->ne [2 ];
85738573
85748574 const int ith = params->ith ;
85758575 const int nth = params->nth ;
85768576
85778577 const int64_t wdata_per_thread = DK + 2 *DV + CACHE_LINE_SIZE_F32 ;
8578- float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
8578+ float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
85798579
8580- const int64_t partials_offset = nth * (DK + 2 *DV + CACHE_LINE_SIZE_F32 );
8581- const int64_t partial_size = 2 + DV ;
8582- const float * partials_base = (const float *) params->wdata + partials_offset;
8580+ const int64_t partials_offset = nth * (DK + 2 *DV + CACHE_LINE_SIZE_F32 );
8581+ const int64_t partial_size = 2 + DV ;
8582+ const float * partials_base = (const float *) params->wdata + partials_offset;
85838583
85848584 // Output layout
85858585 const int64_t ne1 = dst->ne [1 ];
85868586 const int64_t ne2 = dst->ne [2 ];
8587- const size_t nb1 = dst->nb [1 ];
8587+ const size_t nb1 = dst->nb [1 ];
85888588
85898589 // Each thread reduces a subset of query heads
85908590 for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
8591- float M_final = -INFINITY ;
8592- float S_final = 0 .0f ;
8591+ float M_final = -INFINITY ;
8592+ float S_final = 0 .0f ;
85938593 float * VKQ_final = thread_wdata;
85948594 memset (VKQ_final, 0 , DV * sizeof (float ));
85958595
@@ -8598,14 +8598,14 @@ static void ggml_flash_attn_ext_reduce_partials(
85988598 const int64_t ic_start = chunk_idx * chunk_size;
85998599 if (ic_start >= nek1) continue ;
86008600
8601- const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
8602- const float M_chunk = partial[0 ];
8603- const float S_chunk = partial[1 ];
8601+ const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
8602+ const float M_chunk = partial[0 ];
8603+ const float S_chunk = partial[1 ];
86048604 const float * VKQ_chunk = partial + 2 ;
86058605
86068606 if (S_chunk == 0 .0f ) continue ;
86078607
8608- const float M_new = fmaxf (M_final, M_chunk);
8608+ const float M_new = fmaxf (M_final, M_chunk);
86098609 const float scale_old = expf (M_final - M_new);
86108610 const float scale_new = expf (M_chunk - M_new);
86118611
@@ -8671,21 +8671,24 @@ static void ggml_compute_forward_flash_attn_ext_f16(
86718671 const int ith = params->ith ;
86728672 const int nth = params->nth ;
86738673
8674+ // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
8675+ const bool use_ref = params->use_ref ;
8676+
86748677 const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16 );
8675- const bool use_split_kv_path = (neq1 == 1 && neq3 == 1 ) && kv_is_f32_or_f16 && (k->type == v->type ) && q->type == GGML_TYPE_F32 && nek1 >= 512 ;
8678+ const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1 ) && kv_is_f32_or_f16 && (k->type == v->type ) && q->type == GGML_TYPE_F32 && nek1 >= 512 ;
86768679
86778680 if (use_split_kv_path) {
86788681 const int64_t chunk_size = (nek1 + nth - 1 ) / nth;
86798682
86808683 // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
8681- const int64_t partial_size = 2 + DV ;
8682- float * partials_base = (float *) params->wdata + nth * (DK + 2 *DV + CACHE_LINE_SIZE_F32 );
8684+ const int64_t partial_size = 2 + DV ;
8685+ float * partials_base = (float *) params->wdata + nth * (DK + 2 *DV + CACHE_LINE_SIZE_F32 );
86838686
86848687 const int64_t ic_start = ith * chunk_size;
86858688 const int64_t ic_end = std::min (ic_start + chunk_size, nek1);
86868689
86878690 const int64_t partial_stride = nth * partial_size;
8688- float * chunk_partials = partials_base + ith * partial_size;
8691+ float * chunk_partials = partials_base + ith * partial_size;
86898692
86908693 if (ic_start < nek1) {
86918694 for (int64_t q_head = 0 ; q_head < neq2; q_head++) {
@@ -8730,7 +8733,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
87308733
87318734 static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV ;
87328735 static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
8733- const bool use_tiled = (q->type == GGML_TYPE_F32 &&
8736+ const bool use_tiled = !use_ref &&
8737+ (q->type == GGML_TYPE_F32 &&
87348738 kv_is_f32_or_f16 &&
87358739 k->type == v->type &&
87368740 nek1 % KV_TILE_SZ == 0 &&
0 commit comments