Skip to content

Commit e6935c3

Browse files
michalharakalclaude
andcommitted
Make Whisper encoder compile end-to-end via SKaiNET → StableHLO → IREE
Prior state: `iree-compile` rejected the Whisper-encoder MLIR emitted by `StableHloConverterFactory.createExtended()` — 141 `// Unsupported` comments over 296 graph nodes and 12 `stablehlo.custom_call @reduce_*` ops. Cascade analysis in `docs/whisper-iree-issues/` identified a mix of missing converters and invalid-MLIR emissions; a single upstream miss fanned out via `StableHloConverter.processNode`'s `mapNotNull` operand lookup and appeared downstream as bogus "wrong arity" failures. Result: 0 Unsupported, 0 custom_call reductions, 35 real `stablehlo.reduce`, 32 `stablehlo.dot_general` (including the 8 rank-4 attention matmuls), and a 40 MB `.vmfb` compiled with `iree-compile --iree-hal-target-backends=llvm-cpu`. Individual fixes, each guarded by an end-to-end assertion in the new `Conv1dTapeToHloTest`: * `TraceToGraphBuilder.buildOutputSpecs`: fall back to `trace.outputs[i].shape.dimensions` when the KSP-generated tracer omits `outputShapes` from attributes. Fixes `tensor<?xf32>` on every conv/gelu/etc. whose wrapper didn't go through `OpAttributeFactory.shapesAndDTypes`. * `NeuralNetOperationsConverter.buildConv1d/ConvolutionOperation`: emit `stablehlo.broadcast_in_dim %bias, dims = [1]` before the bias add so the add's operands share a type; previously produced `stablehlo.add %conv, %bias : tensor<N,C,L>` with a rank-mismatched bias. * `BasicMathConverter`: rewrite `add`/`subtract`/`multiply`/`divide` lowering. Old code short-circuited broadcasting whenever dtypes matched, and the broadcast path itself emitted split-line syntax MLIR rejects. New path takes `node.outputs[0]` as the target spec, adapts each operand via `stablehlo.convert` then `stablehlo.broadcast_in_dim` with NumPy-style right-aligned dim mapping. * New `UnaryMathConverter`: `sqrt`, `rsqrt`, `exp`, `expm1`, `log`, `log1p`, `abs`, `sign`, `negate`, `ceil`, `floor`, `round`, `cos`, `sin` → 1:1 StableHLO primitives. * New `ScalarOperationsConverter`: `addScalar`/`subScalar`/`mulScalar`/ `divScalar`/`rsubScalar`/`rdivScalar` materialize the scalar as a splat `stablehlo.constant` then apply the matching binary op (reversed for `r*Scalar`). Reads the scalar from `operation.parameters["b"]` (the KSP tracer stores it there via `OpAttributeFactory.scalarOp`). * `ReductionOperationsConverter`: emit real `stablehlo.reduce` instead of `stablehlo.custom_call @reduce_*` (the previous form was both non-standard and syntactically malformed). `mean` = reduce-sum + divide-by-count; `variance` decomposes to `E[X²] − E[X]²` via two reduces. * `ActivationOperationsConverter.convertSoftmax`: softmax's two reductions now use `stablehlo.reduce` applying `stablehlo.maximum` and `stablehlo.add`. Same motivation as the reductions above. * `LinalgOperationsConverter`: route every matmul variant (`matmul`/`dot`/`mm`/`bmm`/`batch_matmul`) through the rank-generic batched lowering. Prior `matmul` path hardcoded `contracting_dims = [1] x [0]` and produced invalid MLIR for Whisper attention's rank-4 matmul `[1,6,1500,64] × [1,6,64,1500]`; now emits `batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2]`. Rank-2 continues to emit the compact form (empty batching clause). * `StableHloConverter.processNodes`: enrich the fallback "Unsupported op" comment with the quoted name and the registry's full key list. Makes the MLIR self-diagnostic so a future missing-converter case surfaces without needing a local reproducer. Registry entries added for the two new converters in all three factory entry points (`createBasic`, `createExtended`, `createFast`). `docs/whisper-iree-issues/` is included for context — it captured the failure surface that these changes close. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c04257c commit e6935c3

16 files changed

Lines changed: 1076 additions & 292 deletions

File tree

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Conv1d/2d/3dOperation.inferOutputs() echoes input shape instead of computing output shape
2+
3+
## Problem
4+
5+
`Conv1dOperation.inferOutputs()` in `TensorOperations.kt` (line ~439) returns the
6+
input tensor's shape as the output shape, ignoring weight shape, stride, padding,
7+
and dilation:
8+
9+
```kotlin
10+
override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
11+
require(inputs.size >= 2) { "Conv1d operation requires at least 2 inputs" }
12+
val outputShape = inputs[0].shape // <-- BUG: just copies input shape
13+
return listOf(TensorSpec("conv1d_output", outputShape, inputs[0].dtype, ...))
14+
}
15+
```
16+
17+
Conv2dOperation (line ~471) and Conv3dOperation (line ~503) have the identical bug.
18+
19+
## Expected
20+
21+
```kotlin
22+
override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
23+
val inShape = inputs[0].shape // [N, Cin, L]
24+
val wShape = inputs[1].shape // [Cout, Cin/g, K]
25+
val stride = (parameters["stride"] as? Int) ?: 1
26+
val padding = (parameters["padding"] as? Int) ?: 0
27+
val dilation = (parameters["dilation"] as? Int) ?: 1
28+
val outShape = if (inShape != null && wShape != null && inShape.size == 3 && wShape.size == 3)
29+
listOf(inShape[0], wShape[0],
30+
(inShape[2] + 2*padding - dilation*(wShape[2]-1) - 1)/stride + 1)
31+
else null
32+
return listOf(TensorSpec("conv1d_output", outShape, inputs[0].dtype, ...))
33+
}
34+
```
35+
36+
The formula already exists in `VoidTensorOps.calculateConv1dShape()` (line ~747).
37+
`ConvShapeUtils` was added to the JAR but `inferOutputs()` does not call it yet.
38+
39+
## Impact
40+
41+
When the StableHLO converter calls `inferOutputs()` to determine the MLIR output
42+
type, it gets the wrong shape. For Whisper's first conv1d:
43+
44+
```
45+
Input: [1, 80, 3000] Weight: [384, 80, 3] stride=1 padding=1
46+
Actual: [1, 80, 3000] ← wrong (echoed input)
47+
Expect: [1, 384, 3000] ← correct
48+
```
49+
50+
This produces `tensor<?xf32>` in the MLIR (12 occurrences), which `iree-compile`
51+
rejects.
52+
53+
## Parameters are available
54+
55+
PR #532 stores stride/padding/dilation in `operation.parameters`:
56+
57+
```kotlin
58+
// RecordingExecution.kt:238-261
59+
val params = mapOf("stride" to stride, "padding" to padding, "dilation" to dilation, "groups" to groups)
60+
record(Conv1dOperation<T, V>(params), ...)
61+
```
62+
63+
Verified by test: `assertEquals(1, recorded.operation.parameters["stride"])`
64+
65+
## Suggested fix
66+
67+
Extract `ConvShapeUtils` calls into all three `inferOutputs()` methods.
68+
Single PR covering conv1d/2d/3d since the bug and fix are identical.
69+
70+
## Test
71+
72+
See `Conv1dTapeToHloTest.kt` — asserts `tensor<?` does not appear in output MLIR.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# toComputeGraph() loses edge wiring and produces wrong op types
2+
3+
## Problem
4+
5+
After `tape.toComputeGraph(synthesizeExternalInputs = true)`, many graph nodes have:
6+
7+
1. **Wrong input edge count** — binary ops (add, matmul, subtract) don't get 2 input
8+
edges wired; unary ops (gelu, softmax, reshape) don't get 1 input edge wired.
9+
The StableHLO converter checks arity and emits "Unsupported X arity" comments.
10+
11+
2. **Wrong operation type** — some ops have `operation.type = "trace"` instead of a
12+
recognized category. The converter uses `operation.name` for dispatch but some
13+
converters also check `type`, and "trace" doesn't match any registered converter.
14+
15+
## Scope
16+
17+
Whisper encoder tape produces 296 graph nodes. After StableHLO conversion:
18+
- 166 nodes emit valid `stablehlo.*` ops
19+
- 157 nodes emit `// Unsupported ...` comments (some nodes emit both)
20+
21+
Breakdown of unsupported:
22+
23+
```
24+
32 add — wrong arity (expected 2 inputs)
25+
24 matmul — wrong arity (expected 2 inputs)
26+
16 unsqueeze — wrong arity (expected 1 input)
27+
12 reshape — wrong arity (expected 1 input)
28+
9 subtract — wrong arity (expected 2 inputs)
29+
9 sqrt — type "trace" not recognized
30+
9 addScalar — type "trace" not recognized
31+
9 multiply — wrong arity (expected 2 inputs)
32+
9 divide — wrong arity (expected 2 inputs)
33+
7 variance — wrong arity (expected 1 input)
34+
7 mean — wrong arity (expected 1 input)
35+
4 softmax — wrong arity (expected 1 input)
36+
4 mulScalar — type "trace" not recognized
37+
4 gelu — wrong arity (expected 1 input)
38+
2 mean — type "trace" not recognized
39+
```
40+
41+
## Root cause hypothesis
42+
43+
`DefaultExecutionTape.toComputeGraph()` builds graph edges by matching tensor ref
44+
IDs between operation outputs and subsequent operation inputs. If:
45+
46+
- The ref ID scheme changed between recording and graph construction, edges don't
47+
connect and binary ops appear to have 0 or 1 inputs.
48+
- Weight tensors created before `startRecording()` may not have their ref IDs in
49+
the tape's scope, so edges from weights to consumers are missing.
50+
51+
The `type = "trace"` issue: `KspTensorOps` (the auto-generated tracing wrapper)
52+
may record operations with a generic "trace" type string for ops that don't have
53+
an explicit `Operation` subclass (sqrt, addScalar, mulScalar, etc.).
54+
55+
## Impact
56+
57+
The generated MLIR is structurally incomplete — most ops are comments instead of
58+
valid StableHLO operations. `iree-compile` cannot process it.
59+
60+
## Suggested investigation
61+
62+
1. In `DefaultExecutionTape.toComputeGraph()`, check how `GraphEdge` source/target
63+
are resolved from tape trace inputs/outputs. Are ref IDs stable?
64+
2. For the `type = "trace"` ops: check what `KspTensorOps` records as the operation
65+
type for `sqrt`, `addScalar`, `mulScalar`. The HLO converter's operation
66+
registry should recognize these names regardless of type.
67+
3. The `synthesizeExternalInputs = true` flag should create input/weight nodes for
68+
external tensors — verify these get edges to their consumers.
69+
70+
## Test
71+
72+
See `Conv1dTapeToHloTest.kt` — asserts that `Unsupported` does not appear in
73+
output MLIR for a simple conv1d → gelu → add pipeline.

docs/whisper-iree-issues/README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SKaiNET Upstream Issues — Whisper IREE Pipeline
2+
3+
Two issues block the native SKaiNET DSL → StableHLO → IREE compilation path.
4+
5+
## Issue A: Conv1dOperation.inferOutputs echoes input shape
6+
7+
`Conv1dOperation.inferOutputs()` returns `inputs[0].shape` instead of
8+
computing `[batch, outChannels, outLength]`. Same bug in Conv2d/Conv3d.
9+
10+
**File:** `skainet-lang/skainet-lang-core/.../tensor/ops/TensorOperations.kt`
11+
**Fix:** Use `ConvShapeUtils` (already in JAR) from `inferOutputs()`.
12+
13+
## Issue B: toComputeGraph loses edge wiring and op types
14+
15+
`tape.toComputeGraph()` produces nodes where:
16+
- Binary ops (add, matmul, subtract, ...) have wrong input edge count
17+
- Some ops have `operation.type = "trace"` instead of recognized names
18+
19+
157 of 296 Whisper encoder nodes emit "Unsupported ... arity" in MLIR.
20+
21+
**File:** `skainet-compile/skainet-compile-dag/.../tape/extensions.kt` or
22+
`DefaultExecutionTape.toComputeGraph()`
23+
24+
## Test
25+
26+
`Conv1dTapeToHloTest.kt` is a KMP commonTest that:
27+
1. Builds a tape-recording context
28+
2. Runs conv1d → gelu → add through `ctx.ops`
29+
3. Converts tape to ComputeGraph
30+
4. Exports to StableHLO MLIR
31+
5. Asserts: no `tensor<?`, no `Unsupported`, valid `stablehlo.convolution`
32+
33+
Place in: `skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/`
34+
35+
Run: `./gradlew :skainet-compile:skainet-compile-hlo:allTests --tests "*Conv1dTapeToHloTest*"`
36+
37+
Currently **fails** on both issues. Will **pass** when both are fixed.

skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,8 @@ public class TraceToGraphBuilder(
350350
val count = trace.outputs.size
351351
return List(count) { i ->
352352
val name = trace.outputs[i].id
353-
val shape = shapes?.getOrNull(i)
354-
val dtype = dtypes?.getOrNull(i) ?: "unknown"
353+
val shape = shapes?.getOrNull(i) ?: trace.outputs[i].shape.dimensions.toList()
354+
val dtype = dtypes?.getOrNull(i) ?: trace.outputs[i].dtype::class.simpleName ?: "unknown"
355355
TensorSpec(name = name, shape = shape, dtype = dtype)
356356
}
357357
}

skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,14 @@ public class StableHloConverter @kotlin.jvm.JvmOverloads constructor(
178178
processNode(node, context)
179179
} catch (e: Exception) {
180180
context.emitComment("Error processing node ${node.id}: ${e.message}")
181-
context.emitComment("Unsupported op ${node.operation.name} (type=${node.operation.type}) for node ${node.id}")
181+
// Quote the name so trailing whitespace / casing surprises are visible,
182+
// and include the registry's full key set so "no converter found"
183+
// failures are self-diagnostic (is the name missing, or mis-matched?).
184+
val known = registry.getSupportedOperations().sorted().joinToString(", ")
185+
context.emitComment(
186+
"Unsupported op '${node.operation.name}' (type=${node.operation.type}) " +
187+
"for node ${node.id}. Known names: [$known]"
188+
)
182189
}
183190
}
184191
}

skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import sk.ainet.compile.hlo.converters.LinalgOperationsConverter
88
import sk.ainet.compile.hlo.converters.MathOperationsConverter
99
import sk.ainet.compile.hlo.converters.NeuralNetOperationsConverter
1010
import sk.ainet.compile.hlo.converters.ReductionOperationsConverter
11+
import sk.ainet.compile.hlo.converters.ScalarOperationsConverter
1112
import sk.ainet.compile.hlo.converters.ShapeOperationsConverter
13+
import sk.ainet.compile.hlo.converters.UnaryMathConverter
1214
import kotlin.jvm.JvmStatic
1315

1416
/**
@@ -53,6 +55,15 @@ public object StableHloConverterFactory {
5355
// Register reduction operations converter
5456
registry.register(ReductionOperationsConverter())
5557

58+
// Register elementwise unary math converter (sqrt, exp, log, abs, …).
59+
// Must be present so downstream consumers don't cascade-fail with
60+
// "wrong arity" when an upstream op is silently dropped.
61+
registry.register(UnaryMathConverter())
62+
63+
// Register tensor+scalar ops (addScalar / mulScalar / …) emitted by the
64+
// KSP-generated tracing wrapper for `tensor op Number` expressions.
65+
registry.register(ScalarOperationsConverter())
66+
5667
// Register constant operations converter
5768
registry.register(ConstantOperationsConverter())
5869

@@ -98,6 +109,15 @@ public object StableHloConverterFactory {
98109
// Register reduction operations converter
99110
registry.register(ReductionOperationsConverter())
100111

112+
// Register elementwise unary math converter (sqrt, exp, log, abs, …).
113+
// Must be present so downstream consumers don't cascade-fail with
114+
// "wrong arity" when an upstream op is silently dropped.
115+
registry.register(UnaryMathConverter())
116+
117+
// Register tensor+scalar ops (addScalar / mulScalar / …) emitted by the
118+
// KSP-generated tracing wrapper for `tensor op Number` expressions.
119+
registry.register(ScalarOperationsConverter())
120+
101121
// Register constant operations converter
102122
registry.register(ConstantOperationsConverter())
103123

@@ -128,6 +148,8 @@ public object StableHloConverterFactory {
128148
registry.register(ActivationOperationsConverter())
129149
registry.register(ShapeOperationsConverter())
130150
registry.register(ReductionOperationsConverter())
151+
registry.register(UnaryMathConverter())
152+
registry.register(ScalarOperationsConverter())
131153
registry.register(ConstantOperationsConverter())
132154

133155
return StableHloConverter(registry, typeMapper, null, policy)

skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ActivationOperationsConverter.kt

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,18 +146,26 @@ public class ActivationOperationsConverter : StableHloOperationConverter {
146146
// mapped to its position in the reduced tensor.
147147
val broadcastDims = (0 until rank).filter { it != axis }.joinToString(", ")
148148

149+
val maxInit = context.nextTempValue()
149150
val maxValue = context.nextTempValue()
150151
val maxBroadcast = context.nextTempValue()
151152
val shiftedValue = context.nextTempValue()
152153
val expValue = context.nextTempValue()
154+
val sumInit = context.nextTempValue()
153155
val sumValue = context.nextTempValue()
154156
val sumBroadcast = context.nextTempValue()
155157
val resultValue = context.nextTempValue()
156158

159+
// Identity for stablehlo.maximum on floats: -inf. Spell it via the bit
160+
// pattern so MLIR parses it regardless of how the element type prints.
161+
val maxIdentity = "0xFF800000"
162+
157163
val operations = listOf(
158164
// Reduce-max along the softmax axis (for numerical stability).
159-
"$maxValue = stablehlo.custom_call @reduce_max(${operands[0]}) " +
160-
"{dimensions = [$axis], keepdim = false} : $reducedType",
165+
"$maxInit = stablehlo.constant dense<$maxIdentity> : tensor<$elementType>",
166+
"$maxValue = stablehlo.reduce(${operands[0]} init: $maxInit) " +
167+
"applies stablehlo.maximum across dimensions = [$axis] : " +
168+
"($outputType, tensor<$elementType>) -> $reducedType",
161169

162170
// Broadcast reduced max back to the input shape.
163171
"$maxBroadcast = stablehlo.broadcast_in_dim $maxValue, " +
@@ -170,8 +178,10 @@ public class ActivationOperationsConverter : StableHloOperationConverter {
170178
"$expValue = stablehlo.exponential $shiftedValue : $outputType",
171179

172180
// Reduce-sum along the softmax axis.
173-
"$sumValue = stablehlo.custom_call @reduce_sum($expValue) " +
174-
"{dimensions = [$axis], keepdim = false} : $reducedType",
181+
"$sumInit = stablehlo.constant dense<0.0> : tensor<$elementType>",
182+
"$sumValue = stablehlo.reduce($expValue init: $sumInit) " +
183+
"applies stablehlo.add across dimensions = [$axis] : " +
184+
"($outputType, tensor<$elementType>) -> $reducedType",
175185

176186
// Broadcast the sum back to the input shape.
177187
"$sumBroadcast = stablehlo.broadcast_in_dim $sumValue, " +

0 commit comments

Comments
 (0)