From 82adc64b86815e2019cae3832c52479a9d3b878e Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Thu, 25 Jun 2026 11:22:40 +0200 Subject: [PATCH] fix(kvcache): trace-faithful PositionalKVCache.update (SKaiNET#763) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PositionalKVCache.update copied the incoming K/V tensor *data* into a heap FloatArray (writeAt) and read it back via ctx.fromData (sliceView), bypassing ctx.ops. Under tracing the symbolic K/V carry no data, so toComputeGraph(embedConstants=true) baked an all-zero KV buffer as stablehlo.constant and disconnected the computed k_proj/v_proj — the exported decoder then attended over K=V=0. In eager inference the buffer holds real data, so the bug was invisible there; in export it was masked in plain decoders by the unnormalized FFN dominating the residual stream, and exposed by Gemma's sandwichNorms (post_ffw_norm normalizes the FFN, so the lost attention becomes significant — ~1.4x/block logit error). Fix: when ctx.isRecording, wire K/V functionally through ops.concat (the same history AppendKVCache already uses) instead of the raw heap buffer, so the StableHLO export carries the real projections. The eager fast-path (heap buffer) is unchanged. As a side effect the traced graph no longer surfaces dangling KV-cache buffer leaves as graph outputs. Verified end-to-end via skainet-iree-conformance gemma-decoder (real 2-block Gemma3 with qk-norm + sandwichNorms + layer-output-scale): export -> iree-compile -> iree-run-module matches the SKaiNET-CPU oracle at max_abs_err 3.8e-6 (was ~5.3). Adds KvCacheTraceFidelityTest. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../models/gemma/KvCacheTraceFidelityTest.kt | 54 +++++++++++++++++++ .../sk/ainet/lang/nn/transformer/KVCache.kt | 26 +++++++++ 2 files changed, 80 insertions(+) create mode 100644 llm-inference/gemma/src/jvmTest/kotlin/sk/ainet/models/gemma/KvCacheTraceFidelityTest.kt diff --git a/llm-inference/gemma/src/jvmTest/kotlin/sk/ainet/models/gemma/KvCacheTraceFidelityTest.kt b/llm-inference/gemma/src/jvmTest/kotlin/sk/ainet/models/gemma/KvCacheTraceFidelityTest.kt new file mode 100644 index 00000000..facd8520 --- /dev/null +++ b/llm-inference/gemma/src/jvmTest/kotlin/sk/ainet/models/gemma/KvCacheTraceFidelityTest.kt @@ -0,0 +1,54 @@ +package sk.ainet.models.gemma + +import sk.ainet.context.ExecutionContext +import sk.ainet.lang.graph.DefaultExecutionTape +import sk.ainet.lang.graph.DefaultGraphExecutionContext +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.VoidOpsTensor +import sk.ainet.lang.tensor.data.TensorData +import sk.ainet.lang.tensor.ops.VoidTensorOps +import sk.ainet.lang.types.FP32 +import sk.ainet.tape.Execution +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +/** + * Regression test for SKaiNET#763: [PositionalKVCache.update] must wire K/V functionally while + * tracing (ctx.isRecording) instead of round-tripping through its heap buffer, which under tracing + * baked an all-zero KV-cache as `stablehlo.constant` and disconnected the computed k_proj/v_proj — + * so the exported decoder attended over K=V=0. (Masked by the unnormalized FFN in non-sandwich + * configs; exposed by Gemma's post_ffw_norm.) + */ +class KvCacheTraceFidelityTest { + @Test + fun tracedDecoderKeepsComputedKV() { + val meta = Gemma4ModelMetadata( + architecture = "gemma3", embeddingLength = 64, contextLength = 128, blockCount = 2, + headCount = 2, kvHeadCount = 1, intermediateSize = 128, headDim = 32, globalHeadDim = 32, + vocabSize = 48, slidingWindow = 64, kvSharedLayers = 0, layerTypes = List(2) { "full_attention" }, + ropeParametersFull = Gemma4RopeConfig(base = 10000.0f), + ropeParametersSliding = Gemma4RopeConfig(base = 10000.0f), maxPositionEmbeddings = 128, + ) + val model = gemmaNetwork(meta, FP32::class, maxInferenceLen = 4, sandwichNorms = true) + val input = VoidOpsTensor(object : TensorData { + override val shape = Shape(4); override fun get(vararg i: Int) = 0.0f; override fun set(vararg i: Int, value: Float) {} + }, FP32::class) + val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps()) + val tape = ctx.record { + val ct = (this as DefaultGraphExecutionContext).currentTape ?: error("no tape") + Execution.tapeStack.pushTape(ct) + try { model.forward(input, this as ExecutionContext) } finally { Execution.tapeStack.popTape() } + }.first + val graph = (tape as DefaultExecutionTape).toComputeGraph(synthesizeExternalInputs = true, embedConstants = true) + val mlir = sk.ainet.compile.hlo.toStableHlo(graph, "gemma").content + + // The KV head shape here is [1, 4, 32] (nKVHeads=1, seq=4, headDim=32). Before the fix the + // K and V for each block appeared as `stablehlo.constant dense<0.0> : tensor<1x4x32xf32>`. + val kvZeroConstants = Regex("stablehlo\\.constant dense<0\\.0> : tensor<1x4x32xf32>").findAll(mlir).count() + assertEquals(0, kvZeroConstants, + "Traced sandwich Gemma decoder baked $kvZeroConstants zero KV-cache constants — K/V not wired (regression of #763)") + // Sanity: K/V projections are present as real dot_generals (q,k,v,o + ffn etc.). + assertTrue(Regex("dot_general").findAll(mlir).count() >= 6, "expected real projection dot_generals in the export") + } +} diff --git a/transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer/KVCache.kt b/transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer/KVCache.kt index bb78f467..b91cfc32 100644 --- a/transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer/KVCache.kt +++ b/transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer/KVCache.kt @@ -215,11 +215,35 @@ public class PositionalKVCache( internal val valueBuf: FloatArray = FloatArray(maxSeqLen * nKVHeads * headDim) private var pos: Int = 0 + // Functional K/V history kept only while tracing (see update()). + private var tracedKeys: Tensor? = null + private var tracedValues: Tensor? = null + override fun update( newKey: Tensor, newValue: Tensor, ctx: ExecutionContext ): Pair, Tensor> { + // The eager buffer path below copies tensor *data* into a heap array + // (writeAt) and reads it back via ctx.fromData (sliceView), bypassing + // ctx.ops. Under tracing the incoming K/V carry no data, so that path + // would bake an all-zero buffer as a constant and disconnect the + // computed k_proj/v_proj from attention (the exported decoder then + // attends over K=V=0). When recording, wire K/V functionally instead — + // the same ops.concat history AppendKVCache uses — so the StableHLO + // export carries the real projections. + if (ctx.isRecording) { + val ops = ctx.ops + val seqDim = newKey.rank - 2 + val prevK = tracedKeys + val prevV = tracedValues + val fullK = if (prevK != null) ops.concat(listOf(prevK, newKey), dim = seqDim) else newKey + val fullV = if (prevV != null) ops.concat(listOf(prevV, newValue), dim = seqDim) else newValue + tracedKeys = fullK + tracedValues = fullV + pos += newKey.shape[seqDim] + return fullK to fullV + } val newLen = newKey.shape[newKey.rank - 2] writeAt(pos, newKey, newValue) pos += newLen @@ -321,6 +345,8 @@ public class PositionalKVCache( override fun reset() { pos = 0 + tracedKeys = null + tracedValues = null } override val position: Int get() = pos