|
| 1 | +package sk.ainet.models.gemma |
| 2 | + |
| 3 | +import sk.ainet.context.ExecutionContext |
| 4 | +import sk.ainet.lang.graph.DefaultExecutionTape |
| 5 | +import sk.ainet.lang.graph.DefaultGraphExecutionContext |
| 6 | +import sk.ainet.lang.tensor.Shape |
| 7 | +import sk.ainet.lang.tensor.VoidOpsTensor |
| 8 | +import sk.ainet.lang.tensor.data.TensorData |
| 9 | +import sk.ainet.lang.tensor.ops.VoidTensorOps |
| 10 | +import sk.ainet.lang.types.FP32 |
| 11 | +import sk.ainet.tape.Execution |
| 12 | +import kotlin.test.Test |
| 13 | +import kotlin.test.assertEquals |
| 14 | +import kotlin.test.assertTrue |
| 15 | + |
| 16 | +/** |
| 17 | + * Regression test for SKaiNET#763: [PositionalKVCache.update] must wire K/V functionally while |
| 18 | + * tracing (ctx.isRecording) instead of round-tripping through its heap buffer, which under tracing |
| 19 | + * baked an all-zero KV-cache as `stablehlo.constant` and disconnected the computed k_proj/v_proj — |
| 20 | + * so the exported decoder attended over K=V=0. (Masked by the unnormalized FFN in non-sandwich |
| 21 | + * configs; exposed by Gemma's post_ffw_norm.) |
| 22 | + */ |
| 23 | +class KvCacheTraceFidelityTest { |
| 24 | + @Test |
| 25 | + fun tracedDecoderKeepsComputedKV() { |
| 26 | + val meta = Gemma4ModelMetadata( |
| 27 | + architecture = "gemma3", embeddingLength = 64, contextLength = 128, blockCount = 2, |
| 28 | + headCount = 2, kvHeadCount = 1, intermediateSize = 128, headDim = 32, globalHeadDim = 32, |
| 29 | + vocabSize = 48, slidingWindow = 64, kvSharedLayers = 0, layerTypes = List(2) { "full_attention" }, |
| 30 | + ropeParametersFull = Gemma4RopeConfig(base = 10000.0f), |
| 31 | + ropeParametersSliding = Gemma4RopeConfig(base = 10000.0f), maxPositionEmbeddings = 128, |
| 32 | + ) |
| 33 | + val model = gemmaNetwork<FP32, Float>(meta, FP32::class, maxInferenceLen = 4, sandwichNorms = true) |
| 34 | + val input = VoidOpsTensor(object : TensorData<FP32, Float> { |
| 35 | + override val shape = Shape(4); override fun get(vararg i: Int) = 0.0f; override fun set(vararg i: Int, value: Float) {} |
| 36 | + }, FP32::class) |
| 37 | + val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps()) |
| 38 | + val tape = ctx.record { |
| 39 | + val ct = (this as DefaultGraphExecutionContext).currentTape ?: error("no tape") |
| 40 | + Execution.tapeStack.pushTape(ct) |
| 41 | + try { model.forward(input, this as ExecutionContext) } finally { Execution.tapeStack.popTape() } |
| 42 | + }.first |
| 43 | + val graph = (tape as DefaultExecutionTape).toComputeGraph(synthesizeExternalInputs = true, embedConstants = true) |
| 44 | + val mlir = sk.ainet.compile.hlo.toStableHlo(graph, "gemma").content |
| 45 | + |
| 46 | + // The KV head shape here is [1, 4, 32] (nKVHeads=1, seq=4, headDim=32). Before the fix the |
| 47 | + // K and V for each block appeared as `stablehlo.constant dense<0.0> : tensor<1x4x32xf32>`. |
| 48 | + val kvZeroConstants = Regex("stablehlo\\.constant dense<0\\.0> : tensor<1x4x32xf32>").findAll(mlir).count() |
| 49 | + assertEquals(0, kvZeroConstants, |
| 50 | + "Traced sandwich Gemma decoder baked $kvZeroConstants zero KV-cache constants — K/V not wired (regression of #763)") |
| 51 | + // Sanity: K/V projections are present as real dot_generals (q,k,v,o + ffn etc.). |
| 52 | + assertTrue(Regex("dot_general").findAll(mlir).count() >= 6, "expected real projection dot_generals in the export") |
| 53 | + } |
| 54 | +} |
0 commit comments