Skip to content

Commit 0d28175

Browse files
michalharakalclaude
andcommitted
Fix StableHLO MLIR export to produce valid IREE-compilable output
The HLO generator was producing broken MLIR because the graph builder did not synthesize nodes for tensor inputs without known producers (model inputs and weights). This adds a finalize() step to TraceToGraphBuilder that creates synthetic "input" and "weight" nodes for unresolved references, fixes dense literal formatting to produce properly nested arrays matching tensor rank, and corrects the stablehlo.convolution assembly syntax (attribute dict, functional type annotation, batch_group_count). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 67b8105 commit 0d28175

8 files changed

Lines changed: 248 additions & 74 deletions

File tree

rgb2grayscale.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module {
2+
func.func @rgb2grayscale(%arg0: tensor<1x3x4x4xf32>) -> (tensor<1x1x4x4xf32>) {
3+
// input n1_input: t0 : tensor<1x3x4x4xf32>
4+
// weight n2_weight: frozen parameter
5+
%v0 = stablehlo.constant dense<[[[[0.2989]], [[0.587]], [[0.114]]]]> : tensor<1x3x1x1xf32>
6+
%v1 = stablehlo.convolution(%arg0, %v0) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x4xf32>, tensor<1x3x1x1xf32>) -> tensor<1x1x4x4xf32>
7+
return %v1 : tensor<1x1x4x4xf32>
8+
}
9+
}

skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,19 @@ public open class DefaultExecutionTape(
240240
return newTape
241241
}
242242

243-
public fun toComputeGraph(): ComputeGraph {
243+
public fun toComputeGraph(
244+
synthesizeExternalInputs: Boolean = false,
245+
inputTensorIds: Set<String> = emptySet()
246+
): ComputeGraph {
244247
// Prefer trace-based offline build when traces are available to ensure
245248
// consistency with online GraphSink wiring rules (PRD FR6).
246249
if (_traces.isNotEmpty()) {
247250
val graph = DefaultComputeGraph()
248251
val builder = TraceToGraphBuilder(graph, session)
249252
builder.addAll(_traces)
253+
if (synthesizeExternalInputs) {
254+
builder.finalize(inputTensorIds)
255+
}
250256
return graph
251257
}
252258

skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/tape/extensions.kt

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,20 @@ import sk.ainet.lang.graph.DefaultComputeGraph
55
import sk.ainet.tape.ExecutionTape
66

77
/**
8-
* Convert the tape to a compute graph
8+
* Convert the tape to a compute graph.
9+
*
10+
* @param synthesizeExternalInputs When true, placeholder "input" and "weight" constant nodes are
11+
* created for tensor inputs that have no known producer in the trace. Required for StableHLO
12+
* compilation where every operand must be wired through graph edges.
13+
* @param inputTensorIds Tensor IDs that should always become function arguments (model inputs)
14+
* rather than constants, even if their data is resolvable.
915
*/
10-
public fun ExecutionTape.toComputeGraph(): ComputeGraph {
16+
public fun ExecutionTape.toComputeGraph(
17+
synthesizeExternalInputs: Boolean = false,
18+
inputTensorIds: Set<String> = emptySet()
19+
): ComputeGraph {
1120
return when (this) {
12-
is sk.ainet.lang.graph.DefaultExecutionTape -> this.toComputeGraph()
21+
is sk.ainet.lang.graph.DefaultExecutionTape -> this.toComputeGraph(synthesizeExternalInputs, inputTensorIds)
1322
else -> DefaultComputeGraph()
1423
}
1524
}

skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt

Lines changed: 140 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ import sk.ainet.lang.tensor.ops.ValidationResult
1818
* "e_<srcNodeId>_<srcOut>__<dstNodeId>_<dstIn>"
1919
* Example: e_n0_Add_0__n1_Relu_0. This is deterministic given the node IDs and wiring.
2020
*
21-
* Note: This builder does NOT synthesize explicit "input" placeholder nodes for tensors without a
22-
* known producer. Only real operation nodes are created, and edges are added solely between such
23-
* nodes when a producer is known. This matches the expectations of TracingAcceptanceTest.
21+
* By default this builder does NOT synthesize explicit "input" placeholder nodes for tensors
22+
* without a known producer. Call [finalize] after adding all traces to synthesize "input" and
23+
* "weight" constant nodes for unresolved external inputs. This is required for StableHLO
24+
* compilation where every operand must be wired through graph edges.
2425
*/
2526
public class TraceToGraphBuilder(
2627
private val graph: ComputeGraph,
@@ -32,6 +33,14 @@ public class TraceToGraphBuilder(
3233
private data class Producer(val node: GraphNode, val outIndex: Int, val spec: TensorSpec)
3334
private val producersByTensorId = mutableMapOf<String, Producer>()
3435

36+
private data class UnresolvedRef(
37+
val tensorRef: TensorRef,
38+
val consumerNode: GraphNode,
39+
val inputIndex: Int,
40+
val spec: TensorSpec
41+
)
42+
private val unresolvedByTensorId = mutableMapOf<String, MutableList<UnresolvedRef>>()
43+
3544
/**
3645
* Add a single OpTrace into the graph, wiring known producers to inputs
3746
* and registering the outputs as new producers.
@@ -106,6 +115,35 @@ public class TraceToGraphBuilder(
106115
}
107116
}
108117
}
118+
"conv2d" -> {
119+
// For Conv2d layer: conv2d(input, weight, bias?)
120+
// Resolve weight tensor (second input) from session
121+
if (trace.inputs.size >= 2) {
122+
val weightRef = trace.inputs[1]
123+
if (!producersByTensorId.containsKey(weightRef.id)) {
124+
val tensor = session?.resolve(weightRef)
125+
if (tensor != null) {
126+
val values = extractFloatArray(tensor)
127+
if (values != null) {
128+
parameters["weights"] = values
129+
}
130+
}
131+
}
132+
}
133+
// Resolve optional bias tensor (third input) from session
134+
if (trace.inputs.size >= 3) {
135+
val biasRef = trace.inputs[2]
136+
if (!producersByTensorId.containsKey(biasRef.id)) {
137+
val tensor = session?.resolve(biasRef)
138+
if (tensor != null) {
139+
val values = extractFloatArray(tensor)
140+
if (values != null) {
141+
parameters["bias_values"] = values
142+
}
143+
}
144+
}
145+
}
146+
}
109147
}
110148
}
111149

@@ -114,8 +152,6 @@ public class TraceToGraphBuilder(
114152
val inputSpecs = buildInputSpecs(trace)
115153
val outputSpecs = buildOutputSpecs(trace)
116154

117-
// Do not synthesize placeholder input nodes; leave unknown producers unresolved.
118-
119155
val nodeId = "n${nextNodeId++}_${trace.opType}"
120156
val node = GraphNode(
121157
id = nodeId,
@@ -125,7 +161,7 @@ public class TraceToGraphBuilder(
125161
)
126162
graph.addNode(node)
127163

128-
// Wire edges from producers
164+
// Wire edges from producers; track unresolved inputs for later finalization
129165
trace.inputs.forEachIndexed { idx, tRef ->
130166
val prod = producersByTensorId[tRef.id]
131167
if (prod != null) {
@@ -141,6 +177,15 @@ public class TraceToGraphBuilder(
141177
tensorSpec = tensorSpec
142178
)
143179
)
180+
} else {
181+
// Track for finalize() — no placeholder synthesized here by default
182+
val spec = inputSpecs.getOrNull(idx) ?: TensorSpec(
183+
name = tRef.id,
184+
shape = tRef.shape.dimensions.toList(),
185+
dtype = tRef.dtype::class.simpleName ?: "FP32"
186+
)
187+
unresolvedByTensorId.getOrPut(tRef.id) { mutableListOf() }
188+
.add(UnresolvedRef(tRef, node, idx, spec))
144189
}
145190
}
146191

@@ -159,6 +204,95 @@ public class TraceToGraphBuilder(
159204
traces.forEach { addTrace(it) }
160205
}
161206

207+
/**
208+
* Synthesize placeholder nodes for tensor inputs that had no known producer.
209+
*
210+
* For each unresolved tensor:
211+
* - If the tensor ID is in [inputTensorIds], an "input" placeholder node is created
212+
* (representing a function argument).
213+
* - Else if the original tensor can be resolved from the session and contains constant data
214+
* (e.g. model weights), a "weight" constant node is created.
215+
* - Otherwise an "input" placeholder node is created as a fallback.
216+
*
217+
* Edges are wired from the new nodes to every consumer that referenced the tensor.
218+
* Call this after [addAll] when building graphs for compilation.
219+
*
220+
* @param inputTensorIds Tensor IDs that should always become function arguments (model inputs).
221+
*/
222+
public fun finalize(inputTensorIds: Set<String> = emptySet()) {
223+
for ((tensorId, refs) in unresolvedByTensorId) {
224+
val firstRef = refs.first()
225+
val spec = firstRef.spec
226+
227+
// If explicitly marked as model input, create an input node unconditionally
228+
val forceInput = inputTensorIds.contains(tensorId)
229+
230+
// Try to resolve as a constant from the session
231+
val tensor = if (!forceInput) session?.resolve(firstRef.tensorRef) else null
232+
val constantValues = tensor?.let { extractFloatArray(it) }
233+
234+
val syntheticNode: GraphNode
235+
if (constantValues != null) {
236+
// Create a constant/weight node with embedded values
237+
val weightShape = tensor!!.shape.dimensions.toList()
238+
val weightDtype = tensor.dtype.simpleName ?: "FP32"
239+
val nodeId = "n${nextNodeId++}_weight"
240+
val op = TraceBackedOperation(
241+
name = "weight",
242+
type = "constant",
243+
parameters = mapOf(
244+
"initial_value" to constantValues.toList(),
245+
"trainable" to false
246+
)
247+
)
248+
syntheticNode = GraphNode(
249+
id = nodeId,
250+
operation = op,
251+
inputs = emptyList(),
252+
outputs = listOf(TensorSpec(
253+
name = tensorId,
254+
shape = weightShape,
255+
dtype = weightDtype
256+
))
257+
)
258+
} else {
259+
// Create an input placeholder node
260+
val nodeId = "n${nextNodeId++}_input"
261+
val op = TraceBackedOperation(
262+
name = "input",
263+
type = "input",
264+
parameters = emptyMap()
265+
)
266+
syntheticNode = GraphNode(
267+
id = nodeId,
268+
operation = op,
269+
inputs = emptyList(),
270+
outputs = listOf(spec)
271+
)
272+
}
273+
274+
graph.addNode(syntheticNode)
275+
276+
// Wire edges to all consumers
277+
for (ref in refs) {
278+
graph.addEdge(
279+
GraphEdge(
280+
id = "e_${syntheticNode.id}_0__${ref.consumerNode.id}_${ref.inputIndex}",
281+
source = syntheticNode,
282+
destination = ref.consumerNode,
283+
sourceOutputIndex = 0,
284+
destinationInputIndex = ref.inputIndex,
285+
tensorSpec = ref.spec
286+
)
287+
)
288+
}
289+
290+
// Register as producer
291+
producersByTensorId[tensorId] = Producer(syntheticNode, 0, spec)
292+
}
293+
unresolvedByTensorId.clear()
294+
}
295+
162296
private fun buildInputSpecs(trace: OpTrace): List<TensorSpec> {
163297
val shapes = (trace.attributes["inputShapes"] as? List<*>)?.map { it as? List<Int> }
164298
val dtypes = (trace.attributes["inputDTypes"] as? List<*>)?.map { it?.toString() }

skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/TypeMapper.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class TypeMapper {
2525
"UI32", "UINT32" -> "ui32"
2626
"UI64", "UINT64" -> "ui64"
2727
"BF16", "BFLOAT16" -> "bf16"
28-
"F16", "FLOAT16" -> "f16"
28+
"FP16", "F16", "FLOAT16" -> "f16"
2929
"BOOL", "BOOLEAN" -> "i1"
3030
else -> {
3131
// Default fallback with warning

skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -342,55 +342,54 @@ public class ConstantOperationsConverter : StableHloOperationConverter {
342342
}
343343

344344
/**
345-
* Format tensor values for MLIR dense constant
345+
* Format tensor values for MLIR dense constant.
346+
* MLIR dense<> syntax requires nested brackets matching the tensor rank:
347+
* scalar: dense<42.0>
348+
* 1D [3]: dense<[v0, v1, v2]>
349+
* 2D [2,3]: dense<[[v0,v1,v2],[v3,v4,v5]]>
350+
* 4D [1,3,1,1]: dense<[[[[v0],[v1],[v2]]]]>
346351
*/
347352
private fun formatTensorValues(values: List<*>, outputSpec: TensorSpec?): String {
348353
val shape = outputSpec?.shape ?: emptyList()
349-
354+
350355
return when {
351356
values.isEmpty() -> "0.0"
352-
shape.isEmpty() || shape.size == 1 -> {
353-
// 1D tensor or scalar
354-
values.joinToString(", ") { formatConstantValue(it as Number) }
355-
}
356-
shape.size == 2 -> {
357-
// 2D tensor - format as nested arrays
358-
formatAs2DTensor(values, shape)
359-
}
360-
else -> {
361-
// Multi-dimensional tensor - flatten for now
362-
values.joinToString(", ") { formatConstantValue(it as Number) }
363-
}
357+
values.size == 1 -> formatConstantValue(values[0] as Number)
358+
shape.isEmpty() -> "[" + values.joinToString(", ") { formatConstantValue(it as Number) } + "]"
359+
else -> formatNestedTensor(values, shape, 0, IntArray(1))
364360
}
365361
}
366-
362+
367363
/**
368-
* Format values as a 2D tensor for MLIR
364+
* Recursively format a flat list of values into nested MLIR dense literal
365+
* matching the given shape. [offset] tracks the current position in the flat values list.
369366
*/
370-
private fun formatAs2DTensor(values: List<*>, shape: List<Int>): String {
371-
if (shape.size != 2) return values.joinToString(", ") { formatConstantValue(it as Number) }
372-
373-
val rows = shape[0]
374-
val cols = shape[1]
375-
val result = StringBuilder()
376-
377-
result.append("[")
378-
for (i in 0 until rows) {
379-
if (i > 0) result.append(", ")
380-
result.append("[")
381-
for (j in 0 until cols) {
382-
if (j > 0) result.append(", ")
383-
val index = i * cols + j
384-
if (index < values.size) {
385-
result.append(formatConstantValue(values[index] as Number))
367+
private fun formatNestedTensor(values: List<*>, shape: List<Int>, dim: Int, offset: IntArray): String {
368+
if (dim == shape.size - 1) {
369+
// Innermost dimension: emit a flat array of values
370+
val size = shape[dim]
371+
val sb = StringBuilder("[")
372+
for (i in 0 until size) {
373+
if (i > 0) sb.append(", ")
374+
val idx = offset[0]++
375+
if (idx < values.size) {
376+
sb.append(formatConstantValue(values[idx] as Number))
386377
} else {
387-
result.append("0.0")
378+
sb.append("0.0")
388379
}
389380
}
390-
result.append("]")
381+
sb.append("]")
382+
return sb.toString()
391383
}
392-
result.append("]")
393-
394-
return result.toString()
384+
385+
// Non-innermost dimension: recurse
386+
val size = shape[dim]
387+
val sb = StringBuilder("[")
388+
for (i in 0 until size) {
389+
if (i > 0) sb.append(", ")
390+
sb.append(formatNestedTensor(values, shape, dim + 1, offset))
391+
}
392+
sb.append("]")
393+
return sb.toString()
395394
}
396395
}

0 commit comments

Comments
 (0)