Skip to content

Commit 21f9a71

Browse files
Merge pull request #676 from SKaiNET-developers/fix/conv-gather-pool-reducewindow-675
fix(dag+hlo): conv1d/gather/pooling/flatten + IREE-valid reduce_window — 7/7 models compile (#675)
2 parents 3b8aff3 + eab9b29 commit 21f9a71

5 files changed

Lines changed: 243 additions & 43 deletions

File tree

skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,25 @@ public class MlirValidator {
112112
// not).
113113
if (trimmed.isEmpty() || trimmed.startsWith("//") || trimmed.startsWith("module")) continue
114114

115-
// Extract defined SSA values
116-
if (trimmed.contains(" = ")) {
117-
val parts = trimmed.split(" = ", limit = 2)
118-
if (parts.size == 2) {
119-
val valueName = parts[0].trim()
120-
if (valueName.startsWith("%")) {
121-
if (definedValues.contains(valueName)) {
122-
errors.add("Line ${lineNum + 1}: SSA value $valueName redefined")
123-
}
124-
definedValues.add(valueName)
125-
}
115+
// Extract defined SSA values. A line may carry more than one result
116+
// definition when an op with a region is emitted on a single line
117+
// (e.g. `reduce_window … ({ ^bb0(%a, %b): %r = … })`), so register every
118+
// `%name =` result, not just the leading assignment.
119+
Regex("(%[a-zA-Z0-9_]+)\\s*=").findAll(trimmed).forEach { m ->
120+
val valueName = m.groupValues[1]
121+
if (definedValues.contains(valueName)) {
122+
errors.add("Line ${lineNum + 1}: SSA value $valueName redefined")
126123
}
124+
definedValues.add(valueName)
127125
}
128-
126+
127+
// Register block-argument definitions from region entry blocks
128+
// (`^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): …`). These bind SSA
129+
// values without a ` = `, so they must be collected separately.
130+
Regex("\\^[a-zA-Z0-9_]*\\(([^)]*)\\)").findAll(trimmed).forEach { block ->
131+
Regex("%[a-zA-Z0-9_]+").findAll(block.groupValues[1]).forEach { definedValues.add(it.value) }
132+
}
133+
129134
// Extract used SSA values
130135
val usedInLine = extractUsedValues(trimmed)
131136
usedValues.addAll(usedInLine)

skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt

Lines changed: 71 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,13 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
184184
val padding = extractPadding(params)
185185

186186
val resultValue = context.nextTempValue()
187-
187+
val inputType = node.inputs.firstOrNull()?.let { context.getTypeMapper().mapTensorType(it) } ?: outputType
188+
188189
// Build StableHLO reduce_window operation for max pooling
189190
val operations = buildMaxPoolOperations(
190191
resultValue = resultValue,
191192
input = operands[0],
193+
inputType = inputType,
192194
outputType = outputType,
193195
kernelSize = kernelSize,
194196
stride = stride,
@@ -225,11 +227,13 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
225227
val padding = extractPadding(params)
226228

227229
val resultValue = context.nextTempValue()
228-
230+
val inputType = node.inputs.firstOrNull()?.let { context.getTypeMapper().mapTensorType(it) } ?: outputType
231+
229232
// Build StableHLO reduce_window operation for average pooling
230233
val operations = buildAvgPoolOperations(
231234
resultValue = resultValue,
232235
input = operands[0],
236+
inputType = inputType,
233237
outputType = outputType,
234238
kernelSize = kernelSize,
235239
stride = stride,
@@ -691,64 +695,103 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
691695
}
692696
}
693697

698+
/** The MLIR element type ("f32"/"f16"/…) parsed from a `tensor<…xT>` string. */
699+
private fun elementTypeOf(tensorType: String): String =
700+
tensorType.substringAfterLast('x').substringBefore('>').ifBlank { "f32" }
701+
702+
/**
703+
* Emit a `reduce_window` in IREE's parseable **generic region** form. The pretty
704+
* `… applies <op> over window dimensions = …` form is rejected by IREE's StableHLO
705+
* parser ("has no custom assembly form"), and its 2-element window only covered H/W;
706+
* the generic form carries full NCHW-rank (`[1, 1, kH, kW]`) window attributes. (#675)
707+
*/
708+
private fun reduceWindowGeneric(
709+
resultValue: String,
710+
input: String,
711+
inputType: String,
712+
initValue: String,
713+
reduceOp: String,
714+
elem: String,
715+
kernelSize: Pair<Int, Int>,
716+
stride: Pair<Int, Int>,
717+
padding: Pair<Int, Int>,
718+
outputType: String,
719+
): String {
720+
val (kH, kW) = kernelSize
721+
val (sH, sW) = stride
722+
val (pH, pW) = padding
723+
// Single line: MLIR treats newlines as whitespace, and the line-based MLIR
724+
// validator only handles one op per line. The region body ops are separated
725+
// by spaces, which the MLIR parser accepts.
726+
// Region-local SSA names are derived from the (unique) result value so two
727+
// pooling ops in one function don't collide in the flat validator (they are
728+
// region-scoped in MLIR, but the validator tracks names globally).
729+
val t = resultValue.removePrefix("%")
730+
return "$resultValue = \"stablehlo.reduce_window\"($input, $initValue) ({ " +
731+
"^bb0(%lhs_$t: tensor<$elem>, %rhs_$t: tensor<$elem>): " +
732+
"%out_$t = $reduceOp %lhs_$t, %rhs_$t : tensor<$elem> " +
733+
"stablehlo.return %out_$t : tensor<$elem> " +
734+
"}) {window_dimensions = array<i64: 1, 1, $kH, $kW>, " +
735+
"window_strides = array<i64: 1, 1, $sH, $sW>, " +
736+
"base_dilations = array<i64: 1, 1, 1, 1>, " +
737+
"window_dilations = array<i64: 1, 1, 1, 1>, " +
738+
"padding = dense<[[0, 0], [0, 0], [$pH, $pH], [$pW, $pW]]> : tensor<4x2xi64>} : " +
739+
"($inputType, tensor<$elem>) -> $outputType"
740+
}
741+
694742
private fun buildMaxPoolOperations(
695743
resultValue: String,
696744
input: String,
745+
inputType: String,
697746
outputType: String,
698747
kernelSize: Pair<Int, Int>,
699748
stride: Pair<Int, Int>,
700749
padding: Pair<Int, Int>,
701750
context: ConversionContext
702751
): List<String> {
703-
// For max pooling, we need to create a negative infinity constant as the initial value
752+
val elem = elementTypeOf(outputType)
704753
val initValue = context.nextTempValue()
705-
val initConstant = "$initValue = stablehlo.constant dense<-3.4028235e+38> : tensor<f32>"
706-
707-
val poolOp = "$resultValue = stablehlo.reduce_window($input, $initValue) " +
708-
"applies stablehlo.maximum " +
709-
"over window dimensions = [${kernelSize.first}, ${kernelSize.second}] " +
710-
"stride = [${stride.first}, ${stride.second}] " +
711-
"pad = [[${padding.first}, ${padding.first}], [${padding.second}, ${padding.second}]] : $outputType"
712-
713-
// Emit operations through context
754+
val initConstant = "$initValue = stablehlo.constant dense<-3.4028235e+38> : tensor<$elem>"
755+
val poolOp = reduceWindowGeneric(
756+
resultValue, input, inputType, initValue, "stablehlo.maximum",
757+
elem, kernelSize, stride, padding, outputType,
758+
)
714759
context.emitOperation(initConstant)
715760
context.emitOperation(poolOp)
716-
717761
return listOf(initConstant, poolOp)
718762
}
719-
763+
720764
private fun buildAvgPoolOperations(
721765
resultValue: String,
722766
input: String,
767+
inputType: String,
723768
outputType: String,
724769
kernelSize: Pair<Int, Int>,
725770
stride: Pair<Int, Int>,
726771
padding: Pair<Int, Int>,
727772
context: ConversionContext
728773
): List<String> {
729-
// Average pooling requires sum + division by kernel size
774+
// Average pooling requires sum + division by kernel size.
775+
val elem = elementTypeOf(outputType)
730776
val kernelArea = kernelSize.first * kernelSize.second
731777
val initZero = context.nextTempValue()
732778
val kernelAreaConst = context.nextTempValue()
733779
val sumResult = context.nextTempValue()
734-
735-
val initConstant = "$initZero = stablehlo.constant dense<0.0> : tensor<f32>"
736-
val areaConstant = "$kernelAreaConst = stablehlo.constant dense<$kernelArea.0> : tensor<f32>"
737-
738-
val sumOp = "$sumResult = stablehlo.reduce_window($input, $initZero) " +
739-
"applies stablehlo.add " +
740-
"over window dimensions = [${kernelSize.first}, ${kernelSize.second}] " +
741-
"stride = [${stride.first}, ${stride.second}] " +
742-
"pad = [[${padding.first}, ${padding.first}], [${padding.second}, ${padding.second}]] : $outputType"
743-
780+
781+
val initConstant = "$initZero = stablehlo.constant dense<0.0> : tensor<$elem>"
782+
// Splat over the output type so the divide is element-type consistent (a scalar
783+
// tensor<f32> divisor was a latent type mismatch).
784+
val areaConstant = "$kernelAreaConst = stablehlo.constant dense<$kernelArea.0> : $outputType"
785+
val sumOp = reduceWindowGeneric(
786+
sumResult, input, inputType, initZero, "stablehlo.add",
787+
elem, kernelSize, stride, padding, outputType,
788+
)
744789
val divideOp = "$resultValue = stablehlo.divide $sumResult, $kernelAreaConst : $outputType"
745-
746-
// Emit operations through context
790+
747791
context.emitOperation(initConstant)
748792
context.emitOperation(areaConstant)
749793
context.emitOperation(sumOp)
750794
context.emitOperation(divideOp)
751-
752795
return listOf(initConstant, areaConstant, sumOp, divideOp)
753796
}
754797

skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NeuralNetOperationsConverterTest.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ class NeuralNetOperationsConverterTest {
7373

7474
assertNotNull(module)
7575
assertContains(module.content, "stablehlo.reduce_window")
76-
// Should contain kernel size and stride information
77-
assertContains(module.content, "window dimensions")
78-
assertContains(module.content, "stride")
76+
// Generic region form (IREE-parseable): window_dimensions / window_strides attrs.
77+
assertContains(module.content, "window_dimensions")
78+
assertContains(module.content, "window_strides")
7979
}
8080

8181
private fun createConv2dGraph(): DefaultComputeGraph {
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package sk.ainet.compile.hlo
2+
3+
import sk.ainet.lang.dag.avgPool2d
4+
import sk.ainet.lang.dag.conv1d
5+
import sk.ainet.lang.dag.dag
6+
import sk.ainet.lang.dag.flatten
7+
import sk.ainet.lang.dag.gather
8+
import sk.ainet.lang.dag.maxPool2d
9+
import sk.ainet.lang.graph.dsl.toComputeGraph
10+
import sk.ainet.lang.tensor.ops.TensorSpec
11+
import sk.ainet.lang.types.FP32
12+
import sk.ainet.lang.types.Int32
13+
import kotlin.test.Test
14+
import kotlin.test.assertFalse
15+
import kotlin.test.assertTrue
16+
17+
/**
18+
* Remaining post-#674 DAG-DSL export bugs (tracked under the #674 follow-up issue).
19+
*
20+
* #674 fixed reshape/matmul/concat output-spec inference. These ops still declare a
21+
* result/return type that contradicts the value they produce (conv/gather: `inferDagOutputSpecs`
22+
* has no shape rule for them; pooling: also emits `reduce_window` in a form IREE rejects).
23+
* All RED on develop after #674; lock for the follow-up fix.
24+
*/
25+
class DagConvGatherPoolExportTest {
26+
27+
private fun lower(name: String, build: sk.ainet.lang.dag.DagBuilder.() -> Unit): String =
28+
StableHloConverterFactory.createExtended().convert(dag(build).toComputeGraph(), name).content
29+
30+
@Test
31+
fun conv1d_declares_inferred_output_channels_and_length() {
32+
// input (1,3,8), weight (4,3,3) stride 1 pad 0 -> (1,4,6). Currently declared 1x3x8.
33+
val mlir = lower("op_conv1d") {
34+
val x = input<FP32>("x", TensorSpec("x", listOf(1, 3, 8), "FP32"))
35+
val w = input<FP32>("w", TensorSpec("w", listOf(4, 3, 3), "FP32"))
36+
val b = input<FP32>("b", TensorSpec("b", listOf(4), "FP32"))
37+
output(conv1d(x, w, b, 1, 0, 1, 1))
38+
}
39+
assertTrue(mlir.contains("-> tensor<1x4x6xf32>"), "conv1d result must be inferred 1x4x6:\n$mlir")
40+
assertFalse(Regex("""return %\w+ : tensor<1x3x8xf32>""").containsMatchIn(mlir), "return must not echo the input shape:\n$mlir")
41+
}
42+
43+
@Test
44+
fun gather_declares_inferred_rows() {
45+
// table (8,4), 3 indices -> (3,4). Currently declared 8x4.
46+
val mlir = lower("op_gather") {
47+
val t = input<FP32>("t", TensorSpec("t", listOf(8, 4), "FP32"))
48+
val idx = input<Int32>("idx", TensorSpec("idx", listOf(3), "INT32"))
49+
output(gather(t, idx, 0))
50+
}
51+
assertTrue(mlir.contains("-> tensor<3x4xf32>"), "gather result must be inferred 3x4:\n$mlir")
52+
assertFalse(Regex("""return %\w+ : tensor<8x4xf32>""").containsMatchIn(mlir), "return must not echo the table shape:\n$mlir")
53+
}
54+
55+
@Test
56+
fun maxpool2d_declares_pooled_shape_and_iree_valid_reduce_window() {
57+
// input (1,3,8,8), 2x2 stride 2 -> (1,3,4,4). Currently declared 1x3x8x8.
58+
val mlir = lower("op_maxpool2d") {
59+
val x = input<FP32>("x", TensorSpec("x", listOf(1, 3, 8, 8), "FP32"))
60+
output(maxPool2d(x, 2 to 2, 2 to 2, 0 to 0))
61+
}
62+
assertTrue(mlir.contains("tensor<1x3x4x4xf32>"), "maxpool output must be the pooled 1x3x4x4:\n$mlir")
63+
// IREE's parser rejects the pretty `applies … over window` form; it needs the generic
64+
// region-based reduce_window. Assert we are not emitting the rejected pretty form.
65+
assertFalse(
66+
Regex("""reduce_window\([^)]*\)\s+applies""").containsMatchIn(mlir),
67+
"reduce_window must use the IREE-parseable generic form, not 'applies … over window':\n$mlir",
68+
)
69+
}
70+
71+
@Test
72+
fun flatten_preserves_leading_batch_dim() {
73+
// (1,16,7,7) flatten dims 1..3 -> (1, 784); must NOT collapse to rank-1 (784),
74+
// which breaks a downstream dense matmul (mnist-cnn).
75+
val mlir = lower("op_flatten") {
76+
val x = input<FP32>("x", TensorSpec("x", listOf(1, 16, 7, 7), "FP32"))
77+
output(flatten(x, 1, 3))
78+
}
79+
assertTrue(mlir.contains("tensor<1x784xf32>"), "flatten must keep batch: (1,16,7,7)->(1,784):\n$mlir")
80+
assertFalse(Regex("""-> tensor<784xf32>""").containsMatchIn(mlir), "flatten must not collapse the batch dim:\n$mlir")
81+
}
82+
83+
@Test
84+
fun avgpool2d_declares_pooled_shape() {
85+
val mlir = lower("op_avgpool2d") {
86+
val x = input<FP32>("x", TensorSpec("x", listOf(1, 3, 8, 8), "FP32"))
87+
output(avgPool2d(x, 2 to 2, 2 to 2, 0 to 0, false))
88+
}
89+
assertTrue(mlir.contains("tensor<1x3x4x4xf32>"), "avgpool output must be the pooled 1x3x4x4:\n$mlir")
90+
}
91+
}

skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,20 @@ public class DagBuilder {
196196
val target = reshapeTargetShape(operation) ?: return null
197197
return spec(target)
198198
}
199+
"flatten" -> {
200+
// Collapse dims [startDim..endDim] into one, preserving the others
201+
// (notably the leading batch dim). Without this, flatten echoes operand-0
202+
// or collapses everything, so a downstream dense matmul mis-types. (#675)
203+
val inS = input?.shape ?: return null
204+
val rank = inS.size
205+
val rawStart = operation.parameters["startDim"] as? Int ?: 1
206+
val rawEnd = operation.parameters["endDim"] as? Int ?: -1
207+
val start = if (rawStart < 0) rank + rawStart else rawStart
208+
val end = if (rawEnd < 0) rank + rawEnd else rawEnd
209+
if (start !in 0 until rank || end !in 0 until rank || start > end) return null
210+
val collapsed = inS.subList(start, end + 1).fold(1) { a, b -> a * b }
211+
return spec(inS.subList(0, start) + collapsed + inS.subList(end + 1, rank))
212+
}
199213
"matmul", "dot", "mm", "bmm", "batch_matmul" -> {
200214
val lhs = inputs.getOrNull(0)?.spec?.shape
201215
val rhs = inputs.getOrNull(1)?.spec?.shape
@@ -215,10 +229,57 @@ public class DagBuilder {
215229
out[axis] = shapes.sumOf { it[axis] }
216230
return spec(out)
217231
}
232+
"conv1d" -> {
233+
// (N, Cin, L) * (Cout, Cin/groups, K) -> (N, Cout, Lout). conv2d already
234+
// infers via Conv2dOperation; conv1d is a GenericOperation with no inference. (#675)
235+
val inS = inputs.getOrNull(0)?.spec?.shape
236+
val wS = inputs.getOrNull(1)?.spec?.shape
237+
if (inS == null || wS == null || inS.size != 3 || wS.size != 3) return null
238+
val stride = operation.parameters["stride"] as? Int ?: 1
239+
val pad = operation.parameters["padding"] as? Int ?: 0
240+
val dil = operation.parameters["dilation"] as? Int ?: 1
241+
return spec(listOf(inS[0], wS[0], windowedOutput(inS[2], wS[2], stride, pad, dil)))
242+
}
243+
"gather" -> {
244+
// table[..axis..] gathered by `indices` -> table[:axis] ⊕ indices.shape ⊕ table[axis+1:]. (#675)
245+
val table = inputs.getOrNull(0)?.spec?.shape
246+
val idx = inputs.getOrNull(1)?.spec?.shape
247+
if (table == null || idx == null || table.isEmpty()) return null
248+
val rawAxis = operation.parameters["dim"] as? Int ?: operation.parameters["axis"] as? Int ?: -1
249+
val axis = if (rawAxis < 0) table.size + rawAxis else rawAxis
250+
if (axis !in table.indices) return null
251+
return spec(table.subList(0, axis) + idx + table.subList(axis + 1, table.size))
252+
}
253+
"maxpool2d", "avgpool2d" -> {
254+
// (N, C, H, W) windowed by kernel/stride/padding -> (N, C, Hout, Wout). (#675)
255+
val inS = inputs.getOrNull(0)?.spec?.shape
256+
if (inS == null || inS.size != 4) return null
257+
val k = pairParam(operation, "kernel") ?: pairParam(operation, "kernelSize") ?: return null
258+
val s = pairParam(operation, "stride") ?: (1 to 1)
259+
val p = pairParam(operation, "padding") ?: (0 to 0)
260+
return spec(
261+
listOf(
262+
inS[0], inS[1],
263+
windowedOutput(inS[2], k.first, s.first, p.first, 1),
264+
windowedOutput(inS[3], k.second, s.second, p.second, 1),
265+
),
266+
)
267+
}
218268
}
219269
return null
220270
}
221271

272+
/** Windowed (conv/pool) output extent: floor((in + 2·pad − dilation·(k−1) − 1) / stride) + 1. */
273+
private fun windowedOutput(inDim: Int, k: Int, stride: Int, pad: Int, dilation: Int): Int =
274+
(inDim + 2 * pad - dilation * (k - 1) - 1) / stride + 1
275+
276+
private fun pairParam(operation: Operation, key: String): Pair<Int, Int>? =
277+
(operation.parameters[key] as? Pair<*, *>)?.let {
278+
val a = it.first as? Int
279+
val b = it.second as? Int
280+
if (a != null && b != null) a to b else null
281+
}
282+
222283
/** Recover a reshape/view target shape from the op's `newShape`/`shape` parameter. */
223284
private fun reshapeTargetShape(operation: Operation): List<Int>? {
224285
val raw = operation.parameters["newShape"]

0 commit comments

Comments
 (0)