Skip to content

Commit 2f7a099

Browse files
graph: parameterize GraphBuilder/CompiledGraph on StateT (PEP 695 generics) (#3)
* graph: parameterize builder/compiled/subgraph on StateT (PEP 695 generics) Carry the concrete State subclass through at type-check time. GraphBuilder[StateT], CompiledGraph[StateT], SubgraphNode[ParentT, ChildT], ConditionalEdge[StateT], ProjectionStrategy[ParentT, ChildT], Node[StateT], FunctionNode[StateT] — all generic. Consumers no longer need cast(MyState, ...) on invoke() returns, projection arguments, or edge function parameters. PEP 695 syntax throughout (targets Python 3.12+). SubgraphNode carries both parent and child types; add_subgraph_node[ChildT: State] infers ChildT from the compiled argument. FieldNameMatching[ParentT, ChildT] is generic for Protocol conformance even though ParentT is unused in project_in (keeps default-factory inference clean at the SubgraphNode boundary). _merge_partial[StateT] parameterized so type(prior).model_validate preserves the subclass without an internal cast. Conformance adapter pins to [State, State, State] since YAML fixtures are schema-agnostic. Adds tests/unit/test_generics.py with assert_type to lock the typing surface against regression. 41 tests pass, pyright strict clean. * gitignore _docs/
1 parent 37d6536 commit 2f7a099

11 files changed

Lines changed: 213 additions & 60 deletions

File tree

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
.claude/
22

3+
# Local-only implementation notes (concepts walkthrough, rough edges log,
4+
# future-work log, langgraph comparison). Kept out of public history.
5+
_docs/
6+
37
# Python
48
__pycache__/
59
*.py[cod]

src/openarmature/graph/builder.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
Per spec §2: compilation MUST fail if the graph has no declared entry,
44
unreachable nodes, dangling edges, a node with more than one outgoing edge,
55
or a field with more than one declared reducer.
6+
7+
`GraphBuilder[StateT]` is parameterized on the graph's state type. Node
8+
functions, conditional-edge functions, and the returned `CompiledGraph[StateT]`
9+
all carry `StateT` forward so consumers get typed `invoke()` return values and
10+
a type-checked `state` parameter on every callback — without `cast(...)` calls.
611
"""
712

813
from collections.abc import Awaitable, Callable, Mapping
@@ -24,35 +29,37 @@
2429
from .subgraph import SubgraphNode
2530

2631

27-
class GraphBuilder:
32+
class GraphBuilder[StateT: State]:
2833
"""Mutable builder for a graph; call `compile()` to produce a `CompiledGraph`."""
2934

30-
def __init__(self, state_cls: type[State]) -> None:
31-
self.state_cls = state_cls
32-
self._nodes: dict[str, Node] = {}
33-
self._edges: list[StaticEdge | ConditionalEdge] = []
35+
def __init__(self, state_cls: type[StateT]) -> None:
36+
self.state_cls: type[StateT] = state_cls
37+
self._nodes: dict[str, Node[StateT]] = {}
38+
self._edges: list[StaticEdge | ConditionalEdge[StateT]] = []
3439
self._entry: str | None = None
3540

3641
def add_node(
3742
self,
3843
name: str,
39-
fn: Callable[[Any], Awaitable[Mapping[str, Any]]],
44+
fn: Callable[[StateT], Awaitable[Mapping[str, Any]]],
4045
) -> Self:
4146
if name in self._nodes:
4247
raise ValueError(f"node {name!r} already declared")
43-
self._nodes[name] = FunctionNode(name=name, fn=fn)
48+
self._nodes[name] = FunctionNode[StateT](name=name, fn=fn)
4449
return self
4550

46-
def add_subgraph_node(
51+
def add_subgraph_node[ChildT: State](
4752
self,
4853
name: str,
49-
compiled: CompiledGraph,
50-
projection: ProjectionStrategy | None = None,
54+
compiled: CompiledGraph[ChildT],
55+
projection: ProjectionStrategy[StateT, ChildT] | None = None,
5156
) -> Self:
5257
if name in self._nodes:
5358
raise ValueError(f"node {name!r} already declared")
54-
proj: ProjectionStrategy = projection if projection is not None else FieldNameMatching()
55-
self._nodes[name] = SubgraphNode(name=name, compiled=compiled, projection=proj)
59+
proj: ProjectionStrategy[StateT, ChildT] = (
60+
projection if projection is not None else FieldNameMatching[StateT, ChildT]()
61+
)
62+
self._nodes[name] = SubgraphNode[StateT, ChildT](name=name, compiled=compiled, projection=proj)
5663
return self
5764

5865
def add_edge(self, source: str, target: str | EndSentinel) -> Self:
@@ -62,16 +69,16 @@ def add_edge(self, source: str, target: str | EndSentinel) -> Self:
6269
def add_conditional_edge(
6370
self,
6471
source: str,
65-
fn: Callable[[Any], str | EndSentinel],
72+
fn: Callable[[StateT], str | EndSentinel],
6673
) -> Self:
67-
self._edges.append(ConditionalEdge(source=source, fn=fn))
74+
self._edges.append(ConditionalEdge[StateT](source=source, fn=fn))
6875
return self
6976

7077
def set_entry(self, name: str) -> Self:
7178
self._entry = name
7279
return self
7380

74-
def compile(self) -> CompiledGraph:
81+
def compile(self) -> CompiledGraph[StateT]:
7582
# 1. ConflictingReducers — state schema check.
7683
per_field = field_reducers(self.state_cls)
7784
for fname, declared in per_field.items():
@@ -98,7 +105,7 @@ def compile(self) -> CompiledGraph:
98105
raise DanglingEdge(source=edge.source, target=edge.target)
99106

100107
# 5. MultipleOutgoingEdges + index by source for the reachability pass.
101-
edges_by_source: dict[str, StaticEdge | ConditionalEdge] = {}
108+
edges_by_source: dict[str, StaticEdge | ConditionalEdge[StateT]] = {}
102109
for edge in self._edges:
103110
if edge.source in edges_by_source:
104111
raise MultipleOutgoingEdges(edge.source)
@@ -112,7 +119,7 @@ def compile(self) -> CompiledGraph:
112119
if node_name not in reachable:
113120
raise UnreachableNode(node_name)
114121

115-
return CompiledGraph(
122+
return CompiledGraph[StateT](
116123
state_cls=self.state_cls,
117124
entry=self._entry,
118125
nodes=dict(self._nodes),
@@ -122,7 +129,7 @@ def compile(self) -> CompiledGraph:
122129

123130
def _reachable_nodes(
124131
self,
125-
edges_by_source: Mapping[str, StaticEdge | ConditionalEdge],
132+
edges_by_source: Mapping[str, StaticEdge | ConditionalEdge[StateT]],
126133
) -> set[str]:
127134
assert self._entry is not None
128135
reachable: set[str] = {self._entry}

src/openarmature/graph/compiled.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
88
Per spec §4 Error semantics: node, edge, reducer, and routing errors carry
99
recoverable state; state validation errors do not.
10+
11+
`CompiledGraph[StateT]` and `_merge_partial[StateT]` carry the concrete state
12+
subclass through to `invoke()`'s return type, so consumers don't need
13+
`cast(MyState, ...)` at the call site.
1014
"""
1115

1216
from collections.abc import Mapping
@@ -28,12 +32,12 @@
2832
from .state import State
2933

3034

31-
def _merge_partial(
32-
prior: State,
35+
def _merge_partial[StateT: State](
36+
prior: StateT,
3337
partial: Mapping[str, Any],
3438
reducers: Mapping[str, Reducer],
3539
producing_node: str,
36-
) -> State:
40+
) -> StateT:
3741
"""Apply per-field reducers to merge a node's partial update into prior state.
3842
3943
Re-validates the resulting state against the schema (per spec §2 SHOULD
@@ -60,6 +64,7 @@ def _merge_partial(
6064
) from e
6165

6266
try:
67+
# type(prior) narrows to `type[StateT]`; model_validate returns StateT.
6368
return type(prior).model_validate(new_values)
6469
except ValidationError as e:
6570
offending = sorted({str(err["loc"][0]) for err in e.errors() if err["loc"]})
@@ -71,16 +76,16 @@ def _merge_partial(
7176

7277

7378
@dataclass(frozen=True)
74-
class CompiledGraph:
79+
class CompiledGraph[StateT: State]:
7580
"""An immutable, executable graph produced by `GraphBuilder.compile()`."""
7681

77-
state_cls: type[State]
82+
state_cls: type[StateT]
7883
entry: str
79-
nodes: Mapping[str, Node]
80-
edges: Mapping[str, StaticEdge | ConditionalEdge]
84+
nodes: Mapping[str, Node[StateT]]
85+
edges: Mapping[str, StaticEdge | ConditionalEdge[StateT]]
8186
reducers: Mapping[str, Reducer]
8287

83-
async def invoke(self, initial_state: State) -> State:
88+
async def invoke(self, initial_state: StateT) -> StateT:
8489
"""Run the graph from `initial_state` to END and return the final state.
8590
8691
Raises one of the runtime error categories from spec §4 on failure.

src/openarmature/graph/edges.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
Per spec §2 Concepts (Edge, END): edges are static or conditional; each node
44
has exactly one outgoing edge. END is a distinct engine sentinel (not a
55
reserved node name) used as a routing target to halt execution.
6+
7+
`ConditionalEdge` is generic on the outer graph's state type so the routing
8+
function's parameter is typed against the user's `State` subclass — not
9+
`Any` — at type-check time.
610
"""
711

812
from collections.abc import Callable
913
from dataclasses import dataclass
10-
from typing import Any, Final
14+
from typing import Final
15+
16+
from .state import State
1117

1218

1319
class EndSentinel:
@@ -29,11 +35,11 @@ class StaticEdge:
2935

3036

3137
@dataclass(frozen=True)
32-
class ConditionalEdge:
38+
class ConditionalEdge[StateT: State]:
3339
"""Routes from `source` to whichever node `fn(state)` returns. The function
3440
MUST return either a declared node name or `END`; any other value raises
3541
`RoutingError` at runtime.
3642
"""
3743

3844
source: str
39-
fn: Callable[[Any], str | EndSentinel]
45+
fn: Callable[[StateT], str | EndSentinel]

src/openarmature/graph/nodes.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,33 @@
55
partial update which the engine merges via reducers.
66
77
The `Node` Protocol exists so subgraphs can compose as nodes alongside
8-
plain function-backed nodes (see `subgraph.SubgraphNode`).
8+
plain function-backed nodes (see `subgraph.SubgraphNode`). Both are
9+
parameterized on `StateT` so the outer graph's state type flows through
10+
to node functions at type-check time.
911
"""
1012

1113
from collections.abc import Awaitable, Callable, Mapping
1214
from dataclasses import dataclass
1315
from typing import Any, Protocol
1416

17+
from .state import State
1518

16-
class Node(Protocol):
19+
20+
class Node[StateT: State](Protocol):
1721
"""A unit of work in a compiled graph."""
1822

1923
@property
2024
def name(self) -> str: ...
2125

22-
async def run(self, state: Any) -> Mapping[str, Any]: ...
26+
async def run(self, state: StateT) -> Mapping[str, Any]: ...
2327

2428

2529
@dataclass(frozen=True)
26-
class FunctionNode:
30+
class FunctionNode[StateT: State]:
2731
"""A node backed by an async callable."""
2832

2933
name: str
30-
fn: Callable[[Any], Awaitable[Mapping[str, Any]]]
34+
fn: Callable[[StateT], Awaitable[Mapping[str, Any]]]
3135

32-
async def run(self, state: Any) -> Mapping[str, Any]:
36+
async def run(self, state: StateT) -> Mapping[str, Any]:
3337
return await self.fn(state)

src/openarmature/graph/projection.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
88
`ProjectionStrategy` is exposed as a seam so proposal 0002 (explicit
99
input/output mapping) can slot in without changes to the engine's compile or
10-
execute paths.
10+
execute paths. Parameterized on the parent and child state types so
11+
consumer-authored projections get typed `project_in` / `project_out`
12+
signatures without `cast(...)` gymnastics.
1113
"""
1214

1315
from collections.abc import Mapping
@@ -16,30 +18,37 @@
1618
from .state import State
1719

1820

19-
class ProjectionStrategy(Protocol):
21+
class ProjectionStrategy[ParentT: State, ChildT: State](Protocol):
2022
"""Strategy for moving state across the parent ↔ subgraph boundary."""
2123

22-
def project_in(self, parent_state: State, subgraph_state_cls: type[State]) -> State: ...
24+
def project_in(self, parent_state: ParentT, subgraph_state_cls: type[ChildT]) -> ChildT: ...
2325

2426
def project_out(
2527
self,
26-
subgraph_final_state: State,
27-
parent_state: State,
28-
subgraph_state_cls: type[State],
28+
subgraph_final_state: ChildT,
29+
parent_state: ParentT,
30+
subgraph_state_cls: type[ChildT],
2931
) -> Mapping[str, Any]: ...
3032

3133

32-
class FieldNameMatching:
33-
"""Default projection per spec v0.1.1 §2 Subgraph."""
34+
class FieldNameMatching[ParentT: State, ChildT: State]:
35+
"""Default projection per spec v0.1.1 §2 Subgraph.
3436
35-
def project_in(self, parent_state: State, subgraph_state_cls: type[State]) -> State:
37+
Parameterized for protocol conformance under generics. `ParentT` is not
38+
consumed (the default projection ignores parent state on the way in),
39+
but carrying the type variable keeps the default assignable to
40+
`ProjectionStrategy[ParentT, ChildT]` without type gymnastics at the
41+
SubgraphNode default-factory site.
42+
"""
43+
44+
def project_in(self, parent_state: ParentT, subgraph_state_cls: type[ChildT]) -> ChildT:
3645
return subgraph_state_cls()
3746

3847
def project_out(
3948
self,
40-
subgraph_final_state: State,
41-
parent_state: State,
42-
subgraph_state_cls: type[State],
49+
subgraph_final_state: ChildT,
50+
parent_state: ParentT,
51+
subgraph_state_cls: type[ChildT],
4352
) -> Mapping[str, Any]:
4453
parent_fields = set(type(parent_state).model_fields.keys())
4554
sub_fields = set(subgraph_state_cls.model_fields.keys())

src/openarmature/graph/subgraph.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
graph. The subgraph runs against its own state schema; projection between
55
parent and subgraph is delegated to a `ProjectionStrategy` (default:
66
`FieldNameMatching`).
7+
8+
Parameterized on both the parent's state type (`ParentT`) and the subgraph's
9+
state type (`ChildT`). The outer graph only ever sees `run(state: ParentT)`
10+
— the `ChildT` lives on the `compiled` and `projection` fields and is
11+
invisible at the outer graph's node dispatch site.
712
"""
813

914
from collections.abc import Mapping
@@ -18,14 +23,16 @@
1823

1924

2025
@dataclass(frozen=True)
21-
class SubgraphNode:
26+
class SubgraphNode[ParentT: State, ChildT: State]:
2227
"""A node backed by a compiled subgraph."""
2328

2429
name: str
25-
compiled: "CompiledGraph"
26-
projection: ProjectionStrategy = field(default_factory=FieldNameMatching)
30+
compiled: "CompiledGraph[ChildT]"
31+
projection: ProjectionStrategy[ParentT, ChildT] = field(
32+
default_factory=FieldNameMatching[ParentT, ChildT]
33+
)
2734

28-
async def run(self, state: State) -> Mapping[str, Any]:
35+
async def run(self, state: ParentT) -> Mapping[str, Any]:
2936
sub_initial = self.projection.project_in(state, self.compiled.state_cls)
3037
sub_final = await self.compiled.invoke(sub_initial)
3138
return self.projection.project_out(sub_final, state, self.compiled.state_cls)

tests/conformance/adapter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ async def fn(_state: Any) -> Mapping[str, Any]:
119119

120120
def _make_subgraph_fn(
121121
node_name: str,
122-
compiled: CompiledGraph,
122+
compiled: CompiledGraph[State],
123123
trace: list[str],
124-
projection: ProjectionStrategy,
124+
projection: ProjectionStrategy[State, State],
125125
) -> Callable[[Any], Awaitable[Mapping[str, Any]]]:
126126
"""Outer-graph node that delegates to a compiled subgraph.
127127
@@ -160,7 +160,7 @@ class BuiltGraph:
160160
"""Result of translating a fixture into runnable engine constructs."""
161161

162162
state_cls: type[State]
163-
builder: GraphBuilder
163+
builder: GraphBuilder[State]
164164
trace: list[str]
165165

166166
def initial_state(self, overrides: Mapping[str, Any]) -> State:
@@ -170,7 +170,7 @@ def initial_state(self, overrides: Mapping[str, Any]) -> State:
170170
def build_graph(
171171
spec: Mapping[str, Any],
172172
*,
173-
subgraphs: Mapping[str, CompiledGraph] | None = None,
173+
subgraphs: Mapping[str, CompiledGraph[State]] | None = None,
174174
trace: list[str] | None = None,
175175
model_name: str = "FixtureState",
176176
) -> BuiltGraph:
@@ -196,7 +196,7 @@ def build_graph(
196196
compiled = subgraphs[sub_name]
197197
builder.add_node(
198198
node_name,
199-
_make_subgraph_fn(node_name, compiled, trace, FieldNameMatching()),
199+
_make_subgraph_fn(node_name, compiled, trace, FieldNameMatching[State, State]()),
200200
)
201201
elif "raises" in node_spec:
202202
builder.add_node(node_name, _make_raising_fn(node_name, node_spec["raises"], trace))

0 commit comments

Comments
 (0)