Skip to content

Commit 3849170

Browse files
committed
add ref impl
1 parent 8c19a42 commit 3849170

6 files changed

Lines changed: 59 additions & 24 deletions

File tree

ggml/include/ggml-cpu.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ extern "C" {
1919
// abort ggml_graph_compute when true
2020
ggml_abort_callback abort_callback;
2121
void * abort_callback_data;
22+
23+
// use only reference implementations
24+
bool use_ref;
2225
};
2326

2427
// numa strategies
@@ -132,6 +135,8 @@ extern "C" {
132135
GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
133136
GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
134137

138+
GGML_BACKEND_API void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref);
139+
135140
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
136141

137142
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);

ggml/src/ggml-cpu/ggml-cpu-impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ struct ggml_compute_params {
2424
void * wdata;
2525

2626
struct ggml_threadpool * threadpool;
27+
28+
// use reference implementation
29+
bool use_ref;
2730
};
2831

2932

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2936,11 +2936,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
29362936
set_numa_thread_affinity(state->ith);
29372937

29382938
struct ggml_compute_params params = {
2939-
/*.ith =*/ state->ith,
2940-
/*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK,
2941-
/*.wsize =*/ cplan->work_size,
2942-
/*.wdata =*/ cplan->work_data,
2943-
/*.threadpool=*/ tp,
2939+
/*.ith =*/ state->ith,
2940+
/*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK,
2941+
/*.wsize =*/ cplan->work_size,
2942+
/*.wdata =*/ cplan->work_data,
2943+
/*.threadpool =*/ tp,
2944+
/*.use_ref =*/ cplan->use_ref,
29442945
};
29452946

29462947
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);

ggml/src/ggml-cpu/ggml-cpu.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ struct ggml_backend_cpu_context {
105105

106106
ggml_abort_callback abort_callback;
107107
void * abort_callback_data;
108+
109+
bool use_ref; // use reference implementation
108110
};
109111

110112
static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
@@ -143,6 +145,7 @@ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend
143145

144146
cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
145147
cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
148+
cpu_plan->cplan.use_ref = cpu_ctx->use_ref;
146149

147150
return cpu_plan;
148151
}
@@ -182,6 +185,7 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s
182185

183186
cplan.abort_callback = cpu_ctx->abort_callback;
184187
cplan.abort_callback_data = cpu_ctx->abort_callback_data;
188+
cplan.use_ref = cpu_ctx->use_ref;
185189

186190
return ggml_graph_compute(cgraph, &cplan);
187191
}
@@ -223,6 +227,7 @@ ggml_backend_t ggml_backend_cpu_init(void) {
223227
ctx->work_size = 0;
224228
ctx->abort_callback = NULL;
225229
ctx->abort_callback_data = NULL;
230+
ctx->use_ref = false;
226231

227232
ggml_backend_t cpu_backend = new ggml_backend {
228233
/* .guid = */ ggml_backend_cpu_guid(),
@@ -270,6 +275,13 @@ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_
270275
ctx->abort_callback_data = abort_callback_data;
271276
}
272277

278+
void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref) {
279+
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
280+
281+
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
282+
ctx->use_ref = use_ref;
283+
}
284+
273285
// CPU backend - device
274286

275287
struct ggml_backend_cpu_device_context {
@@ -646,6 +658,9 @@ static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const ch
646658
if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) {
647659
return (void *)ggml_is_numa;
648660
}
661+
if (strcmp(name, "ggml_backend_cpu_set_use_ref") == 0) {
662+
return (void *)ggml_backend_cpu_set_use_ref;
663+
}
649664

650665
// threadpool - TODO: move to ggml-base
651666
if (strcmp(name, "ggml_threadpool_new") == 0) {

ggml/src/ggml-cpu/ops.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8566,30 +8566,30 @@ static void ggml_flash_attn_ext_reduce_partials(
85668566
const ggml_tensor * k = dst->src[1];
85678567
const ggml_tensor * v = dst->src[2];
85688568

8569-
const int64_t DK = k->ne[0];
8570-
const int64_t DV = v->ne[0];
8571-
const int64_t nek1 = k->ne[1];
8569+
const int64_t DK = k->ne[0];
8570+
const int64_t DV = v->ne[0];
8571+
const int64_t nek1 = k->ne[1];
85728572
const int64_t n_q_heads = q->ne[2];
85738573

85748574
const int ith = params->ith;
85758575
const int nth = params->nth;
85768576

85778577
const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
8578-
float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
8578+
float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
85798579

8580-
const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8581-
const int64_t partial_size = 2 + DV;
8582-
const float * partials_base = (const float *) params->wdata + partials_offset;
8580+
const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8581+
const int64_t partial_size = 2 + DV;
8582+
const float * partials_base = (const float *) params->wdata + partials_offset;
85838583

85848584
// Output layout
85858585
const int64_t ne1 = dst->ne[1];
85868586
const int64_t ne2 = dst->ne[2];
8587-
const size_t nb1 = dst->nb[1];
8587+
const size_t nb1 = dst->nb[1];
85888588

85898589
// Each thread reduces a subset of query heads
85908590
for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
8591-
float M_final = -INFINITY;
8592-
float S_final = 0.0f;
8591+
float M_final = -INFINITY;
8592+
float S_final = 0.0f;
85938593
float * VKQ_final = thread_wdata;
85948594
memset(VKQ_final, 0, DV * sizeof(float));
85958595

@@ -8598,14 +8598,14 @@ static void ggml_flash_attn_ext_reduce_partials(
85988598
const int64_t ic_start = chunk_idx * chunk_size;
85998599
if (ic_start >= nek1) continue;
86008600

8601-
const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
8602-
const float M_chunk = partial[0];
8603-
const float S_chunk = partial[1];
8601+
const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
8602+
const float M_chunk = partial[0];
8603+
const float S_chunk = partial[1];
86048604
const float * VKQ_chunk = partial + 2;
86058605

86068606
if (S_chunk == 0.0f) continue;
86078607

8608-
const float M_new = fmaxf(M_final, M_chunk);
8608+
const float M_new = fmaxf(M_final, M_chunk);
86098609
const float scale_old = expf(M_final - M_new);
86108610
const float scale_new = expf(M_chunk - M_new);
86118611

@@ -8671,21 +8671,24 @@ static void ggml_compute_forward_flash_attn_ext_f16(
86718671
const int ith = params->ith;
86728672
const int nth = params->nth;
86738673

8674+
// When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
8675+
const bool use_ref = params->use_ref;
8676+
86748677
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
8675-
const bool use_split_kv_path = (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
8678+
const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
86768679

86778680
if (use_split_kv_path) {
86788681
const int64_t chunk_size = (nek1 + nth - 1) / nth;
86798682

86808683
// Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
8681-
const int64_t partial_size = 2 + DV;
8682-
float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8684+
const int64_t partial_size = 2 + DV;
8685+
float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
86838686

86848687
const int64_t ic_start = ith * chunk_size;
86858688
const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
86868689

86878690
const int64_t partial_stride = nth * partial_size;
8688-
float * chunk_partials = partials_base + ith * partial_size;
8691+
float * chunk_partials = partials_base + ith * partial_size;
86898692

86908693
if (ic_start < nek1) {
86918694
for (int64_t q_head = 0; q_head < neq2; q_head++) {
@@ -8730,7 +8733,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
87308733

87318734
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
87328735
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
8733-
const bool use_tiled = (q->type == GGML_TYPE_F32 &&
8736+
const bool use_tiled = !use_ref &&
8737+
(q->type == GGML_TYPE_F32 &&
87348738
kv_is_f32_or_f16 &&
87358739
k->type == v->type &&
87368740
nek1 % KV_TILE_SZ == 0 &&

tests/test-backend-ops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8591,6 +8591,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
85918591
output_printer->print_operation(info);
85928592
return false;
85938593
}
8594+
// Use reference implementation on the CPU backend for comparison
8595+
using ggml_backend_cpu_set_use_ref_t = void (*)(ggml_backend_t, bool);
8596+
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
8597+
auto * set_use_ref = (ggml_backend_cpu_set_use_ref_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_use_ref");
8598+
if (set_use_ref) {
8599+
set_use_ref(backend_cpu, true);
8600+
}
85948601

85958602
size_t n_ok = 0;
85968603
size_t tests_run = 0;

0 commit comments

Comments
 (0)