Skip to content

Commit 4c8283a

Browse files
Merge pull request #544 from SKaiNET-developers/feature/sdpa-tape-recording-and-hlo
Record and emit scaledDotProductAttention for IREE (#543)
2 parents d68fd02 + da82fda commit 4c8283a

5 files changed

Lines changed: 366 additions & 3 deletions

File tree

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# scaledDotProductAttention: not recorded by tape, no StableHLO converter
2+
3+
## Summary
4+
5+
`ctx.ops.scaledDotProductAttention()` exists in TensorOps interface,
6+
VoidTensorOps, and DefaultCpuOps — but it is not tape-recorded and has
7+
no StableHLO converter. This blocks multi-head attention in the
8+
SKaiNET → IREE compilation path.
9+
10+
## Impact
11+
12+
Without SDPA, Whisper's multi-head attention must be decomposed into
13+
individual ops (reshape, transpose, matmul, softmax, matmul). The
14+
per-batch K transpose requires raw FloatArray manipulation which
15+
creates zero constants in the VMFB (proven in TapeAttentionPermuteBugTest).
16+
17+
Result: GPU Whisper encoder produces wrong hidden states → decoder
18+
outputs "," instead of real transcription.
19+
20+
## Three fixes needed
21+
22+
### 1. RecordingExecution: record SDPA
23+
24+
**File:** `skainet-compile-core/.../tape/RecordingExecution.kt` line 436
25+
26+
Current (just delegates, no recording):
27+
```kotlin
28+
override fun <T : DType, V> scaledDotProductAttention(...) =
29+
base.scaledDotProductAttention(query, key, value, mask, scale, causal)
30+
```
31+
32+
Fix (same pattern as conv1d in PR #532):
33+
```kotlin
34+
override fun <T : DType, V> scaledDotProductAttention(
35+
query, key, value, mask, scale, causal
36+
): Tensor<T, V> {
37+
val out = base.scaledDotProductAttention(query, key, value, mask, scale, causal)
38+
val params = mapOf("scale" to scale, "causal" to causal)
39+
record(ScaledDotProductAttentionOperation(params),
40+
listOfNotNull(query, key, value, mask), listOf(out))
41+
return out
42+
}
43+
```
44+
45+
### 2. TensorOperations: add ScaledDotProductAttentionOperation
46+
47+
**File:** `skainet-lang-core/.../tensor/ops/TensorOperations.kt`
48+
49+
```kotlin
50+
class ScaledDotProductAttentionOperation<T : DType, V>(
51+
parameters: Map<String, Any> = emptyMap()
52+
) : BaseOperation<T, V>("scaledDotProductAttention", "nn", parameters) {
53+
override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
54+
// Output shape = query shape: [batch, nHeads, seqLen, headDim]
55+
return listOf(TensorSpec("sdpa_output", inputs[0].shape, inputs[0].dtype))
56+
}
57+
}
58+
```
59+
60+
### 3. StableHLO converter: decompose SDPA
61+
62+
**File:** `skainet-compile-hlo/.../converters/NeuralNetOperationsConverter.kt`
63+
64+
Register "scaledDotProductAttention" and decompose into:
65+
```mlir
66+
// scores = Q @ K.T (batched matmul with K transposed)
67+
%scores = stablehlo.dot_general %query, %key,
68+
batching_dims = [0, 1] x [0, 1],
69+
contracting_dims = [3] x [3]
70+
: (tensor<BxHxSxDxf32>, tensor<BxHxTxDxf32>) -> tensor<BxHxSxTxf32>
71+
72+
// scale
73+
%scaled = stablehlo.multiply %scores, %scale_splat
74+
75+
// optional mask (additive)
76+
%masked = stablehlo.add %scaled, %mask // if mask != null
77+
78+
// softmax over last dim
79+
%weights = stablehlo softmax ...
80+
81+
// output = weights @ V (batched matmul)
82+
%output = stablehlo.dot_general %weights, %value,
83+
batching_dims = [0, 1] x [0, 1],
84+
contracting_dims = [3] x [2]
85+
```
86+
87+
Note: `contracting_dims = [3] x [3]` for Q@K.T because we contract
88+
headDim of Q (last dim) with headDim of K (also last dim). This is
89+
different from standard matmul where you contract last of A with
90+
second-to-last of B — here K is NOT pre-transposed.
91+
92+
## Test
93+
94+
```kotlin
95+
val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps())
96+
val q = ctx.fromFloatArray(Shape(1, 6, 4, 64), ...) // [batch, heads, seq, headDim]
97+
val k = ctx.fromFloatArray(Shape(1, 6, 4, 64), ...)
98+
val v = ctx.fromFloatArray(Shape(1, 6, 4, 64), ...)
99+
100+
val (tape, out) = ctx.record {
101+
ctx.ops.scaledDotProductAttention(q, k, v)
102+
}
103+
104+
val graph = tape!!.toComputeGraph(synthesizeExternalInputs = true)
105+
val module = StableHloConverterFactory.createExtended().convert(graph, "test_sdpa")
106+
107+
// Should contain dot_general for Q@K.T and weights@V
108+
assertTrue(module.content.contains("stablehlo.dot_general"))
109+
assertFalse(module.content.contains("dense<0.0>")) // no zero constants
110+
```

skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,20 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor
433433
override fun <T : DType, V> indexSelect(input: Tensor<T, V>, indices: Tensor<DType, *>, dim: Int): Tensor<T, V> = base.indexSelect(input, indices, dim)
434434
override fun <T : DType, V> exp(tensor: Tensor<T, V>): Tensor<T, V> = base.exp(tensor)
435435
override fun <T : DType, V> expm1(tensor: Tensor<T, V>): Tensor<T, V> = base.expm1(tensor)
436-
override fun <T : DType, V> scaledDotProductAttention(query: Tensor<T, V>, key: Tensor<T, V>, value: Tensor<T, V>, mask: Tensor<T, V>?, scale: Float, causal: Boolean): Tensor<T, V> = base.scaledDotProductAttention(query, key, value, mask, scale, causal)
436+
override fun <T : DType, V> scaledDotProductAttention(
437+
query: Tensor<T, V>, key: Tensor<T, V>, value: Tensor<T, V>,
438+
mask: Tensor<T, V>?, scale: Float, causal: Boolean
439+
): Tensor<T, V> {
440+
val out = base.scaledDotProductAttention(query, key, value, mask, scale, causal)
441+
val params = mutableMapOf<String, Any>(
442+
"scale" to scale,
443+
"causal" to causal
444+
)
445+
@Suppress("UNCHECKED_CAST")
446+
val inputs = listOfNotNull(query, key, value, mask) as List<Tensor<T, V>>
447+
record(ScaledDotProductAttentionOperation(params), inputs, listOf(out))
448+
return out
449+
}
437450
}
438451

439452
private class ConcatRecordingOperation(

skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt

Lines changed: 149 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
2828
// Normalization operations
2929
"batchNorm", "batchNormalization", "BatchNormalization",
3030
"layerNorm", "layerNormalization", "LayerNormalization",
31-
"rmsNorm", "rms_norm", "RMSNorm", "RmsNorm"
31+
"rmsNorm", "rms_norm", "RMSNorm", "RmsNorm",
32+
// Attention
33+
"scaledDotProductAttention"
3234
)
3335

3436
override fun convert(
@@ -44,6 +46,7 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
4446
"batchnorm", "batchnormalization" -> convertBatchNorm(node, operands, context)
4547
"layernorm", "layernormalization" -> convertLayerNorm(node, operands, context)
4648
"rmsnorm", "rms_norm" -> convertRmsNorm(node, operands, context)
49+
"scaleddotproductattention" -> convertSdpa(node, operands, context)
4750
else -> ConversionResult.Unsupported(
4851
node.operation.name,
4952
"Operation not supported by NeuralNetOperationsConverter"
@@ -770,5 +773,149 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
770773
"epsilon = $epsilon, feature_index = $featureIndex : $outputType"
771774
}
772775
}
773-
776+
777+
/**
778+
* Convert scaledDotProductAttention to StableHLO.
779+
* Decomposes into: Q @ K.T (batched) → scale → optional mask → softmax → @ V (batched)
780+
*
781+
* Input shapes: Q[B,H,S,D], K[B,H,T,D], V[B,H,T,D], optional mask[B,H,S,T] or broadcastable
782+
* Output: [B,H,S,D]
783+
*/
784+
private fun convertSdpa(
785+
node: GraphNode,
786+
operands: List<String>,
787+
context: ConversionContext
788+
): ConversionResult {
789+
if (operands.size < 3) {
790+
return ConversionResult.Failure("SDPA requires at least 3 operands (Q, K, V), got ${operands.size}")
791+
}
792+
793+
val query = operands[0] // [B, H, S, D]
794+
val key = operands[1] // [B, H, T, D]
795+
val value = operands[2] // [B, H, T, D]
796+
val mask = if (operands.size >= 4) operands[3] else null
797+
798+
val querySpec = node.inputs.getOrNull(0)
799+
val keySpec = node.inputs.getOrNull(1)
800+
val valueSpec = node.inputs.getOrNull(2)
801+
802+
val outputSpec = node.outputs.firstOrNull()
803+
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
804+
?: "tensor<?xf32>"
805+
806+
// Infer shapes for intermediate types
807+
val qShape = querySpec?.shape ?: return ConversionResult.Failure("Unknown Q shape")
808+
val kShape = keySpec?.shape ?: return ConversionResult.Failure("Unknown K shape")
809+
val vShape = valueSpec?.shape ?: return ConversionResult.Failure("Unknown V shape")
810+
811+
val rank = qShape.size
812+
if (rank != 4) {
813+
return ConversionResult.Failure("SDPA expects 4D tensors [B,H,S,D], got rank $rank")
814+
}
815+
816+
val batch = qShape[0]
817+
val heads = qShape[1]
818+
val seqQ = qShape[2]
819+
val headDim = qShape[3]
820+
val seqK = kShape[2]
821+
822+
val queryType = context.getValueType(query) ?: "tensor<${qShape.joinToString("x")}xf32>"
823+
val keyType = context.getValueType(key) ?: "tensor<${kShape.joinToString("x")}xf32>"
824+
val valueType = context.getValueType(value) ?: "tensor<${vShape.joinToString("x")}xf32>"
825+
826+
// scores = Q @ K.T: [B,H,S,D] @ [B,H,T,D] → [B,H,S,T]
827+
// dot_general with batching_dims=[0,1], contracting_dims=[3]x[3]
828+
val scoresType = "tensor<${batch}x${heads}x${seqQ}x${seqK}xf32>"
829+
val scoresVal = context.nextTempValue()
830+
context.emitOperation(
831+
"$scoresVal = stablehlo.dot_general $query, $key, " +
832+
"batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [3] " +
833+
": ($queryType, $keyType) -> $scoresType"
834+
)
835+
context.setValueType(scoresVal, scoresType)
836+
837+
// Scale
838+
val scale = node.operation.parameters["scale"] as? Float
839+
?: (1.0f / kotlin.math.sqrt(headDim.toFloat()))
840+
val scaledVal = context.nextTempValue()
841+
val scaleConst = context.nextTempValue()
842+
context.emitOperation("$scaleConst = stablehlo.constant dense<$scale> : tensor<f32>")
843+
context.emitOperation(
844+
"$scaledVal = stablehlo.broadcast_in_dim $scaleConst, dims = [] " +
845+
": (tensor<f32>) -> $scoresType"
846+
)
847+
val scaledScores = context.nextTempValue()
848+
context.emitOperation(
849+
"$scaledScores = stablehlo.multiply $scoresVal, $scaledVal : $scoresType"
850+
)
851+
context.setValueType(scaledScores, scoresType)
852+
853+
// Optional mask
854+
var presoft = scaledScores
855+
if (mask != null) {
856+
val maskedVal = context.nextTempValue()
857+
val maskType = context.getValueType(mask) ?: scoresType
858+
context.emitOperation(
859+
"$maskedVal = stablehlo.add $presoft, $mask : $scoresType"
860+
)
861+
context.setValueType(maskedVal, scoresType)
862+
presoft = maskedVal
863+
}
864+
865+
// Softmax over last dim (seqK)
866+
// Decompose: exp(x - max(x)) / sum(exp(x - max(x)))
867+
val maxVal = context.nextTempValue()
868+
val maxInitVal = context.nextTempValue()
869+
context.emitOperation("$maxInitVal = stablehlo.constant dense<0xFF800000> : tensor<f32>") // -inf
870+
context.emitOperation(
871+
"$maxVal = stablehlo.reduce($presoft init: $maxInitVal) applies stablehlo.maximum " +
872+
"across dimensions = [${rank - 1}] : ($scoresType, tensor<f32>) -> " +
873+
"tensor<${batch}x${heads}x${seqQ}xf32>"
874+
)
875+
876+
val maxBcast = context.nextTempValue()
877+
val reducedType = "tensor<${batch}x${heads}x${seqQ}xf32>"
878+
context.emitOperation(
879+
"$maxBcast = stablehlo.broadcast_in_dim $maxVal, dims = [0, 1, 2] " +
880+
": ($reducedType) -> $scoresType"
881+
)
882+
883+
val shifted = context.nextTempValue()
884+
context.emitOperation("$shifted = stablehlo.subtract $presoft, $maxBcast : $scoresType")
885+
886+
val expVal = context.nextTempValue()
887+
context.emitOperation("$expVal = stablehlo.exponential $shifted : $scoresType")
888+
889+
val sumInit = context.nextTempValue()
890+
context.emitOperation("$sumInit = stablehlo.constant dense<0.0> : tensor<f32>")
891+
val sumVal = context.nextTempValue()
892+
context.emitOperation(
893+
"$sumVal = stablehlo.reduce($expVal init: $sumInit) applies stablehlo.add " +
894+
"across dimensions = [${rank - 1}] : ($scoresType, tensor<f32>) -> $reducedType"
895+
)
896+
897+
val sumBcast = context.nextTempValue()
898+
context.emitOperation(
899+
"$sumBcast = stablehlo.broadcast_in_dim $sumVal, dims = [0, 1, 2] " +
900+
": ($reducedType) -> $scoresType"
901+
)
902+
903+
val weightsVal = context.nextTempValue()
904+
context.emitOperation("$weightsVal = stablehlo.divide $expVal, $sumBcast : $scoresType")
905+
context.setValueType(weightsVal, scoresType)
906+
907+
// output = weights @ V: [B,H,S,T] @ [B,H,T,D] → [B,H,S,D]
908+
val resultValue = context.nextTempValue()
909+
context.emitOperation(
910+
"$resultValue = stablehlo.dot_general $weightsVal, $value, " +
911+
"batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] " +
912+
": ($scoresType, $valueType) -> $outputType"
913+
)
914+
915+
return ConversionResult.Success(
916+
outputValueName = resultValue,
917+
emittedOperations = emptyList()
918+
)
919+
}
920+
774921
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package sk.ainet.compile.hlo
2+
3+
import sk.ainet.lang.graph.DefaultGraphExecutionContext
4+
import sk.ainet.lang.tensor.Shape
5+
import sk.ainet.lang.tensor.Tensor
6+
import sk.ainet.lang.tensor.ops.VoidTensorOps
7+
import sk.ainet.lang.tape.toComputeGraph
8+
import sk.ainet.lang.types.FP32
9+
import kotlin.test.Test
10+
import kotlin.test.assertFalse
11+
import kotlin.test.assertTrue
12+
13+
class SdpaHloExportTest {
14+
15+
@Test
16+
fun sdpa_produces_dot_general_ops() {
17+
val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps())
18+
19+
val q = ctx.fromFloatArray<FP32, Float>(Shape(1, 2, 4, 8), FP32::class, FloatArray(64))
20+
val k = ctx.fromFloatArray<FP32, Float>(Shape(1, 2, 4, 8), FP32::class, FloatArray(64))
21+
val v = ctx.fromFloatArray<FP32, Float>(Shape(1, 2, 4, 8), FP32::class, FloatArray(64))
22+
23+
@Suppress("UNCHECKED_CAST")
24+
val inputIds = setOf(
25+
ctx.session.refOf(q as Tensor<*, *>).id,
26+
ctx.session.refOf(k as Tensor<*, *>).id,
27+
ctx.session.refOf(v as Tensor<*, *>).id
28+
)
29+
30+
val (tape, out) = ctx.record {
31+
ctx.ops.scaledDotProductAttention(q, k, v)
32+
}
33+
34+
println("Output shape: ${out.shape}")
35+
36+
val graph = tape!!.toComputeGraph(
37+
synthesizeExternalInputs = true,
38+
inputTensorIds = inputIds
39+
)
40+
val nodes = graph.getTopologicalOrder()
41+
println("Graph: ${nodes.size} nodes")
42+
println("Ops: ${nodes.map { it.operation.name }}")
43+
44+
val module = StableHloConverterFactory.createExtended().convert(graph, "sdpa_test")
45+
println("MLIR:\n${module.content}")
46+
47+
// Should contain dot_general for Q@K.T and weights@V
48+
assertTrue(module.content.contains("stablehlo.dot_general"), "Should have dot_general ops")
49+
50+
// Should NOT contain large zero constant tensors (from raw permutation)
51+
// Scalar zeros (tensor<f32>) for softmax init are fine
52+
assertFalse(
53+
module.content.contains(Regex("dense<0\\.0> : tensor<\\d+x")),
54+
"Should not have large zero constant tensors"
55+
)
56+
57+
// Should contain exponential (softmax decomposition)
58+
assertTrue(module.content.contains("stablehlo.exponential"), "Should have softmax (exponential)")
59+
}
60+
}

0 commit comments

Comments
 (0)