|
8 | 8 | Per spec §4 Error semantics: node, edge, reducer, and routing errors carry |
9 | 9 | recoverable state; state validation errors do not. |
10 | 10 |
|
11 | | -Per spec v0.3.0 §6 Observer hooks: between merge and edge evaluation, the |
12 | | -engine dispatches a `NodeEvent` for the just-completed node onto the |
13 | | -invocation's delivery queue. On node/reducer/state-validation failure, the |
14 | | -event is dispatched (with `error` populated) before the failure propagates. |
15 | | -Routing errors do NOT produce their own event — they arise after the |
16 | | -preceding node's event has already been dispatched. |
| 11 | +Per spec v0.6.0 §6 Observer hooks: each node attempt produces a |
| 12 | +started/completed event PAIR. The engine dispatches the started event |
| 13 | +before invoking the wrapped node function and the completed event after |
| 14 | +the reducer merge succeeds (with `post_state` populated) or after the |
| 15 | +node, reducer, or state validation fails (with `error` populated). |
| 16 | +Routing errors do NOT produce their own event pair — they arise after |
| 17 | +the preceding node's completed event has already been dispatched. |
17 | 18 |
|
18 | 19 | `CompiledGraph[StateT]` and `_merge_partial[StateT]` carry the concrete state |
19 | 20 | subclass through to `invoke()`'s return type, so consumers don't need |
|
42 | 43 | _DRAIN_SENTINEL, |
43 | 44 | Observer, |
44 | 45 | RemoveHandle, |
| 46 | + SubscribedObserver, |
| 47 | + _coerce_subscribed, |
45 | 48 | _dispatch, |
46 | 49 | _InvocationContext, |
47 | 50 | _QueuedItem, |
@@ -113,31 +116,41 @@ class CompiledGraph[StateT: State]: |
113 | 116 | # Observer plumbing — see attach_observer/drain. Mutable on a frozen |
114 | 117 | # dataclass: the list reference is fixed but its contents change. |
115 | 118 | # Parameterized factories so pyright infers the element types. |
116 | | - _attached_observers: list[Observer] = field(default_factory=list[Observer]) |
| 119 | + _attached_observers: list[SubscribedObserver] = field(default_factory=list[SubscribedObserver]) |
117 | 120 | # `set` (not list) so a per-task `add_done_callback(self._active_workers.discard)` |
118 | 121 | # auto-removes completed workers — long-running services that never call |
119 | 122 | # drain() don't accumulate completed Task references indefinitely. |
120 | 123 | _active_workers: set[asyncio.Task[None]] = field(default_factory=set[asyncio.Task[None]]) |
121 | 124 |
|
122 | 125 | # ------------------------------------------------------------------ |
123 | | - # Observer registration (spec v0.3.0 §6) |
| 126 | + # Observer registration (spec v0.6.0 §6) |
124 | 127 | # ------------------------------------------------------------------ |
125 | 128 |
|
126 | | - def attach_observer(self, observer: Observer) -> RemoveHandle: |
| 129 | + def attach_observer( |
| 130 | + self, |
| 131 | + observer: Observer, |
| 132 | + *, |
| 133 | + phases: Iterable[str] | None = None, |
| 134 | + ) -> RemoveHandle: |
127 | 135 | """Register a graph-attached observer. |
128 | 136 |
|
129 | | - Per spec v0.3.0 §6: graph-attached observers fire on every invocation |
| 137 | + Per spec v0.6.0 §6: graph-attached observers fire on every invocation |
130 | 138 | of this graph until removed — including when this graph runs as a |
131 | 139 | subgraph inside a parent. Returns a `RemoveHandle` whose `.remove()` |
132 | 140 | method detaches the observer; idempotent. |
133 | 141 |
|
| 142 | + `phases` selects the phase strings (`"started"`, `"completed"`) the |
| 143 | + observer subscribes to; default is both. An empty `phases` set |
| 144 | + raises `ValueError` at registration time. |
| 145 | +
|
134 | 146 | Per spec: changes to the registered set during a graph run do NOT |
135 | 147 | take effect until the next invocation. The set of observers |
136 | 148 | delivering events for an in-flight invocation is fixed at the point |
137 | 149 | the invocation begins. |
138 | 150 | """ |
139 | | - self._attached_observers.append(observer) |
140 | | - return RemoveHandle(_observers=self._attached_observers, _observer=observer) |
| 151 | + subscribed = _coerce_subscribed(observer, phases=phases) |
| 152 | + self._attached_observers.append(subscribed) |
| 153 | + return RemoveHandle(_observers=self._attached_observers, _observer=subscribed) |
141 | 154 |
|
142 | 155 | async def drain(self) -> None: |
143 | 156 | """Await delivery of every observer event produced by prior |
@@ -166,23 +179,27 @@ async def drain(self) -> None: |
166 | 179 | async def invoke( |
167 | 180 | self, |
168 | 181 | initial_state: StateT, |
169 | | - observers: Iterable[Observer] | None = None, |
| 182 | + observers: Iterable[Observer | SubscribedObserver] | None = None, |
170 | 183 | ) -> StateT: |
171 | 184 | """Run the graph from `initial_state` to END and return the final state. |
172 | 185 |
|
173 | 186 | Optional `observers` are invocation-scoped — they fire only for this |
174 | 187 | run, after all graph-attached observers (including subgraph-attached |
175 | | - ones for events originating in subgraphs) per spec v0.3.0 §6. |
| 188 | + ones for events originating in subgraphs) per spec v0.6.0 §6. |
| 189 | +
|
| 190 | + Each entry in `observers` may be either a bare `Observer` callable |
| 191 | + (subscribes to both phases) or a `SubscribedObserver` wrapping an |
| 192 | + observer with an explicit `phases` set. |
176 | 193 |
|
177 | | - Per spec v0.3.0 §6: this method returns as soon as the graph |
| 194 | + Per spec v0.6.0 §6: this method returns as soon as the graph |
178 | 195 | execution loop completes, regardless of whether the observer |
179 | 196 | delivery queue has finished processing every dispatched event. Use |
180 | 197 | `await compiled.drain()` if you need delivery-completion guarantees. |
181 | 198 |
|
182 | 199 | Raises one of the runtime error categories from spec §4 on failure. |
183 | 200 | """ |
184 | 201 |
|
185 | | - invocation_scoped = tuple(observers) if observers else () |
| 202 | + invocation_scoped = tuple(_coerce_subscribed(o) for o in (observers or ())) |
186 | 203 | queue: asyncio.Queue[_QueuedItem | None] = asyncio.Queue() |
187 | 204 | context = _InvocationContext( |
188 | 205 | queue=queue, |
@@ -271,62 +288,79 @@ async def _step_function_node( |
271 | 288 | state: StateT, |
272 | 289 | context: _InvocationContext, |
273 | 290 | ) -> StateT: |
274 | | - """Run one function-node step: take a step, run, merge, dispatch. |
275 | | -
|
276 | | - Dispatches a `NodeEvent` exactly once per call: |
277 | | - - On run failure (NodeException): event with error populated. |
278 | | - - On merge failure (ReducerError or StateValidationError): event with |
279 | | - error populated; the original error propagates unchanged after. |
280 | | - - On success: event with post_state populated, then return. |
| 291 | + """Run one function-node step: take a step, dispatch started, run, |
| 292 | + merge, dispatch completed. |
| 293 | +
|
| 294 | + Per spec v0.6.0 §6: each attempt produces a started/completed pair. |
| 295 | + Both events share the same `step`. The completed event carries |
| 296 | + `post_state` on success, or `error` on failure (one of run, reducer, |
| 297 | + or state-validation). The completed event is dispatched before the |
| 298 | + failure propagates. |
281 | 299 | """ |
282 | 300 | step = context.take_step() |
283 | 301 | namespace = context.namespace_prefix + (current,) |
284 | 302 | pre_state = state |
285 | 303 |
|
| 304 | + self._dispatch_started(context, current, namespace, step, pre_state) |
| 305 | + |
286 | 306 | try: |
287 | 307 | partial = await node.run(state) |
288 | 308 | except Exception as e: |
289 | 309 | wrapped = NodeException(node_name=current, cause=e, recoverable_state=state) |
290 | | - self._dispatch_failure_event(context, current, namespace, step, pre_state, wrapped) |
| 310 | + self._dispatch_completed(context, current, namespace, step, pre_state, error=wrapped) |
291 | 311 | raise wrapped from e |
292 | 312 |
|
293 | 313 | try: |
294 | 314 | new_state = _merge_partial(state, partial, self.reducers, current) |
295 | 315 | except (ReducerError, StateValidationError) as e: |
296 | | - self._dispatch_failure_event(context, current, namespace, step, pre_state, e) |
| 316 | + self._dispatch_completed(context, current, namespace, step, pre_state, error=e) |
297 | 317 | raise |
298 | 318 |
|
| 319 | + self._dispatch_completed(context, current, namespace, step, pre_state, post_state=new_state) |
| 320 | + return new_state |
| 321 | + |
| 322 | + @staticmethod |
| 323 | + def _dispatch_started( |
| 324 | + context: _InvocationContext, |
| 325 | + current: str, |
| 326 | + namespace: tuple[str, ...], |
| 327 | + step: int, |
| 328 | + pre_state: State, |
| 329 | + ) -> None: |
299 | 330 | _dispatch( |
300 | 331 | context, |
301 | 332 | NodeEvent( |
302 | 333 | node_name=current, |
303 | 334 | namespace=namespace, |
304 | 335 | step=step, |
| 336 | + phase="started", |
305 | 337 | pre_state=pre_state, |
306 | | - post_state=new_state, |
| 338 | + post_state=None, |
307 | 339 | error=None, |
308 | 340 | parent_states=context.parent_states_prefix, |
309 | 341 | ), |
310 | 342 | ) |
311 | | - return new_state |
312 | 343 |
|
313 | 344 | @staticmethod |
314 | | - def _dispatch_failure_event( |
| 345 | + def _dispatch_completed( |
315 | 346 | context: _InvocationContext, |
316 | 347 | current: str, |
317 | 348 | namespace: tuple[str, ...], |
318 | 349 | step: int, |
319 | 350 | pre_state: State, |
320 | | - error: RuntimeGraphError, |
| 351 | + *, |
| 352 | + post_state: State | None = None, |
| 353 | + error: RuntimeGraphError | None = None, |
321 | 354 | ) -> None: |
322 | 355 | _dispatch( |
323 | 356 | context, |
324 | 357 | NodeEvent( |
325 | 358 | node_name=current, |
326 | 359 | namespace=namespace, |
327 | 360 | step=step, |
| 361 | + phase="completed", |
328 | 362 | pre_state=pre_state, |
329 | | - post_state=None, |
| 363 | + post_state=post_state, |
330 | 364 | error=error, |
331 | 365 | parent_states=context.parent_states_prefix, |
332 | 366 | ), |
|
0 commit comments