|
| 1 | +package sk.ainet.compile.hlo |
| 2 | + |
| 3 | +import sk.ainet.compile.export.GraphExportArtifact |
| 4 | +import sk.ainet.compile.export.GraphExportArtifactRole |
| 5 | +import sk.ainet.compile.export.GraphExportComponentRole |
| 6 | +import sk.ainet.compile.export.GraphExportContext |
| 7 | +import sk.ainet.compile.export.GraphExportConverter |
| 8 | +import sk.ainet.compile.export.GraphExportResult |
| 9 | +import sk.ainet.compile.export.GraphExportStage |
| 10 | +import sk.ainet.compile.export.GraphExportStatus |
| 11 | +import sk.ainet.compile.export.GraphExportWriter |
| 12 | +import sk.ainet.lang.graph.ComputeGraph |
| 13 | + |
| 14 | +/** |
| 15 | + * StableHLO component mapping in the shared graph-export architecture. |
| 16 | + */ |
| 17 | +public object StableHloExportArchitecture { |
| 18 | + public const val backendName: String = "stablehlo" |
| 19 | + |
| 20 | + public val componentNames: Map<GraphExportComponentRole, String> = mapOf( |
| 21 | + GraphExportComponentRole.CONVERTER to "StableHloConverter", |
| 22 | + GraphExportComponentRole.CONTEXT to "ConversionContext", |
| 23 | + GraphExportComponentRole.REGISTRY to "StableHloOperationRegistry", |
| 24 | + GraphExportComponentRole.FACTORY to "StableHloConverterFactory", |
| 25 | + GraphExportComponentRole.WRITER to "StableHloTextWriter", |
| 26 | + GraphExportComponentRole.VERIFIER to "MlirValidator" |
| 27 | + ) |
| 28 | +} |
| 29 | + |
| 30 | +/** |
| 31 | + * Adapter that exposes [StableHloConverter] through the shared export contract. |
| 32 | + */ |
| 33 | +public class StableHloGraphExportConverter @kotlin.jvm.JvmOverloads constructor( |
| 34 | + private val converter: StableHloConverter = StableHloConverterFactory.createBasic(), |
| 35 | + public val functionName: String = "main", |
| 36 | + override val backendName: String = StableHloExportArchitecture.backendName |
| 37 | +) : GraphExportConverter<ComputeGraph, StableHloModule> { |
| 38 | + |
| 39 | + override fun convert(input: ComputeGraph, context: GraphExportContext): StableHloModule { |
| 40 | + val resolvedFunctionName = context.targetName ?: functionName |
| 41 | + context.info( |
| 42 | + stage = GraphExportStage.LOWERING, |
| 43 | + code = "stablehlo.lowering.started", |
| 44 | + message = "Lowering ComputeGraph to StableHLO MLIR.", |
| 45 | + details = mapOf("functionName" to resolvedFunctionName) |
| 46 | + ) |
| 47 | + |
| 48 | + val module = converter.convert(input, resolvedFunctionName) |
| 49 | + |
| 50 | + context.info( |
| 51 | + stage = GraphExportStage.LOWERING, |
| 52 | + code = "stablehlo.lowering.completed", |
| 53 | + message = "Lowered ComputeGraph to StableHLO MLIR.", |
| 54 | + details = mapOf( |
| 55 | + "functionName" to module.functionName, |
| 56 | + "inputs" to module.inputSpecs.size.toString(), |
| 57 | + "outputs" to module.outputSpecs.size.toString(), |
| 58 | + "externalParameters" to module.externalParameters.size.toString() |
| 59 | + ) |
| 60 | + ) |
| 61 | + return module |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +/** |
| 66 | + * Shared-contract writer that renders a [StableHloModule] as MLIR text. |
| 67 | + */ |
| 68 | +public class StableHloTextWriter @kotlin.jvm.JvmOverloads constructor( |
| 69 | + public val logicalPath: String? = null, |
| 70 | + override val backendName: String = StableHloExportArchitecture.backendName |
| 71 | +) : GraphExportWriter<StableHloModule, String> { |
| 72 | + |
| 73 | + override fun write(intermediate: StableHloModule, context: GraphExportContext): String { |
| 74 | + val artifactPath = logicalPath ?: "${intermediate.functionName}.stablehlo.mlir" |
| 75 | + context.addArtifact( |
| 76 | + GraphExportArtifact( |
| 77 | + path = artifactPath, |
| 78 | + role = GraphExportArtifactRole.SOURCE, |
| 79 | + description = "StableHLO MLIR module text", |
| 80 | + metadata = mapOf( |
| 81 | + "functionName" to intermediate.functionName, |
| 82 | + "format" to "mlir" |
| 83 | + ) |
| 84 | + ) |
| 85 | + ) |
| 86 | + context.info( |
| 87 | + stage = GraphExportStage.WRITING, |
| 88 | + code = "stablehlo.writing.text", |
| 89 | + message = "Rendered StableHLO module as MLIR text.", |
| 90 | + details = mapOf( |
| 91 | + "path" to artifactPath, |
| 92 | + "characters" to intermediate.content.length.toString() |
| 93 | + ) |
| 94 | + ) |
| 95 | + return intermediate.content |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +/** |
| 100 | + * Convenience facade that composes StableHLO lowering and writing into shared results. |
| 101 | + */ |
| 102 | +public class StableHloGraphExporter @kotlin.jvm.JvmOverloads constructor( |
| 103 | + public val converter: StableHloGraphExportConverter = StableHloGraphExportConverter(), |
| 104 | + public val writer: StableHloTextWriter = StableHloTextWriter() |
| 105 | +) { |
| 106 | + public val backendName: String |
| 107 | + get() = converter.backendName |
| 108 | + |
| 109 | + public fun exportModule(graph: ComputeGraph): GraphExportResult<StableHloModule> { |
| 110 | + return exportModule(graph, GraphExportContext(backendName = backendName)) |
| 111 | + } |
| 112 | + |
| 113 | + public fun exportModule( |
| 114 | + graph: ComputeGraph, |
| 115 | + context: GraphExportContext |
| 116 | + ): GraphExportResult<StableHloModule> { |
| 117 | + return try { |
| 118 | + val module = converter.convert(graph, context) |
| 119 | + GraphExportResult.success( |
| 120 | + backendName = backendName, |
| 121 | + output = module, |
| 122 | + diagnostics = context.diagnosticReport(), |
| 123 | + artifacts = context.artifacts, |
| 124 | + metadata = context.metadata |
| 125 | + ) |
| 126 | + } catch (exception: Exception) { |
| 127 | + stableHloFailureResult( |
| 128 | + backendName = backendName, |
| 129 | + stage = GraphExportStage.LOWERING, |
| 130 | + exception = exception, |
| 131 | + context = context |
| 132 | + ) |
| 133 | + } |
| 134 | + } |
| 135 | + |
| 136 | + public fun exportText(graph: ComputeGraph): GraphExportResult<String> { |
| 137 | + return exportText(graph, GraphExportContext(backendName = backendName)) |
| 138 | + } |
| 139 | + |
| 140 | + public fun exportText( |
| 141 | + graph: ComputeGraph, |
| 142 | + context: GraphExportContext |
| 143 | + ): GraphExportResult<String> { |
| 144 | + var stage = GraphExportStage.LOWERING |
| 145 | + return try { |
| 146 | + val module = converter.convert(graph, context) |
| 147 | + stage = GraphExportStage.WRITING |
| 148 | + val text = writer.write(module, context) |
| 149 | + GraphExportResult.success( |
| 150 | + backendName = backendName, |
| 151 | + output = text, |
| 152 | + diagnostics = context.diagnosticReport(), |
| 153 | + artifacts = context.artifacts, |
| 154 | + metadata = context.metadata |
| 155 | + ) |
| 156 | + } catch (exception: Exception) { |
| 157 | + stableHloFailureResult( |
| 158 | + backendName = backendName, |
| 159 | + stage = stage, |
| 160 | + exception = exception, |
| 161 | + context = context |
| 162 | + ) |
| 163 | + } |
| 164 | + } |
| 165 | +} |
| 166 | + |
| 167 | +private fun <T> stableHloFailureResult( |
| 168 | + backendName: String, |
| 169 | + stage: GraphExportStage, |
| 170 | + exception: Exception, |
| 171 | + context: GraphExportContext |
| 172 | +): GraphExportResult<T> { |
| 173 | + val reason = exception.message ?: exception.toString() |
| 174 | + context.error( |
| 175 | + stage = stage, |
| 176 | + code = "stablehlo.export.failed", |
| 177 | + message = "StableHLO export failed: $reason" |
| 178 | + ) |
| 179 | + return GraphExportResult( |
| 180 | + backendName = backendName, |
| 181 | + status = GraphExportStatus.FAILED, |
| 182 | + output = null, |
| 183 | + diagnostics = context.diagnosticReport(), |
| 184 | + artifacts = context.artifacts, |
| 185 | + metadata = context.metadata |
| 186 | + ) |
| 187 | +} |
0 commit comments