Skip to content

Commit faeebe4

Browse files
michalharakalclaude
andcommitted
Add DTypeConstraintResolutionPass + register in pipeline (W7 of #615)
New GraphOptimizationPass that walks the graph, reads each node's attached DTypePolicy from `metadata["dtype_policy"]` (set by the W6 DSL extension), and enforces the RFC's "fail before execution" rule. Three policy arms: - Require(target): every input edge MUST already have the target dtype. Mismatch throws DtypeConstraintViolationException with a message naming the failing edge, the required dtype, the actual dtype, and the suggested resolution (loader-side policy via SafeTensorsParametersLoader.withPolicy / StreamingGguf- ParametersLoader.withPolicy). - Prefer(target): mismatch emits a diagnostic and falls through. - OneOf(allowed): input must be in the allowed set; otherwise throw. - Any: no-op (and nodes without an attached policy are passed through unvisited). Visited nodes get `metadata["dtype_resolved"] = true` so the W8 ResolvedComputeGraph wrapper can verify the pass has run. dtypeStringMatches() handles both the canonical DType.name strings ("Float32", "BFloat16") and the short class-derived aliases the DAG DSL produces ("FP32", "BF16", …) — the resolver works whether the graph came from the DSL or from a TensorSpec-string-driven source like the StableHLO converter. NO new cast kernels in this PR — when a Require mismatches and no cast kernel is registered, the pass fails fast. Cast-node insertion is a follow-up that ships with concrete cast kernels. Registered in all three pipeline factories (createDefault, createAggressive, createLLM) BEFORE the fusion passes so fusion sees dtype-resolved nodes. All existing pipeline tests still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent e03934a commit faeebe4

3 files changed

Lines changed: 283 additions & 0 deletions

File tree

skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/GraphOptimizationPipeline.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package sk.ainet.compile.opt
22

33
import sk.ainet.lang.graph.ComputeGraph
44
import sk.ainet.compile.opt.passes.ConstantFoldingPass
5+
import sk.ainet.compile.opt.passes.DTypeConstraintResolutionPass
56
import sk.ainet.compile.opt.passes.DeadCodeEliminationPass
67
import sk.ainet.compile.opt.passes.LLMFusionPass
78
import sk.ainet.compile.opt.passes.OperationFusionPass
@@ -73,6 +74,12 @@ public class GraphOptimizationPipeline(
7374
*/
7475
public fun createDefault(): GraphOptimizationPipeline = GraphOptimizationPipeline(
7576
passes = listOf(
77+
// Resolve dtype constraints first so fusion / DCE / constant
78+
// folding see the resolved-or-failed graph rather than a
79+
// mix of policy-tagged and bare nodes. Per the RFC, this
80+
// is the boundary where dtype problems surface — every
81+
// later pass can assume dtype-validity.
82+
DTypeConstraintResolutionPass(),
7683
DeadCodeEliminationPass(),
7784
ConstantFoldingPass(),
7885
OperationFusionPass()
@@ -84,6 +91,7 @@ public class GraphOptimizationPipeline(
8491
*/
8592
public fun createAggressive(): GraphOptimizationPipeline = GraphOptimizationPipeline(
8693
passes = listOf(
94+
DTypeConstraintResolutionPass(),
8795
DeadCodeEliminationPass(),
8896
ConstantFoldingPass(),
8997
OperationFusionPass()
@@ -95,6 +103,7 @@ public class GraphOptimizationPipeline(
95103
* Creates an LLM-optimized pipeline with transformer-specific passes.
96104
*
97105
* Pass ordering:
106+
* 0. DTypeConstraintResolution — resolve dtype policies before fusion
98107
* 1. TransposeElimination — fold transposes into matmuls
99108
* 2. SharedWeightDedup — deduplicate tied weights (e.g. token_embd ↔ output)
100109
* 3. LLMFusion — fuse RMSNorm, SwiGLU, QKV patterns
@@ -103,6 +112,7 @@ public class GraphOptimizationPipeline(
103112
*/
104113
public fun createLLM(): GraphOptimizationPipeline = GraphOptimizationPipeline(
105114
passes = listOf(
115+
DTypeConstraintResolutionPass(),
106116
TransposeEliminationPass(),
107117
SharedWeightDeduplicationPass(),
108118
LLMFusionPass(),
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package sk.ainet.compile.opt.passes
2+
3+
import sk.ainet.compile.opt.GraphOptimizationPass
4+
import sk.ainet.compile.opt.GraphOptimizationResult
5+
import sk.ainet.lang.graph.ComputeGraph
6+
import sk.ainet.lang.graph.GraphNode
7+
import sk.ainet.lang.types.BF16
8+
import sk.ainet.lang.types.DType
9+
import sk.ainet.lang.types.DTypePolicy
10+
import sk.ainet.lang.types.FP16
11+
import sk.ainet.lang.types.FP32
12+
import sk.ainet.lang.types.FP64
13+
import sk.ainet.lang.types.Int8
14+
import sk.ainet.lang.types.Int16
15+
import sk.ainet.lang.types.Int32
16+
import sk.ainet.lang.types.Int64
17+
18+
/**
19+
* Pass that enforces per-node [DTypePolicy] constraints attached to
20+
* graph nodes (via the `dag { … dtypePolicy(…) }` DSL extension from
21+
* W6 of #615). Implements the RFC's "fail before execution" rule —
22+
* any [DTypePolicy.Require] that can't be satisfied raises
23+
* [DtypeConstraintViolationException] *here*, at graph-prep time,
24+
* not at forward execution.
25+
*
26+
* Policy semantics:
27+
* - `Any`: never visited; nodes without an attached policy are
28+
* passed through.
29+
* - `Require(target)`: every input edge to the node MUST already
30+
* have dtype matching `target`. Mismatch throws
31+
* [DtypeConstraintViolationException].
32+
* - `Prefer(target)`: input dtype matching `target` is preferred;
33+
* mismatches emit a diagnostic but do not fail.
34+
* - `OneOf(allowed)`: every input edge's dtype MUST already be in
35+
* `allowed`. Mismatch throws.
36+
*
37+
* **Scope intentionally narrow.** This pass does not insert cast
38+
* nodes today — when a `Require` mismatches, it fails fast (which
39+
* is the RFC's prescribed behaviour when no cast kernel exists).
40+
* Cast-node insertion is a follow-up that ships alongside concrete
41+
* cast kernels (Q4_K → Int8, FP32 → BF16, …). See the
42+
* out-of-scope section of issue #615.
43+
*
44+
* Side effect on the graph: visited nodes get
45+
* `metadata["dtype_resolved"] = true` so downstream passes (and the
46+
* future `ResolvedComputeGraph` wrapper from W8) can confirm the
47+
* pass has run.
48+
*/
49+
public class DTypeConstraintResolutionPass : GraphOptimizationPass {
50+
51+
override val name: String = "dtype-constraint-resolution"
52+
53+
override fun apply(graph: ComputeGraph): GraphOptimizationResult {
54+
val diagnostics = mutableListOf<String>()
55+
var changed = false
56+
57+
for (node in graph.nodes) {
58+
val policy = node.metadata[POLICY_KEY] as? DTypePolicy ?: continue
59+
val inputDtypes = node.inputs.map { it.dtype }
60+
61+
when (policy) {
62+
DTypePolicy.Any -> { /* permissive; no-op */ }
63+
64+
is DTypePolicy.Require -> {
65+
val targetName = policy.target.name
66+
for ((i, dtypeStr) in inputDtypes.withIndex()) {
67+
if (!dtypeStringMatches(dtypeStr, policy.target)) {
68+
throw DtypeConstraintViolationException(
69+
"Node '${node.id}' (${node.operationName}) declares " +
70+
"DTypePolicy.Require($targetName) but input $i has dtype '$dtypeStr'. " +
71+
"Cast kernels are not registered for this conversion; resolve at the " +
72+
"loader (e.g. SafeTensorsParametersLoader.withPolicy) or change the " +
73+
"policy to Prefer/OneOf to permit fallback."
74+
)
75+
}
76+
}
77+
}
78+
79+
is DTypePolicy.Prefer -> {
80+
val targetName = policy.target.name
81+
for ((i, dtypeStr) in inputDtypes.withIndex()) {
82+
if (!dtypeStringMatches(dtypeStr, policy.target)) {
83+
diagnostics += "Node '${node.id}' (${node.operationName}) prefers " +
84+
"$targetName but input $i has dtype '$dtypeStr' — using the existing dtype."
85+
}
86+
}
87+
}
88+
89+
is DTypePolicy.OneOf -> {
90+
val allowedNames = policy.allowed.joinToString { it.name }
91+
for ((i, dtypeStr) in inputDtypes.withIndex()) {
92+
if (policy.allowed.none { dtypeStringMatches(dtypeStr, it) }) {
93+
throw DtypeConstraintViolationException(
94+
"Node '${node.id}' (${node.operationName}) declares " +
95+
"DTypePolicy.OneOf($allowedNames) but input $i has dtype " +
96+
"'$dtypeStr' which is outside the allowed set. Cast kernels " +
97+
"are not registered; resolve at the loader."
98+
)
99+
}
100+
}
101+
}
102+
}
103+
104+
// Mark the node as resolved by this pass. Use copy to keep
105+
// the immutable-copy convention the other passes follow.
106+
val resolved = node.copy(metadata = node.metadata + (RESOLVED_KEY to true))
107+
graph.removeNode(node)
108+
graph.addNode(resolved)
109+
changed = true
110+
}
111+
112+
return GraphOptimizationResult(graph, changed = changed, diagnostics = diagnostics)
113+
}
114+
115+
/**
116+
* Matches the string form used by [sk.ainet.lang.tensor.ops.TensorSpec.dtype]
117+
* against a typed [DType]. Handles both registry-canonical names
118+
* (`"Float32"`, `"BFloat16"`) and the short class-derived
119+
* aliases produced by the DAG DSL's `dtypeName()` helper (`"FP32"`,
120+
* `"BF16"`, `"Int8"`, …).
121+
*/
122+
internal fun dtypeStringMatches(dtypeStr: String, dtype: DType): Boolean {
123+
if (dtypeStr == dtype.name) return true
124+
return when (dtype) {
125+
FP32 -> dtypeStr == "FP32" || dtypeStr == "F32"
126+
FP16 -> dtypeStr == "FP16" || dtypeStr == "F16"
127+
BF16 -> dtypeStr == "BF16"
128+
FP64 -> dtypeStr == "FP64" || dtypeStr == "F64"
129+
Int8 -> dtypeStr == "Int8" || dtypeStr == "I8"
130+
Int16 -> dtypeStr == "Int16" || dtypeStr == "I16"
131+
Int32 -> dtypeStr == "Int32" || dtypeStr == "I32"
132+
Int64 -> dtypeStr == "Int64" || dtypeStr == "I64"
133+
else -> false
134+
}
135+
}
136+
137+
public companion object {
138+
/** Attribute key shared with the DSL extension (W6). */
139+
public const val POLICY_KEY: String = "dtype_policy"
140+
141+
/** Marker the pass writes onto every node it visits. */
142+
public const val RESOLVED_KEY: String = "dtype_resolved"
143+
}
144+
}
145+
146+
/**
147+
* Raised when [DTypeConstraintResolutionPass] cannot satisfy a hard
148+
* [DTypePolicy.Require] (or `OneOf` rejection) and no cast kernel
149+
* is available to bridge the gap. Surfaces dtype problems at
150+
* graph-prep time, before forward execution — exactly the RFC's
151+
* "fail before execution" boundary.
152+
*/
153+
public class DtypeConstraintViolationException(message: String) : RuntimeException(message)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package sk.ainet.compile.opt
2+
3+
import kotlin.test.Test
4+
import kotlin.test.assertEquals
5+
import kotlin.test.assertFailsWith
6+
import kotlin.test.assertFalse
7+
import kotlin.test.assertTrue
8+
import sk.ainet.compile.opt.passes.DTypeConstraintResolutionPass
9+
import sk.ainet.compile.opt.passes.DtypeConstraintViolationException
10+
import sk.ainet.lang.graph.DefaultComputeGraph
11+
import sk.ainet.lang.graph.GraphNode
12+
import sk.ainet.lang.tensor.ops.GenericOperation
13+
import sk.ainet.lang.tensor.ops.TensorSpec
14+
import sk.ainet.lang.types.BF16
15+
import sk.ainet.lang.types.DTypePolicy
16+
import sk.ainet.lang.types.FP32
17+
import sk.ainet.lang.types.Int8
18+
19+
class DTypeConstraintResolutionPassTest {
20+
21+
private fun node(
22+
id: String,
23+
opName: String = "matmul",
24+
inputDtype: String = "Float32",
25+
policy: DTypePolicy? = null,
26+
): GraphNode {
27+
val meta = if (policy != null) mapOf<String, Any>(DTypeConstraintResolutionPass.POLICY_KEY to policy) else emptyMap()
28+
return GraphNode(
29+
id = id,
30+
operation = GenericOperation(opName),
31+
inputs = listOf(TensorSpec(name = "$id-in", shape = listOf(4, 4), dtype = inputDtype)),
32+
outputs = listOf(TensorSpec(name = "$id-out", shape = listOf(4, 4), dtype = inputDtype)),
33+
metadata = meta,
34+
)
35+
}
36+
37+
@Test
38+
fun nodes_without_policy_are_passed_through() {
39+
val g = DefaultComputeGraph()
40+
g.addNode(node("n0"))
41+
g.addNode(node("n1"))
42+
val result = DTypeConstraintResolutionPass().apply(g)
43+
assertFalse(result.changed, "no policy = no work")
44+
// Neither node should be marked resolved (only visited nodes get the marker).
45+
assertEquals(emptyList(), result.graph.nodes.filter { it.metadata.containsKey(DTypeConstraintResolutionPass.RESOLVED_KEY) })
46+
}
47+
48+
@Test
49+
fun any_policy_passes_through() {
50+
val g = DefaultComputeGraph()
51+
g.addNode(node("n0", policy = DTypePolicy.Any))
52+
val result = DTypeConstraintResolutionPass().apply(g)
53+
assertTrue(result.changed, "the resolved-marker write counts as a change")
54+
val n = result.graph.nodes.single()
55+
assertTrue(n.metadata[DTypeConstraintResolutionPass.RESOLVED_KEY] == true)
56+
}
57+
58+
@Test
59+
fun require_matching_dtype_passes() {
60+
val g = DefaultComputeGraph()
61+
g.addNode(node("n0", inputDtype = "Float32", policy = DTypePolicy.Require(FP32)))
62+
val result = DTypeConstraintResolutionPass().apply(g)
63+
assertTrue(result.changed)
64+
}
65+
66+
@Test
67+
fun require_mismatched_dtype_fails_fast() {
68+
val g = DefaultComputeGraph()
69+
g.addNode(node("n0", inputDtype = "Float32", policy = DTypePolicy.Require(BF16)))
70+
val ex = assertFailsWith<DtypeConstraintViolationException> {
71+
DTypeConstraintResolutionPass().apply(g)
72+
}
73+
val msg = ex.message ?: ""
74+
assertTrue(msg.contains("BFloat16"), "msg must name the required dtype: $msg")
75+
assertTrue(msg.contains("Float32"), "msg must name the actual input dtype: $msg")
76+
assertTrue(msg.contains("Cast kernels"), "msg must hint at the resolution path: $msg")
77+
}
78+
79+
@Test
80+
fun require_mismatched_dtype_with_short_alias_also_resolves() {
81+
// DAG DSL emits dtype strings like "FP32" / "BF16" via dtypeName().
82+
// The pass must handle both the registry canonical name and the short alias.
83+
val g = DefaultComputeGraph()
84+
g.addNode(node("n0", inputDtype = "FP32", policy = DTypePolicy.Require(FP32)))
85+
val result = DTypeConstraintResolutionPass().apply(g)
86+
assertTrue(result.changed, "alias 'FP32' must satisfy Require(FP32)")
87+
}
88+
89+
@Test
90+
fun prefer_mismatched_dtype_emits_diagnostic_no_throw() {
91+
val g = DefaultComputeGraph()
92+
g.addNode(node("n0", inputDtype = "Float32", policy = DTypePolicy.Prefer(BF16)))
93+
val result = DTypeConstraintResolutionPass().apply(g)
94+
assertTrue(result.changed)
95+
assertTrue(
96+
result.diagnostics.any { it.contains("prefers") && it.contains("BFloat16") },
97+
"diagnostic must mention the preference: ${result.diagnostics}",
98+
)
99+
}
100+
101+
@Test
102+
fun oneOf_in_set_passes() {
103+
val g = DefaultComputeGraph()
104+
g.addNode(node("n0", inputDtype = "Float32", policy = DTypePolicy.OneOf(setOf(FP32, BF16))))
105+
val result = DTypeConstraintResolutionPass().apply(g)
106+
assertTrue(result.changed)
107+
}
108+
109+
@Test
110+
fun oneOf_outside_set_fails_fast() {
111+
val g = DefaultComputeGraph()
112+
g.addNode(node("n0", inputDtype = "Float32", policy = DTypePolicy.OneOf(setOf(BF16, Int8))))
113+
val ex = assertFailsWith<DtypeConstraintViolationException> {
114+
DTypeConstraintResolutionPass().apply(g)
115+
}
116+
val msg = ex.message ?: ""
117+
assertTrue(msg.contains("OneOf"), msg)
118+
assertTrue(msg.contains("Float32"), msg)
119+
}
120+
}

0 commit comments

Comments
 (0)