Skip to content

Commit d86f003

Browse files
Merge pull request #706 from SKaiNET-developers/feature/693-minerva-npz
feat(minerva): emit npz compiler input
2 parents 419b0c6 + 0cbda4a commit d86f003

10 files changed

Lines changed: 924 additions & 27 deletions

File tree

skainet-compile/skainet-compile-minerva/api/skainet-compile-minerva.api

Lines changed: 90 additions & 10 deletions
Large diffs are not rendered by default.

skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportFacade.kt

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import sk.ainet.tape.Execution
2121
public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor(
2222
public val backendName: String = MinervaExportBackend.backendName,
2323
public val compatibilityValidator: MinervaCompatibilityValidator = MinervaCompatibilityValidator(),
24-
public val graphCanonicalizer: MinervaGraphCanonicalizer = MinervaGraphCanonicalizer()
24+
public val graphCanonicalizer: MinervaGraphCanonicalizer = MinervaGraphCanonicalizer(),
25+
public val npzWriter: MinervaNpzModelWriter = MinervaNpzModelWriter()
2526
) {
2627

2728
/**
@@ -98,17 +99,31 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor(
9899
return loweringFailedResult(options, context, compatibilityReport, exception)
99100
}
100101

102+
val npzModel = try {
103+
npzWriter.write(intermediate, context)
104+
} catch (exception: MinervaNpzSchemaException) {
105+
return npzSchemaFailedResult(
106+
options = options,
107+
context = context,
108+
compatibilityReport = compatibilityReport,
109+
intermediate = intermediate,
110+
exception = exception
111+
)
112+
}
113+
101114
val failure = MinervaExportFailure(
102115
kind = MinervaExportFailureKind.NOT_IMPLEMENTED,
103-
stage = GraphExportStage.WRITING,
116+
stage = GraphExportStage.PACKAGING,
104117
code = "minerva.export.not_implemented",
105-
message = "Minerva export lowered the graph to phase-one IR; compiler invocation, packaging, and verification are implemented in follow-up issues.",
118+
message = "Minerva export lowered the graph and emitted the NPZ compiler input; compiler invocation, packaging, and verification are implemented in follow-up issues.",
106119
details = mapOf(
107-
"nextStep" to "Invoke the Minerva compiler and write the runtime project.",
108-
"issue" to "#693",
120+
"nextStep" to "Invoke libminerva compiler and package generated outputs.",
121+
"issue" to "#694",
109122
"layers" to intermediate.layerCount.toString(),
110123
"input" to intermediate.input.id,
111-
"output" to intermediate.output.id
124+
"output" to intermediate.output.id,
125+
"npzPath" to npzModel.logicalPath,
126+
"npzBytes" to npzModel.bytes.size.toString()
112127
)
113128
)
114129
context.error(
@@ -122,7 +137,8 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor(
122137
context = context,
123138
failure = failure,
124139
compatibilityReport = compatibilityReport,
125-
intermediate = intermediate
140+
intermediate = intermediate,
141+
npzModel = npzModel
126142
)
127143
}
128144

@@ -223,12 +239,49 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor(
223239
)
224240
}
225241

242+
private fun npzSchemaFailedResult(
243+
options: MinervaExportOptions,
244+
context: GraphExportContext,
245+
compatibilityReport: MinervaCompatibilityReport,
246+
intermediate: MinervaIntermediate,
247+
exception: MinervaNpzSchemaException
248+
): MinervaExportResult {
249+
val details = mutableMapOf(
250+
"code" to exception.code,
251+
"issue" to "#693"
252+
)
253+
exception.layerId?.let { details["layerId"] = it }
254+
exception.arrayName?.let { details["arrayName"] = it }
255+
details += exception.details
256+
val failure = MinervaExportFailure(
257+
kind = MinervaExportFailureKind.NPZ_SCHEMA_FAILED,
258+
stage = GraphExportStage.WRITING,
259+
code = exception.code,
260+
message = exception.message ?: "Minerva NPZ schema validation failed.",
261+
details = details
262+
)
263+
context.error(
264+
stage = failure.stage,
265+
code = failure.code,
266+
message = failure.message,
267+
details = failure.details
268+
)
269+
return failedResult(
270+
options = options,
271+
context = context,
272+
failure = failure,
273+
compatibilityReport = compatibilityReport,
274+
intermediate = intermediate
275+
)
276+
}
277+
226278
private fun failedResult(
227279
options: MinervaExportOptions,
228280
context: GraphExportContext,
229281
failure: MinervaExportFailure,
230282
compatibilityReport: MinervaCompatibilityReport? = null,
231-
intermediate: MinervaIntermediate? = null
283+
intermediate: MinervaIntermediate? = null,
284+
npzModel: MinervaNpzModel? = null
232285
): MinervaExportResult {
233286
return MinervaExportResult(
234287
options = options,
@@ -238,7 +291,8 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor(
238291
failure = failure,
239292
metadata = context.metadata,
240293
compatibilityReport = compatibilityReport,
241-
intermediate = intermediate
294+
intermediate = intermediate,
295+
npzModel = npzModel
242296
)
243297
}
244298

skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportModels.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ public enum class MinervaExportFailureKind {
9292
GRAPH_VALIDATION_FAILED,
9393
COMPATIBILITY_VALIDATION_FAILED,
9494
LOWERING_FAILED,
95+
NPZ_SCHEMA_FAILED,
9596
NOT_IMPLEMENTED
9697
}
9798

@@ -203,7 +204,8 @@ public data class MinervaExportResult(
203204
public val failure: MinervaExportFailure? = null,
204205
public val metadata: Map<String, String> = emptyMap(),
205206
public val compatibilityReport: MinervaCompatibilityReport? = null,
206-
public val intermediate: MinervaIntermediate? = null
207+
public val intermediate: MinervaIntermediate? = null,
208+
public val npzModel: MinervaNpzModel? = null
207209
) {
208210
init {
209211
require(status != GraphExportStatus.SUCCESS || bundle != null) {

skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizer.kt

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,11 +337,88 @@ public class MinervaGraphCanonicalizer @kotlin.jvm.JvmOverloads constructor(
337337
dtype = spec.dtype,
338338
role = role,
339339
sourceNodeId = sourceNode.id,
340+
values = tensorValues(spec, shape, sourceNode, context),
340341
metadata = spec.metadata.mapValues { it.value.toString() }
341342
)
342343
}
343344
}
344345

346+
private fun tensorValues(
347+
spec: TensorSpec,
348+
shape: List<Int>,
349+
sourceNode: GraphNode,
350+
context: GraphExportContext
351+
): List<Float>? {
352+
val elementCount = shape.fold(1) { acc, dim -> acc * dim }
353+
val values = when (val rawValues = spec.metadata["values"]) {
354+
null -> symbolicValues(spec, elementCount)
355+
is FloatArray -> rawValues.toList()
356+
is IntArray -> rawValues.map { it.toFloat() }
357+
is List<*> -> rawValues.map { value ->
358+
when (value) {
359+
is Number -> value.toFloat()
360+
else -> fail(
361+
context = context,
362+
code = "minerva.lowering.tensor_values_invalid",
363+
message = "Tensor '${spec.name}' on node '${sourceNode.id}' has non-numeric initializer data.",
364+
node = sourceNode,
365+
details = mapOf("remediation" to "Use numeric FloatArray or IntArray initializer metadata.")
366+
)
367+
}
368+
}
369+
else -> fail(
370+
context = context,
371+
code = "minerva.lowering.tensor_values_invalid",
372+
message = "Tensor '${spec.name}' on node '${sourceNode.id}' has unsupported initializer metadata.",
373+
node = sourceNode,
374+
details = mapOf(
375+
"valuesType" to rawValues::class.simpleName.orEmpty(),
376+
"remediation" to "Use numeric FloatArray or IntArray initializer metadata."
377+
)
378+
)
379+
} ?: return null
380+
if (values.size != elementCount) {
381+
fail(
382+
context = context,
383+
code = "minerva.lowering.tensor_values_shape_mismatch",
384+
message = "Tensor '${spec.name}' on node '${sourceNode.id}' initializer has ${values.size} value(s), expected $elementCount.",
385+
node = sourceNode,
386+
details = mapOf(
387+
"actual" to values.size.toString(),
388+
"expected" to elementCount.toString(),
389+
"remediation" to "Match initializer data length to the tensor shape."
390+
)
391+
)
392+
}
393+
if (values.any { !it.isFinite() }) {
394+
fail(
395+
context = context,
396+
code = "minerva.lowering.tensor_values_non_finite",
397+
message = "Tensor '${spec.name}' on node '${sourceNode.id}' initializer contains non-finite values.",
398+
node = sourceNode,
399+
details = mapOf("remediation" to "Use finite numeric initializer values.")
400+
)
401+
}
402+
return values
403+
}
404+
405+
private fun symbolicValues(spec: TensorSpec, elementCount: Int): List<Float>? {
406+
return when (val init = spec.metadata["init"]?.toString()) {
407+
"zeros" -> List(elementCount) { 0.0f }
408+
"ones" -> List(elementCount) { 1.0f }
409+
null, "unspecified" -> null
410+
else -> {
411+
if (init.startsWith("full(") && init.endsWith(")")) {
412+
val value = spec.metadata["value"] as? Number
413+
?: init.removePrefix("full(").removeSuffix(")").toFloatOrNull()
414+
if (value != null) List(elementCount) { value.toFloat() } else null
415+
} else {
416+
null
417+
}
418+
}
419+
}
420+
}
421+
345422
private fun tensorId(role: MinervaTensorRole, sourceNodeId: String, tensorName: String): String {
346423
val cleanName = tensorName.replace(Regex("[^A-Za-z0-9_]+"), "_").ifBlank { "tensor" }
347424
val cleanNode = sourceNodeId.replace(Regex("[^A-Za-z0-9_]+"), "_").ifBlank { "node" }

skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaIntermediateModels.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public data class MinervaTensorRef(
3737
public val dtype: String,
3838
public val role: MinervaTensorRole,
3939
public val sourceNodeId: String? = null,
40+
public val values: List<Float>? = null,
4041
public val metadata: Map<String, String> = emptyMap()
4142
) {
4243
init {
@@ -45,6 +46,12 @@ public data class MinervaTensorRef(
4546
require(shape.isNotEmpty()) { "tensor shape cannot be empty" }
4647
require(shape.all { it > 0 }) { "tensor shape dimensions must be positive" }
4748
require(dtype.isNotBlank()) { "tensor dtype cannot be blank" }
49+
require(values == null || values.size == elementCount) {
50+
"tensor values must match tensor element count"
51+
}
52+
require(values == null || values.all { it.isFinite() }) {
53+
"tensor values must be finite"
54+
}
4855
}
4956

5057
public val elementCount: Int
@@ -99,4 +106,3 @@ public data class MinervaIntermediate(
99106

100107
public fun requireLowered(): MinervaIntermediate = this
101108
}
102-

0 commit comments

Comments
 (0)