Skip to content

Commit e9c2058

Browse files
committed
ggml-ve: use KV shadow in compiled graph
1 parent d0f4ae9 commit e9c2058

4 files changed

Lines changed: 198 additions & 17 deletions

File tree

ggml/src/ggml-ve/ggml-ve.cpp

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,15 @@ ggml_status backend_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph)
325325
if (time_enabled) {
326326
char buf[256];
327327
int counts[GGML_OP_COUNT] = {0};
328+
int64_t n_tok = 1;
328329
for (int i = 0; i < cgraph->n_nodes; ++i) {
329330
ggml_op op = cgraph->nodes[i]->op;
330331
if ((int) op < GGML_OP_COUNT) counts[op]++;
332+
if (op == GGML_OP_MUL_MAT && cgraph->nodes[i]->ne[1] > n_tok) {
333+
n_tok = cgraph->nodes[i]->ne[1];
334+
}
331335
}
332-
int off = std::snprintf(buf, sizeof(buf), "n=%d", cgraph->n_nodes);
336+
int off = std::snprintf(buf, sizeof(buf), "n=%d,ntok=%lld", cgraph->n_nodes, (long long) n_tok);
333337
for (int i = 0; i < GGML_OP_COUNT && off < (int) sizeof(buf) - 32; ++i) {
334338
if (counts[i] > 0) {
335339
off += std::snprintf(buf + off, sizeof(buf) - off,
@@ -339,6 +343,22 @@ ggml_status backend_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph)
339343
}
340344
gtime_sig = buf;
341345
}
346+
bool gtime_recorded = false;
347+
auto record_gtime = [&]() {
348+
if (!time_enabled || gtime_recorded) return;
349+
auto end = std::chrono::steady_clock::now();
350+
double ns = std::chrono::duration_cast<std::chrono::nanoseconds>(end - gtime_start).count();
351+
// Linear search — there are < 50 distinct cgraph shapes per token.
352+
bool found = false;
353+
for (auto & e : gtimes) {
354+
if (e.sig == gtime_sig) { e.count++; e.ns_total += ns; found = true; break; }
355+
}
356+
if (!found) {
357+
gtime_entry e; e.sig = gtime_sig; e.count = 1; e.ns_total = ns;
358+
gtimes.push_back(std::move(e));
359+
}
360+
gtime_recorded = true;
361+
};
342362

343363
// --- Compiled-graph fast path (opt-in via GGML_VE_COMPILE_GRAPH=1) ----
344364
// What actually compiles is decided by the trace pre-pass's
@@ -393,6 +413,7 @@ ggml_status backend_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph)
393413
if (entry.executable && entry.cg) {
394414
if (gc.execute(entry.cg, ctx, cgraph)) {
395415
ctx->ops_total() += cgraph->n_nodes;
416+
record_gtime();
396417
return GGML_STATUS_SUCCESS;
397418
}
398419
// Execute regressed — remember and stop trying.
@@ -441,6 +462,7 @@ ggml_status backend_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph)
441462
}
442463
if (new_entry.cg && new_entry.executable) {
443464
ctx->ops_total() += cgraph->n_nodes;
465+
record_gtime();
444466
return GGML_STATUS_SUCCESS;
445467
}
446468
// Compile or execute didn't pan out (or trace refused). Whatever
@@ -477,6 +499,7 @@ ggml_status backend_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph)
477499
node->src[0] ? ggml_type_name(node->src[0]->type) : "-", src0_buft,
478500
node->src[1] ? ggml_type_name(node->src[1]->type) : "-", src1_buft);
479501
ctx->abort_pending();
502+
record_gtime();
480503
return GGML_STATUS_FAILED;
481504
}
482505
ctx->ops_total()++;
@@ -493,19 +516,7 @@ ggml_status backend_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph)
493516
// legacy. Trust the scheduler — anything that legitimately needs to
494517
// be sync'd will route through backend_synchronize.
495518

496-
if (time_enabled) {
497-
auto end = std::chrono::steady_clock::now();
498-
double ns = std::chrono::duration_cast<std::chrono::nanoseconds>(end - gtime_start).count();
499-
// Linear search — there are < 50 distinct cgraph shapes per token.
500-
bool found = false;
501-
for (auto & e : gtimes) {
502-
if (e.sig == gtime_sig) { e.count++; e.ns_total += ns; found = true; break; }
503-
}
504-
if (!found) {
505-
gtime_entry e; e.sig = gtime_sig; e.count = 1; e.ns_total = ns;
506-
gtimes.push_back(std::move(e));
507-
}
508-
}
519+
record_gtime();
509520
return GGML_STATUS_SUCCESS;
510521
}
511522

ggml/src/ggml-ve/graph_compiler.cpp

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "common.hpp"
88
#include "device.hpp"
99
#include "hbm_cache.hpp"
10+
#include "kv_shadow_cache.hpp"
1011

1112
#include "ggml.h"
1213
#include "ggml-backend.h"
@@ -363,8 +364,10 @@ bool GraphCompiler::trace_one(ggml_tensor * node) {
363364
const ggml_tensor * K = node->src[1];
364365
const ggml_tensor * V = node->src[2];
365366
if (!K || !V) return false;
367+
const ggml_tensor * Kc = canonical(K);
366368
op.p.flash_attn.n_kv_heads = K->ne[2];
367369
op.p.flash_attn.kv_type = (int) K->type;
370+
op.p.flash_attn.kv_seq_max = Kc ? Kc->ne[1] : K->ne[1];
368371
op.p.flash_attn.nb_k1 = K->nb[1];
369372
op.p.flash_attn.nb_k2 = K->nb[2];
370373
op.p.flash_attn.nb_v1 = V->nb[1];
@@ -481,6 +484,47 @@ bool GraphCompiler::trace_one(ggml_tensor * node) {
481484
op.vebp_wsc_idx = sp.wsc_slot;
482485
}
483486

487+
// BF16 KV-cache col-major shadow companions. The interpreter path keeps
488+
// these shadows up to date in ops/set_rows.cpp and consumes them in
489+
// ops/flash_attn.cpp. Compiled graphs bypass those hooks, so mirror the
490+
// SET_ROWS writes in generated code and read the same shadows from FA.
491+
static const bool kv_shadow_enabled =
492+
std::getenv("GGML_VE_NO_KV_SHADOW") == nullptr;
493+
if (kv_shadow_enabled && op.type == OpType::SET_ROWS &&
494+
op.dst_kind == BufferKind::KV_CACHE && op.dst_type == GGML_TYPE_BF16 &&
495+
op.dst_idx >= 0 && op.src0_ne[0] > 0 && op.ne[1] > 0) {
496+
KvShadowSpec sp;
497+
sp.source_slot = op.dst_idx;
498+
sp.channels = op.src0_ne[0];
499+
sp.seq_max = op.ne[1];
500+
sp.shadow_slot = (int) tensor_slot_order_.size();
501+
tensor_slot_order_.push_back({nullptr, BufferKind::KV_SHADOW});
502+
kv_shadow_specs_.push_back(sp);
503+
op.dst_shadow_idx = sp.shadow_slot;
504+
}
505+
if (kv_shadow_enabled && op.type == OpType::FLASH_ATTN &&
506+
op.p.flash_attn.kv_type == (int) GGML_TYPE_BF16 &&
507+
op.src1_idx >= 0 && op.src2_idx >= 0 && op.src1_ne[0] > 0 && op.src1_ne[1] > 0) {
508+
const int64_t channels = op.src1_ne[0] * op.p.flash_attn.n_kv_heads;
509+
KvShadowSpec ksp;
510+
ksp.source_slot = op.src1_idx;
511+
ksp.channels = channels;
512+
ksp.seq_max = op.p.flash_attn.kv_seq_max;
513+
ksp.shadow_slot = (int) tensor_slot_order_.size();
514+
tensor_slot_order_.push_back({nullptr, BufferKind::KV_SHADOW});
515+
kv_shadow_specs_.push_back(ksp);
516+
op.src1_shadow_idx = ksp.shadow_slot;
517+
518+
KvShadowSpec vsp;
519+
vsp.source_slot = op.src2_idx;
520+
vsp.channels = channels;
521+
vsp.seq_max = op.p.flash_attn.kv_seq_max;
522+
vsp.shadow_slot = (int) tensor_slot_order_.size();
523+
tensor_slot_order_.push_back({nullptr, BufferKind::KV_SHADOW});
524+
kv_shadow_specs_.push_back(vsp);
525+
op.src2_shadow_idx = vsp.shadow_slot;
526+
}
527+
484528
traced_ops_.push_back(op);
485529
return true;
486530
}
@@ -497,6 +541,7 @@ bool GraphCompiler::trace(ggml_cgraph * cgraph) {
497541
colmajor_specs_.clear();
498542
q4k_specs_.clear();
499543
vebp_specs_.clear();
544+
kv_shadow_specs_.clear();
500545
trace_valid_ = true;
501546

502547
// n_tokens of this cgraph: 1 = decode, N = prompt eval. Any MUL_MAT's dst
@@ -1169,14 +1214,22 @@ std::string GraphCompiler::gen_op_code(const TracedOp & op, int idx) const {
11691214
ss << " int cols = " << cols << ";\n";
11701215
ss << " char* dstbase = (char*)" << dst << ";\n";
11711216
ss << " const float* src0p = (const float*)" << src0 << ";\n";
1217+
if (op.dst_shadow_idx >= 0) {
1218+
ss << " uint16_t* shbase = (uint16_t*)p[" << op.dst_shadow_idx << "];\n";
1219+
ss << " int seq_max = " << op.ne[1] << ";\n";
1220+
}
11721221
ss << " #pragma omp for\n";
11731222
ss << " for (long t = 0; t < n_tok; t++) {\n";
11741223
ss << " int64_t idx0 = positions[t];\n";
11751224
ss << " const float* src = src0p + t * cols;\n";
11761225
ss << " uint16_t* drow = (uint16_t*)(dstbase + idx0 * " << dst_row_bytes << ");\n";
11771226
ss << " for (int j = 0; j < cols; j++) {\n";
11781227
ss << " uint32_t u; memcpy(&u, &src[j], 4);\n";
1179-
ss << " drow[j] = (uint16_t)(u >> 16);\n";
1228+
ss << " uint16_t bf = (uint16_t)(u >> 16);\n";
1229+
ss << " drow[j] = bf;\n";
1230+
if (op.dst_shadow_idx >= 0) {
1231+
ss << " shbase[(size_t)j * seq_max + idx0] = bf;\n";
1232+
}
11801233
ss << " }\n";
11811234
ss << " }\n";
11821235
ss << " }\n";
@@ -1203,6 +1256,12 @@ std::string GraphCompiler::gen_op_code(const TracedOp & op, int idx) const {
12031256
int n_q_heads = (op.src0_ne[2] > 1) ? (int) op.src0_ne[2]
12041257
: (int) op.src0_ne[1];
12051258
int n_kv_heads= (int) op.p.flash_attn.n_kv_heads;
1259+
const bool has_kv_shadow = op.src1_shadow_idx >= 0 && op.src2_shadow_idx >= 0;
1260+
const int colmajor_min = []{
1261+
const char * e = std::getenv("GGML_VE_COLMAJOR_FA_MIN");
1262+
int v = e ? std::atoi(e) : 96;
1263+
return v > 0 ? v : 1;
1264+
}();
12061265
ss << " {\n";
12071266
ss << " int head_dim = " << head_dim << ";\n";
12081267
ss << " int n_q_heads = " << n_q_heads << ";\n";
@@ -1217,17 +1276,34 @@ std::string GraphCompiler::gen_op_code(const TracedOp & op, int idx) const {
12171276
ss << " const char* qp = (const char*)" << src0 << ";\n";
12181277
ss << " const void* kp = (const void*)" << src1 << ";\n";
12191278
ss << " const void* vp = (const void*)" << src2 << ";\n";
1279+
if (has_kv_shadow) {
1280+
ss << " const uint16_t* kp_col = (const uint16_t*)p[" << op.src1_shadow_idx << "];\n";
1281+
ss << " const uint16_t* vp_col = (const uint16_t*)p[" << op.src2_shadow_idx << "];\n";
1282+
ss << " int seq_max = " << op.p.flash_attn.kv_seq_max << ";\n";
1283+
ss << " int colmajor_min = " << colmajor_min << ";\n";
1284+
}
12201285
if (op.p.flash_attn.kv_type == (int) GGML_TYPE_BF16) {
12211286
// One query token at a time; the strided _inner shares heads
12221287
// across the team via `#pragma omp for` (no fork). seq_len =
12231288
// positions[t]+1 is the causal mask. Decode = the n_tok==1,
12241289
// q_nb2==head_dim*4 special case.
12251290
ss << " for (int64_t t = 0; t < n_tok; t++) {\n";
12261291
ss << " int seq_len = positions[t] + 1;\n";
1292+
if (has_kv_shadow) {
1293+
ss << " if (seq_len >= colmajor_min) {\n";
1294+
ss << " attention_f32q_bf16kv_colmajor_inner_strided(\n";
1295+
ss << " (float*)(outp + t*o_nb2), (const float*)(qp + t*q_nb1),\n";
1296+
ss << " kp_col, vp_col, head_dim, n_q_heads, n_kv_heads, seq_len, seq_max, scale,\n";
1297+
ss << " q_nb2, o_nb1);\n";
1298+
ss << " } else {\n";
1299+
}
12271300
ss << " attention_f32q_bf16kv_fused_gqa_inner_strided(\n";
12281301
ss << " (float*)(outp + t*o_nb2), (const float*)(qp + t*q_nb1),\n";
12291302
ss << " kp, vp, head_dim, n_q_heads, n_kv_heads, seq_len, scale,\n";
12301303
ss << " q_nb2, o_nb1, nb_k1, nb_k2, nb_v1, nb_v2);\n";
1304+
if (has_kv_shadow) {
1305+
ss << " }\n";
1306+
}
12311307
ss << " }\n";
12321308
} else {
12331309
// F32 KV: no strided _inner; per-token omp-single fallback (the
@@ -1369,6 +1445,9 @@ std::string GraphCompiler::generate_source(const std::string & func_name) const
13691445
ss << "extern void attention_f32q_bf16kv_fused_gqa_inner_strided(float* out, const float* q, const void* k, const void* v,"
13701446
<< " int head_dim, int n_q_heads, int n_kv_heads, int seq_len, float scale,"
13711447
<< " size_t q_nb2, size_t o_nb1, size_t nb_k1, size_t nb_k2, size_t nb_v1, size_t nb_v2);\n";
1448+
ss << "extern void attention_f32q_bf16kv_colmajor_inner_strided(float* out, const float* q, const uint16_t* k_col, const uint16_t* v_col,"
1449+
<< " int head_dim, int n_q_heads, int n_kv_heads, int seq_len, int seq_max, float scale,"
1450+
<< " size_t q_nb2, size_t o_nb1);\n";
13721451
ss << "extern void swiglu_hbm_full_inner(float* y, float* gate, float* up, int nc, int nr);\n\n";
13731452

13741453
// Ops are grouped into small static chunk functions instead of one giant
@@ -1628,6 +1707,7 @@ CompiledGraph * GraphCompiler::load_compiled(const std::string & so_path, const
16281707
cg->colmajor_specs = colmajor_specs_;
16291708
cg->q4k_specs = q4k_specs_;
16301709
cg->vebp_specs = vebp_specs_;
1710+
cg->kv_shadow_specs = kv_shadow_specs_;
16311711
return cg;
16321712
}
16331713

@@ -1762,7 +1842,8 @@ bool GraphCompiler::execute(CompiledGraph * graph,
17621842
|| graph->slot_kinds[slot_pos] == BufferKind::WEIGHT_Q4K_HDR
17631843
|| graph->slot_kinds[slot_pos] == BufferKind::WEIGHT_VEBP_WS
17641844
|| graph->slot_kinds[slot_pos] == BufferKind::WEIGHT_VEBP_WN
1765-
|| graph->slot_kinds[slot_pos] == BufferKind::WEIGHT_VEBP_WSC)) {
1845+
|| graph->slot_kinds[slot_pos] == BufferKind::WEIGHT_VEBP_WSC
1846+
|| graph->slot_kinds[slot_pos] == BufferKind::KV_SHADOW)) {
17661847
++slot_pos;
17671848
}
17681849
};
@@ -2013,7 +2094,34 @@ bool GraphCompiler::execute(CompiledGraph * graph,
20132094
}
20142095
}
20152096

2016-
// After walking the cgraph + populating colmajor slots, slot_pos should
2097+
// BF16 KV-cache col-major shadow slots. Generated SET_ROWS mirrors into
2098+
// these shadows and generated FLASH_ATTN can read them directly, preserving
2099+
// the interpreter's kv-shadow long-context optimization inside the fused graph.
2100+
if (!graph->kv_shadow_specs.empty() && bctx) {
2101+
auto * dev = bctx->dev();
2102+
auto * shadows = dev ? dev->kv_shadow : nullptr;
2103+
if (!shadows) {
2104+
if (debug_enabled()) fprintf(stderr, "[VE-GC] kv-shadow cache unavailable — abort\n");
2105+
return false;
2106+
}
2107+
for (const auto & sp : graph->kv_shadow_specs) {
2108+
if (sp.source_slot < 0 || sp.source_slot >= (int) tptrs.size()) continue;
2109+
if (sp.shadow_slot < 0 || sp.shadow_slot >= (int) tptrs.size()) continue;
2110+
VEDAdeviceptr src = tptrs[sp.source_slot];
2111+
if (src == 0) return false;
2112+
kv_shadow * sh = shadows->get_or_create(src, sp.channels, sp.seq_max);
2113+
if (!sh || sh->shadow_hbm == 0) {
2114+
if (debug_enabled()) {
2115+
fprintf(stderr, "[VE-GC] kv-shadow lookup failed for slot %d (channels=%ld seq=%ld)\n",
2116+
sp.source_slot, (long) sp.channels, (long) sp.seq_max);
2117+
}
2118+
return false;
2119+
}
2120+
tptrs[sp.shadow_slot] = sh->shadow_hbm;
2121+
}
2122+
}
2123+
2124+
// After walking the cgraph + populating companion slots, slot_pos should
20172125
// have stepped past every WEIGHT/KV/INTERMEDIATE slot in the table —
20182126
// anything less means try_push ran out of cgraph tensors before filling
20192127
// all expected real slots.
@@ -2039,6 +2147,7 @@ bool GraphCompiler::execute(CompiledGraph * graph,
20392147
if (graph->slot_kinds[i] == BufferKind::WEIGHT_VEBP_WS) continue;
20402148
if (graph->slot_kinds[i] == BufferKind::WEIGHT_VEBP_WN) continue;
20412149
if (graph->slot_kinds[i] == BufferKind::WEIGHT_VEBP_WSC) continue;
2150+
if (graph->slot_kinds[i] == BufferKind::KV_SHADOW) continue;
20422151
if (tptrs[i] == 0) {
20432152
if (debug_enabled()) {
20442153
fprintf(stderr, "[VE-GC] slot %zu kind=%d is NULL — abort\n",

ggml/src/ggml-ve/graph_compiler.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ enum class BufferKind {
6262
WEIGHT_VEBP_WN, // VEBP interleaved nonzero plane (lazy from hbm cache)
6363
WEIGHT_VEBP_WSC, // VEBP interleaved group scales (lazy from hbm cache)
6464
KV_CACHE, // persistent across tokens, mutates each step
65+
KV_SHADOW, // BF16 KV cache shadow in seq-unit-stride col-major layout
6566
INTERMEDIATE, // scratch, reused inside the kernel
6667
INPUT, // pseudo: input embedding row
6768
OUTPUT, // pseudo: output logits
@@ -96,6 +97,12 @@ struct TracedOp {
9697
int vebp_ws_idx = -1;
9798
int vebp_wn_idx = -1;
9899
int vebp_wsc_idx = -1;
100+
// For BF16 KV cache ops, companion slots holding the col-major shadow.
101+
// SET_ROWS uses dst_shadow_idx to mirror freshly-written rows; FLASH_ATTN
102+
// uses src1/src2 shadow slots for K/V reads when seq_len is large enough.
103+
int dst_shadow_idx = -1;
104+
int src1_shadow_idx = -1;
105+
int src2_shadow_idx = -1;
99106
BufferKind dst_kind = BufferKind::INTERMEDIATE;
100107
BufferKind src0_kind = BufferKind::INTERMEDIATE;
101108
BufferKind src1_kind = BufferKind::INTERMEDIATE;
@@ -117,6 +124,7 @@ struct TracedOp {
117124
float ext_factor, beta_fast, beta_slow; int n_ctx_orig; } rope;
118125
struct { float scale, max_bias, softcap;
119126
int kv_type; int64_t n_kv_heads;
127+
int64_t kv_seq_max;
120128
size_t nb_k1, nb_k2, nb_v1, nb_v2;
121129
// Q byte strides (src0 nb): per-token (nb1) and per-head (nb2).
122130
// For prompt eval Q is permuted [D,N,H] so the head stride is
@@ -161,6 +169,15 @@ struct VebpSpec {
161169
int64_t K = 0;
162170
};
163171

172+
// Spec for a BF16 KV-cache col-major shadow companion slot. The shadow is
173+
// keyed by the row-major KV cache HBM pointer and allocated via kv_shadow_cache.
174+
struct KvShadowSpec {
175+
int source_slot = -1;
176+
int shadow_slot = -1;
177+
int64_t channels = 0;
178+
int64_t seq_max = 0;
179+
};
180+
164181
struct CompiledGraph {
165182
VEDAmodule module = 0;
166183
VEDAfunction run_func = 0;
@@ -190,6 +207,9 @@ struct CompiledGraph {
190207
// hbm_cache::get_or_upload_vebp. Empty if no VEBP MUL_MAT.
191208
std::vector<VebpSpec> vebp_specs;
192209

210+
// KV-shadow companion slots populated at execute time via kv_shadow_cache.
211+
std::vector<KvShadowSpec> kv_shadow_specs;
212+
193213
// Reusable HMEM staging buffers for the kernel's `input` (token id)
194214
// and `output` (logits row) args. Allocated lazily on first execute
195215
// and reused for every subsequent call of the same graph — the
@@ -269,6 +289,8 @@ class GraphCompiler {
269289
std::vector<Q4KSpec> q4k_specs_;
270290
// VEBP companion slots created during trace, populated at execute.
271291
std::vector<VebpSpec> vebp_specs_;
292+
// KV-shadow companion slots created during trace, populated at execute.
293+
std::vector<KvShadowSpec> kv_shadow_specs_;
272294

273295
bool trace_valid_ = false;
274296

0 commit comments

Comments
 (0)