Skip to content

Commit f032810

Browse files
Merge pull request #698 from SKaiNET-developers/feature/689-stablehlo-export-contracts
feat(hlo): expose StableHLO graph export adapter
2 parents 1509e7f + 42a112d commit f032810

3 files changed

Lines changed: 355 additions & 0 deletions

File tree

skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,13 @@ public final class sk/ainet/compile/hlo/ConversionContext {
5959
public final fun getModuleDeclarations ()Ljava/lang/String;
6060
public final fun getTypeMapper ()Lsk/ainet/compile/hlo/TypeMapper;
6161
public final fun getValueName (Ljava/lang/String;)Ljava/lang/String;
62+
public final fun getValueName (Ljava/lang/String;I)Ljava/lang/String;
6263
public final fun getValueType (Ljava/lang/String;)Ljava/lang/String;
6364
public final fun nextTempValue ()Ljava/lang/String;
6465
public final fun registerExternalParameter (Lsk/ainet/compile/hlo/ExternalParameterRef;)V
66+
public final fun resolveOperands (Lsk/ainet/lang/graph/GraphNode;)Ljava/util/List;
6567
public final fun setGraph (Lsk/ainet/lang/graph/ComputeGraph;)V
68+
public final fun setValueName (Ljava/lang/String;ILjava/lang/String;)V
6669
public final fun setValueName (Ljava/lang/String;Ljava/lang/String;)V
6770
public final fun setValueType (Ljava/lang/String;Ljava/lang/String;)V
6871
}
@@ -201,6 +204,38 @@ public final class sk/ainet/compile/hlo/StableHloConverterFactory {
201204
public static synthetic fun createFast$default (Lsk/ainet/compile/hlo/ConstantMaterializationPolicy;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloConverter;
202205
}
203206

207+
public final class sk/ainet/compile/hlo/StableHloExportArchitecture {
208+
public static final field INSTANCE Lsk/ainet/compile/hlo/StableHloExportArchitecture;
209+
public static final field backendName Ljava/lang/String;
210+
public final fun getComponentNames ()Ljava/util/Map;
211+
}
212+
213+
public final class sk/ainet/compile/hlo/StableHloGraphExportConverter : sk/ainet/compile/export/GraphExportConverter {
214+
public fun <init> ()V
215+
public fun <init> (Lsk/ainet/compile/hlo/StableHloConverter;)V
216+
public fun <init> (Lsk/ainet/compile/hlo/StableHloConverter;Ljava/lang/String;)V
217+
public fun <init> (Lsk/ainet/compile/hlo/StableHloConverter;Ljava/lang/String;Ljava/lang/String;)V
218+
public synthetic fun <init> (Lsk/ainet/compile/hlo/StableHloConverter;Ljava/lang/String;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
219+
public synthetic fun convert (Ljava/lang/Object;Lsk/ainet/compile/export/GraphExportContext;)Ljava/lang/Object;
220+
public fun convert (Lsk/ainet/lang/graph/ComputeGraph;Lsk/ainet/compile/export/GraphExportContext;)Lsk/ainet/compile/hlo/StableHloModule;
221+
public fun getBackendName ()Ljava/lang/String;
222+
public final fun getFunctionName ()Ljava/lang/String;
223+
}
224+
225+
public final class sk/ainet/compile/hlo/StableHloGraphExporter {
226+
public fun <init> ()V
227+
public fun <init> (Lsk/ainet/compile/hlo/StableHloGraphExportConverter;)V
228+
public fun <init> (Lsk/ainet/compile/hlo/StableHloGraphExportConverter;Lsk/ainet/compile/hlo/StableHloTextWriter;)V
229+
public synthetic fun <init> (Lsk/ainet/compile/hlo/StableHloGraphExportConverter;Lsk/ainet/compile/hlo/StableHloTextWriter;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
230+
public final fun exportModule (Lsk/ainet/lang/graph/ComputeGraph;)Lsk/ainet/compile/export/GraphExportResult;
231+
public final fun exportModule (Lsk/ainet/lang/graph/ComputeGraph;Lsk/ainet/compile/export/GraphExportContext;)Lsk/ainet/compile/export/GraphExportResult;
232+
public final fun exportText (Lsk/ainet/lang/graph/ComputeGraph;)Lsk/ainet/compile/export/GraphExportResult;
233+
public final fun exportText (Lsk/ainet/lang/graph/ComputeGraph;Lsk/ainet/compile/export/GraphExportContext;)Lsk/ainet/compile/export/GraphExportResult;
234+
public final fun getBackendName ()Ljava/lang/String;
235+
public final fun getConverter ()Lsk/ainet/compile/hlo/StableHloGraphExportConverter;
236+
public final fun getWriter ()Lsk/ainet/compile/hlo/StableHloTextWriter;
237+
}
238+
204239
public final class sk/ainet/compile/hlo/StableHloModule {
205240
public fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;Ljava/util/List;)V
206241
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/util/Map;Ljava/util/List;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
@@ -259,6 +294,17 @@ public final class sk/ainet/compile/hlo/StableHloOptimizer$Companion {
259294
public final fun createDefault ()Lsk/ainet/compile/hlo/StableHloOptimizer;
260295
}
261296

297+
public final class sk/ainet/compile/hlo/StableHloTextWriter : sk/ainet/compile/export/GraphExportWriter {
298+
public fun <init> ()V
299+
public fun <init> (Ljava/lang/String;)V
300+
public fun <init> (Ljava/lang/String;Ljava/lang/String;)V
301+
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
302+
public fun getBackendName ()Ljava/lang/String;
303+
public final fun getLogicalPath ()Ljava/lang/String;
304+
public synthetic fun write (Ljava/lang/Object;Lsk/ainet/compile/export/GraphExportContext;)Ljava/lang/Object;
305+
public fun write (Lsk/ainet/compile/hlo/StableHloModule;Lsk/ainet/compile/export/GraphExportContext;)Ljava/lang/String;
306+
}
307+
262308
public final class sk/ainet/compile/hlo/TypeMapper {
263309
public fun <init> ()V
264310
public final fun areTypesCompatible (Lsk/ainet/lang/tensor/ops/TensorSpec;Lsk/ainet/lang/tensor/ops/TensorSpec;)Z
@@ -277,6 +323,12 @@ public final class sk/ainet/compile/hlo/converters/ActivationOperationsConverter
277323
public fun getSupportedOperations ()Ljava/util/Set;
278324
}
279325

326+
public final class sk/ainet/compile/hlo/converters/AttentionOperationsConverter : sk/ainet/compile/hlo/StableHloOperationConverter {
327+
public fun <init> ()V
328+
public fun convert (Lsk/ainet/lang/graph/GraphNode;Ljava/util/List;Lsk/ainet/compile/hlo/ConversionContext;)Lsk/ainet/compile/hlo/ConversionResult;
329+
public fun getSupportedOperations ()Ljava/util/Set;
330+
}
331+
280332
public final class sk/ainet/compile/hlo/converters/BasicMathConverter : sk/ainet/compile/hlo/StableHloOperationConverter {
281333
public fun <init> ()V
282334
public fun convert (Lsk/ainet/lang/graph/GraphNode;Ljava/util/List;Lsk/ainet/compile/hlo/ConversionContext;)Lsk/ainet/compile/hlo/ConversionResult;
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
package sk.ainet.compile.hlo
2+
3+
import sk.ainet.compile.export.GraphExportArtifact
4+
import sk.ainet.compile.export.GraphExportArtifactRole
5+
import sk.ainet.compile.export.GraphExportComponentRole
6+
import sk.ainet.compile.export.GraphExportContext
7+
import sk.ainet.compile.export.GraphExportConverter
8+
import sk.ainet.compile.export.GraphExportResult
9+
import sk.ainet.compile.export.GraphExportStage
10+
import sk.ainet.compile.export.GraphExportStatus
11+
import sk.ainet.compile.export.GraphExportWriter
12+
import sk.ainet.lang.graph.ComputeGraph
13+
14+
/**
15+
* StableHLO component mapping in the shared graph-export architecture.
16+
*/
17+
public object StableHloExportArchitecture {
18+
public const val backendName: String = "stablehlo"
19+
20+
public val componentNames: Map<GraphExportComponentRole, String> = mapOf(
21+
GraphExportComponentRole.CONVERTER to "StableHloConverter",
22+
GraphExportComponentRole.CONTEXT to "ConversionContext",
23+
GraphExportComponentRole.REGISTRY to "StableHloOperationRegistry",
24+
GraphExportComponentRole.FACTORY to "StableHloConverterFactory",
25+
GraphExportComponentRole.WRITER to "StableHloTextWriter",
26+
GraphExportComponentRole.VERIFIER to "MlirValidator"
27+
)
28+
}
29+
30+
/**
31+
* Adapter that exposes [StableHloConverter] through the shared export contract.
32+
*/
33+
public class StableHloGraphExportConverter @kotlin.jvm.JvmOverloads constructor(
34+
private val converter: StableHloConverter = StableHloConverterFactory.createBasic(),
35+
public val functionName: String = "main",
36+
override val backendName: String = StableHloExportArchitecture.backendName
37+
) : GraphExportConverter<ComputeGraph, StableHloModule> {
38+
39+
override fun convert(input: ComputeGraph, context: GraphExportContext): StableHloModule {
40+
val resolvedFunctionName = context.targetName ?: functionName
41+
context.info(
42+
stage = GraphExportStage.LOWERING,
43+
code = "stablehlo.lowering.started",
44+
message = "Lowering ComputeGraph to StableHLO MLIR.",
45+
details = mapOf("functionName" to resolvedFunctionName)
46+
)
47+
48+
val module = converter.convert(input, resolvedFunctionName)
49+
50+
context.info(
51+
stage = GraphExportStage.LOWERING,
52+
code = "stablehlo.lowering.completed",
53+
message = "Lowered ComputeGraph to StableHLO MLIR.",
54+
details = mapOf(
55+
"functionName" to module.functionName,
56+
"inputs" to module.inputSpecs.size.toString(),
57+
"outputs" to module.outputSpecs.size.toString(),
58+
"externalParameters" to module.externalParameters.size.toString()
59+
)
60+
)
61+
return module
62+
}
63+
}
64+
65+
/**
66+
* Shared-contract writer that renders a [StableHloModule] as MLIR text.
67+
*/
68+
public class StableHloTextWriter @kotlin.jvm.JvmOverloads constructor(
69+
public val logicalPath: String? = null,
70+
override val backendName: String = StableHloExportArchitecture.backendName
71+
) : GraphExportWriter<StableHloModule, String> {
72+
73+
override fun write(intermediate: StableHloModule, context: GraphExportContext): String {
74+
val artifactPath = logicalPath ?: "${intermediate.functionName}.stablehlo.mlir"
75+
context.addArtifact(
76+
GraphExportArtifact(
77+
path = artifactPath,
78+
role = GraphExportArtifactRole.SOURCE,
79+
description = "StableHLO MLIR module text",
80+
metadata = mapOf(
81+
"functionName" to intermediate.functionName,
82+
"format" to "mlir"
83+
)
84+
)
85+
)
86+
context.info(
87+
stage = GraphExportStage.WRITING,
88+
code = "stablehlo.writing.text",
89+
message = "Rendered StableHLO module as MLIR text.",
90+
details = mapOf(
91+
"path" to artifactPath,
92+
"characters" to intermediate.content.length.toString()
93+
)
94+
)
95+
return intermediate.content
96+
}
97+
}
98+
99+
/**
100+
* Convenience facade that composes StableHLO lowering and writing into shared results.
101+
*/
102+
public class StableHloGraphExporter @kotlin.jvm.JvmOverloads constructor(
103+
public val converter: StableHloGraphExportConverter = StableHloGraphExportConverter(),
104+
public val writer: StableHloTextWriter = StableHloTextWriter()
105+
) {
106+
public val backendName: String
107+
get() = converter.backendName
108+
109+
public fun exportModule(graph: ComputeGraph): GraphExportResult<StableHloModule> {
110+
return exportModule(graph, GraphExportContext(backendName = backendName))
111+
}
112+
113+
public fun exportModule(
114+
graph: ComputeGraph,
115+
context: GraphExportContext
116+
): GraphExportResult<StableHloModule> {
117+
return try {
118+
val module = converter.convert(graph, context)
119+
GraphExportResult.success(
120+
backendName = backendName,
121+
output = module,
122+
diagnostics = context.diagnosticReport(),
123+
artifacts = context.artifacts,
124+
metadata = context.metadata
125+
)
126+
} catch (exception: Exception) {
127+
stableHloFailureResult(
128+
backendName = backendName,
129+
stage = GraphExportStage.LOWERING,
130+
exception = exception,
131+
context = context
132+
)
133+
}
134+
}
135+
136+
public fun exportText(graph: ComputeGraph): GraphExportResult<String> {
137+
return exportText(graph, GraphExportContext(backendName = backendName))
138+
}
139+
140+
public fun exportText(
141+
graph: ComputeGraph,
142+
context: GraphExportContext
143+
): GraphExportResult<String> {
144+
var stage = GraphExportStage.LOWERING
145+
return try {
146+
val module = converter.convert(graph, context)
147+
stage = GraphExportStage.WRITING
148+
val text = writer.write(module, context)
149+
GraphExportResult.success(
150+
backendName = backendName,
151+
output = text,
152+
diagnostics = context.diagnosticReport(),
153+
artifacts = context.artifacts,
154+
metadata = context.metadata
155+
)
156+
} catch (exception: Exception) {
157+
stableHloFailureResult(
158+
backendName = backendName,
159+
stage = stage,
160+
exception = exception,
161+
context = context
162+
)
163+
}
164+
}
165+
}
166+
167+
private fun <T> stableHloFailureResult(
168+
backendName: String,
169+
stage: GraphExportStage,
170+
exception: Exception,
171+
context: GraphExportContext
172+
): GraphExportResult<T> {
173+
val reason = exception.message ?: exception.toString()
174+
context.error(
175+
stage = stage,
176+
code = "stablehlo.export.failed",
177+
message = "StableHLO export failed: $reason"
178+
)
179+
return GraphExportResult(
180+
backendName = backendName,
181+
status = GraphExportStatus.FAILED,
182+
output = null,
183+
diagnostics = context.diagnosticReport(),
184+
artifacts = context.artifacts,
185+
metadata = context.metadata
186+
)
187+
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package sk.ainet.compile.hlo
2+
3+
import kotlin.test.Test
4+
import kotlin.test.assertEquals
5+
import kotlin.test.assertFalse
6+
import kotlin.test.assertTrue
7+
import sk.ainet.compile.export.GraphExportArtifactRole
8+
import sk.ainet.compile.export.GraphExportComponentRole
9+
import sk.ainet.compile.export.GraphExportContext
10+
import sk.ainet.compile.export.GraphExportStage
11+
import sk.ainet.compile.export.GraphExportStatus
12+
import sk.ainet.lang.graph.DefaultComputeGraph
13+
import sk.ainet.lang.graph.GraphEdge
14+
import sk.ainet.lang.graph.GraphNode
15+
import sk.ainet.lang.tensor.ops.AddOperation
16+
import sk.ainet.lang.tensor.ops.InputOperation
17+
import sk.ainet.lang.tensor.ops.TensorSpec
18+
import sk.ainet.lang.types.DType
19+
20+
class StableHloGraphExportTest {
21+
22+
@Test
23+
fun graphExportConverterPreservesStableHloOutput() {
24+
val graph = simpleAddGraph()
25+
val directConverter = StableHloConverterFactory.createBasic()
26+
val expected = directConverter.convert(graph, "shared_export")
27+
val context = GraphExportContext(
28+
backendName = StableHloExportArchitecture.backendName,
29+
targetName = "shared_export"
30+
)
31+
32+
val actual = StableHloGraphExportConverter(converter = directConverter).convert(graph, context)
33+
34+
assertEquals(expected.content, actual.content)
35+
assertEquals(expected.functionName, actual.functionName)
36+
assertFalse(context.diagnosticReport().hasErrors)
37+
assertTrue(context.diagnostics.any { it.code == "stablehlo.lowering.started" })
38+
assertTrue(context.diagnostics.any { it.code == "stablehlo.lowering.completed" })
39+
}
40+
41+
@Test
42+
fun textWriterRecordsLogicalStableHloArtifact() {
43+
val module = toStableHlo(simpleAddGraph(), "text_export")
44+
val context = GraphExportContext(backendName = StableHloExportArchitecture.backendName)
45+
val writer = StableHloTextWriter(logicalPath = "build/generated/text_export.mlir")
46+
47+
val text = writer.write(module, context)
48+
49+
assertEquals(module.content, text)
50+
assertEquals(1, context.artifacts.size)
51+
assertEquals("build/generated/text_export.mlir", context.artifacts.single().path)
52+
assertEquals(GraphExportArtifactRole.SOURCE, context.artifacts.single().role)
53+
assertTrue(context.diagnostics.any { it.stage == GraphExportStage.WRITING })
54+
}
55+
56+
@Test
57+
fun graphExporterReturnsSharedResultEnvelopeForStableHloText() {
58+
val graph = simpleAddGraph()
59+
val expected = StableHloConverterFactory.createBasic().convert(graph, "result_export")
60+
val context = GraphExportContext(
61+
backendName = StableHloExportArchitecture.backendName,
62+
targetName = "result_export"
63+
)
64+
val exporter = StableHloGraphExporter(
65+
writer = StableHloTextWriter(logicalPath = "result_export.mlir")
66+
)
67+
68+
val result = exporter.exportText(graph, context)
69+
70+
assertEquals(GraphExportStatus.SUCCESS, result.status)
71+
assertEquals(expected.content, result.requireSuccess())
72+
assertFalse(result.diagnostics.hasErrors)
73+
assertEquals(1, result.artifacts.size)
74+
assertEquals("result_export.mlir", result.artifacts.single().path)
75+
assertEquals(
76+
"StableHloConverter",
77+
StableHloExportArchitecture.componentNames[GraphExportComponentRole.CONVERTER]
78+
)
79+
assertEquals(
80+
"StableHloTextWriter",
81+
StableHloExportArchitecture.componentNames[GraphExportComponentRole.WRITER]
82+
)
83+
}
84+
85+
private fun simpleAddGraph(): DefaultComputeGraph {
86+
val graph = DefaultComputeGraph()
87+
val inputA = GraphNode(
88+
id = "a",
89+
operation = InputOperation<DType, Any>(),
90+
inputs = emptyList(),
91+
outputs = listOf(TensorSpec("a", listOf(2, 3), "FP32"))
92+
)
93+
val inputB = GraphNode(
94+
id = "b",
95+
operation = InputOperation<DType, Any>(),
96+
inputs = emptyList(),
97+
outputs = listOf(TensorSpec("b", listOf(2, 3), "FP32"))
98+
)
99+
val add = GraphNode(
100+
id = "add1",
101+
operation = AddOperation<DType, Any>(),
102+
inputs = listOf(
103+
TensorSpec("a", listOf(2, 3), "FP32"),
104+
TensorSpec("b", listOf(2, 3), "FP32")
105+
),
106+
outputs = listOf(TensorSpec("c", listOf(2, 3), "FP32"))
107+
)
108+
109+
graph.addNode(inputA)
110+
graph.addNode(inputB)
111+
graph.addNode(add)
112+
graph.addEdge(GraphEdge("e1", inputA, add, 0, 0, inputA.outputs[0]))
113+
graph.addEdge(GraphEdge("e2", inputB, add, 0, 1, inputB.outputs[0]))
114+
return graph
115+
}
116+
}

0 commit comments

Comments
 (0)