Skip to content

Commit 019b049

Browse files
michalharakalclaude
andcommitted
fix(rope): traceable interleaved RoPE for graph export (unblocks TinyLlama→IREE)
Interleaved RoPE's raw-array path (copyToFloatArray → scalar rotate → fromFloatArray) records the rotated Q/K as a disconnected constant under graph tracing, severing the link to the projection weights. Post-GQA head-broadcast that lowers to an insert_slice-into-tensor.empty() constant cascade that segfaults iree-compile (iree-dispatch-creation-convert-tensor-to-flow, null ElementsAttr::getType in greedy fold; seqLen>=2 only). Add applyRoPEInterleavedOps: a pure-tensor-op interleaved rotation (reshape [headDim]->[halfRotary,2], narrow even/odd, rotate with cos/sin tables, re-interleave), numerically identical to the raw path. Gated on input.ops is KspTensorOps so it runs only under tracing; eager keeps the raw fast path byte-identical (no perf/correctness change). Full-rotary only. Verified via the skainet-tinyllama-iree composite build: real TinyLlama exports + compiles to aarch64 .vmfb at seq=2 and seq=8; eager-jvm still coherent (matches llama.cpp); LlamaDslPipelineTest green. Perf-Tag: perf/b1-rope-traceable Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 3791f88 commit 019b049

1 file changed

Lines changed: 76 additions & 0 deletions

File tree

  • transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer

transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer/RoPE.kt

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import sk.ainet.context.ExecutionContext
44
import sk.ainet.lang.nn.Module
55
import sk.ainet.lang.tensor.Shape
66
import sk.ainet.lang.tensor.Tensor
7+
import sk.ainet.lang.tensor.ops.KspTensorOps
78
import sk.ainet.lang.types.DType
89
import kotlin.math.cos
910
import kotlin.math.pow
@@ -259,6 +260,16 @@ public class RoPE<T : DType, V>(
259260
* `headDim - rotaryDim` floats of every head are left untouched.
260261
*/
261262
private fun applyRoPEInterleaved(input: Tensor<T, V>, position: Int, ctx: ExecutionContext): Tensor<T, V> {
263+
// Graph tracing: the raw-array path below reads input.data and rebuilds via
264+
// fromFloatArray, which records the rotated Q/K as a DISCONNECTED CONSTANT —
265+
// severing the link to the projection weights. Post-GQA-broadcast that lowers
266+
// to a slice-into-empty const cascade that crashes iree-compile. Under the
267+
// tracing wrapper (KspTensorOps), take the traceable op-based path so the
268+
// rotation is recorded as tensor ops. Full-rotary only (TinyLlama/Llama/
269+
// Mistral); partial rotary keeps the raw path (no GGUF model needs it traced).
270+
if (rotaryDim == headDim && input.ops is KspTensorOps) {
271+
return applyRoPEInterleavedOps(input, position, ctx)
272+
}
262273
val data = input.data.copyToFloatArray()
263274
val lastDim = input.shape[input.rank - 1]
264275
require(lastDim == headDim) { "RoPE input last dim ($lastDim) != headDim ($headDim)" }
@@ -287,4 +298,69 @@ public class RoPE<T : DType, V>(
287298

288299
return ctx.fromFloatArray(input.shape, input.dtype, data)
289300
}
301+
302+
/**
303+
* Traceable interleaved RoPE: pure tensor ops, numerically identical to
304+
* [applyRoPEInterleaved] but recordable to a compute graph. Used under
305+
* void/graph tracing where the raw-array path bakes a disconnected constant.
306+
*
307+
* Interleaved pairing `(x[2i], x[2i+1])` is realized by reshaping the head
308+
* dim `[headDim] -> [halfRotary, 2]` (row-major: `[i,0]=x[2i]`, `[i,1]=x[2i+1]`),
309+
* rotating the even/odd planes, then reshaping back. Full-rotary only
310+
* (`rotaryDim == headDim`); the caller gates on that.
311+
*/
312+
private fun applyRoPEInterleavedOps(input: Tensor<T, V>, position: Int, ctx: ExecutionContext): Tensor<T, V> {
313+
val ops = ctx.ops
314+
val rank = input.rank
315+
val lastDim = input.shape[rank - 1]
316+
require(lastDim == headDim) { "RoPE input last dim ($lastDim) != headDim ($headDim)" }
317+
val seqLen = input.shape[rank - 2]
318+
319+
// cos/sin tables [seqLen, halfRotary] for the requested positions — same
320+
// tables as the raw path, so the rotation is bit-for-bit equivalent.
321+
val cosData = FloatArray(seqLen * halfRotary)
322+
val sinData = FloatArray(seqLen * halfRotary)
323+
for (s in 0 until seqLen) {
324+
val pos = position + s
325+
for (i in 0 until halfRotary) {
326+
cosData[s * halfRotary + i] = cosTable[pos * halfRotary + i]
327+
sinData[s * halfRotary + i] = sinTable[pos * halfRotary + i]
328+
}
329+
}
330+
val tableShape = Shape(seqLen, halfRotary)
331+
@Suppress("UNCHECKED_CAST")
332+
val cosTensor: Tensor<T, V> = ctx.fromData(
333+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(tableShape, cosData) as sk.ainet.lang.tensor.data.TensorData<T, V>,
334+
input.dtype,
335+
)
336+
@Suppress("UNCHECKED_CAST")
337+
val sinTensor: Tensor<T, V> = ctx.fromData(
338+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(tableShape, sinData) as sk.ainet.lang.tensor.data.TensorData<T, V>,
339+
input.dtype,
340+
)
341+
342+
// [..., seqLen, headDim] -> [..., seqLen, halfRotary, 2] so interleaved pairs
343+
// land on the trailing size-2 axis.
344+
val leading = IntArray(rank - 1) { input.shape[it] }
345+
val pairedShape = Shape(*leading, halfRotary, 2)
346+
val paired = ops.reshape(input, pairedShape)
347+
348+
// even = pairs[..., 0], odd = pairs[..., 1] (narrow the size-2 axis, drop it).
349+
val pairAxis = rank // trailing axis index of pairedShape
350+
val planeShape = Shape(*leading, halfRotary)
351+
val even = ops.reshape(ops.narrow(paired, pairAxis, 0, 1), planeShape) // [..., seqLen, halfRotary]
352+
val odd = ops.reshape(ops.narrow(paired, pairAxis, 1, 1), planeShape)
353+
354+
// (even, odd) -> (even*cos - odd*sin, even*sin + odd*cos); cos/sin [seqLen, halfRotary]
355+
// broadcast over the leading (head/batch) dims.
356+
val rotEven = ops.subtract(ops.multiply(even, cosTensor), ops.multiply(odd, sinTensor))
357+
val rotOdd = ops.add(ops.multiply(even, sinTensor), ops.multiply(odd, cosTensor))
358+
359+
// Re-interleave: stack on a new trailing axis -> [..., halfRotary, 2] -> [..., headDim].
360+
val recombined = ops.concat(
361+
listOf(ops.unsqueeze(rotEven, rank), ops.unsqueeze(rotOdd, rank)),
362+
dim = rank,
363+
)
364+
return ops.reshape(recombined, input.shape)
365+
}
290366
}

0 commit comments

Comments
 (0)