@@ -4,6 +4,7 @@ import sk.ainet.context.ExecutionContext
44import sk.ainet.lang.nn.Module
55import sk.ainet.lang.tensor.Shape
66import sk.ainet.lang.tensor.Tensor
7+ import sk.ainet.lang.tensor.ops.KspTensorOps
78import sk.ainet.lang.types.DType
89import kotlin.math.cos
910import kotlin.math.pow
@@ -259,6 +260,16 @@ public class RoPE<T : DType, V>(
259260 * `headDim - rotaryDim` floats of every head are left untouched.
260261 */
261262 private fun applyRoPEInterleaved (input : Tensor <T , V >, position : Int , ctx : ExecutionContext ): Tensor <T , V > {
263+ // Graph tracing: the raw-array path below reads input.data and rebuilds via
264+ // fromFloatArray, which records the rotated Q/K as a DISCONNECTED CONSTANT —
265+ // severing the link to the projection weights. Post-GQA-broadcast that lowers
266+ // to a slice-into-empty const cascade that crashes iree-compile. Under the
267+ // tracing wrapper (KspTensorOps), take the traceable op-based path so the
268+ // rotation is recorded as tensor ops. Full-rotary only (TinyLlama/Llama/
269+ // Mistral); partial rotary keeps the raw path (no GGUF model needs it traced).
270+ if (rotaryDim == headDim && input.ops is KspTensorOps ) {
271+ return applyRoPEInterleavedOps(input, position, ctx)
272+ }
262273 val data = input.data.copyToFloatArray()
263274 val lastDim = input.shape[input.rank - 1 ]
264275 require(lastDim == headDim) { " RoPE input last dim ($lastDim ) != headDim ($headDim )" }
@@ -287,4 +298,69 @@ public class RoPE<T : DType, V>(
287298
288299 return ctx.fromFloatArray(input.shape, input.dtype, data)
289300 }
301+
302+ /* *
303+ * Traceable interleaved RoPE: pure tensor ops, numerically identical to
304+ * [applyRoPEInterleaved] but recordable to a compute graph. Used under
305+ * void/graph tracing where the raw-array path bakes a disconnected constant.
306+ *
307+ * Interleaved pairing `(x[2i], x[2i+1])` is realized by reshaping the head
308+ * dim `[headDim] -> [halfRotary, 2]` (row-major: `[i,0]=x[2i]`, `[i,1]=x[2i+1]`),
309+ * rotating the even/odd planes, then reshaping back. Full-rotary only
310+ * (`rotaryDim == headDim`); the caller gates on that.
311+ */
312+ private fun applyRoPEInterleavedOps (input : Tensor <T , V >, position : Int , ctx : ExecutionContext ): Tensor <T , V > {
313+ val ops = ctx.ops
314+ val rank = input.rank
315+ val lastDim = input.shape[rank - 1 ]
316+ require(lastDim == headDim) { " RoPE input last dim ($lastDim ) != headDim ($headDim )" }
317+ val seqLen = input.shape[rank - 2 ]
318+
319+ // cos/sin tables [seqLen, halfRotary] for the requested positions — same
320+ // tables as the raw path, so the rotation is bit-for-bit equivalent.
321+ val cosData = FloatArray (seqLen * halfRotary)
322+ val sinData = FloatArray (seqLen * halfRotary)
323+ for (s in 0 until seqLen) {
324+ val pos = position + s
325+ for (i in 0 until halfRotary) {
326+ cosData[s * halfRotary + i] = cosTable[pos * halfRotary + i]
327+ sinData[s * halfRotary + i] = sinTable[pos * halfRotary + i]
328+ }
329+ }
330+ val tableShape = Shape (seqLen, halfRotary)
331+ @Suppress(" UNCHECKED_CAST" )
332+ val cosTensor: Tensor <T , V > = ctx.fromData(
333+ sk.ainet.lang.tensor.data.DenseFloatArrayTensorData <T >(tableShape, cosData) as sk.ainet.lang.tensor.data.TensorData <T , V >,
334+ input.dtype,
335+ )
336+ @Suppress(" UNCHECKED_CAST" )
337+ val sinTensor: Tensor <T , V > = ctx.fromData(
338+ sk.ainet.lang.tensor.data.DenseFloatArrayTensorData <T >(tableShape, sinData) as sk.ainet.lang.tensor.data.TensorData <T , V >,
339+ input.dtype,
340+ )
341+
342+ // [..., seqLen, headDim] -> [..., seqLen, halfRotary, 2] so interleaved pairs
343+ // land on the trailing size-2 axis.
344+ val leading = IntArray (rank - 1 ) { input.shape[it] }
345+ val pairedShape = Shape (* leading, halfRotary, 2 )
346+ val paired = ops.reshape(input, pairedShape)
347+
348+ // even = pairs[..., 0], odd = pairs[..., 1] (narrow the size-2 axis, drop it).
349+ val pairAxis = rank // trailing axis index of pairedShape
350+ val planeShape = Shape (* leading, halfRotary)
351+ val even = ops.reshape(ops.narrow(paired, pairAxis, 0 , 1 ), planeShape) // [..., seqLen, halfRotary]
352+ val odd = ops.reshape(ops.narrow(paired, pairAxis, 1 , 1 ), planeShape)
353+
354+ // (even, odd) -> (even*cos - odd*sin, even*sin + odd*cos); cos/sin [seqLen, halfRotary]
355+ // broadcast over the leading (head/batch) dims.
356+ val rotEven = ops.subtract(ops.multiply(even, cosTensor), ops.multiply(odd, sinTensor))
357+ val rotOdd = ops.add(ops.multiply(even, sinTensor), ops.multiply(odd, cosTensor))
358+
359+ // Re-interleave: stack on a new trailing axis -> [..., halfRotary, 2] -> [..., headDim].
360+ val recombined = ops.concat(
361+ listOf (ops.unsqueeze(rotEven, rank), ops.unsqueeze(rotOdd, rank)),
362+ dim = rank,
363+ )
364+ return ops.reshape(recombined, input.shape)
365+ }
290366}
0 commit comments