11"""Subgraph projection strategies.
22
3- Per spec v0.1.1 §2 Subgraph: the default is **no projection in** (a subgraph
3+ Per spec v0.2.0 §2 Subgraph: the default is **no projection in** (a subgraph
44runs from its own schema's field defaults) and **field-name matching for
55projection out** (subgraph fields whose names match parent fields are merged
66back into the parent via the parent's reducers).
77
8- `ProjectionStrategy` is exposed as a seam so proposal 0002 (explicit
9- input/output mapping) can slot in without changes to the engine's compile or
10- execute paths. Parameterized on the parent and child state types so
11- consumer-authored projections get typed `project_in` / `project_out`
12- signatures without `cast(...)` gymnastics.
8+ Spec v0.2.0 (proposal 0002) adds explicit input/output mapping: a
9+ subgraph-as-node MAY declare `inputs` (parent → subgraph, additive over the
10+ default of no-projection-in) and/or `outputs` (subgraph → parent, replacement
11+ for field-name matching). Implemented here as `ExplicitMapping`.
12+
13+ Strategies parameterize on parent and child state types so consumer-authored
14+ projections get typed `project_in` / `project_out` signatures without
15+ `cast(...)` gymnastics.
1316"""
1417
1518from collections .abc import Mapping
1619from typing import Any , Protocol
1720
21+ from .errors import MappingReferencesUndeclaredField
1822from .state import State
1923
2024
25+ def _field_name_match_projection [ChildT : State ](
26+ subgraph_final_state : ChildT ,
27+ parent_state : State ,
28+ subgraph_state_cls : type [ChildT ],
29+ ) -> Mapping [str , Any ]:
30+ """Spec v0.2 §2 default projection-out: subgraph fields whose names
31+ match parent fields are merged back via the parent's reducers; non-
32+ matching subgraph fields are discarded.
33+
34+ Shared by `FieldNameMatching.project_out` (which always uses it) and
35+ `ExplicitMapping.project_out` (which falls back to it when `outputs`
36+ was not declared, per spec v0.2).
37+ """
38+ parent_fields = set (type (parent_state ).model_fields .keys ())
39+ sub_fields = set (subgraph_state_cls .model_fields .keys ())
40+ shared = parent_fields & sub_fields
41+ return {name : getattr (subgraph_final_state , name ) for name in shared }
42+
43+
2144class ProjectionStrategy [ParentT : State , ChildT : State ](Protocol ):
22- """Strategy for moving state across the parent ↔ subgraph boundary."""
45+ """Strategy for moving state across the parent ↔ subgraph boundary.
46+
47+ Two required methods plus one optional hook:
2348
24- def project_in (self , parent_state : ParentT , subgraph_state_cls : type [ChildT ]) -> ChildT : ...
49+ - `project_in` and `project_out` are required: the engine calls them on
50+ every subgraph step.
51+ - `validate(parent_cls, subgraph_state_cls) -> None` is an *optional*
52+ compile-time validation hook. If a strategy defines it, the parent
53+ graph's `compile()` calls it once per `SubgraphNode`; the strategy
54+ may raise a `CompileError` subclass when its declarations don't
55+ match the supplied schemas. Declarative strategies like
56+ `ExplicitMapping` use this to catch field-name typos before any
57+ node runs. Imperative custom projections typically have nothing
58+ declarative to check and can simply omit the method — the engine
59+ uses duck typing (`getattr`) to find it.
60+ """
61+
62+ def project_in (self , parent_state : ParentT , subgraph_state_cls : type [ChildT ]) -> ChildT :
63+ """Build the subgraph's initial state at the moment it begins."""
64+ raise NotImplementedError
2565
2666 def project_out (
2767 self ,
2868 subgraph_final_state : ChildT ,
2969 parent_state : ParentT ,
3070 subgraph_state_cls : type [ChildT ],
31- ) -> Mapping [str , Any ]: ...
71+ ) -> Mapping [str , Any ]:
72+ """Project the subgraph's final state back to the parent as a partial update."""
73+ raise NotImplementedError
3274
3375
3476class FieldNameMatching [ParentT : State , ChildT : State ]:
35- """Default projection per spec v0.1.1 §2 Subgraph.
77+ """Default projection per spec v0.2.0 §2 Subgraph.
3678
3779 Parameterized for protocol conformance under generics. `ParentT` is not
3880 consumed (the default projection ignores parent state on the way in),
@@ -50,7 +92,83 @@ def project_out(
5092 parent_state : ParentT ,
5193 subgraph_state_cls : type [ChildT ],
5294 ) -> Mapping [str , Any ]:
53- parent_fields = set (type (parent_state ).model_fields .keys ())
95+ return _field_name_match_projection (subgraph_final_state , parent_state , subgraph_state_cls )
96+
97+
98+ class ExplicitMapping [ParentT : State , ChildT : State ]:
99+ """Per spec v0.2.0 §2: explicit input/output mapping.
100+
101+ `inputs`: subgraph_field → parent_field. At entry, the named parent field's
102+ current value is copied into the named subgraph field. Subgraph fields not
103+ listed receive their schema-declared defaults — there is NO field-name
104+ fallback (additive over the spec's default no-projection-in).
105+
106+ `outputs`: parent_field → subgraph_field. At exit, the named subgraph
107+ field's value is merged into the named parent field via the parent's
108+ reducer. Subgraph fields not listed are discarded — `outputs` REPLACES
109+ field-name matching for projection-out.
110+
111+ The two directions are independent: pass either, both, or neither. The
112+ spec distinguishes "absent" (default applies) from "present but empty"
113+ (only for `outputs`, where the defaults differ); `outputs=None` means
114+ absent (fall back to field-name matching), `outputs={}` means present
115+ and empty (project nothing). For `inputs` the two defaults coincide
116+ (no-projection-in either way), so the distinction is only meaningful
117+ for `outputs`.
118+ """
119+
120+ def __init__ (
121+ self ,
122+ * ,
123+ inputs : Mapping [str , str ] | None = None ,
124+ outputs : Mapping [str , str ] | None = None ,
125+ ) -> None :
126+ self .inputs : dict [str , str ] = dict (inputs ) if inputs is not None else {}
127+ # Preserve absence on outputs so project_out can fall back to
128+ # field-name matching when None.
129+ self .outputs : dict [str , str ] | None = dict (outputs ) if outputs is not None else None
130+
131+ def project_in (self , parent_state : ParentT , subgraph_state_cls : type [ChildT ]) -> ChildT :
132+ kwargs : dict [str , Any ] = {
133+ sub_field : getattr (parent_state , parent_field ) for sub_field , parent_field in self .inputs .items ()
134+ }
135+ return subgraph_state_cls (** kwargs )
136+
137+ def project_out (
138+ self ,
139+ subgraph_final_state : ChildT ,
140+ parent_state : ParentT ,
141+ subgraph_state_cls : type [ChildT ],
142+ ) -> Mapping [str , Any ]:
143+ if self .outputs is None :
144+ # Outputs absent → spec default of field-name matching applies.
145+ return _field_name_match_projection (subgraph_final_state , parent_state , subgraph_state_cls )
146+ return {
147+ parent_field : getattr (subgraph_final_state , sub_field )
148+ for parent_field , sub_field in self .outputs .items ()
149+ }
150+
151+ def validate (self , parent_cls : type [ParentT ], subgraph_state_cls : type [ChildT ]) -> None :
152+ parent_fields = set (parent_cls .model_fields .keys ())
54153 sub_fields = set (subgraph_state_cls .model_fields .keys ())
55- shared = parent_fields & sub_fields
56- return {name : getattr (subgraph_final_state , name ) for name in shared }
154+
155+ for sub_field , parent_field in self .inputs .items ():
156+ if sub_field not in sub_fields :
157+ raise MappingReferencesUndeclaredField (
158+ direction = "inputs" , side = "subgraph" , field_name = sub_field
159+ )
160+ if parent_field not in parent_fields :
161+ raise MappingReferencesUndeclaredField (
162+ direction = "inputs" , side = "parent" , field_name = parent_field
163+ )
164+
165+ if self .outputs is not None :
166+ for parent_field , sub_field in self .outputs .items ():
167+ if parent_field not in parent_fields :
168+ raise MappingReferencesUndeclaredField (
169+ direction = "outputs" , side = "parent" , field_name = parent_field
170+ )
171+ if sub_field not in sub_fields :
172+ raise MappingReferencesUndeclaredField (
173+ direction = "outputs" , side = "subgraph" , field_name = sub_field
174+ )
0 commit comments