|
13 | 13 | from __future__ import annotations |
14 | 14 |
|
15 | 15 | import asyncio |
| 16 | +from collections.abc import Mapping |
16 | 17 | from typing import Any |
17 | 18 |
|
18 | 19 | import pytest |
@@ -555,3 +556,151 @@ async def _read_after_write(_s: _SimpleState) -> dict[str, Any]: |
555 | 556 | # Caller baseline + in-node write, both visible to the read. |
556 | 557 | assert captured == {"tenantId": "T1", "audit_kind": "fraud"} |
557 | 558 | assert captured_type == [MappingProxyType] |
| 559 | + |
| 560 | + |
| 561 | +# Spec observability §3.4 *Per-attempt scoping*: under retry |
| 562 | +# middleware, each attempt sees only the metadata in scope at |
| 563 | +# retry-entry plus that attempt's own writes; failed-attempt |
| 564 | +# writes are discarded along with the attempt itself. The pin |
| 565 | +# below mirrors the spec's fixture 045 case shape (attempt 0 |
| 566 | +# writes + fails, attempt 1 asserts marker absent + writes + |
| 567 | +# succeeds, downstream reads successful attempt's marker). |
| 568 | +# Companion test verifies the same discard discipline on |
| 569 | +# terminal failure (all retries exhausted). |
| 570 | + |
| 571 | + |
| 572 | +class _RetryTransient(Exception): |
| 573 | + """Carries a transient category so the default classifier |
| 574 | + treats it as retryable. Matches the ``provider_rate_limit`` |
| 575 | + category used in ``tests/unit/test_middleware.py``.""" |
| 576 | + |
| 577 | + category = "provider_rate_limit" |
| 578 | + |
| 579 | + |
| 580 | +async def test_per_attempt_scoping_under_retry_discards_failed_attempt_writes() -> None: |
| 581 | + from openarmature.graph.middleware import RetryMiddleware |
| 582 | + |
| 583 | + captured_attempt_1_read: dict[str, Any] = {} |
| 584 | + captured_downstream_read: dict[str, Any] = {} |
| 585 | + attempts: list[int] = [] |
| 586 | + |
| 587 | + async def _retried(_s: _SimpleState) -> dict[str, Any]: |
| 588 | + attempt_n = len(attempts) |
| 589 | + attempts.append(attempt_n) |
| 590 | + if attempt_n == 0: |
| 591 | + # First attempt: write a marker, then raise transient. |
| 592 | + set_invocation_metadata(attempt_marker="first") |
| 593 | + raise _RetryTransient() |
| 594 | + # Second attempt: read first — assert the failed-attempt's |
| 595 | + # marker is NOT visible — then write a new marker and succeed. |
| 596 | + captured_attempt_1_read.update(dict(get_invocation_metadata())) |
| 597 | + set_invocation_metadata(attempt_marker="second") |
| 598 | + return {"counter": 1} |
| 599 | + |
| 600 | + async def _downstream(_s: _SimpleState) -> dict[str, Any]: |
| 601 | + captured_downstream_read.update(dict(get_invocation_metadata())) |
| 602 | + return {"counter": 2} |
| 603 | + |
| 604 | + graph = ( |
| 605 | + GraphBuilder(_SimpleState) |
| 606 | + .add_node( |
| 607 | + "retried", |
| 608 | + _retried, |
| 609 | + middleware=[RetryMiddleware(max_attempts=2, backoff=lambda _i: 0.0)], |
| 610 | + ) |
| 611 | + .add_node("downstream", _downstream) |
| 612 | + .add_edge("retried", "downstream") |
| 613 | + .add_edge("downstream", END) |
| 614 | + .set_entry("retried") |
| 615 | + .compile() |
| 616 | + ) |
| 617 | + await graph.invoke(_SimpleState(), metadata={"tenantId": "T1"}) |
| 618 | + |
| 619 | + assert attempts == [0, 1] |
| 620 | + # Attempt 1's read: baseline only — attempt 0's transient |
| 621 | + # ``attempt_marker=first`` write was discarded on failure. |
| 622 | + assert captured_attempt_1_read == {"tenantId": "T1"} |
| 623 | + # Downstream node: baseline + the successful attempt's write |
| 624 | + # persists past the retry boundary. |
| 625 | + assert captured_downstream_read == {"tenantId": "T1", "attempt_marker": "second"} |
| 626 | + |
| 627 | + |
| 628 | +async def test_terminal_failure_discards_final_failed_attempt_writes() -> None: |
| 629 | + # Exercises the middleware directly via ``compose_chain`` so the |
| 630 | + # post-retry metadata view is readable in the test scope (the |
| 631 | + # engine's outer invoke() reset would otherwise pop the var back |
| 632 | + # to empty before control returns to the test, masking the |
| 633 | + # middleware's own discard). The contract pinned here is that |
| 634 | + # AFTER the retry middleware re-raises a terminal failure, the |
| 635 | + # metadata ContextVar is back at the pre-attempt baseline — no |
| 636 | + # leak of the final failed attempt's writes. |
| 637 | + from openarmature.graph.middleware import RetryMiddleware, compose_chain |
| 638 | + from openarmature.observability.metadata import ( |
| 639 | + _reset_invocation_metadata, |
| 640 | + _set_invocation_metadata, |
| 641 | + validate_invocation_metadata, |
| 642 | + ) |
| 643 | + |
| 644 | + attempts: list[int] = [] |
| 645 | + |
| 646 | + async def _always_fails(_state: Any) -> Mapping[str, Any]: |
| 647 | + attempts.append(len(attempts)) |
| 648 | + set_invocation_metadata(attempt_marker=f"attempt_{len(attempts) - 1}") |
| 649 | + raise _RetryTransient() |
| 650 | + |
| 651 | + retry = RetryMiddleware(max_attempts=2, backoff=lambda _i: 0.0) |
| 652 | + chain = compose_chain([retry], _always_fails) |
| 653 | + |
| 654 | + # Establish a baseline outside the middleware so we can read it |
| 655 | + # back post-failure. Mirrors how the engine sets the baseline |
| 656 | + # at the invoke() boundary. |
| 657 | + baseline_token = _set_invocation_metadata(validate_invocation_metadata({"tenantId": "T1"})) |
| 658 | + try: |
| 659 | + with pytest.raises(_RetryTransient): |
| 660 | + await chain(_SimpleState()) |
| 661 | + # Both attempts ran. |
| 662 | + assert attempts == [0, 1] |
| 663 | + # Post-failure view: the pre-attempt baseline, with NO |
| 664 | + # ``attempt_marker`` leaked from the final failed attempt. |
| 665 | + assert dict(get_invocation_metadata()) == {"tenantId": "T1"} |
| 666 | + finally: |
| 667 | + _reset_invocation_metadata(baseline_token) |
| 668 | + |
| 669 | + |
| 670 | +async def test_cancellation_discards_in_flight_attempt_writes() -> None: |
| 671 | + # Spec §3.4: failed-attempt metadata writes are discarded along |
| 672 | + # with the attempt. When ``CancelledError`` (or any other |
| 673 | + # ``BaseException``) ends the attempt, the same discard discipline |
| 674 | + # applies — cancellation IS a failed attempt from the |
| 675 | + # metadata-scoping perspective. Spec §6.1: cancellation MUST |
| 676 | + # propagate (no retry, no swallow), so the reset must happen IN |
| 677 | + # ADDITION to, not instead of, propagating ``CancelledError``. |
| 678 | + from openarmature.graph.middleware import RetryMiddleware, compose_chain |
| 679 | + from openarmature.observability.metadata import ( |
| 680 | + _reset_invocation_metadata, |
| 681 | + _set_invocation_metadata, |
| 682 | + validate_invocation_metadata, |
| 683 | + ) |
| 684 | + |
| 685 | + attempts: list[int] = [] |
| 686 | + |
| 687 | + async def _writes_then_cancels(_state: Any) -> Mapping[str, Any]: |
| 688 | + attempts.append(len(attempts)) |
| 689 | + set_invocation_metadata(attempt_marker="leaked") |
| 690 | + raise asyncio.CancelledError("aborted") |
| 691 | + |
| 692 | + retry = RetryMiddleware(max_attempts=3, backoff=lambda _i: 0.0) |
| 693 | + chain = compose_chain([retry], _writes_then_cancels) |
| 694 | + |
| 695 | + baseline_token = _set_invocation_metadata(validate_invocation_metadata({"tenantId": "T1"})) |
| 696 | + try: |
| 697 | + with pytest.raises(asyncio.CancelledError): |
| 698 | + await chain(_SimpleState()) |
| 699 | + # Cancellation propagated — exactly ONE attempt ran (retry |
| 700 | + # MUST NOT swallow ``CancelledError`` per spec §6.1). |
| 701 | + assert attempts == [0] |
| 702 | + # The cancelled attempt's metadata write was discarded per |
| 703 | + # §3.4 — post-failure view is the pre-attempt baseline. |
| 704 | + assert dict(get_invocation_metadata()) == {"tenantId": "T1"} |
| 705 | + finally: |
| 706 | + _reset_invocation_metadata(baseline_token) |
0 commit comments