Skip to content

Commit a620cbd

Browse files
committed
Fix SWA KVarN audit findings: OOB shared memory, force-materialize crash, ring sizing, Vulkan
Four fixes from code audit of commit e211e4d: 1. CUDA WHT out-of-bounds (critical): kvarn_materialize_swa_kernel re-implemented the 128-dim inverse WHT inline, running the butterfly on all 128 threads without the threadIdx.x < 64 guard. Threads 64-127 read/wrote sh[128..255] on a float[128] shared array — UB that produced correct results in sh[0..127] (only threads 0-63 touch those) but could corrupt neighboring blocks under high occupancy. Replaced with kvarn_wht_128(sh), the guarded store-path WHT. Since H_128 is symmetric, the forward WHT is the correct inverse. 2. force-materialize null mat_idxs (major): self_kvarn_mat_idxs_swa was only built under !kvarn_force_materialize_enabled(), but the non-rotated (force-materialize) path still calls get_k/get_v -> materialize(swa=true, mat_idxs=nullptr), which derefs indices->type and crashes. Now built whenever the SWA cache is KVarN, independent of force-materialize. 3. SWA ring under-size (major): n_groups_per_stream = ceil(kv_size/128) was too small — the metadata window of kv_size cells spans ceil(kv_size/128)+1 tiles (sliding window is rarely tile-aligned), so the oldest in-window tile's record slot collided with a newer tile, silently zeroing it. Now ceil(kv_size/128)+2 for SWA, with a backstop assert documenting the invariant. 4. Vulkan SWA path (gap): kvarn_store.comp and kvarn_materialize.comp had no SWA support (linear group decode, group==0 sink, no swa push-constant). Vulkan advertises kvarn_native_ops, so SWA KVarN layers could offload to Vulkan and run the non-SWA shaders on absolute-position indices -> silent garbage. Added swa push-constant, ring slot math, per-cell position decode, and empty-cell zeroing to both shaders, mirroring CPU/CUDA. Host dispatch reads op_params[4] (store) and [6] (materialize) and asserts single-stream for SWA. Verified: test-kvarn green (CPU+CUDA SWA parity, GPU SWA path now uses guarded WHT); llama-perplexity KLD on Gemma 4 31B Q5/16k/kvarn4 = 0.7296 (statistically identical to pre-fix 0.7305 — fixes resolve latent bugs without changing validated quality); GGML_KVARN_FORCE_MATERIALIZE=1 smoke on Gemma 4 31B generates coherent text (no crash). Vulkan path is theoretical (not compiled in CUDA-only build).
1 parent e211e4d commit a620cbd

6 files changed

Lines changed: 97 additions & 38 deletions

File tree

ggml/src/ggml-cuda/kvarn.cu

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,21 +1335,16 @@ static __global__ void kvarn_materialize_swa_kernel(
13351335
return;
13361336
}
13371337

1338-
// inverse WHT (128-dim) via shared-memory butterfly; mirrors kvarn_wht_128
1338+
// inverse WHT (128-dim): reuse the store path's shared-memory butterfly. It
1339+
// guards the butterfly to lanes < 64 (each lane handles one pair) and applies
1340+
// the 1/sqrt(128) normalization. Running the butterfly unguarded over all 128
1341+
// lanes (as an earlier inline version did) makes lanes 64..127 read/write
1342+
// sh[128..255] out of bounds on this float[128] array.
13391343
__shared__ float sh[KVAR_N_DIM];
13401344
sh[dim] = rotated;
1341-
__syncthreads();
1342-
for (int stride = 1; stride < KVAR_N_DIM; stride *= 2) {
1343-
const int j = (dim / stride) * (2 * stride) + (dim % stride);
1344-
const float a = sh[j];
1345-
const float b = sh[j + stride];
1346-
sh[j] = a + b;
1347-
sh[j + stride] = a - b;
1348-
__syncthreads();
1349-
}
1350-
const float out_val = sh[dim] * 0.08838834764831845f;
1345+
kvarn_wht_128(sh);
13511346
half * out = dst + ((int64_t) cell * n_heads + head) * KVAR_N_DIM;
1352-
out[dim] = __float2half_rn(out_val);
1347+
out[dim] = __float2half_rn(sh[dim]);
13531348
}
13541349

13551350
template<int BITS, bool VALUE>

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,7 @@ struct vk_op_kvarn_store_push_constants {
11981198
uint32_t bits;
11991199
uint32_t iterations;
12001200
uint32_t value;
1201+
uint32_t swa; // SWA sliding-window ring store (absolute-position indices, no sink)
12011202
};
12021203
static_assert(sizeof(vk_op_kvarn_store_push_constants) <= 128, "sizeof(vk_op_kvarn_store_push_constants) must be <= 128");
12031204

@@ -1212,6 +1213,7 @@ struct vk_op_kvarn_materialize_push_constants {
12121213
uint32_t value;
12131214
uint32_t n_indices;
12141215
uint32_t emit_rotated;
1216+
uint32_t swa; // SWA sliding-window ring materialize (indices carry per-cell positions)
12151217
};
12161218
static_assert(sizeof(vk_op_kvarn_materialize_push_constants) <= 128, "sizeof(vk_op_kvarn_materialize_push_constants) must be <= 128");
12171219

@@ -9158,9 +9160,13 @@ static void ggml_vk_kvarn_store(ggml_backend_vk_context * ctx, vk_context& subct
91589160
const int bits = ggml_get_op_params_i32(dst, 0);
91599161
const int iterations = ggml_get_op_params_i32(dst, 1);
91609162
const bool value = ggml_get_op_params_i32(dst, 2) != 0;
9163+
const bool swa = ggml_get_op_params_i32(dst, 4) != 0; // KVAR_N_OP_PARAM_STORE_SWA
91619164
GGML_ASSERT(ggml_vk_kvarn_valid_bits(bits));
91629165
const int n_stream = (int) (stage->ne[2] / 384);
91639166
const int groups_per_stream = (int) (records->ne[2] / n_stream);
9167+
if (swa) {
9168+
GGML_ASSERT(n_stream == 1 && "SWA KVarN ring requires a single stream");
9169+
}
91649170

91659171
vk_op_kvarn_store_push_constants pc = {
91669172
(uint32_t) current->ne[1],
@@ -9171,6 +9177,7 @@ static void ggml_vk_kvarn_store(ggml_backend_vk_context * ctx, vk_context& subct
91719177
(uint32_t) bits,
91729178
(uint32_t) iterations,
91739179
value ? 1u : 0u,
9180+
swa ? 1u : 0u,
91749181
};
91759182

91769183
const vk_subbuffer current_buf = ggml_vk_tensor_subbuffer(ctx, current);
@@ -9201,9 +9208,13 @@ static void ggml_vk_kvarn_materialize(ggml_backend_vk_context * ctx, vk_context&
92019208
const int stream_start = ggml_get_op_params_i32(dst, 2);
92029209
const int n_stream = ggml_get_op_params_i32(dst, 3);
92039210
const bool emit_rotated = ggml_get_op_params_i32(dst, 5) != 0;
9211+
const bool swa = ggml_get_op_params_i32(dst, 6) != 0; // KVAR_N_OP_PARAM_MAT_SWA
92049212
GGML_ASSERT(ggml_vk_kvarn_valid_bits(bits));
92059213
const int n_total_stream = (int) (stage->ne[2] / 384);
92069214
const int groups_per_stream = (int) (records->ne[2] / n_total_stream);
9215+
if (swa) {
9216+
GGML_ASSERT(n_stream == 1 && "SWA KVarN ring materialize requires a single stream");
9217+
}
92079218

92089219
vk_op_kvarn_materialize_push_constants pc = {
92099220
(uint32_t) dst->ne[1],
@@ -9216,6 +9227,7 @@ static void ggml_vk_kvarn_materialize(ggml_backend_vk_context * ctx, vk_context&
92169227
value ? 1u : 0u,
92179228
(uint32_t) indices->ne[0],
92189229
emit_rotated ? 1u : 0u,
9230+
swa ? 1u : 0u,
92199231
};
92209232

92219233
const vk_subbuffer records_buf = ggml_vk_tensor_subbuffer(ctx, records);

ggml/src/ggml-vulkan/vulkan-shaders/kvarn_materialize.comp

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ layout(push_constant) uniform parameter {
1313
uint value;
1414
uint n_indices;
1515
uint emit_rotated;
16+
uint swa; // SWA sliding-window ring materialize (indices carry per-cell positions)
1617
} p;
1718

1819
layout(binding = 0, std430) readonly buffer Records { uint data_records[]; };
@@ -59,10 +60,18 @@ float load_axis_value(uint record_base, uint payload_words, uint axis, uint inde
5960
uint compute_live_group(uint stream, uint lane) {
6061
uint live_group = 0;
6162
for (uint i = lane; i < p.n_indices; i += KVAR_N_DIM) {
62-
const uint group_global = read_index_low(i) / KVAR_N_DIM;
63-
const uint idx_stream = group_global / p.groups_per_stream;
64-
if (idx_stream == stream) {
65-
live_group = max(live_group, group_global - stream * p.groups_per_stream);
63+
if (p.swa != 0u) {
64+
// SWA ring: indices carry absolute positions; negative marks empty cells
65+
const int v = int(read_index_low(i));
66+
if (v >= 0) {
67+
live_group = max(live_group, uint(v) / KVAR_N_DIM);
68+
}
69+
} else {
70+
const uint group_global = read_index_low(i) / KVAR_N_DIM;
71+
const uint idx_stream = group_global / p.groups_per_stream;
72+
if (idx_stream == stream) {
73+
live_group = max(live_group, group_global - stream * p.groups_per_stream);
74+
}
6675
}
6776
}
6877

@@ -111,16 +120,40 @@ void main() {
111120
const uint stream = p.stream_start + out_stream;
112121

113122
const uint live_group = compute_live_group(stream, lane);
114-
const uint group = token / KVAR_N_DIM;
115-
const uint pos = token % KVAR_N_DIM;
116123
const uint stage_base = stream * KVAR_N_DIM * KVAR_N_STAGE_GROUPS;
117124

125+
uint group;
126+
uint pos;
127+
bool from_stage;
128+
bool from_record;
129+
if (p.swa != 0u) {
130+
// SWA ring: one absolute position per output cell; negative => empty cell.
131+
const int abs_pos = int(read_index_low(token));
132+
if (abs_pos < 0) {
133+
sh_wht[lane] = 0.0;
134+
barrier();
135+
store_dst_pair(out_stream, token, head, lane);
136+
return;
137+
}
138+
group = uint(abs_pos) / KVAR_N_DIM;
139+
pos = uint(abs_pos) % KVAR_N_DIM;
140+
from_stage = (group + 1u >= live_group) && (group <= live_group);
141+
from_record = (!from_stage) && (group < live_group) && (live_group - group < p.groups_per_stream);
142+
} else {
143+
group = token / KVAR_N_DIM;
144+
pos = token % KVAR_N_DIM;
145+
from_stage = group == 0u || (group > 0u && group <= live_group && group + 1u >= live_group);
146+
from_record = (!from_stage) && group < live_group && group < p.groups_per_stream;
147+
}
148+
118149
float x = 0.0;
119-
if (group == 0 || (group > 0 && group <= live_group && group + 1 >= live_group)) {
120-
const uint stage_pos = stage_base + (group == 0 ? pos : KVAR_N_DIM + (((group - 1) & 1) * KVAR_N_DIM) + pos);
150+
if (from_stage) {
151+
const uint stage_slot = p.swa != 0u ? (group % KVAR_N_STAGE_GROUPS) : (group == 0u ? 0u : 1u + ((group - 1u) & 1u));
152+
const uint stage_pos = stage_base + stage_slot * KVAR_N_DIM + pos;
121153
x = load_stage_value(stage_pos, head, lane);
122-
} else if (group < live_group && group < p.groups_per_stream) {
123-
const uint record_group = stream * p.groups_per_stream + group;
154+
} else if (from_record) {
155+
const uint ring = p.swa != 0u ? (group % p.groups_per_stream) : group;
156+
const uint record_group = stream * p.groups_per_stream + ring;
124157
const uint record_base = (record_group * p.n_heads + head) * p.record_words;
125158
const uint row = p.value != 0 ? pos : lane;
126159
const uint col = p.value != 0 ? lane : pos;

ggml/src/ggml-vulkan/vulkan-shaders/kvarn_store.comp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ layout(push_constant) uniform parameter {
1111
uint bits;
1212
uint iterations;
1313
uint value;
14+
uint swa; // SWA sliding-window ring store (absolute-position indices, no group-0 sink)
1415
} p;
1516

1617
layout(binding = 0, std430) readonly buffer Current { float data_current[]; };
@@ -71,7 +72,9 @@ void wht_128(uint lane) {
7172
float stage_matrix_value(uint head, uint stage_base, uint stage_group, uint row, uint col) {
7273
const uint token = p.value != 0 ? row : col;
7374
const uint dim = p.value != 0 ? col : row;
74-
const uint stage_pos = stage_base + KVAR_N_DIM + (((stage_group - 1) & 1) * KVAR_N_DIM) + token;
75+
// SWA: 3-deep ping-pong over absolute tiles; non-SWA: tile 0 sink + slots 1/2.
76+
const uint stage_slot = p.swa != 0u ? (stage_group % KVAR_N_STAGE_GROUPS) : (1u + ((stage_group - 1u) & 1u));
77+
const uint stage_pos = stage_base + stage_slot * KVAR_N_DIM + token;
7578
return load_stage_value(stage_pos, head, dim);
7679
}
7780

@@ -240,24 +243,28 @@ void main() {
240243
for (uint token = 0; token < p.n_tokens; ++token) {
241244
const uint idx = read_index_low(token);
242245
const uint group_global = idx / KVAR_N_DIM;
243-
const uint stream = group_global / p.groups_per_stream;
244-
const uint group = group_global - stream * p.groups_per_stream;
245246
const uint pos = idx % KVAR_N_DIM;
246-
if (stream >= p.n_stream || group >= p.groups_per_stream) {
247+
// SWA: idx is the absolute token position; records form a ring (single
248+
// stream) and there is no permanent group-0 sink.
249+
const uint stream = p.swa != 0u ? 0u : group_global / p.groups_per_stream;
250+
const uint group = p.swa != 0u ? group_global : group_global - stream * p.groups_per_stream;
251+
if (stream >= p.n_stream || (p.swa == 0u && group >= p.groups_per_stream)) {
247252
return;
248253
}
249254

250255
const uint stage_base = stream * KVAR_N_DIM * KVAR_N_STAGE_GROUPS;
251-
if (group > 2 && pos == 0) {
252-
const uint flush_group = group - 2;
253-
const uint flush_record_group = stream * p.groups_per_stream + flush_group;
256+
if (pos == 0 && (p.swa != 0u ? group >= 2u : group > 2u)) {
257+
const uint flush_group = group - 2u;
258+
const uint flush_ring = p.swa != 0u ? flush_group % p.groups_per_stream : flush_group;
259+
const uint flush_record_group = stream * p.groups_per_stream + flush_ring;
254260
const uint record_base = (flush_record_group * p.n_heads + head) * p.record_words;
255261
quantize_stage(head, stage_base, flush_group, record_base, lane);
256262
}
257263

258264
sh_wht[lane] = data_current[(token * p.n_heads + head) * KVAR_N_DIM + lane];
259265
wht_128(lane);
260-
const uint stage_pos = stage_base + (group == 0 ? pos : KVAR_N_DIM + (((group - 1) & 1) * KVAR_N_DIM) + pos);
266+
const uint stage_slot = p.swa != 0u ? (group % KVAR_N_STAGE_GROUPS) : (group == 0u ? 0u : 1u + ((group - 1u) & 1u));
267+
const uint stage_pos = stage_base + stage_slot * KVAR_N_DIM + pos;
261268
store_stage_pair(stage_pos, head, lane);
262269
barrier();
263270
}

src/llama-graph.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2891,13 +2891,15 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
28912891

28922892
inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0);
28932893
inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0);
2894-
if (!kvarn_force_materialize_enabled()) {
2895-
if (const auto * kvarn_swa = dynamic_cast<const llama_kv_cache_kvarn_context *>(mctx_cur->get_swa())) {
2896-
inp->self_kvarn_mat_idxs_swa = kvarn_swa->build_input_kvarn_mat_idxs(ctx0);
2897-
// make the materialize indices available to the context at graph build time
2898-
// (get_k/get_v/materialize run during build, before set_input populates them)
2899-
const_cast<llama_kv_cache_kvarn_context *>(kvarn_swa)->set_mat_idxs(inp->self_kvarn_mat_idxs_swa);
2900-
}
2894+
// SWA KVarN materialize needs per-cell positions on BOTH the rotated and the
2895+
// force-materialize (non-rotated) paths, so build them whenever the SWA cache
2896+
// is a KVarN cache — independent of kvarn_force_materialize_enabled(). Omitting
2897+
// them under force-materialize left mat_idxs null and crashed in materialize().
2898+
if (const auto * kvarn_swa = dynamic_cast<const llama_kv_cache_kvarn_context *>(mctx_cur->get_swa())) {
2899+
inp->self_kvarn_mat_idxs_swa = kvarn_swa->build_input_kvarn_mat_idxs(ctx0);
2900+
// make the materialize indices available to the context at graph build time
2901+
// (get_k/get_v/materialize run during build, before set_input populates them)
2902+
const_cast<llama_kv_cache_kvarn_context *>(kvarn_swa)->set_mat_idxs(inp->self_kvarn_mat_idxs_swa);
29012903
}
29022904

29032905
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));

src/llama-kv-cache-kvarn.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,12 @@ llama_kv_cache_kvarn::llama_kv_cache_kvarn(
414414
hparams(hparams),
415415
params(params),
416416
n_stream(unified ? 1u : n_seq_max),
417-
n_groups_per_stream((kv_size + KVAR_N_GROUP - 1) / KVAR_N_GROUP),
417+
// SWA: the metadata window of up to kv_size cells spans kv_size/128 + 1 tiles
418+
// (a sliding window is rarely tile-aligned), so the record ring needs 2 extra
419+
// slots over the non-SWA count to represent the oldest in-window tile without
420+
// a slot collision and to keep (live_group - group) < groups_per_stream.
421+
n_groups_per_stream(((kv_size + KVAR_N_GROUP - 1) / KVAR_N_GROUP) +
422+
((n_swa > 0 && swa_type != LLAMA_SWA_TYPE_NONE) ? 2u : 0u)),
418423
swa(n_swa > 0 && swa_type != LLAMA_SWA_TYPE_NONE),
419424
metadata(std::make_unique<llama_kv_cache>(
420425
model,
@@ -437,6 +442,11 @@ llama_kv_cache_kvarn::llama_kv_cache_kvarn(
437442
GGML_ASSERT(swa || kv_size % KVAR_N_GROUP == 0);
438443
if (swa) {
439444
GGML_ASSERT(n_stream == 1 && "SWA KVarN ring requires a unified (single-stream) cache");
445+
// Backstop for the ring-size invariant above: the record ring must have
446+
// strictly more slots than the metadata window's worst-case tile span so
447+
// the oldest in-window tile still materializes from records.
448+
GGML_ASSERT(n_groups_per_stream > (kv_size + KVAR_N_GROUP - 1) / KVAR_N_GROUP &&
449+
"SWA KVarN record ring is too small for the sliding window");
440450
}
441451

442452
struct buft_comparator {

0 commit comments

Comments
 (0)