Skip to content

Commit c855195

Browse files
Wire get_invocation_metadata fixtures (043/045/046)
Wire the get_invocation_metadata read-access family (proposal 0048) into the YAML conformance harness via a hand-built runner -- the cross-cap adapter doesn't model augment_metadata, capture_invocation_metadata_into, retry_middleware, or direct_call. The runner processes a node's in-node directives in YAML key order (043 augments then captures; 045 attempt 1 captures then augments), drives 045's retry via RetryMiddleware + a category-carrying transient, and handles 046's no-graph direct call. Asserts final_state field equality plus the immutability invariant (the read is a MappingProxyType). Move 043/045/046 to _SUPPORTED_FIXTURES; 044 (fan-out scoping) stays a follow-up -- its fan-out collection of per-instance captures is a distinct shape. Test-only.
1 parent 5874173 commit c855195

1 file changed

Lines changed: 167 additions & 7 deletions

File tree

tests/conformance/test_observability.py

Lines changed: 167 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ def _reset_otel_global_tracer_provider(restore_to: object) -> None:
9494
# v0.42.0 — proposal 0050 call-level-retry per-attempt LLM span
9595
# surface. Single-attempt default: one span, attempt_index 0.
9696
"057-llm-attempt-index-single-attempt-default",
97+
# proposal 0048 get_invocation_metadata read access: 043 roundtrip,
98+
# 045 retry-scoping, 046 outside-invocation. 044 (fan-out) is a follow-up.
99+
"043-get-invocation-metadata-roundtrip",
100+
"045-get-invocation-metadata-retry-scoping",
101+
"046-get-invocation-metadata-outside-invocation",
97102
"001-otel-basic-trace",
98103
"002-otel-subgraph-hierarchy",
99104
"003-otel-error-status",
@@ -316,13 +321,10 @@ def _reset_otel_global_tracer_provider(restore_to: object) -> None:
316321
# conformance runner test_observability_langfuse.py -- see
317322
# _LANGFUSE_HARNESS_FIXTURES, NOT here (they are not unit-only).
318323
(
319-
(
320-
"043-get-invocation-metadata-roundtrip",
321-
"044-get-invocation-metadata-fan-out-scoping",
322-
"045-get-invocation-metadata-retry-scoping",
323-
"046-get-invocation-metadata-outside-invocation",
324-
),
325-
"proposal 0048 get_invocation_metadata; covered by test_observability_metadata.py",
324+
("044-get-invocation-metadata-fan-out-scoping",),
325+
"proposal 0048 get_invocation_metadata fan-out scoping; covered by "
326+
"test_observability_metadata.py (043/045/046 now wired into "
327+
"_SUPPORTED_FIXTURES; 044's fan-out collection is a follow-up)",
326328
),
327329
# Fixture-harness catch-up tier 1 wired the rest of the 0057/0058
328330
# family into _SUPPORTED_FIXTURES; these three stay here, each blocked
@@ -576,6 +578,12 @@ async def test_observability_fixture(fixture_path: Path) -> None:
576578
"098-langfuse-tool-observation",
577579
}:
578580
await _run_tool_fixture(spec)
581+
elif fixture_id in {
582+
"043-get-invocation-metadata-roundtrip",
583+
"045-get-invocation-metadata-retry-scoping",
584+
"046-get-invocation-metadata-outside-invocation",
585+
}:
586+
await _run_get_invocation_metadata_fixture(spec)
579587
else:
580588
raise AssertionError(f"no driver for supported fixture {fixture_id!r}")
581589

@@ -1390,6 +1398,158 @@ async def _body(_s: Any) -> dict[str, Any]:
13901398
raise AssertionError(f"case {case_name!r}: {e}") from e
13911399

13921400

1401+
# ---------------------------------------------------------------------------
1402+
# Proposal 0048 read access: get_invocation_metadata (fixtures 043 / 045 / 046).
1403+
# Hand-built -- the cross-cap adapter does not model augment_metadata /
1404+
# capture_invocation_metadata_into / retry_middleware / direct_call. Mirrors the
1405+
# unit tests in test_observability_metadata.py.
1406+
# ---------------------------------------------------------------------------
1407+
1408+
1409+
class _MetadataRetryTransient(Exception):
1410+
# The default retry classifier treats provider_rate_limit as retryable; this
1411+
# is the test's "transient_marker_error" for fixture 045.
1412+
category = "provider_rate_limit"
1413+
1414+
1415+
def _apply_metadata_directives(
1416+
directives: Mapping[str, Any], types_seen: dict[str, type]
1417+
) -> tuple[dict[str, Any], bool]:
1418+
"""Run a node's (or per-attempt's) in-node metadata directives, returning the
1419+
resulting state update and whether the node should then raise."""
1420+
# Directives run IN KEY ORDER -- 043 augments then captures, 045 attempt 1
1421+
# captures then augments, and the YAML key order encodes that. The capture
1422+
# records the read's type so the immutability invariant can verify it was a
1423+
# MappingProxyType. ``raises`` is terminal (a real node body that raises runs
1424+
# nothing after), so stop at it rather than processing later directives.
1425+
from openarmature.observability.metadata import ( # noqa: PLC0415
1426+
get_invocation_metadata,
1427+
set_invocation_metadata,
1428+
)
1429+
1430+
update: dict[str, Any] = {}
1431+
should_raise = False
1432+
for key, val in directives.items():
1433+
if key == "augment_metadata":
1434+
set_invocation_metadata(**cast("dict[str, Any]", val))
1435+
elif key == "capture_invocation_metadata_into":
1436+
read = get_invocation_metadata()
1437+
types_seen[cast("str", val)] = type(read)
1438+
update[cast("str", val)] = dict(read)
1439+
elif key == "raises":
1440+
should_raise = True
1441+
break
1442+
return update, should_raise
1443+
1444+
1445+
def _make_metadata_node_body(node_spec: Mapping[str, Any], types_seen: dict[str, type]) -> Any:
1446+
"""Build a node body that runs the node's metadata directives -- attempt-keyed
1447+
when the node declares ``per_attempt_behavior`` (045), else once per call."""
1448+
per_attempt = cast("list[dict[str, Any]] | None", node_spec.get("per_attempt_behavior"))
1449+
if per_attempt is not None:
1450+
by_attempt = {int(b["attempt"]): b for b in per_attempt}
1451+
attempts: list[int] = []
1452+
1453+
async def _retry_body(_s: Any) -> dict[str, Any]:
1454+
n = len(attempts)
1455+
attempts.append(n)
1456+
update, should_raise = _apply_metadata_directives(by_attempt.get(n, {}), types_seen)
1457+
if should_raise:
1458+
raise _MetadataRetryTransient()
1459+
return update
1460+
1461+
return _retry_body
1462+
1463+
async def _body(_s: Any) -> dict[str, Any]:
1464+
update, should_raise = _apply_metadata_directives(node_spec, types_seen)
1465+
if should_raise:
1466+
raise _MetadataRetryTransient()
1467+
return update
1468+
1469+
return _body
1470+
1471+
1472+
async def _run_get_invocation_metadata_fixture(spec: Mapping[str, Any]) -> None:
1473+
"""Drive every case of a get_invocation_metadata fixture (043 / 045 / 046)."""
1474+
for case in cast("list[dict[str, Any]]", spec["cases"]):
1475+
await _run_get_invocation_metadata_case(case)
1476+
1477+
1478+
async def _run_get_invocation_metadata_case(case: Mapping[str, Any]) -> None:
1479+
"""Assert one case: a bare get_invocation_metadata() call (046), or a graph
1480+
whose final_state captures the in-node reads (043 / 045)."""
1481+
from types import MappingProxyType # noqa: PLC0415
1482+
1483+
from openarmature.observability.metadata import get_invocation_metadata # noqa: PLC0415
1484+
1485+
expected = cast("dict[str, Any]", case["expected"])
1486+
invariants = cast("dict[str, Any]", expected.get("invariants") or {})
1487+
1488+
# Fixture 046: a bare get_invocation_metadata() call outside any invocation.
1489+
direct_call = cast("dict[str, Any] | None", case.get("direct_call"))
1490+
if direct_call is not None:
1491+
result = get_invocation_metadata()
1492+
dc = cast("dict[str, Any]", expected["direct_call_result"])
1493+
assert dict(result) == cast("dict[str, Any]", dc.get("value") or {}), (
1494+
f"direct_call value {dict(result)!r} != {dc.get('value')!r}"
1495+
)
1496+
if dc.get("type") == "immutable_mapping":
1497+
assert isinstance(result, MappingProxyType), "direct_call result is not an immutable mapping"
1498+
# ``exception: null`` -- reaching here means the call did not raise.
1499+
return
1500+
1501+
# Fixtures 043 / 045: build the graph, invoke with caller metadata, assert
1502+
# final_state field equality + the immutability invariant.
1503+
from openarmature.graph import END, GraphBuilder # noqa: PLC0415
1504+
from openarmature.graph.middleware import RetryConfig, RetryMiddleware # noqa: PLC0415
1505+
1506+
from .adapter import build_state_cls # noqa: PLC0415
1507+
1508+
types_seen: dict[str, type] = {}
1509+
state_cls = build_state_cls("MetadataFixtureState", cast("dict[str, Any]", case["state"]["fields"]))
1510+
builder = GraphBuilder(state_cls)
1511+
for node_name, node_spec_any in cast("dict[str, Any]", case["nodes"]).items():
1512+
node_spec = cast("dict[str, Any]", node_spec_any)
1513+
body = _make_metadata_node_body(node_spec, types_seen)
1514+
retry_cfg = cast("dict[str, Any] | None", node_spec.get("retry_middleware"))
1515+
if retry_cfg is not None:
1516+
# The fixture's abstract ``classifier`` (transient_marker) maps to the
1517+
# default retry classifier: _MetadataRetryTransient carries the
1518+
# provider_rate_limit category, which that classifier retries. Only
1519+
# max_attempts is read off the directive.
1520+
builder.add_node(
1521+
node_name,
1522+
body,
1523+
middleware=[
1524+
RetryMiddleware(
1525+
RetryConfig(max_attempts=int(retry_cfg["max_attempts"]), backoff=lambda _i: 0.0)
1526+
)
1527+
],
1528+
)
1529+
else:
1530+
builder.add_node(node_name, body)
1531+
for edge in cast("list[dict[str, str]]", case["edges"]):
1532+
target = END if edge["to"] == "END" else edge["to"]
1533+
builder.add_edge(edge["from"], target)
1534+
builder.set_entry(cast("str", case["entry"]))
1535+
graph = builder.compile()
1536+
1537+
final = await graph.invoke(
1538+
state_cls(**cast("dict[str, Any]", case.get("initial_state") or {})),
1539+
metadata=cast("dict[str, Any] | None", case.get("caller_metadata")),
1540+
)
1541+
await graph.drain()
1542+
1543+
for field_name, expected_value in cast("dict[str, Any]", expected.get("final_state") or {}).items():
1544+
actual = getattr(final, field_name)
1545+
assert actual == expected_value, f"final_state.{field_name}: {actual!r} != {expected_value!r}"
1546+
1547+
if invariants.get("read_returns_immutable_mapping"):
1548+
assert types_seen and all(t is MappingProxyType for t in types_seen.values()), (
1549+
f"read_returns_immutable_mapping: captured read types {types_seen!r} not all MappingProxyType"
1550+
)
1551+
1552+
13931553
def _normalize_attr_value(value: Any) -> Any:
13941554
"""OTel attribute values can be tuple or list shapes for sequence
13951555
types depending on how they were set; normalize for comparison."""

0 commit comments

Comments
 (0)