Skip to content

Commit d9a0518

Browse files
Merge pull request #193 from SKaiNET-developers/fix/763-kvcache-trace-fidelity
fix(kvcache): trace-faithful PositionalKVCache.update (#763)
2 parents a49fe30 + 82adc64 commit d9a0518

2 files changed

Lines changed: 80 additions & 0 deletions

File tree

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package sk.ainet.models.gemma
2+
3+
import sk.ainet.context.ExecutionContext
4+
import sk.ainet.lang.graph.DefaultExecutionTape
5+
import sk.ainet.lang.graph.DefaultGraphExecutionContext
6+
import sk.ainet.lang.tensor.Shape
7+
import sk.ainet.lang.tensor.VoidOpsTensor
8+
import sk.ainet.lang.tensor.data.TensorData
9+
import sk.ainet.lang.tensor.ops.VoidTensorOps
10+
import sk.ainet.lang.types.FP32
11+
import sk.ainet.tape.Execution
12+
import kotlin.test.Test
13+
import kotlin.test.assertEquals
14+
import kotlin.test.assertTrue
15+
16+
/**
17+
* Regression test for SKaiNET#763: [PositionalKVCache.update] must wire K/V functionally while
18+
* tracing (ctx.isRecording) instead of round-tripping through its heap buffer, which under tracing
19+
* baked an all-zero KV-cache as `stablehlo.constant` and disconnected the computed k_proj/v_proj —
20+
* so the exported decoder attended over K=V=0. (Masked by the unnormalized FFN in non-sandwich
21+
* configs; exposed by Gemma's post_ffw_norm.)
22+
*/
23+
class KvCacheTraceFidelityTest {
24+
@Test
25+
fun tracedDecoderKeepsComputedKV() {
26+
val meta = Gemma4ModelMetadata(
27+
architecture = "gemma3", embeddingLength = 64, contextLength = 128, blockCount = 2,
28+
headCount = 2, kvHeadCount = 1, intermediateSize = 128, headDim = 32, globalHeadDim = 32,
29+
vocabSize = 48, slidingWindow = 64, kvSharedLayers = 0, layerTypes = List(2) { "full_attention" },
30+
ropeParametersFull = Gemma4RopeConfig(base = 10000.0f),
31+
ropeParametersSliding = Gemma4RopeConfig(base = 10000.0f), maxPositionEmbeddings = 128,
32+
)
33+
val model = gemmaNetwork<FP32, Float>(meta, FP32::class, maxInferenceLen = 4, sandwichNorms = true)
34+
val input = VoidOpsTensor(object : TensorData<FP32, Float> {
35+
override val shape = Shape(4); override fun get(vararg i: Int) = 0.0f; override fun set(vararg i: Int, value: Float) {}
36+
}, FP32::class)
37+
val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps())
38+
val tape = ctx.record {
39+
val ct = (this as DefaultGraphExecutionContext).currentTape ?: error("no tape")
40+
Execution.tapeStack.pushTape(ct)
41+
try { model.forward(input, this as ExecutionContext) } finally { Execution.tapeStack.popTape() }
42+
}.first
43+
val graph = (tape as DefaultExecutionTape).toComputeGraph(synthesizeExternalInputs = true, embedConstants = true)
44+
val mlir = sk.ainet.compile.hlo.toStableHlo(graph, "gemma").content
45+
46+
// The KV head shape here is [1, 4, 32] (nKVHeads=1, seq=4, headDim=32). Before the fix the
47+
// K and V for each block appeared as `stablehlo.constant dense<0.0> : tensor<1x4x32xf32>`.
48+
val kvZeroConstants = Regex("stablehlo\\.constant dense<0\\.0> : tensor<1x4x32xf32>").findAll(mlir).count()
49+
assertEquals(0, kvZeroConstants,
50+
"Traced sandwich Gemma decoder baked $kvZeroConstants zero KV-cache constants — K/V not wired (regression of #763)")
51+
// Sanity: K/V projections are present as real dot_generals (q,k,v,o + ffn etc.).
52+
assertTrue(Regex("dot_general").findAll(mlir).count() >= 6, "expected real projection dot_generals in the export")
53+
}
54+
}

transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer/KVCache.kt

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,35 @@ public class PositionalKVCache<T : DType, V>(
215215
internal val valueBuf: FloatArray = FloatArray(maxSeqLen * nKVHeads * headDim)
216216
private var pos: Int = 0
217217

218+
// Functional K/V history kept only while tracing (see update()).
219+
private var tracedKeys: Tensor<T, V>? = null
220+
private var tracedValues: Tensor<T, V>? = null
221+
218222
override fun update(
219223
newKey: Tensor<T, V>,
220224
newValue: Tensor<T, V>,
221225
ctx: ExecutionContext
222226
): Pair<Tensor<T, V>, Tensor<T, V>> {
227+
// The eager buffer path below copies tensor *data* into a heap array
228+
// (writeAt) and reads it back via ctx.fromData (sliceView), bypassing
229+
// ctx.ops. Under tracing the incoming K/V carry no data, so that path
230+
// would bake an all-zero buffer as a constant and disconnect the
231+
// computed k_proj/v_proj from attention (the exported decoder then
232+
// attends over K=V=0). When recording, wire K/V functionally instead —
233+
// the same ops.concat history AppendKVCache uses — so the StableHLO
234+
// export carries the real projections.
235+
if (ctx.isRecording) {
236+
val ops = ctx.ops
237+
val seqDim = newKey.rank - 2
238+
val prevK = tracedKeys
239+
val prevV = tracedValues
240+
val fullK = if (prevK != null) ops.concat(listOf(prevK, newKey), dim = seqDim) else newKey
241+
val fullV = if (prevV != null) ops.concat(listOf(prevV, newValue), dim = seqDim) else newValue
242+
tracedKeys = fullK
243+
tracedValues = fullV
244+
pos += newKey.shape[seqDim]
245+
return fullK to fullV
246+
}
223247
val newLen = newKey.shape[newKey.rank - 2]
224248
writeAt(pos, newKey, newValue)
225249
pos += newLen
@@ -321,6 +345,8 @@ public class PositionalKVCache<T : DType, V>(
321345

322346
override fun reset() {
323347
pos = 0
348+
tracedKeys = null
349+
tracedValues = null
324350
}
325351

326352
override val position: Int get() = pos

0 commit comments

Comments
 (0)