Skip to content

Commit 3179e9e

Browse files
Merge pull request #196 from SKaiNET-developers/perf/fused-decode-attention
perf(mha)+fix(rope): fused decode-attention & traceable interleaved RoPE
2 parents e4a0799 + 019b049 commit 3179e9e

2 files changed

Lines changed: 157 additions & 0 deletions

File tree

transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer/MultiHeadAttention.kt

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,24 @@ public class MultiHeadAttention<T : DType, V>(
333333
mhaDumpStat("[blk.0.mha cached-V (full) ]", fullV)
334334
}
335335

336+
// Fused decode-attention fast path — the hot autoregressive case.
337+
// When seqQ == 1 (one token per forward), self-attention, and no
338+
// sliding-window mask, compute scores → softmax → (GQA) weighted-V
339+
// directly from the cached K/V buffers in a single buffer-direct pass,
340+
// emitting the merged [1, qDim] output. This skips repeatKVHeads' concat
341+
// (built every token/layer), the unsqueeze → SDPA → squeeze → permute
342+
// chain, and every intermediate tensor those allocate — which the
343+
// jstack profile (docs/upstream/A2-PROFILE.md) showed dominate decode.
344+
// Numerically identical to the general path below for seqLen 1 (same
345+
// max-stable softmax, same GQA head mapping head h → kv head h/nRep).
346+
if (qSeqLen == 1 && !isCrossAttention && slidingWindow == null) {
347+
val merged = fusedDecodeAttention(q, fullK, fullV, scale, ctx)
348+
var output = linearProject(ops, merged, wO)
349+
if (bias) output = ops.add(output, params[oWIdx + 1].value)
350+
if (mhaDump) mhaDumpStat("[blk.0.mha post-fused-decode ]", output)
351+
return output
352+
}
353+
336354
// Expand KV heads for GQA if needed
337355
val expandedK = if (nKVHeads < nHeads) repeatKVHeads(fullK, nHeads / nKVHeads, ops) else fullK
338356
val expandedV = if (nKVHeads < nHeads) repeatKVHeads(fullV, nHeads / nKVHeads, ops) else fullV
@@ -391,6 +409,69 @@ public class MultiHeadAttention<T : DType, V>(
391409
return output
392410
}
393411

412+
/**
413+
* Fused single-token (decode) attention. [q] is `[nHeads, 1, headDim]`
414+
* (heads-first, post-RoPE); [fullK]/[fullV] are `[nKVHeads, seqKV, headDim]`
415+
* (post-cache, post-V-norm). Returns the merged `[1, qDim]` context where
416+
* row 0 is the concatenation of each head's output — exactly what the
417+
* general SDPA + squeeze + swapSeqHeadDims + reshape chain produces for
418+
* seqLen 1, but with zero intermediate tensors. GQA query head `h` reads KV
419+
* head `h / (nHeads / nKVHeads)`, matching [repeatKVHeads].
420+
*/
421+
private fun fusedDecodeAttention(
422+
q: Tensor<T, V>,
423+
fullK: Tensor<T, V>,
424+
fullV: Tensor<T, V>,
425+
scale: Float,
426+
ctx: ExecutionContext,
427+
): Tensor<T, V> {
428+
val qBuf = q.data.copyToFloatArray() // [nHeads * headDim]
429+
val kBuf = fullK.data.copyToFloatArray() // [nKVHeads * seqKV * headDim]
430+
val vBuf = fullV.data.copyToFloatArray() // [nKVHeads * seqKV * headDim]
431+
val seqKV = fullK.shape[1]
432+
val nRep = nHeads / nKVHeads
433+
val out = FloatArray(nHeads * headDim) // == qDim, row-major [h, d]
434+
val scores = FloatArray(seqKV)
435+
for (h in 0 until nHeads) {
436+
val g = h / nRep // GQA: which KV head this query head reads
437+
val qOff = h * headDim
438+
val kvHeadBase = g * seqKV * headDim
439+
// scores[ki] = (q_h · k_{g,ki}) * scale, tracking the max for a stable softmax
440+
var maxV = Float.NEGATIVE_INFINITY
441+
for (ki in 0 until seqKV) {
442+
val kOff = kvHeadBase + ki * headDim
443+
var dot = 0f
444+
for (d in 0 until headDim) dot += qBuf[qOff + d] * kBuf[kOff + d]
445+
val s = dot * scale
446+
scores[ki] = s
447+
if (s > maxV) maxV = s
448+
}
449+
// softmax over keys
450+
var sum = 0f
451+
for (ki in 0 until seqKV) {
452+
val e = kotlin.math.exp(scores[ki] - maxV)
453+
scores[ki] = e
454+
sum += e
455+
}
456+
val inv = if (sum > 0f) 1f / sum else 0f
457+
// context_h = Σ_ki softmax_ki * v_{g,ki}
458+
val oOff = h * headDim
459+
for (d in 0 until headDim) {
460+
var acc = 0f
461+
for (ki in 0 until seqKV) {
462+
acc += scores[ki] * vBuf[kvHeadBase + ki * headDim + d]
463+
}
464+
out[oOff + d] = acc * inv
465+
}
466+
}
467+
@Suppress("UNCHECKED_CAST")
468+
return ctx.fromData(
469+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(Shape(1, qDim), out)
470+
as sk.ainet.lang.tensor.data.TensorData<T, V>,
471+
q.dtype,
472+
)
473+
}
474+
394475
/**
395476
* Build an additive mask tensor of shape `[1, 1, seqQ, seqKV]` where allowed
396477
* (query, key) cells are 0 and masked cells are a large negative value so

transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer/RoPE.kt

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import sk.ainet.context.ExecutionContext
44
import sk.ainet.lang.nn.Module
55
import sk.ainet.lang.tensor.Shape
66
import sk.ainet.lang.tensor.Tensor
7+
import sk.ainet.lang.tensor.ops.KspTensorOps
78
import sk.ainet.lang.types.DType
89
import kotlin.math.cos
910
import kotlin.math.pow
@@ -259,6 +260,16 @@ public class RoPE<T : DType, V>(
259260
* `headDim - rotaryDim` floats of every head are left untouched.
260261
*/
261262
private fun applyRoPEInterleaved(input: Tensor<T, V>, position: Int, ctx: ExecutionContext): Tensor<T, V> {
263+
// Graph tracing: the raw-array path below reads input.data and rebuilds via
264+
// fromFloatArray, which records the rotated Q/K as a DISCONNECTED CONSTANT —
265+
// severing the link to the projection weights. Post-GQA-broadcast that lowers
266+
// to a slice-into-empty const cascade that crashes iree-compile. Under the
267+
// tracing wrapper (KspTensorOps), take the traceable op-based path so the
268+
// rotation is recorded as tensor ops. Full-rotary only (TinyLlama/Llama/
269+
// Mistral); partial rotary keeps the raw path (no GGUF model needs it traced).
270+
if (rotaryDim == headDim && input.ops is KspTensorOps) {
271+
return applyRoPEInterleavedOps(input, position, ctx)
272+
}
262273
val data = input.data.copyToFloatArray()
263274
val lastDim = input.shape[input.rank - 1]
264275
require(lastDim == headDim) { "RoPE input last dim ($lastDim) != headDim ($headDim)" }
@@ -287,4 +298,69 @@ public class RoPE<T : DType, V>(
287298

288299
return ctx.fromFloatArray(input.shape, input.dtype, data)
289300
}
301+
302+
/**
303+
* Traceable interleaved RoPE: pure tensor ops, numerically identical to
304+
* [applyRoPEInterleaved] but recordable to a compute graph. Used under
305+
* void/graph tracing where the raw-array path bakes a disconnected constant.
306+
*
307+
* Interleaved pairing `(x[2i], x[2i+1])` is realized by reshaping the head
308+
* dim `[headDim] -> [halfRotary, 2]` (row-major: `[i,0]=x[2i]`, `[i,1]=x[2i+1]`),
309+
* rotating the even/odd planes, then reshaping back. Full-rotary only
310+
* (`rotaryDim == headDim`); the caller gates on that.
311+
*/
312+
private fun applyRoPEInterleavedOps(input: Tensor<T, V>, position: Int, ctx: ExecutionContext): Tensor<T, V> {
313+
val ops = ctx.ops
314+
val rank = input.rank
315+
val lastDim = input.shape[rank - 1]
316+
require(lastDim == headDim) { "RoPE input last dim ($lastDim) != headDim ($headDim)" }
317+
val seqLen = input.shape[rank - 2]
318+
319+
// cos/sin tables [seqLen, halfRotary] for the requested positions — same
320+
// tables as the raw path, so the rotation is bit-for-bit equivalent.
321+
val cosData = FloatArray(seqLen * halfRotary)
322+
val sinData = FloatArray(seqLen * halfRotary)
323+
for (s in 0 until seqLen) {
324+
val pos = position + s
325+
for (i in 0 until halfRotary) {
326+
cosData[s * halfRotary + i] = cosTable[pos * halfRotary + i]
327+
sinData[s * halfRotary + i] = sinTable[pos * halfRotary + i]
328+
}
329+
}
330+
val tableShape = Shape(seqLen, halfRotary)
331+
@Suppress("UNCHECKED_CAST")
332+
val cosTensor: Tensor<T, V> = ctx.fromData(
333+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(tableShape, cosData) as sk.ainet.lang.tensor.data.TensorData<T, V>,
334+
input.dtype,
335+
)
336+
@Suppress("UNCHECKED_CAST")
337+
val sinTensor: Tensor<T, V> = ctx.fromData(
338+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(tableShape, sinData) as sk.ainet.lang.tensor.data.TensorData<T, V>,
339+
input.dtype,
340+
)
341+
342+
// [..., seqLen, headDim] -> [..., seqLen, halfRotary, 2] so interleaved pairs
343+
// land on the trailing size-2 axis.
344+
val leading = IntArray(rank - 1) { input.shape[it] }
345+
val pairedShape = Shape(*leading, halfRotary, 2)
346+
val paired = ops.reshape(input, pairedShape)
347+
348+
// even = pairs[..., 0], odd = pairs[..., 1] (narrow the size-2 axis, drop it).
349+
val pairAxis = rank // trailing axis index of pairedShape
350+
val planeShape = Shape(*leading, halfRotary)
351+
val even = ops.reshape(ops.narrow(paired, pairAxis, 0, 1), planeShape) // [..., seqLen, halfRotary]
352+
val odd = ops.reshape(ops.narrow(paired, pairAxis, 1, 1), planeShape)
353+
354+
// (even, odd) -> (even*cos - odd*sin, even*sin + odd*cos); cos/sin [seqLen, halfRotary]
355+
// broadcast over the leading (head/batch) dims.
356+
val rotEven = ops.subtract(ops.multiply(even, cosTensor), ops.multiply(odd, sinTensor))
357+
val rotOdd = ops.add(ops.multiply(even, sinTensor), ops.multiply(odd, cosTensor))
358+
359+
// Re-interleave: stack on a new trailing axis -> [..., halfRotary, 2] -> [..., headDim].
360+
val recombined = ops.concat(
361+
listOf(ops.unsqueeze(rotEven, rank), ops.unsqueeze(rotOdd, rank)),
362+
dim = rank,
363+
)
364+
return ops.reshape(recombined, input.shape)
365+
}
290366
}

0 commit comments

Comments
 (0)