Skip to content

Commit f6e02a5

Browse files
Merge pull request #81 from SKaiNET-developers/feature/ISSUE-80-mha-reshape-permute
fix(mha): materialise multi-head reshape — fixes forwardBatched divergence
2 parents 1e49140 + 2c611d6 commit f6e02a5

2 files changed

Lines changed: 275 additions & 6 deletions

File tree

llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer/MultiHeadAttention.kt

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,29 @@ public class MultiHeadAttention<T : DType, V>(
197197
v = ops.add(v, params[vWIdx + 1].value)
198198
}
199199

200-
// Reshape to multi-head: [seqLen, dim] → [nHeads, seqLen, headDim]
200+
// Reshape to multi-head and put heads first.
201+
//
202+
// Q/K/V projections produce [seqLen, qDim] where qDim = nHeads*headDim.
203+
// Row-major flat layout is [s, h, d] → s*qDim + h*headDim + d. SDPA
204+
// expects [batch, nHeads, seqLen, headDim] — i.e. heads-first layout
205+
// [h, s, d] → h*seqLen*headDim + s*headDim + d.
206+
//
207+
// For seqLen == 1 the two layouts coincide flat-byte-for-flat-byte,
208+
// so a naked `reshape(t, Shape(nHeads, seqLen, headDim))` was visibly
209+
// correct in the autoregressive (one-token-per-forward) path. For
210+
// seqLen > 1 it silently reorders the data: `t.get(h, s, d)` reads
211+
// `data[h*N*headDim + s*headDim + d]` from a buffer laid out as
212+
// `s*nHeads*headDim + h*headDim + d`, mixing the rows of head h with
213+
// values from other heads. That is the root cause of the batched-
214+
// prefill divergence (commit `bd3eb9c`).
215+
//
216+
// The correct transformation needs an explicit dim-0/dim-1 swap.
217+
// SKaiNET's `ops.transpose` only swaps the LAST two dims, so we
218+
// can't reuse it here; we materialise the permute via a copy.
201219
val seqLen = if (input.rank >= 2) input.shape[input.rank - 2] else 1
202-
q = ops.reshape(q, Shape(nHeads, seqLen, headDim))
203-
k = ops.reshape(k, Shape(nKVHeads, seqLen, headDim))
204-
var vReshaped = ops.reshape(v, Shape(nKVHeads, seqLen, headDim))
220+
q = swapSeqHeadDims(ops.reshape(q, Shape(seqLen, nHeads, headDim)), ctx)
221+
k = swapSeqHeadDims(ops.reshape(k, Shape(seqLen, nKVHeads, headDim)), ctx)
222+
var vReshaped = swapSeqHeadDims(ops.reshape(v, Shape(seqLen, nKVHeads, headDim)), ctx)
205223

206224
// Optional QK-Norm
207225
if (qNorm != null && kNorm != null) {
@@ -276,9 +294,18 @@ public class MultiHeadAttention<T : DType, V>(
276294
)
277295
if (mhaDump) mhaDumpStat("[blk.0.mha post-SDPA ]", attnOut)
278296

279-
// Remove batch dim and merge heads: [1, nHeads, seqLen, headDim] → [seqLen, qDim]
297+
// Remove batch dim and merge heads.
298+
//
299+
// SDPA returns [1, nHeads, seqLen, headDim]. We need [seqLen, qDim].
300+
// Symmetric inverse of the heads-first permute on the input side:
301+
// first squeeze the batch dim → [nHeads, seqLen, headDim], then
302+
// swap dims 0/1 → [seqLen, nHeads, headDim], finally reshape to
303+
// [seqLen, qDim] (contiguous: row s = concatenation of head 0..N-1
304+
// for that token). For seqLen == 1 the swap is identity, so this
305+
// matches the prior naked reshape for the autoregressive case.
280306
val squeezed = ops.squeeze(attnOut, 0)
281-
val merged = ops.reshape(squeezed, Shape(seqLen, qDim))
307+
val swappedBack = swapSeqHeadDims(squeezed, ctx)
308+
val merged = ops.reshape(swappedBack, Shape(seqLen, qDim))
282309

283310
// Output projection: merged @ wO^T (+ bias if enabled)
284311
var output = linearProject(ops, merged, wO)
@@ -333,6 +360,38 @@ public class MultiHeadAttention<T : DType, V>(
333360
)
334361
}
335362

363+
/**
364+
* Swap dims 0 and 1 of a rank-3 tensor: `[D0, D1, D2]` → `[D1, D0, D2]`.
365+
*
366+
* SKaiNET's [TensorOps.transpose] only swaps the last two dims, so this
367+
* transformation is materialised via a copy. For `D0 == 1` or `D1 == 1`
368+
* the result has the same flat layout as the input, but we still pay
369+
* the copy cost; callers that know seqLen == 1 can short-circuit.
370+
*/
371+
private fun swapSeqHeadDims(t: Tensor<T, V>, ctx: ExecutionContext): Tensor<T, V> {
372+
require(t.rank == 3) { "swapSeqHeadDims: expected rank-3 tensor, got rank ${t.rank}" }
373+
val d0 = t.shape[0]
374+
val d1 = t.shape[1]
375+
val d2 = t.shape[2]
376+
if (d0 == 1 || d1 == 1) {
377+
// Layouts coincide; just reinterpret the shape.
378+
return ctx.ops.reshape(t, Shape(d1, d0, d2))
379+
}
380+
val src = t.data.copyToFloatArray()
381+
val out = FloatArray(d1 * d0 * d2)
382+
for (i in 0 until d0) {
383+
for (j in 0 until d1) {
384+
val srcOff = (i * d1 + j) * d2
385+
val dstOff = (j * d0 + i) * d2
386+
src.copyInto(out, dstOff, srcOff, srcOff + d2)
387+
}
388+
}
389+
@Suppress("UNCHECKED_CAST")
390+
val data = sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(Shape(d1, d0, d2), out)
391+
as sk.ainet.lang.tensor.data.TensorData<T, V>
392+
return ctx.fromData(data, t.dtype)
393+
}
394+
336395
private fun repeatKVHeads(t: Tensor<T, V>, repeats: Int, ops: sk.ainet.lang.tensor.ops.TensorOps): Tensor<T, V> {
337396
if (repeats == 1) return t
338397
// Repeat each KV head individually so head mapping matches GQA:
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
package sk.ainet.apps.kllama
2+
3+
import java.nio.file.Path
4+
import kotlin.io.path.exists
5+
import kotlin.test.Test
6+
import kotlin.test.assertEquals
7+
import kotlin.test.assertTrue
8+
import kotlinx.coroutines.runBlocking
9+
import sk.ainet.apps.llm.OptimizedLLMMode
10+
import sk.ainet.apps.llm.OptimizedLLMRuntime
11+
import sk.ainet.context.DirectCpuExecutionContext
12+
import sk.ainet.io.JvmRandomAccessSource
13+
import sk.ainet.io.model.QuantPolicy
14+
import sk.ainet.lang.tensor.Tensor
15+
import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData
16+
import sk.ainet.lang.tensor.data.MemorySegmentTensorData
17+
import sk.ainet.lang.types.FP32
18+
import sk.ainet.models.llama.LlamaNetworkLoader
19+
20+
/**
21+
* Verifies that `forwardBatched(IntArray)` produces the same last-position
22+
* logits as the equivalent autoregressive `forward(t)` per token. This is
23+
* the regression test the `bd3eb9c` revert was missing — without it,
24+
* batched prefill quietly diverged from the autoregressive baseline.
25+
*
26+
* Uses TinyLlama 1.1B Q8_0 (DEQUANTIZE_TO_FP32 policy → pure FP32 forward
27+
* pass). This sidesteps the Gemma 4 forward-pass correctness issues
28+
* tracked separately on develop, so this test is a clean check on the
29+
* batched-vs-autoregressive plumbing only.
30+
*
31+
* Skipped if the model is not present.
32+
*/
33+
class BatchedPrefillEquivalenceTest {
34+
35+
companion object {
36+
private val MODEL_PATH = Path.of(
37+
System.getProperty("user.home"),
38+
".lmstudio/models/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
39+
"tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
40+
)
41+
}
42+
43+
@Test
44+
fun `forwardBatched matches autoregressive at N=1`() {
45+
runEquivalence(intArrayOf(450)) // first prompt token only — should be trivial
46+
}
47+
48+
@Test
49+
fun `forwardBatched matches autoregressive at N=2`() {
50+
runEquivalence(intArrayOf(450, 7483))
51+
}
52+
53+
@Test
54+
fun `forwardBatched matches autoregressive prefill at last position`() {
55+
if (!MODEL_PATH.exists()) {
56+
println("[skip] Model not at $MODEL_PATH")
57+
return
58+
}
59+
runBlocking {
60+
// Fixed prompt — encode once, replay through both paths.
61+
// Tokenizer is loaded but the integer prompt is what we feed.
62+
val ctx = DirectCpuExecutionContext()
63+
val tokenizer = JvmRandomAccessSource.open(MODEL_PATH.toString()).use { source ->
64+
GGUFTokenizer.fromRandomAccessSource(source)
65+
}
66+
val prompt = "The capital of France is"
67+
val promptTokens = tokenizer.encode(prompt)
68+
require(promptTokens.size >= 2) { "Need ≥2 tokens to exercise the loop" }
69+
println("[diag] prompt tokens: ${promptTokens.toList()}")
70+
71+
// --- Autoregressive baseline ---
72+
val autoLogits = run {
73+
val model = LlamaNetworkLoader.fromGguf(
74+
randomAccessProvider = { JvmRandomAccessSource.open(MODEL_PATH.toString()) },
75+
quantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32
76+
).load<FP32, Float>(ctx)
77+
val runtime = OptimizedLLMRuntime(
78+
model = model,
79+
ctx = ctx,
80+
mode = OptimizedLLMMode.DIRECT,
81+
dtype = FP32::class
82+
)
83+
var l: Tensor<FP32, Float> = runtime.forward(promptTokens[0])
84+
for (i in 1 until promptTokens.size) {
85+
l = runtime.forward(promptTokens[i])
86+
}
87+
extractLogits(l)
88+
}
89+
90+
// --- Batched ---
91+
val batchLogits = run {
92+
val model = LlamaNetworkLoader.fromGguf(
93+
randomAccessProvider = { JvmRandomAccessSource.open(MODEL_PATH.toString()) },
94+
quantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32
95+
).load<FP32, Float>(ctx)
96+
val runtime = OptimizedLLMRuntime(
97+
model = model,
98+
ctx = ctx,
99+
mode = OptimizedLLMMode.DIRECT,
100+
dtype = FP32::class
101+
)
102+
extractLogits(runtime.forwardBatched(promptTokens))
103+
}
104+
105+
// --- Compare ---
106+
assertEquals(autoLogits.size, batchLogits.size,
107+
"logit vector length mismatch")
108+
val tol = 1e-3f
109+
var maxAbsDiff = 0f
110+
var maxRelDiff = 0f
111+
var argmaxAuto = 0
112+
var argmaxBatch = 0
113+
for (i in autoLogits.indices) {
114+
val a = autoLogits[i]
115+
val b = batchLogits[i]
116+
val d = kotlin.math.abs(a - b)
117+
if (d > maxAbsDiff) maxAbsDiff = d
118+
val r = if (kotlin.math.abs(a) > 1e-6f) d / kotlin.math.abs(a) else 0f
119+
if (r > maxRelDiff) maxRelDiff = r
120+
if (a > autoLogits[argmaxAuto]) argmaxAuto = i
121+
if (b > batchLogits[argmaxBatch]) argmaxBatch = i
122+
}
123+
println("[diag] max_abs_diff=$maxAbsDiff max_rel_diff=$maxRelDiff " +
124+
"argmax_auto=$argmaxAuto argmax_batch=$argmaxBatch " +
125+
"auto[argmax]=${autoLogits[argmaxAuto]} " +
126+
"batch[argmax]=${batchLogits[argmaxBatch]}")
127+
assertEquals(argmaxAuto, argmaxBatch,
128+
"argmax token differs: auto=$argmaxAuto batch=$argmaxBatch")
129+
assertTrue(maxAbsDiff < tol,
130+
"max_abs_diff=$maxAbsDiff exceeds tolerance $tol; " +
131+
"batched prefill diverges from autoregressive")
132+
}
133+
}
134+
135+
private fun runEquivalence(promptTokens: IntArray) {
136+
if (!MODEL_PATH.exists()) {
137+
println("[skip] Model not at $MODEL_PATH")
138+
return
139+
}
140+
runBlocking {
141+
val ctx = DirectCpuExecutionContext()
142+
println("[diag] N=${promptTokens.size} prompt tokens: ${promptTokens.toList()}")
143+
144+
val autoLogits = run {
145+
val model = LlamaNetworkLoader.fromGguf(
146+
randomAccessProvider = { JvmRandomAccessSource.open(MODEL_PATH.toString()) },
147+
quantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32
148+
).load<FP32, Float>(ctx)
149+
val runtime = OptimizedLLMRuntime(
150+
model = model, ctx = ctx,
151+
mode = OptimizedLLMMode.DIRECT, dtype = FP32::class
152+
)
153+
var l: Tensor<FP32, Float> = runtime.forward(promptTokens[0])
154+
for (i in 1 until promptTokens.size) l = runtime.forward(promptTokens[i])
155+
extractLogits(l)
156+
}
157+
val batchLogits = run {
158+
val model = LlamaNetworkLoader.fromGguf(
159+
randomAccessProvider = { JvmRandomAccessSource.open(MODEL_PATH.toString()) },
160+
quantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32
161+
).load<FP32, Float>(ctx)
162+
val runtime = OptimizedLLMRuntime(
163+
model = model, ctx = ctx,
164+
mode = OptimizedLLMMode.DIRECT, dtype = FP32::class
165+
)
166+
extractLogits(runtime.forwardBatched(promptTokens))
167+
}
168+
assertEquals(autoLogits.size, batchLogits.size)
169+
var maxAbsDiff = 0f
170+
var argmaxAuto = 0
171+
var argmaxBatch = 0
172+
for (i in autoLogits.indices) {
173+
val d = kotlin.math.abs(autoLogits[i] - batchLogits[i])
174+
if (d > maxAbsDiff) maxAbsDiff = d
175+
if (autoLogits[i] > autoLogits[argmaxAuto]) argmaxAuto = i
176+
if (batchLogits[i] > batchLogits[argmaxBatch]) argmaxBatch = i
177+
}
178+
println("[diag] N=${promptTokens.size} max_abs_diff=$maxAbsDiff " +
179+
"argmax_auto=$argmaxAuto argmax_batch=$argmaxBatch " +
180+
"auto_top=${autoLogits[argmaxAuto]} batch_top=${batchLogits[argmaxBatch]}")
181+
assertEquals(argmaxAuto, argmaxBatch,
182+
"argmax differs at N=${promptTokens.size}")
183+
assertTrue(maxAbsDiff < 1e-3f,
184+
"max_abs_diff=$maxAbsDiff exceeds 1e-3 at N=${promptTokens.size}")
185+
}
186+
}
187+
188+
private fun extractLogits(t: Tensor<FP32, Float>): FloatArray {
189+
val data = t.data
190+
return when (data) {
191+
is DenseFloatArrayTensorData<*> -> {
192+
val n = t.shape.volume
193+
if (data.buffer.size == n) data.buffer.copyOf()
194+
else data.buffer.copyOf(n)
195+
}
196+
is MemorySegmentTensorData<*> -> {
197+
val n = t.shape.volume
198+
val out = FloatArray(n)
199+
java.lang.foreign.MemorySegment.copy(
200+
data.segment,
201+
java.lang.foreign.ValueLayout.JAVA_FLOAT,
202+
data.segmentByteOffset,
203+
out, 0, n
204+
)
205+
out
206+
}
207+
else -> error("Unsupported tensor data type: ${data::class}")
208+
}
209+
}
210+
}

0 commit comments

Comments
 (0)