77#include " common.hpp"
88#include " device.hpp"
99#include " hbm_cache.hpp"
10+ #include " kv_shadow_cache.hpp"
1011
1112#include " ggml.h"
1213#include " ggml-backend.h"
@@ -363,8 +364,10 @@ bool GraphCompiler::trace_one(ggml_tensor * node) {
363364 const ggml_tensor * K = node->src [1 ];
364365 const ggml_tensor * V = node->src [2 ];
365366 if (!K || !V) return false ;
367+ const ggml_tensor * Kc = canonical (K);
366368 op.p .flash_attn .n_kv_heads = K->ne [2 ];
367369 op.p .flash_attn .kv_type = (int ) K->type ;
370+ op.p .flash_attn .kv_seq_max = Kc ? Kc->ne [1 ] : K->ne [1 ];
368371 op.p .flash_attn .nb_k1 = K->nb [1 ];
369372 op.p .flash_attn .nb_k2 = K->nb [2 ];
370373 op.p .flash_attn .nb_v1 = V->nb [1 ];
@@ -481,6 +484,47 @@ bool GraphCompiler::trace_one(ggml_tensor * node) {
481484 op.vebp_wsc_idx = sp.wsc_slot ;
482485 }
483486
487+ // BF16 KV-cache col-major shadow companions. The interpreter path keeps
488+ // these shadows up to date in ops/set_rows.cpp and consumes them in
489+ // ops/flash_attn.cpp. Compiled graphs bypass those hooks, so mirror the
490+ // SET_ROWS writes in generated code and read the same shadows from FA.
491+ static const bool kv_shadow_enabled =
492+ std::getenv (" GGML_VE_NO_KV_SHADOW" ) == nullptr ;
493+ if (kv_shadow_enabled && op.type == OpType::SET_ROWS &&
494+ op.dst_kind == BufferKind::KV_CACHE && op.dst_type == GGML_TYPE_BF16 &&
495+ op.dst_idx >= 0 && op.src0_ne [0 ] > 0 && op.ne [1 ] > 0 ) {
496+ KvShadowSpec sp;
497+ sp.source_slot = op.dst_idx ;
498+ sp.channels = op.src0_ne [0 ];
499+ sp.seq_max = op.ne [1 ];
500+ sp.shadow_slot = (int ) tensor_slot_order_.size ();
501+ tensor_slot_order_.push_back ({nullptr , BufferKind::KV_SHADOW});
502+ kv_shadow_specs_.push_back (sp);
503+ op.dst_shadow_idx = sp.shadow_slot ;
504+ }
505+ if (kv_shadow_enabled && op.type == OpType::FLASH_ATTN &&
506+ op.p .flash_attn .kv_type == (int ) GGML_TYPE_BF16 &&
507+ op.src1_idx >= 0 && op.src2_idx >= 0 && op.src1_ne [0 ] > 0 && op.src1_ne [1 ] > 0 ) {
508+ const int64_t channels = op.src1_ne [0 ] * op.p .flash_attn .n_kv_heads ;
509+ KvShadowSpec ksp;
510+ ksp.source_slot = op.src1_idx ;
511+ ksp.channels = channels;
512+ ksp.seq_max = op.p .flash_attn .kv_seq_max ;
513+ ksp.shadow_slot = (int ) tensor_slot_order_.size ();
514+ tensor_slot_order_.push_back ({nullptr , BufferKind::KV_SHADOW});
515+ kv_shadow_specs_.push_back (ksp);
516+ op.src1_shadow_idx = ksp.shadow_slot ;
517+
518+ KvShadowSpec vsp;
519+ vsp.source_slot = op.src2_idx ;
520+ vsp.channels = channels;
521+ vsp.seq_max = op.p .flash_attn .kv_seq_max ;
522+ vsp.shadow_slot = (int ) tensor_slot_order_.size ();
523+ tensor_slot_order_.push_back ({nullptr , BufferKind::KV_SHADOW});
524+ kv_shadow_specs_.push_back (vsp);
525+ op.src2_shadow_idx = vsp.shadow_slot ;
526+ }
527+
484528 traced_ops_.push_back (op);
485529 return true ;
486530}
@@ -497,6 +541,7 @@ bool GraphCompiler::trace(ggml_cgraph * cgraph) {
497541 colmajor_specs_.clear ();
498542 q4k_specs_.clear ();
499543 vebp_specs_.clear ();
544+ kv_shadow_specs_.clear ();
500545 trace_valid_ = true ;
501546
502547 // n_tokens of this cgraph: 1 = decode, N = prompt eval. Any MUL_MAT's dst
@@ -1169,14 +1214,22 @@ std::string GraphCompiler::gen_op_code(const TracedOp & op, int idx) const {
11691214 ss << " int cols = " << cols << " ;\n " ;
11701215 ss << " char* dstbase = (char*)" << dst << " ;\n " ;
11711216 ss << " const float* src0p = (const float*)" << src0 << " ;\n " ;
1217+ if (op.dst_shadow_idx >= 0 ) {
1218+ ss << " uint16_t* shbase = (uint16_t*)p[" << op.dst_shadow_idx << " ];\n " ;
1219+ ss << " int seq_max = " << op.ne [1 ] << " ;\n " ;
1220+ }
11721221 ss << " #pragma omp for\n " ;
11731222 ss << " for (long t = 0; t < n_tok; t++) {\n " ;
11741223 ss << " int64_t idx0 = positions[t];\n " ;
11751224 ss << " const float* src = src0p + t * cols;\n " ;
11761225 ss << " uint16_t* drow = (uint16_t*)(dstbase + idx0 * " << dst_row_bytes << " );\n " ;
11771226 ss << " for (int j = 0; j < cols; j++) {\n " ;
11781227 ss << " uint32_t u; memcpy(&u, &src[j], 4);\n " ;
1179- ss << " drow[j] = (uint16_t)(u >> 16);\n " ;
1228+ ss << " uint16_t bf = (uint16_t)(u >> 16);\n " ;
1229+ ss << " drow[j] = bf;\n " ;
1230+ if (op.dst_shadow_idx >= 0 ) {
1231+ ss << " shbase[(size_t)j * seq_max + idx0] = bf;\n " ;
1232+ }
11801233 ss << " }\n " ;
11811234 ss << " }\n " ;
11821235 ss << " }\n " ;
@@ -1203,6 +1256,12 @@ std::string GraphCompiler::gen_op_code(const TracedOp & op, int idx) const {
12031256 int n_q_heads = (op.src0_ne [2 ] > 1 ) ? (int ) op.src0_ne [2 ]
12041257 : (int ) op.src0_ne [1 ];
12051258 int n_kv_heads= (int ) op.p .flash_attn .n_kv_heads ;
1259+ const bool has_kv_shadow = op.src1_shadow_idx >= 0 && op.src2_shadow_idx >= 0 ;
1260+ const int colmajor_min = []{
1261+ const char * e = std::getenv (" GGML_VE_COLMAJOR_FA_MIN" );
1262+ int v = e ? std::atoi (e) : 96 ;
1263+ return v > 0 ? v : 1 ;
1264+ }();
12061265 ss << " {\n " ;
12071266 ss << " int head_dim = " << head_dim << " ;\n " ;
12081267 ss << " int n_q_heads = " << n_q_heads << " ;\n " ;
@@ -1217,17 +1276,34 @@ std::string GraphCompiler::gen_op_code(const TracedOp & op, int idx) const {
12171276 ss << " const char* qp = (const char*)" << src0 << " ;\n " ;
12181277 ss << " const void* kp = (const void*)" << src1 << " ;\n " ;
12191278 ss << " const void* vp = (const void*)" << src2 << " ;\n " ;
1279+ if (has_kv_shadow) {
1280+ ss << " const uint16_t* kp_col = (const uint16_t*)p[" << op.src1_shadow_idx << " ];\n " ;
1281+ ss << " const uint16_t* vp_col = (const uint16_t*)p[" << op.src2_shadow_idx << " ];\n " ;
1282+ ss << " int seq_max = " << op.p .flash_attn .kv_seq_max << " ;\n " ;
1283+ ss << " int colmajor_min = " << colmajor_min << " ;\n " ;
1284+ }
12201285 if (op.p .flash_attn .kv_type == (int ) GGML_TYPE_BF16) {
12211286 // One query token at a time; the strided _inner shares heads
12221287 // across the team via `#pragma omp for` (no fork). seq_len =
12231288 // positions[t]+1 is the causal mask. Decode = the n_tok==1,
12241289 // q_nb2==head_dim*4 special case.
12251290 ss << " for (int64_t t = 0; t < n_tok; t++) {\n " ;
12261291 ss << " int seq_len = positions[t] + 1;\n " ;
1292+ if (has_kv_shadow) {
1293+ ss << " if (seq_len >= colmajor_min) {\n " ;
1294+ ss << " attention_f32q_bf16kv_colmajor_inner_strided(\n " ;
1295+ ss << " (float*)(outp + t*o_nb2), (const float*)(qp + t*q_nb1),\n " ;
1296+ ss << " kp_col, vp_col, head_dim, n_q_heads, n_kv_heads, seq_len, seq_max, scale,\n " ;
1297+ ss << " q_nb2, o_nb1);\n " ;
1298+ ss << " } else {\n " ;
1299+ }
12271300 ss << " attention_f32q_bf16kv_fused_gqa_inner_strided(\n " ;
12281301 ss << " (float*)(outp + t*o_nb2), (const float*)(qp + t*q_nb1),\n " ;
12291302 ss << " kp, vp, head_dim, n_q_heads, n_kv_heads, seq_len, scale,\n " ;
12301303 ss << " q_nb2, o_nb1, nb_k1, nb_k2, nb_v1, nb_v2);\n " ;
1304+ if (has_kv_shadow) {
1305+ ss << " }\n " ;
1306+ }
12311307 ss << " }\n " ;
12321308 } else {
12331309 // F32 KV: no strided _inner; per-token omp-single fallback (the
@@ -1369,6 +1445,9 @@ std::string GraphCompiler::generate_source(const std::string & func_name) const
13691445 ss << " extern void attention_f32q_bf16kv_fused_gqa_inner_strided(float* out, const float* q, const void* k, const void* v,"
13701446 << " int head_dim, int n_q_heads, int n_kv_heads, int seq_len, float scale,"
13711447 << " size_t q_nb2, size_t o_nb1, size_t nb_k1, size_t nb_k2, size_t nb_v1, size_t nb_v2);\n " ;
1448+ ss << " extern void attention_f32q_bf16kv_colmajor_inner_strided(float* out, const float* q, const uint16_t* k_col, const uint16_t* v_col,"
1449+ << " int head_dim, int n_q_heads, int n_kv_heads, int seq_len, int seq_max, float scale,"
1450+ << " size_t q_nb2, size_t o_nb1);\n " ;
13721451 ss << " extern void swiglu_hbm_full_inner(float* y, float* gate, float* up, int nc, int nr);\n\n " ;
13731452
13741453 // Ops are grouped into small static chunk functions instead of one giant
@@ -1628,6 +1707,7 @@ CompiledGraph * GraphCompiler::load_compiled(const std::string & so_path, const
16281707 cg->colmajor_specs = colmajor_specs_;
16291708 cg->q4k_specs = q4k_specs_;
16301709 cg->vebp_specs = vebp_specs_;
1710+ cg->kv_shadow_specs = kv_shadow_specs_;
16311711 return cg;
16321712}
16331713
@@ -1762,7 +1842,8 @@ bool GraphCompiler::execute(CompiledGraph * graph,
17621842 || graph->slot_kinds [slot_pos] == BufferKind::WEIGHT_Q4K_HDR
17631843 || graph->slot_kinds [slot_pos] == BufferKind::WEIGHT_VEBP_WS
17641844 || graph->slot_kinds [slot_pos] == BufferKind::WEIGHT_VEBP_WN
1765- || graph->slot_kinds [slot_pos] == BufferKind::WEIGHT_VEBP_WSC)) {
1845+ || graph->slot_kinds [slot_pos] == BufferKind::WEIGHT_VEBP_WSC
1846+ || graph->slot_kinds [slot_pos] == BufferKind::KV_SHADOW)) {
17661847 ++slot_pos;
17671848 }
17681849 };
@@ -2013,7 +2094,34 @@ bool GraphCompiler::execute(CompiledGraph * graph,
20132094 }
20142095 }
20152096
2016- // After walking the cgraph + populating colmajor slots, slot_pos should
2097+ // BF16 KV-cache col-major shadow slots. Generated SET_ROWS mirrors into
2098+ // these shadows and generated FLASH_ATTN can read them directly, preserving
2099+ // the interpreter's kv-shadow long-context optimization inside the fused graph.
2100+ if (!graph->kv_shadow_specs .empty () && bctx) {
2101+ auto * dev = bctx->dev ();
2102+ auto * shadows = dev ? dev->kv_shadow : nullptr ;
2103+ if (!shadows) {
2104+ if (debug_enabled ()) fprintf (stderr, " [VE-GC] kv-shadow cache unavailable — abort\n " );
2105+ return false ;
2106+ }
2107+ for (const auto & sp : graph->kv_shadow_specs ) {
2108+ if (sp.source_slot < 0 || sp.source_slot >= (int ) tptrs.size ()) continue ;
2109+ if (sp.shadow_slot < 0 || sp.shadow_slot >= (int ) tptrs.size ()) continue ;
2110+ VEDAdeviceptr src = tptrs[sp.source_slot ];
2111+ if (src == 0 ) return false ;
2112+ kv_shadow * sh = shadows->get_or_create (src, sp.channels , sp.seq_max );
2113+ if (!sh || sh->shadow_hbm == 0 ) {
2114+ if (debug_enabled ()) {
2115+ fprintf (stderr, " [VE-GC] kv-shadow lookup failed for slot %d (channels=%ld seq=%ld)\n " ,
2116+ sp.source_slot , (long ) sp.channels , (long ) sp.seq_max );
2117+ }
2118+ return false ;
2119+ }
2120+ tptrs[sp.shadow_slot ] = sh->shadow_hbm ;
2121+ }
2122+ }
2123+
2124+ // After walking the cgraph + populating companion slots, slot_pos should
20172125 // have stepped past every WEIGHT/KV/INTERMEDIATE slot in the table —
20182126 // anything less means try_push ran out of cgraph tensors before filling
20192127 // all expected real slots.
@@ -2039,6 +2147,7 @@ bool GraphCompiler::execute(CompiledGraph * graph,
20392147 if (graph->slot_kinds [i] == BufferKind::WEIGHT_VEBP_WS) continue ;
20402148 if (graph->slot_kinds [i] == BufferKind::WEIGHT_VEBP_WN) continue ;
20412149 if (graph->slot_kinds [i] == BufferKind::WEIGHT_VEBP_WSC) continue ;
2150+ if (graph->slot_kinds [i] == BufferKind::KV_SHADOW) continue ;
20422151 if (tptrs[i] == 0 ) {
20432152 if (debug_enabled ()) {
20442153 fprintf (stderr, " [VE-GC] slot %zu kind=%d is NULL — abort\n " ,
0 commit comments