Skip to content

Commit 353c85f

Browse files
committed
ggml-cpu: split across kv for faster TG
1 parent 15818ac commit 353c85f

2 files changed

Lines changed: 191 additions & 61 deletions

File tree

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "ggml-backend.h"
66
#include "traits.h"
77
#include "ggml-cpu-impl.h"
8-
#include "ggml-cpu.h"
98
#include "ggml-impl.h"
109
#include "quants.h"
1110
#include "ggml-threading.h"
@@ -2867,12 +2866,20 @@ struct ggml_cplan ggml_graph_plan(
28672866
} break;
28682867
case GGML_OP_FLASH_ATTN_EXT:
28692868
{
2869+
const int64_t neq2 = node->src[0]->ne[2]; // number of query heads
28702870
const int64_t DK = node->src[1]->ne[0];
28712871
const int64_t DV = node->src[2]->ne[0];
28722872

28732873
// Tiled flash attention scratch (tile sizes defined in common.h)
28742874
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
2875-
cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
2875+
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
2876+
2877+
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
2878+
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
2879+
size_t n_chunks = n_tasks;
2880+
size_t decode = sizeof(float)*(neq2*n_chunks*(2+DV) + n_tasks*(DK + 2*DV));
2881+
2882+
cur += MAX(prefill, decode);
28762883
} break;
28772884
case GGML_OP_FLASH_ATTN_BACK:
28782885
{

ggml/src/ggml-cpu/ops.cpp

Lines changed: 182 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -8042,12 +8042,14 @@ void ggml_compute_forward_top_k(
80428042
}
80438043
}
80448044

8045-
// ggml_compute_forward_flash_attn_ext
8046-
80478045
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
80488046
const ggml_compute_params * params,
80498047
ggml_tensor * dst,
8050-
int ir0, int ir1) {
8048+
int ir0, int ir1,
8049+
int64_t ic_start, int64_t ic_end,
8050+
float * partials, int64_t partial_stride) {
8051+
8052+
const bool write_partials = (partials != nullptr);
80518053
const ggml_tensor * q = dst->src[0];
80528054
const ggml_tensor * k = dst->src[1];
80538055
const ggml_tensor * v = dst->src[2];
@@ -8124,7 +8126,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
81248126

81258127
int ith = params->ith;
81268128

8127-
// loop over n_batch and n_head
81288129
for (int ir = ir0; ir < ir1; ++ir) {
81298130
// q indices
81308131
const int iq3 = ir/(neq2*neq1);
@@ -8165,7 +8166,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
81658166
// loop over n_kv and n_head_kv
81668167
// ref: https://arxiv.org/pdf/2112.05682.pdf
81678168

8168-
for (int64_t ic = 0; ic < nek1; ++ic) {
8169+
for (int64_t ic = ic_start; ic < ic_end; ++ic) {
81698170
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
81708171
if (mv == -INFINITY) {
81718172
continue;
@@ -8238,8 +8239,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
82388239
}
82398240
}
82408241

8241-
// sinks
8242-
if (sinks) {
8242+
// sinks - skip when writing partials, reduce function will apply once
8243+
if (sinks && !write_partials) {
82438244
const float s = ((float *)((char *) sinks->data))[h];
82448245

82458246
float ms = 1.0f;
@@ -8255,20 +8256,26 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
82558256
S = S*ms + vs;
82568257
}
82578258

8258-
// V /= S
8259-
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8260-
ggml_vec_scale_f32(DV, VKQ32, S_inv);
8261-
8262-
// dst indices
8263-
const int i1 = iq1;
8264-
const int i2 = iq2;
8265-
const int i3 = iq3;
8259+
if (write_partials) {
8260+
// Write M, S, VKQ to partials for later reduction
8261+
// partials layout: [M, S, VKQ[DV]] per query head
8262+
float * partial = partials + ir * partial_stride;
8263+
partial[0] = M;
8264+
partial[1] = S;
8265+
memcpy(partial + 2, VKQ32, DV * sizeof(float));
8266+
} else {
8267+
// V /= S
8268+
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8269+
ggml_vec_scale_f32(DV, VKQ32, S_inv);
82668270

8267-
// original
8268-
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
8271+
// dst indices
8272+
const int i1 = iq1;
8273+
const int i2 = iq2;
8274+
const int i3 = iq3;
82698275

8270-
// permute(0, 2, 1, 3)
8271-
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
8276+
// permute(0, 2, 1, 3)
8277+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
8278+
}
82728279
}
82738280
}
82748281

@@ -8546,6 +8553,93 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
85468553
}
85478554
}
85488555

8556+
// Reduction function: combines partial results across KV chunks
8557+
// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
8558+
static void ggml_flash_attn_ext_reduce_partials(
8559+
const ggml_compute_params * params,
8560+
ggml_tensor * dst,
8561+
const int64_t n_chunks,
8562+
const int64_t chunk_size) {
8563+
8564+
const ggml_tensor * q = dst->src[0];
8565+
const ggml_tensor * k = dst->src[1];
8566+
const ggml_tensor * v = dst->src[2];
8567+
const ggml_tensor * sinks = dst->src[4];
8568+
8569+
const int64_t DK = k->ne[0];
8570+
const int64_t DV = v->ne[0];
8571+
const int64_t nek1 = k->ne[1];
8572+
const int64_t n_q_heads = q->ne[2];
8573+
8574+
const int ith = params->ith;
8575+
const int nth = params->nth;
8576+
8577+
const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
8578+
float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
8579+
8580+
const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8581+
const int64_t partial_size = 2 + DV;
8582+
const float * partials_base = (const float *) params->wdata + partials_offset;
8583+
8584+
// Output layout
8585+
const int64_t ne1 = dst->ne[1];
8586+
const int64_t ne2 = dst->ne[2];
8587+
const size_t nb1 = dst->nb[1];
8588+
8589+
// Each thread reduces a subset of query heads
8590+
for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
8591+
float M_final = -INFINITY;
8592+
float S_final = 0.0f;
8593+
float * VKQ_final = thread_wdata;
8594+
memset(VKQ_final, 0, DV * sizeof(float));
8595+
8596+
// Combine partials from all chunks
8597+
for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
8598+
const int64_t ic_start = chunk_idx * chunk_size;
8599+
if (ic_start >= nek1) continue;
8600+
8601+
const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
8602+
const float M_chunk = partial[0];
8603+
const float S_chunk = partial[1];
8604+
const float * VKQ_chunk = partial + 2;
8605+
8606+
if (S_chunk == 0.0f) continue;
8607+
8608+
const float M_new = fmaxf(M_final, M_chunk);
8609+
const float scale_old = expf(M_final - M_new);
8610+
const float scale_new = expf(M_chunk - M_new);
8611+
8612+
for (int64_t d = 0; d < DV; ++d) {
8613+
VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
8614+
}
8615+
S_final = S_final * scale_old + S_chunk * scale_new;
8616+
M_final = M_new;
8617+
}
8618+
8619+
// Apply sinks once after combining all chunks
8620+
if (sinks) {
8621+
const float s = ((float *) sinks->data)[q_head];
8622+
8623+
if (s > M_final) {
8624+
const float ms = expf(M_final - s);
8625+
ggml_vec_scale_f32(DV, VKQ_final, ms);
8626+
S_final = S_final * ms + 1.0f;
8627+
M_final = s;
8628+
} else {
8629+
S_final = S_final + expf(s - M_final);
8630+
}
8631+
}
8632+
8633+
// Normalize and write to output
8634+
if (S_final != 0.0f) {
8635+
const float S_inv = 1.0f / S_final;
8636+
ggml_vec_scale_f32(DV, VKQ_final, S_inv);
8637+
}
8638+
// iq1=0, iq3=0 for decode
8639+
memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
8640+
}
8641+
}
8642+
85498643
static void ggml_compute_forward_flash_attn_ext_f16(
85508644
const ggml_compute_params * params,
85518645
ggml_tensor * dst) {
@@ -8567,6 +8661,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
85678661
const int64_t DV = nev0;
85688662
const int64_t N = neq1;
85698663

8664+
85708665
GGML_ASSERT(ne0 == DV);
85718666
GGML_ASSERT(ne2 == N);
85728667

@@ -8587,60 +8682,88 @@ static void ggml_compute_forward_flash_attn_ext_f16(
85878682
GGML_ASSERT(nb1 <= nb2);
85888683
GGML_ASSERT(nb2 <= nb3);
85898684

8590-
// parallelize by q rows using ggml_vec_dot_f32
8591-
8592-
// total rows in q
8593-
const int64_t nr = neq1*neq2*neq3;
8594-
8595-
// rows per thread
85968685
const int ith = params->ith;
85978686
const int nth = params->nth;
85988687

8599-
// disable for NUMA
8600-
const bool disable_chunking = ggml_is_numa();
8688+
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
8689+
const bool use_split_kv_path = (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
86018690

8602-
// 4x chunks per thread
8603-
int nth_scaled = nth * 4;
8604-
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8605-
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8691+
if (use_split_kv_path) {
8692+
const int64_t chunk_size = (nek1 + nth - 1) / nth;
86068693

8607-
if (nth == 1 || nchunk < nth || disable_chunking) {
8608-
nchunk = nth;
8609-
}
8694+
// Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
8695+
const int64_t partial_size = 2 + DV;
8696+
float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
86108697

8611-
if (ith == 0) {
8612-
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
8613-
ggml_threadpool_chunk_set(params->threadpool, nth);
8614-
}
8698+
const int64_t ic_start = ith * chunk_size;
8699+
const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
86158700

8616-
ggml_barrier(params->threadpool);
8701+
const int64_t partial_stride = nth * partial_size;
8702+
float * chunk_partials = partials_base + ith * partial_size;
8703+
8704+
if (ic_start < nek1) {
8705+
for (int64_t q_head = 0; q_head < neq2; q_head++) {
8706+
ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8707+
params, dst, q_head, q_head + 1, ic_start, ic_end,
8708+
chunk_partials, partial_stride);
8709+
}
8710+
} else {
8711+
for (int64_t q_head = 0; q_head < neq2; q_head++) {
8712+
float * q_partials = chunk_partials + q_head * partial_stride;
8713+
q_partials[0] = -INFINITY; // M
8714+
q_partials[1] = 0.0f; // S
8715+
}
8716+
}
8717+
8718+
ggml_barrier(params->threadpool);
8719+
ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
8720+
} else {
86178721

8618-
// The number of elements in each chunk
8619-
const int64_t dr = (nr + nchunk - 1) / nchunk;
8722+
// total rows in q
8723+
const int64_t nr = neq1*neq2*neq3;
86208724

8621-
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
8622-
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
8623-
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
8624-
const bool use_tiled = (q->type == GGML_TYPE_F32 &&
8625-
kv_is_f32_or_f16 &&
8626-
k->type == v->type &&
8627-
nek1 % KV_TILE_SZ == 0 &&
8628-
neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size
8725+
// disable for NUMA
8726+
const bool disable_chunking = ggml_is_numa();
86298727

8630-
// The first chunk comes from our thread_id, the rest will get auto-assigned.
8631-
int current_chunk = ith;
8728+
// 4x chunks per thread
8729+
int nth_scaled = nth * 4;
8730+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8731+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
86328732

8633-
while (current_chunk < nchunk) {
8634-
const int64_t ir0 = dr * current_chunk;
8635-
const int64_t ir1 = MIN(ir0 + dr, nr);
8733+
if (nth == 1 || nchunk < nth || disable_chunking) {
8734+
nchunk = nth;
8735+
}
86368736

8637-
if (use_tiled) {
8638-
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
8639-
} else {
8640-
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
8737+
if (ith == 0) {
8738+
ggml_threadpool_chunk_set(params->threadpool, nth);
86418739
}
86428740

8643-
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
8741+
ggml_barrier(params->threadpool);
8742+
8743+
const int64_t dr = (nr + nchunk - 1) / nchunk;
8744+
8745+
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
8746+
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
8747+
const bool use_tiled = (q->type == GGML_TYPE_F32 &&
8748+
kv_is_f32_or_f16 &&
8749+
k->type == v->type &&
8750+
nek1 % KV_TILE_SZ == 0 &&
8751+
neq1 >= Q_TILE_SZ);
8752+
8753+
int current_chunk = ith;
8754+
8755+
while (current_chunk < nchunk) {
8756+
const int64_t ir0 = dr * current_chunk;
8757+
const int64_t ir1 = MIN(ir0 + dr, nr);
8758+
8759+
if (use_tiled) {
8760+
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
8761+
} else {
8762+
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
8763+
}
8764+
8765+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
8766+
}
86448767
}
86458768
}
86468769

0 commit comments

Comments
 (0)