Skip to content

Commit b942d7d

Browse files
Wire 044 get_invocation_metadata fan-out scoping (#192)
* Wire 044 get_invocation_metadata fan-out scoping Wire fixture 044 (get-invocation-metadata-fan-out-scoping) into the YAML conformance harness, completing the proposal 0048 read-access family. Refactor the runner into a build dispatcher (simple node chain vs fan-out). For 044, each instance augments item_id via a runtime-state middleware (029's shape), the inner node captures its read into a synthesized field the fan-out collects (list-append) into the outer per_instance_metadata, and a post-join serial node captures outermost_metadata (baseline only -- per-instance writes don't flow back across the join). Move 044 from _UNIT_TESTED to _SUPPORTED. Test-only. * Tighten 044 fan-out shape assertions From CoPilot review of #192: make the 044 build's shape assumptions explicit so a fixture-shape drift fails loudly. - Assert exactly one fan-out node (was: silently use the first). - Assert the inner node's keys are exactly {capture_invocation_metadata_into} (was: only checked the key's presence, dropping any extra directive).
1 parent fa95dc2 commit b942d7d

1 file changed

Lines changed: 142 additions & 21 deletions

File tree

tests/conformance/test_observability.py

Lines changed: 142 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ def _reset_otel_global_tracer_provider(restore_to: object) -> None:
9595
# surface. Single-attempt default: one span, attempt_index 0.
9696
"057-llm-attempt-index-single-attempt-default",
9797
# proposal 0048 get_invocation_metadata read access: 043 roundtrip,
98-
# 045 retry-scoping, 046 outside-invocation. 044 (fan-out) is a follow-up.
98+
# 044 fan-out scoping, 045 retry-scoping, 046 outside-invocation.
9999
"043-get-invocation-metadata-roundtrip",
100+
"044-get-invocation-metadata-fan-out-scoping",
100101
"045-get-invocation-metadata-retry-scoping",
101102
"046-get-invocation-metadata-outside-invocation",
102103
"001-otel-basic-trace",
@@ -320,12 +321,6 @@ def _reset_otel_global_tracer_provider(restore_to: object) -> None:
320321
# The Langfuse-mapping fixtures are fixture-tested by the sibling
321322
# conformance runner test_observability_langfuse.py -- see
322323
# _LANGFUSE_HARNESS_FIXTURES, NOT here (they are not unit-only).
323-
(
324-
("044-get-invocation-metadata-fan-out-scoping",),
325-
"proposal 0048 get_invocation_metadata fan-out scoping; covered by "
326-
"test_observability_metadata.py (043/045/046 now wired into "
327-
"_SUPPORTED_FIXTURES; 044's fan-out collection is a follow-up)",
328-
),
329324
# Fixture-harness catch-up tier 1 wired the rest of the 0057/0058
330325
# family into _SUPPORTED_FIXTURES; these three stay here, each blocked
331326
# on a spec-side fixture change that python picks up at the v0.16.0 pin
@@ -580,6 +575,7 @@ async def test_observability_fixture(fixture_path: Path) -> None:
580575
await _run_tool_fixture(spec)
581576
elif fixture_id in {
582577
"043-get-invocation-metadata-roundtrip",
578+
"044-get-invocation-metadata-fan-out-scoping",
583579
"045-get-invocation-metadata-retry-scoping",
584580
"046-get-invocation-metadata-outside-invocation",
585581
}:
@@ -1488,14 +1484,14 @@ async def _body(_s: Any) -> dict[str, Any]:
14881484

14891485

14901486
async def _run_get_invocation_metadata_fixture(spec: Mapping[str, Any]) -> None:
1491-
"""Drive every case of a get_invocation_metadata fixture (043 / 045 / 046)."""
1487+
"""Drive every case of a get_invocation_metadata fixture (043 / 044 / 045 / 046)."""
14921488
for case in cast("list[dict[str, Any]]", spec["cases"]):
14931489
await _run_get_invocation_metadata_case(case)
14941490

14951491

14961492
async def _run_get_invocation_metadata_case(case: Mapping[str, Any]) -> None:
14971493
"""Assert one case: a bare get_invocation_metadata() call (046), or a graph
1498-
whose final_state captures the in-node reads (043 / 045)."""
1494+
whose final_state captures the in-node reads (043 / 044 / 045)."""
14991495
from types import MappingProxyType # noqa: PLC0415
15001496

15011497
from openarmature.observability.metadata import get_invocation_metadata # noqa: PLC0415
@@ -1516,17 +1512,36 @@ async def _run_get_invocation_metadata_case(case: Mapping[str, Any]) -> None:
15161512
# ``exception: null`` -- reaching here means the call did not raise.
15171513
return
15181514

1519-
# Fixtures 043 / 045: build the graph, invoke with caller metadata, assert
1520-
# final_state field equality + the immutability invariant.
1515+
# Fixtures 043 / 044 / 045: build the graph (simple or fan-out), invoke with
1516+
# caller metadata, then assert final_state + the immutability invariant.
1517+
types_seen: dict[str, type] = {}
1518+
final = await _build_and_invoke_metadata_graph(case, types_seen)
1519+
1520+
for field_name, expected_value in cast("dict[str, Any]", expected.get("final_state") or {}).items():
1521+
actual = getattr(final, field_name)
1522+
assert actual == expected_value, f"final_state.{field_name}: {actual!r} != {expected_value!r}"
1523+
1524+
if invariants.get("read_returns_immutable_mapping"):
1525+
assert types_seen and all(t is MappingProxyType for t in types_seen.values()), (
1526+
f"read_returns_immutable_mapping: captured read types {types_seen!r} not all MappingProxyType"
1527+
)
1528+
1529+
1530+
async def _build_and_invoke_metadata_graph(case: Mapping[str, Any], types_seen: dict[str, type]) -> Any:
1531+
"""Build the case's graph (simple node chain, or 044's fan-out), invoke it
1532+
with the caller metadata, and return the final state."""
15211533
from openarmature.graph import END, GraphBuilder # noqa: PLC0415
15221534
from openarmature.graph.middleware import RetryConfig, RetryMiddleware # noqa: PLC0415
15231535

15241536
from .adapter import build_state_cls # noqa: PLC0415
15251537

1526-
types_seen: dict[str, type] = {}
1538+
nodes = cast("dict[str, Any]", case["nodes"])
1539+
if any("fan_out" in cast("dict[str, Any]", s) for s in nodes.values()):
1540+
return await _build_and_invoke_metadata_fan_out(case, types_seen)
1541+
15271542
state_cls = build_state_cls("MetadataFixtureState", cast("dict[str, Any]", case["state"]["fields"]))
15281543
builder = GraphBuilder(state_cls)
1529-
for node_name, node_spec_any in cast("dict[str, Any]", case["nodes"]).items():
1544+
for node_name, node_spec_any in nodes.items():
15301545
node_spec = cast("dict[str, Any]", node_spec_any)
15311546
body = _make_metadata_node_body(node_spec, types_seen)
15321547
retry_cfg = cast("dict[str, Any] | None", node_spec.get("retry_middleware"))
@@ -1551,8 +1566,7 @@ async def _run_get_invocation_metadata_case(case: Mapping[str, Any]) -> None:
15511566
else:
15521567
builder.add_node(node_name, body)
15531568
for edge in cast("list[dict[str, str]]", case["edges"]):
1554-
target = END if edge["to"] == "END" else edge["to"]
1555-
builder.add_edge(edge["from"], target)
1569+
builder.add_edge(edge["from"], END if edge["to"] == "END" else edge["to"])
15561570
builder.set_entry(cast("str", case["entry"]))
15571571
graph = builder.compile()
15581572

@@ -1561,15 +1575,122 @@ async def _run_get_invocation_metadata_case(case: Mapping[str, Any]) -> None:
15611575
metadata=cast("dict[str, Any] | None", case.get("caller_metadata")),
15621576
)
15631577
await graph.drain()
1578+
return final
15641579

1565-
for field_name, expected_value in cast("dict[str, Any]", expected.get("final_state") or {}).items():
1566-
actual = getattr(final, field_name)
1567-
assert actual == expected_value, f"final_state.{field_name}: {actual!r} != {expected_value!r}"
15681580

1569-
if invariants.get("read_returns_immutable_mapping"):
1570-
assert types_seen and all(t is MappingProxyType for t in types_seen.values()), (
1571-
f"read_returns_immutable_mapping: captured read types {types_seen!r} not all MappingProxyType"
1581+
async def _build_and_invoke_metadata_fan_out(case: Mapping[str, Any], types_seen: dict[str, type]) -> Any:
1582+
"""Build and invoke 044's fan-out graph, returning the final state."""
1583+
# Each instance augments item_id (augment_metadata_from_field) then captures
1584+
# its read, collected into the outer per_instance_metadata list; a post-join
1585+
# serial node captures outermost_metadata. The inner capture writes a
1586+
# synthesized inner field that the fan-out collects into the directive's
1587+
# named outer field.
1588+
from openarmature.graph import END, GraphBuilder # noqa: PLC0415
1589+
1590+
from .adapter import build_state_cls # noqa: PLC0415
1591+
1592+
nodes = cast("dict[str, Any]", case["nodes"])
1593+
fan_out_names = [n for n, s in nodes.items() if "fan_out" in cast("dict[str, Any]", s)]
1594+
assert len(fan_out_names) == 1, f"044 build expects exactly one fan-out node; got {fan_out_names}"
1595+
fan_out_name = fan_out_names[0]
1596+
fan_out_spec = cast("dict[str, Any]", nodes[fan_out_name])
1597+
fan_out_cfg = cast("dict[str, Any]", fan_out_spec["fan_out"])
1598+
items_field = cast("str", fan_out_cfg["items_field"])
1599+
augment_map = cast("dict[str, str]", fan_out_cfg.get("augment_metadata_from_field") or {})
1600+
1601+
# Inner subgraph: the capture node writes into a synthesized inner field; the
1602+
# fan-out collects it into the outer list named by the directive.
1603+
item_field = "oa_fan_out_item"
1604+
inner_capture_field = "oa_captured_read"
1605+
inner_spec = cast("dict[str, Any]", case["inner_subgraphs"][cast("str", fan_out_cfg["inner_subgraph"])])
1606+
inner_state_cls = build_state_cls(
1607+
"MetaInnerState",
1608+
{inner_capture_field: {"type": "dict", "default": {}}, item_field: {"type": "dict", "default": {}}},
1609+
)
1610+
inner_builder = GraphBuilder(inner_state_cls)
1611+
# 044's inner subgraph is a single capture node; assert that shape so a future
1612+
# multi-node or non-capture inner subgraph fails loudly rather than silently
1613+
# collecting into one slot.
1614+
inner_nodes = cast("dict[str, Any]", inner_spec["nodes"])
1615+
assert len(inner_nodes) == 1, (
1616+
f"fan-out metadata inner subgraph must be one capture node; got {sorted(inner_nodes)}"
1617+
)
1618+
inode_name, inode_spec_any = next(iter(inner_nodes.items()))
1619+
inode_spec = cast("dict[str, Any]", inode_spec_any)
1620+
assert set(inode_spec) == {"capture_invocation_metadata_into"}, (
1621+
f"fan-out inner node {inode_name!r} must declare only capture_invocation_metadata_into; "
1622+
f"got {sorted(inode_spec)}"
1623+
)
1624+
outer_target_field = cast("str", inode_spec["capture_invocation_metadata_into"])
1625+
inner_builder.add_node(inode_name, _make_metadata_capture_body(inner_capture_field, types_seen))
1626+
for edge in cast("list[dict[str, str]]", inner_spec["edges"]):
1627+
inner_builder.add_edge(edge["from"], END if edge["to"] == "END" else edge["to"])
1628+
inner_builder.set_entry(cast("str", inner_spec["entry"]))
1629+
inner_graph = inner_builder.compile()
1630+
1631+
# Outer state: the declared fields + the items_field source (shipped via
1632+
# initial_state, like 029).
1633+
outer_fields = dict(cast("dict[str, Any]", case["state"]["fields"]))
1634+
outer_fields.setdefault(items_field, {"type": "list<dict>", "default": []})
1635+
outer_state_cls = build_state_cls("MetaOuterState", outer_fields)
1636+
outer_builder = GraphBuilder(outer_state_cls)
1637+
outer_builder.add_fan_out_node(
1638+
fan_out_name,
1639+
subgraph=inner_graph,
1640+
items_field=items_field,
1641+
item_field=item_field,
1642+
collect_field=inner_capture_field,
1643+
target_field=outer_target_field,
1644+
instance_middleware=[_make_metadata_augment_middleware(augment_map, item_field)],
1645+
)
1646+
for node_name, node_spec_any in nodes.items():
1647+
if node_name == fan_out_name:
1648+
continue
1649+
outer_builder.add_node(
1650+
node_name, _make_metadata_node_body(cast("dict[str, Any]", node_spec_any), types_seen)
15721651
)
1652+
for edge in cast("list[dict[str, str]]", case["edges"]):
1653+
outer_builder.add_edge(edge["from"], END if edge["to"] == "END" else edge["to"])
1654+
outer_builder.set_entry(cast("str", case["entry"]))
1655+
graph = outer_builder.compile()
1656+
1657+
final = await graph.invoke(
1658+
outer_state_cls(**cast("dict[str, Any]", case.get("initial_state") or {})),
1659+
metadata=cast("dict[str, Any] | None", case.get("caller_metadata")),
1660+
)
1661+
await graph.drain()
1662+
return final
1663+
1664+
1665+
def _make_metadata_capture_body(capture_field: str, types_seen: dict[str, type]) -> Any:
1666+
"""Body that captures get_invocation_metadata() into a fixed state field --
1667+
044's inner-subgraph node, whose read the fan-out then collects."""
1668+
from openarmature.observability.metadata import get_invocation_metadata # noqa: PLC0415
1669+
1670+
async def _body(_s: Any) -> dict[str, Any]:
1671+
read = get_invocation_metadata()
1672+
types_seen[capture_field] = type(read)
1673+
return {capture_field: dict(read)}
1674+
1675+
return _body
1676+
1677+
1678+
def _make_metadata_augment_middleware(field_map: dict[str, str], item_field: str) -> Any:
1679+
"""Per-instance fan-out middleware (044's augment_metadata_from_field): read
1680+
the instance's item from item_field and set_invocation_metadata from the
1681+
mapped fields, before the inner read. Reads runtime state (the engine has
1682+
placed the item by the time the chain runs)."""
1683+
from openarmature.observability.metadata import set_invocation_metadata # noqa: PLC0415
1684+
1685+
class _AugmentMW:
1686+
async def __call__(self, state: Any, next_: Any, /) -> Any:
1687+
item = getattr(state, item_field, None)
1688+
if isinstance(item, Mapping):
1689+
item_map = cast("Mapping[str, Any]", item)
1690+
set_invocation_metadata(**{key: item_map[field] for key, field in field_map.items()})
1691+
return await next_(state)
1692+
1693+
return _AugmentMW()
15731694

15741695

15751696
def _normalize_attr_value(value: Any) -> Any:

0 commit comments

Comments
 (0)