Skip to content

Commit efb19f5

Browse files
Bind fan-out lineage vars in instance middleware (#189)
* Bind fan-out lineage vars in instance middleware current_fan_out_index() (and the lineage chains) returned None inside fan-out instance_middleware: the engine binds those ContextVars per-node inside the inner subgraph (compiled.py), but instance_middleware wraps the subgraph from outside, before any node runs. The documented use (RetryMiddleware) doesn't read the index, so it sat latent; custom instance middleware reading the index or calling set_invocation_metadata saw None. Bind the three fan-out lineage ContextVars (fan_out_index + the per-depth index/branch chains) to the instance's child_context around the instance_middleware chain via a _bind_instance_lineage context manager, resetting on exit. Bind only when there is instance middleware to read them (the inner nodes bind them otherwise), so the no-middleware path is unchanged. Reset before the error-handling below so its saves keep their existing context. * Build fan-out test graphs step-by-step From CoPilot review of #189: expand the two new instance-middleware tests' inner and parent graph construction from inline method chains to the step-by-step named-builder pattern used throughout the module.
1 parent c2f66ed commit efb19f5

3 files changed

Lines changed: 147 additions & 2 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ All notable changes to `openarmature-python` are documented in this file.
44

55
The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). The package follows [Semantic Versioning](https://semver.org/); pre-1.0 minor bumps may carry behavioral changes per [spec governance](https://github.com/LunarCommand/openarmature-spec/blob/main/GOVERNANCE.md).
66

7+
## [Unreleased]
8+
9+
### Fixed
10+
11+
- **`current_fan_out_index()` inside fan-out instance middleware** now returns the executing instance's index (and `current_fan_out_index_chain()` its lineage) instead of `None`. The engine set the fan-out lineage ContextVars per-node, inside the inner subgraph, which left them unset in `instance_middleware` that wraps the subgraph from outside; they are now set around the instance-middleware chain. The documented `instance_middleware` use (`RetryMiddleware`) does not read the index, so no shipped behavior changes. This corrects the value seen by custom instance middleware that reads the index or calls `set_invocation_metadata`.
12+
713
## [0.15.0] — 2026-06-22
814

915
### Added

src/openarmature/graph/fan_out.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,20 @@
3232

3333
import asyncio
3434
import time
35-
from collections.abc import Callable, Mapping, Sequence
35+
from collections.abc import Callable, Iterator, Mapping, Sequence
36+
from contextlib import AbstractContextManager, contextmanager, nullcontext
3637
from dataclasses import dataclass, field
3738
from typing import TYPE_CHECKING, Any, Literal, cast
3839

40+
from openarmature.observability.correlation import (
41+
_reset_branch_name_chain,
42+
_reset_fan_out_index,
43+
_reset_fan_out_index_chain,
44+
_set_branch_name_chain,
45+
_set_fan_out_index,
46+
_set_fan_out_index_chain,
47+
)
48+
3949
from .errors import (
4050
FanOutEmpty,
4151
FanOutInvalidConcurrency,
@@ -57,6 +67,25 @@
5767
ConcurrencyResolver = Callable[[Any], int | None]
5868

5969

70+
@contextmanager
71+
def _bind_instance_lineage(child_context: _InvocationContext) -> Iterator[None]:
72+
"""Bind the fan-out lineage ContextVars (the instance index and the
73+
per-depth index / branch chains) to ``child_context`` for the duration of
74+
the ``with`` block, resetting them on exit."""
75+
# compiled.py binds these per-node, inside the inner subgraph; the
76+
# instance_middleware chain runs outside that, so current_fan_out_index()
77+
# and set_invocation_metadata's lineage view would otherwise be unset there.
78+
fan_out_token = _set_fan_out_index(child_context.fan_out_index)
79+
fan_out_chain_token = _set_fan_out_index_chain(child_context.fan_out_index_chain)
80+
branch_chain_token = _set_branch_name_chain(child_context.branch_name_chain)
81+
try:
82+
yield
83+
finally:
84+
_reset_branch_name_chain(branch_chain_token)
85+
_reset_fan_out_index_chain(fan_out_chain_token)
86+
_reset_fan_out_index(fan_out_token)
87+
88+
6089
@dataclass(frozen=True)
6190
class FanOutConfig:
6291
"""Frozen configuration for a :class:`FanOutNode`.
@@ -291,8 +320,17 @@ async def innermost(s: ChildT) -> Mapping[str, Any]:
291320
return _extract_instance_partial(cfg, final_inst_state)
292321

293322
chain: ChainCall = compose_chain(cfg.instance_middleware, innermost)
323+
# Bind the lineage ContextVars around the chain only when there is
324+
# instance middleware to read them; with none, the inner subgraph's
325+
# nodes bind them and this level would be a redundant no-op. The
326+
# context manager resets before the error-handling below so that
327+
# path's saves keep their existing context.
328+
lineage: AbstractContextManager[None] = (
329+
_bind_instance_lineage(child_context) if cfg.instance_middleware else nullcontext()
330+
)
294331
try:
295-
partial = await chain(instance_state)
332+
with lineage:
333+
partial = await chain(instance_state)
296334
except Exception as exc:
297335
if cfg.error_policy == "collect":
298336
# Per §10.11.2 collect mode: the failure becomes a

tests/unit/test_fan_out.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,107 @@ async def maybe_fail(state: WorkerState) -> Mapping[str, Any]:
654654
assert instance_attempts == {7: 2, 9: 2}
655655

656656

657+
async def test_instance_middleware_sees_fan_out_index() -> None:
658+
# An instance_middleware that reads current_fan_out_index() / its chain
659+
# observes the instance's own index: the engine sets the lineage ContextVars
660+
# around the middleware chain, not only inside node bodies. (Regression --
661+
# the index was None here when only compiled.py set it, deeper in node
662+
# execution, so the middleware wrapping the inner subgraph saw nothing.)
663+
from openarmature.observability.correlation import (
664+
current_fan_out_index,
665+
current_fan_out_index_chain,
666+
)
667+
668+
seen_index: dict[int, int | None] = {}
669+
seen_chain: dict[int, tuple[int | None, ...]] = {}
670+
671+
class _RecordIndexMW:
672+
async def __call__(self, state: WorkerState, next_: Any, /) -> Any:
673+
# Key by the item so each instance is identifiable without relying
674+
# on the index under test.
675+
seen_index[state.item] = current_fan_out_index()
676+
seen_chain[state.item] = current_fan_out_index_chain()
677+
return await next_(state)
678+
679+
async def compute(state: WorkerState) -> Mapping[str, Any]:
680+
return {"result": state.item}
681+
682+
inner_builder: GraphBuilder[WorkerState] = GraphBuilder(WorkerState)
683+
inner_builder.set_entry("compute")
684+
inner_builder.add_node("compute", compute)
685+
inner_builder.add_edge("compute", END)
686+
inner = inner_builder.compile()
687+
688+
parent_builder: GraphBuilder[InstanceMwParentState] = GraphBuilder(InstanceMwParentState)
689+
parent_builder.set_entry("process")
690+
parent_builder.add_fan_out_node(
691+
"process",
692+
subgraph=inner,
693+
items_field="items",
694+
item_field="item",
695+
collect_field="result",
696+
target_field="results",
697+
instance_middleware=[_RecordIndexMW()],
698+
)
699+
parent_builder.add_edge("process", END)
700+
parent = parent_builder.compile()
701+
702+
await parent.invoke(InstanceMwParentState(items=[10, 20, 30]))
703+
await parent.drain()
704+
705+
# items 10/20/30 are fan-out indices 0/1/2 in order; the chain carries the
706+
# instance index at the leaf.
707+
assert seen_index == {10: 0, 20: 1, 30: 2}
708+
assert seen_chain == {10: (0,), 20: (1,), 30: (2,)}
709+
710+
711+
async def test_instance_middleware_lineage_reset_on_failure() -> None:
712+
# The lineage ContextVars reset even when an instance fails: the binding's
713+
# finally runs on the exception path, so a failed instance leaks nothing
714+
# into the parent scope.
715+
from openarmature.observability.correlation import current_fan_out_index
716+
717+
seen: list[int | None] = []
718+
719+
class _RecordMW:
720+
async def __call__(self, state: WorkerState, next_: Any, /) -> Any:
721+
seen.append(current_fan_out_index())
722+
return await next_(state)
723+
724+
async def boom(_state: WorkerState) -> Mapping[str, Any]:
725+
raise RuntimeError("boom")
726+
727+
inner_builder: GraphBuilder[WorkerState] = GraphBuilder(WorkerState)
728+
inner_builder.set_entry("boom")
729+
inner_builder.add_node("boom", boom)
730+
inner_builder.add_edge("boom", END)
731+
inner = inner_builder.compile()
732+
733+
parent_builder: GraphBuilder[InstanceMwParentState] = GraphBuilder(InstanceMwParentState)
734+
parent_builder.set_entry("process")
735+
parent_builder.add_fan_out_node(
736+
"process",
737+
subgraph=inner,
738+
items_field="items",
739+
item_field="item",
740+
collect_field="result",
741+
target_field="results",
742+
instance_middleware=[_RecordMW()],
743+
concurrency=1,
744+
)
745+
parent_builder.add_edge("process", END)
746+
parent = parent_builder.compile()
747+
748+
with pytest.raises(NodeException):
749+
await parent.invoke(InstanceMwParentState(items=[1, 2]))
750+
await parent.drain()
751+
752+
# The middleware saw the instance index (the bind happened) ...
753+
assert seen and all(idx is not None for idx in seen)
754+
# ... and the bind's finally reset it despite the failure.
755+
assert current_fan_out_index() is None
756+
757+
657758
# ---------------------------------------------------------------------------
658759
# Fan-in determinism under nondeterministic completion order (§9.4)
659760
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)