@@ -8042,12 +8042,14 @@ void ggml_compute_forward_top_k(
80428042 }
80438043}
80448044
8045- // ggml_compute_forward_flash_attn_ext
8046-
80478045static void ggml_compute_forward_flash_attn_ext_f16_one_chunk (
80488046 const ggml_compute_params * params,
80498047 ggml_tensor * dst,
8050- int ir0, int ir1) {
8048+ int ir0, int ir1,
8049+ int64_t ic_start, int64_t ic_end,
8050+ float * partials, int64_t partial_stride) {
8051+
8052+ const bool write_partials = (partials != nullptr );
80518053 const ggml_tensor * q = dst->src [0 ];
80528054 const ggml_tensor * k = dst->src [1 ];
80538055 const ggml_tensor * v = dst->src [2 ];
@@ -8124,7 +8126,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
81248126
81258127 int ith = params->ith ;
81268128
8127- // loop over n_batch and n_head
81288129 for (int ir = ir0; ir < ir1; ++ir) {
81298130 // q indices
81308131 const int iq3 = ir/(neq2*neq1);
@@ -8165,7 +8166,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
81658166 // loop over n_kv and n_head_kv
81668167 // ref: https://arxiv.org/pdf/2112.05682.pdf
81678168
8168- for (int64_t ic = 0 ; ic < nek1 ; ++ic) {
8169+ for (int64_t ic = ic_start ; ic < ic_end ; ++ic) {
81698170 const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32 (mp[ic]) : 0 .0f ;
81708171 if (mv == -INFINITY ) {
81718172 continue ;
@@ -8238,8 +8239,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
82388239 }
82398240 }
82408241
8241- // sinks
8242- if (sinks) {
8242+ // sinks - skip when writing partials, reduce function will apply once
8243+ if (sinks && !write_partials ) {
82438244 const float s = ((float *)((char *) sinks->data ))[h];
82448245
82458246 float ms = 1 .0f ;
@@ -8255,20 +8256,26 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
82558256 S = S*ms + vs;
82568257 }
82578258
8258- // V /= S
8259- const float S_inv = S == 0 .0f ? 0 .0f : 1 .0f /S;
8260- ggml_vec_scale_f32 (DV , VKQ32 , S_inv);
8261-
8262- // dst indices
8263- const int i1 = iq1;
8264- const int i2 = iq2;
8265- const int i3 = iq3;
8259+ if (write_partials) {
8260+ // Write M, S, VKQ to partials for later reduction
8261+ // partials layout: [M, S, VKQ[DV]] per query head
8262+ float * partial = partials + ir * partial_stride;
8263+ partial[0 ] = M;
8264+ partial[1 ] = S;
8265+ memcpy (partial + 2 , VKQ32 , DV * sizeof (float ));
8266+ } else {
8267+ // V /= S
8268+ const float S_inv = S == 0 .0f ? 0 .0f : 1 .0f /S;
8269+ ggml_vec_scale_f32 (DV , VKQ32 , S_inv);
82668270
8267- // original
8268- // memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
8271+ // dst indices
8272+ const int i1 = iq1;
8273+ const int i2 = iq2;
8274+ const int i3 = iq3;
82698275
8270- // permute(0, 2, 1, 3)
8271- memcpy ((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 , nb1);
8276+ // permute(0, 2, 1, 3)
8277+ memcpy ((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 , nb1);
8278+ }
82728279 }
82738280}
82748281
@@ -8546,6 +8553,93 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
85468553 }
85478554}
85488555
8556+ // Reduction function: combines partial results across KV chunks
8557+ // Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
8558+ static void ggml_flash_attn_ext_reduce_partials (
8559+ const ggml_compute_params * params,
8560+ ggml_tensor * dst,
8561+ const int64_t n_chunks,
8562+ const int64_t chunk_size) {
8563+
8564+ const ggml_tensor * q = dst->src [0 ];
8565+ const ggml_tensor * k = dst->src [1 ];
8566+ const ggml_tensor * v = dst->src [2 ];
8567+ const ggml_tensor * sinks = dst->src [4 ];
8568+
8569+ const int64_t DK = k->ne [0 ];
8570+ const int64_t DV = v->ne [0 ];
8571+ const int64_t nek1 = k->ne [1 ];
8572+ const int64_t n_q_heads = q->ne [2 ];
8573+
8574+ const int ith = params->ith ;
8575+ const int nth = params->nth ;
8576+
8577+ const int64_t wdata_per_thread = DK + 2 *DV + CACHE_LINE_SIZE_F32 ;
8578+ float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
8579+
8580+ const int64_t partials_offset = nth * (DK + 2 *DV + CACHE_LINE_SIZE_F32 );
8581+ const int64_t partial_size = 2 + DV ;
8582+ const float * partials_base = (const float *) params->wdata + partials_offset;
8583+
8584+ // Output layout
8585+ const int64_t ne1 = dst->ne [1 ];
8586+ const int64_t ne2 = dst->ne [2 ];
8587+ const size_t nb1 = dst->nb [1 ];
8588+
8589+ // Each thread reduces a subset of query heads
8590+ for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
8591+ float M_final = -INFINITY ;
8592+ float S_final = 0 .0f ;
8593+ float * VKQ_final = thread_wdata;
8594+ memset (VKQ_final, 0 , DV * sizeof (float ));
8595+
8596+ // Combine partials from all chunks
8597+ for (int64_t chunk_idx = 0 ; chunk_idx < n_chunks; ++chunk_idx) {
8598+ const int64_t ic_start = chunk_idx * chunk_size;
8599+ if (ic_start >= nek1) continue ;
8600+
8601+ const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
8602+ const float M_chunk = partial[0 ];
8603+ const float S_chunk = partial[1 ];
8604+ const float * VKQ_chunk = partial + 2 ;
8605+
8606+ if (S_chunk == 0 .0f ) continue ;
8607+
8608+ const float M_new = fmaxf (M_final, M_chunk);
8609+ const float scale_old = expf (M_final - M_new);
8610+ const float scale_new = expf (M_chunk - M_new);
8611+
8612+ for (int64_t d = 0 ; d < DV ; ++d) {
8613+ VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
8614+ }
8615+ S_final = S_final * scale_old + S_chunk * scale_new;
8616+ M_final = M_new;
8617+ }
8618+
8619+ // Apply sinks once after combining all chunks
8620+ if (sinks) {
8621+ const float s = ((float *) sinks->data )[q_head];
8622+
8623+ if (s > M_final) {
8624+ const float ms = expf (M_final - s);
8625+ ggml_vec_scale_f32 (DV , VKQ_final, ms);
8626+ S_final = S_final * ms + 1 .0f ;
8627+ M_final = s;
8628+ } else {
8629+ S_final = S_final + expf (s - M_final);
8630+ }
8631+ }
8632+
8633+ // Normalize and write to output
8634+ if (S_final != 0 .0f ) {
8635+ const float S_inv = 1 .0f / S_final;
8636+ ggml_vec_scale_f32 (DV , VKQ_final, S_inv);
8637+ }
8638+ // iq1=0, iq3=0 for decode
8639+ memcpy ((char *) dst->data + (0 *ne2*ne1 + q_head + 0 *ne1)*nb1, VKQ_final, nb1);
8640+ }
8641+ }
8642+
85498643static void ggml_compute_forward_flash_attn_ext_f16 (
85508644 const ggml_compute_params * params,
85518645 ggml_tensor * dst) {
@@ -8567,6 +8661,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
85678661 const int64_t DV = nev0;
85688662 const int64_t N = neq1;
85698663
8664+
85708665 GGML_ASSERT (ne0 == DV );
85718666 GGML_ASSERT (ne2 == N);
85728667
@@ -8587,60 +8682,88 @@ static void ggml_compute_forward_flash_attn_ext_f16(
85878682 GGML_ASSERT (nb1 <= nb2);
85888683 GGML_ASSERT (nb2 <= nb3);
85898684
8590- // parallelize by q rows using ggml_vec_dot_f32
8591-
8592- // total rows in q
8593- const int64_t nr = neq1*neq2*neq3;
8594-
8595- // rows per thread
85968685 const int ith = params->ith ;
85978686 const int nth = params->nth ;
85988687
8599- // disable for NUMA
8600- const bool disable_chunking = ggml_is_numa () ;
8688+ const bool kv_is_f32_or_f16 = (k-> type == GGML_TYPE_F32 || k-> type == GGML_TYPE_F16 );
8689+ const bool use_split_kv_path = (neq1 == 1 && neq3 == 1 ) && kv_is_f32_or_f16 && (k-> type == v-> type ) && q-> type == GGML_TYPE_F32 && nek1 >= 512 ;
86018690
8602- // 4x chunks per thread
8603- int nth_scaled = nth * 4 ;
8604- int64_t chunk_size = (nr + nth_scaled - 1 ) / nth_scaled;
8605- int64_t nchunk = (nr + chunk_size - 1 ) / chunk_size;
8691+ if (use_split_kv_path) {
8692+ const int64_t chunk_size = (nek1 + nth - 1 ) / nth;
86068693
8607- if (nth == 1 || nchunk < nth || disable_chunking) {
8608- nchunk = nth ;
8609- }
8694+ // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
8695+ const int64_t partial_size = 2 + DV ;
8696+ float * partials_base = ( float *) params-> wdata + nth * ( DK + 2 * DV + CACHE_LINE_SIZE_F32 );
86108697
8611- if (ith == 0 ) {
8612- // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
8613- ggml_threadpool_chunk_set (params->threadpool , nth);
8614- }
8698+ const int64_t ic_start = ith * chunk_size;
8699+ const int64_t ic_end = std::min (ic_start + chunk_size, nek1);
86158700
8616- ggml_barrier (params->threadpool );
8701+ const int64_t partial_stride = nth * partial_size;
8702+ float * chunk_partials = partials_base + ith * partial_size;
8703+
8704+ if (ic_start < nek1) {
8705+ for (int64_t q_head = 0 ; q_head < neq2; q_head++) {
8706+ ggml_compute_forward_flash_attn_ext_f16_one_chunk (
8707+ params, dst, q_head, q_head + 1 , ic_start, ic_end,
8708+ chunk_partials, partial_stride);
8709+ }
8710+ } else {
8711+ for (int64_t q_head = 0 ; q_head < neq2; q_head++) {
8712+ float * q_partials = chunk_partials + q_head * partial_stride;
8713+ q_partials[0 ] = -INFINITY ; // M
8714+ q_partials[1 ] = 0 .0f ; // S
8715+ }
8716+ }
8717+
8718+ ggml_barrier (params->threadpool );
8719+ ggml_flash_attn_ext_reduce_partials (params, dst, nth, chunk_size);
8720+ } else {
86178721
8618- // The number of elements in each chunk
8619- const int64_t dr = (nr + nchunk - 1 ) / nchunk ;
8722+ // total rows in q
8723+ const int64_t nr = neq1*neq2*neq3 ;
86208724
8621- static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV ;
8622- static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
8623- const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16 );
8624- const bool use_tiled = (q->type == GGML_TYPE_F32 &&
8625- kv_is_f32_or_f16 &&
8626- k->type == v->type &&
8627- nek1 % KV_TILE_SZ == 0 &&
8628- neq1 >= Q_TILE_SZ ); // Only use tiled for batch >= tile size
8725+ // disable for NUMA
8726+ const bool disable_chunking = ggml_is_numa ();
86298727
8630- // The first chunk comes from our thread_id, the rest will get auto-assigned.
8631- int current_chunk = ith;
8728+ // 4x chunks per thread
8729+ int nth_scaled = nth * 4 ;
8730+ int64_t chunk_size = (nr + nth_scaled - 1 ) / nth_scaled;
8731+ int64_t nchunk = (nr + chunk_size - 1 ) / chunk_size;
86328732
8633- while (current_chunk < nchunk) {
8634- const int64_t ir0 = dr * current_chunk ;
8635- const int64_t ir1 = MIN (ir0 + dr, nr);
8733+ if (nth == 1 || nchunk < nth || disable_chunking ) {
8734+ nchunk = nth ;
8735+ }
86368736
8637- if (use_tiled) {
8638- ggml_compute_forward_flash_attn_ext_tiled (params, dst, ir0, ir1);
8639- } else {
8640- ggml_compute_forward_flash_attn_ext_f16_one_chunk (params, dst, ir0, ir1);
8737+ if (ith == 0 ) {
8738+ ggml_threadpool_chunk_set (params->threadpool , nth);
86418739 }
86428740
8643- current_chunk = ggml_threadpool_chunk_add (params->threadpool , 1 );
8741+ ggml_barrier (params->threadpool );
8742+
8743+ const int64_t dr = (nr + nchunk - 1 ) / nchunk;
8744+
8745+ static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV ;
8746+ static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
8747+ const bool use_tiled = (q->type == GGML_TYPE_F32 &&
8748+ kv_is_f32_or_f16 &&
8749+ k->type == v->type &&
8750+ nek1 % KV_TILE_SZ == 0 &&
8751+ neq1 >= Q_TILE_SZ );
8752+
8753+ int current_chunk = ith;
8754+
8755+ while (current_chunk < nchunk) {
8756+ const int64_t ir0 = dr * current_chunk;
8757+ const int64_t ir1 = MIN (ir0 + dr, nr);
8758+
8759+ if (use_tiled) {
8760+ ggml_compute_forward_flash_attn_ext_tiled (params, dst, ir0, ir1);
8761+ } else {
8762+ ggml_compute_forward_flash_attn_ext_f16_one_chunk (params, dst, ir0, ir1, 0 , nek1, nullptr , 0 );
8763+ }
8764+
8765+ current_chunk = ggml_threadpool_chunk_add (params->threadpool , 1 );
8766+ }
86448767 }
86458768}
86468769
0 commit comments