Skip to content

Commit 433632b

Browse files
Add call-level retry to Provider.complete() (#151)
Add an optional retry: RetryConfig parameter to complete(). When set, the wire call is retried in-call on transient provider errors per the config, so a node issuing several LLM calls in a loop does not re-run already-successful calls when a later call hits a transient failure. The request is built and validated once (pre-send errors are never retried) and the call stays terminal-only on the observability surface: exactly one LlmCompletionEvent or LlmFailedEvent fires per complete() call, with one call_id across attempts. The per-attempt span surface is deferred to a future sub-event; conformance.toml marks 0050 partial. Final piece of proposal 0050 (after failure isolation + the RetryConfig refactor). No spec-pin change.
1 parent 8dde25a commit 433632b

6 files changed

Lines changed: 271 additions & 6 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). The
99
### Added
1010

1111
- **`FailureIsolationMiddleware`** (proposal 0050, pipeline-utilities §6.3). A third bundled middleware primitive alongside `RetryMiddleware` and `TimingMiddleware`. It catches exceptions escaping the wrapped node's inner chain and returns a configured degraded partial update, so a non-critical node can fail without aborting the whole invocation. Configuration: `degraded_update` (a static mapping or a `state -> partial_update` callable, resolved at catch time), `event_name` (required, no default, since a generic name makes downstream telemetry strictly worse), an optional `predicate` (`Exception -> bool`; only matching exceptions are caught, others propagate), and an optional async `on_caught` hook. It catches `Exception`; `BaseException` (cancellation) propagates, matching `RetryMiddleware`. On a catch it dispatches a new framework-emitted `FailureIsolatedEvent` (a distinct observer-event variant carrying `event_name`, the wrapped node's lineage identity, `pre_state` / `post_state`, and a `CaughtException` record of category plus message) onto the observer delivery queue; the bundled OTel and Langfuse observers render it as a marker span / observation. Compose it OUTER of `RetryMiddleware` for the "retry transients, degrade gracefully on exhaustion" pattern. Additive: existing pipelines see no behavior change, and the spec pin is unchanged (0050 is already within the v0.53.0 pin).
12+
- **Call-level retry on `Provider.complete()`** (proposal 0050, llm-provider §7). The provider's `complete()` gains an optional `retry: RetryConfig | None` parameter. When supplied, the wire call is retried in-call on transient provider errors per the config (classifier, backoff, `on_retry`, `max_attempts`), so a node issuing several LLM calls in a loop does not re-run the already-successful calls when a later call hits a transient failure. The request is built and validated once (pre-send validation errors are never retried), and the call stays terminal-only on the observability surface: exactly one `LlmCompletionEvent` (eventual success) or `LlmFailedEvent` (retry exhaustion or a non-transient error) fires per `complete()` call, with a single `call_id` shared across attempts. The per-attempt span surface (N per-attempt spans and the `openarmature.llm.attempt_index` attribute) is deferred to a future cycle; `conformance.toml` marks proposal 0050 `partial` accordingly. No spec-pin change.
1213

1314
### Changed
1415

conformance.toml

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,23 @@ status = "implemented"
372372
since = "0.13.0"
373373

374374
# Spec v0.42.0 (proposal 0050). Retry & degradation primitives —
375-
# failure-isolation middleware + call-level retry. Queued for
376-
# v0.14.0 (largest single piece in the roadmap).
375+
# failure-isolation middleware (§6.3) + call-level retry (§7). Both
376+
# primitives implemented across the v0.14.0 cycle:
377+
# FailureIsolationMiddleware (distinct FailureIsolatedEvent +
378+
# CaughtException) and the call-level ``retry`` parameter on
379+
# ``Provider.complete()`` — an in-call loop over transient §7 errors
380+
# reusing the §6.1 RetryConfig record. ``partial`` because §7.1's
381+
# per-attempt span surface — N ``openarmature.llm.complete`` spans +
382+
# the ``openarmature.llm.attempt_index`` attribute — is DEFERRED: the
383+
# python LLM span is rendered from the typed event, which is
384+
# terminal-only per the graph-engine §6 mutual-exclusion contract, so
385+
# per-attempt spans require a dedicated within-call sub-event
386+
# (LlmRetryAttemptEvent) scoped to a future cycle. Call-level retry
387+
# ships terminal-only: exactly one LlmCompletionEvent / LlmFailedEvent
388+
# per ``complete()`` call.
377389
[proposals."0050"]
378-
status = "not-yet"
390+
status = "partial"
391+
since = "0.14.0"
379392

380393
# Spec v0.43.0 (proposal 0051). Langfuse trace.input/trace.output
381394
# implementation-surface caveat. Purely textual: documents that the

docs/concepts/llms.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,51 @@ stateless calls. Conversational memory (if you want it) is the
8585
caller's responsibility: thread it through state and pass the
8686
accumulated message list into each call.
8787

88+
## Retrying transient failures
89+
90+
LLM endpoints fail in transient ways (rate limits, 503s, brief
91+
outages). Pass a `RetryConfig` to `complete(retry=...)` to retry the
92+
call in-place on those transient categories, without re-running any
93+
surrounding work:
94+
95+
```python
96+
from openarmature.graph import RetryConfig
97+
98+
response = await provider.complete(
99+
messages,
100+
retry=RetryConfig(max_attempts=3),
101+
)
102+
```
103+
104+
When `retry` is omitted the call is a single attempt (the default).
105+
With a config, the request is built and validated once, then the wire
106+
call is retried on transient errors per the config's classifier and
107+
backoff; a non-transient error (a bad request, an auth failure)
108+
propagates immediately without retrying. From observability's point of
109+
view the call stays a single unit: exactly one completion-or-failure
110+
event fires for the terminal outcome, regardless of how many attempts
111+
it took.
112+
113+
### Call-level vs node-level retry
114+
115+
There are two retry layers, for different jobs:
116+
117+
- **Call-level** (`complete(retry=...)`) retries one LLM call. Reach
118+
for it when a node issues several LLM calls in a loop (chunked
119+
processing, multi-step) and you do not want a transient failure on
120+
the fifth call to re-run the four that already succeeded.
121+
- **Node-level** (`RetryMiddleware`, see [Middleware](middleware.md))
122+
retries a whole node. Reach for it when the node does LLM work plus
123+
other work (a DB write, a parse) and you want to re-run the entire
124+
body on failure.
125+
126+
They use the same `RetryConfig` shape and compose: a node-level retry
127+
re-runs the node, and each fresh run gets its own call-level budget.
128+
The thing to avoid is stacking both with overlapping budgets without
129+
meaning to: a 3-attempt node retry wrapping a 5-call node with
130+
3-attempt call-level retry can issue up to 45 calls in the worst case.
131+
Pick intentional budgets per layer.
132+
88133
## Pre-flight readiness check
89134

90135
`Provider.ready()` is the optional pre-flight call you make before

src/openarmature/llm/provider.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from __future__ import annotations
3939

4040
from collections.abc import Sequence
41-
from typing import Any, Protocol, cast
41+
from typing import TYPE_CHECKING, Any, Protocol, cast
4242
from urllib.parse import unquote
4343

4444
import jsonschema
@@ -58,6 +58,9 @@
5858
)
5959
from .response import Response, RuntimeConfig
6060

61+
if TYPE_CHECKING:
62+
from openarmature.graph.middleware import RetryConfig
63+
6164

6265
class Provider(Protocol):
6366
"""The shape of any llm-provider implementation.
@@ -78,6 +81,7 @@ async def complete(
7881
config: RuntimeConfig | None = None,
7982
response_schema: dict[str, Any] | type[BaseModel] | None = None,
8083
tool_choice: ToolChoice | None = None,
84+
retry: RetryConfig | None = None,
8185
) -> Response:
8286
"""Perform a single completion call.
8387
@@ -102,6 +106,12 @@ async def complete(
102106
the wire ``tool_choice`` field is omitted and the
103107
provider's own default applies. Pre-send validation
104108
routes through ``provider_invalid_request``.
109+
retry: Optional call-level retry configuration. When
110+
supplied, transient provider errors are retried in-call
111+
per the config; the request is built and validated once,
112+
and exactly one observability event fires for the
113+
terminal outcome. ``None`` (the default) performs a
114+
single attempt.
105115
"""
106116
...
107117

src/openarmature/llm/providers/openai.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,14 @@
5050

5151
from __future__ import annotations
5252

53+
import asyncio
5354
import hashlib
5455
import json
5556
import re
5657
import time
5758
import uuid
5859
from collections.abc import Mapping, Sequence
59-
from typing import Any, Literal, cast
60+
from typing import TYPE_CHECKING, Any, Literal, cast
6061
from urllib.parse import urlparse
6162

6263
import httpx
@@ -116,6 +117,9 @@
116117
)
117118
from ..response import FinishReason, ParsedValue, Response, RuntimeConfig, Usage
118119

120+
if TYPE_CHECKING:
121+
from openarmature.graph.middleware import RetryConfig
122+
119123
# Runtime guard for ``OpenAIProvider(..., readiness_probe=...)``. The
120124
# Literal type narrows callers under static checkers but is not enforced
121125
# at runtime, so an unknown string would silently no-op both dispatch
@@ -348,6 +352,7 @@ async def complete(
348352
config: RuntimeConfig | None = None,
349353
response_schema: dict[str, Any] | type[BaseModel] | None = None,
350354
tool_choice: ToolChoice | None = None,
355+
retry: RetryConfig | None = None,
351356
) -> Response:
352357
"""Single completion call.
353358
@@ -370,6 +375,18 @@ async def complete(
370375
non-empty ``tools``, and ``ForceTool.name`` must appear in the
371376
supplied list. Violations raise ``provider_invalid_request``
372377
BEFORE any HTTP request is sent.
378+
379+
When ``retry`` is supplied, the wire call is retried on
380+
transient provider errors per the config's classifier and
381+
backoff (defaulting to the canonical transient categories with
382+
exponential jittered backoff). The request is built and
383+
validated once; pre-send validation errors are never retried.
384+
Exactly one observability event fires for the call's terminal
385+
outcome regardless of attempt count, and its ``latency_ms``
386+
covers the whole call, retries and backoff included. The
387+
``on_retry`` hook is not exception-isolated (mirroring
388+
``RetryMiddleware``); an exception raised by it propagates out
389+
of the call.
373390
"""
374391
# Spec observability §5.5 LLM provider span: when an
375392
# observability backend is active in the current invocation,
@@ -464,7 +481,7 @@ async def complete(
464481
include_response_format=(schema_dict is None or not self._force_prompt_augmentation_fallback),
465482
tool_choice=tool_choice,
466483
)
467-
response = await self._do_complete(body, schema_dict, schema_class)
484+
response = await self._do_complete_with_retry(body, schema_dict, schema_class, retry)
468485
except LlmProviderError as exc:
469486
# Failure path: dispatch a typed LlmFailedEvent per
470487
# proposal 0058. Only §7 category exceptions
@@ -510,6 +527,52 @@ async def complete(
510527
)
511528
return response
512529

530+
async def _do_complete_with_retry(
531+
self,
532+
body: dict[str, Any],
533+
schema_dict: dict[str, Any] | None,
534+
schema_class: type[BaseModel] | None,
535+
retry: RetryConfig | None,
536+
) -> Response:
537+
"""Run the wire call with optional call-level retry.
538+
539+
Loops the underlying wire call on transient provider errors per
540+
the retry config. Intermediate transient attempts are caught
541+
here and emit no observability event; only the terminal outcome
542+
(success, retry exhaustion, or a non-transient error) reaches
543+
``complete()``'s typed-event dispatch, so exactly one event
544+
fires per ``complete()`` call.
545+
"""
546+
if retry is None:
547+
return await self._do_complete(body, schema_dict, schema_class)
548+
# Lazy import avoids a module-load cycle: graph.middleware.retry
549+
# imports llm.errors. Resolve None config fields to the canonical
550+
# defaults, mirroring RetryMiddleware.
551+
from openarmature.graph.middleware.retry import (
552+
default_classifier,
553+
exponential_jitter_backoff,
554+
)
555+
556+
classifier = retry.classifier or default_classifier
557+
backoff = retry.backoff or exponential_jitter_backoff
558+
attempt = 0
559+
while True:
560+
try:
561+
return await self._do_complete(body, schema_dict, schema_class)
562+
except LlmProviderError as exc:
563+
# No graph state at the call boundary; pass None (the
564+
# default classifier ignores it). Re-raise on exhaustion
565+
# or a non-transient category so complete() emits the
566+
# single terminal LlmFailedEvent.
567+
if attempt + 1 >= retry.max_attempts or not classifier(exc, None):
568+
raise
569+
# on_retry is not exception-isolated (matches
570+
# RetryMiddleware); a raise propagates out of the call.
571+
if retry.on_retry is not None:
572+
await retry.on_retry(exc, attempt)
573+
await asyncio.sleep(backoff(attempt))
574+
attempt += 1
575+
513576
def _build_llm_completion_event(
514577
self,
515578
response: Response,

tests/unit/test_llm_provider.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pydantic import ValidationError
2222

2323
from openarmature.graph.events import LlmCompletionEvent, LlmFailedEvent, NodeEvent
24+
from openarmature.graph.middleware import RetryConfig, deterministic_backoff
2425
from openarmature.graph.observer import ObserverEvent
2526
from openarmature.llm import (
2627
PROVIDER_AUTHENTICATION,
@@ -1336,6 +1337,138 @@ def _503(_req: httpx.Request) -> httpx.Response:
13361337
assert failed_events[0].error_type == "ProviderUnavailable"
13371338

13381339

1340+
# ---------------------------------------------------------------------------
1341+
# Call-level retry (proposal 0050)
1342+
# ---------------------------------------------------------------------------
1343+
1344+
1345+
def _ok_chat_completion() -> dict[str, object]:
1346+
return {
1347+
"id": "x",
1348+
"object": "chat.completion",
1349+
"created": 0,
1350+
"model": "m",
1351+
"choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"}],
1352+
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
1353+
}
1354+
1355+
1356+
def _fail_n_then_ok(calls: list[int], fail_count: int) -> Callable[[httpx.Request], httpx.Response]:
1357+
def handler(_req: httpx.Request) -> httpx.Response:
1358+
calls[0] += 1
1359+
if calls[0] <= fail_count:
1360+
return httpx.Response(503, json={"error": {"message": "down"}})
1361+
return httpx.Response(200, json=_ok_chat_completion())
1362+
1363+
return handler
1364+
1365+
1366+
async def test_call_level_retry_succeeds_after_transient() -> None:
1367+
calls = [0]
1368+
events, token = _collecting_dispatch()
1369+
provider = OpenAIProvider(
1370+
base_url="http://test",
1371+
model="m",
1372+
api_key="k",
1373+
transport=httpx.MockTransport(_fail_n_then_ok(calls, fail_count=1)),
1374+
)
1375+
try:
1376+
response = await provider.complete(
1377+
[UserMessage(content="hi")],
1378+
retry=RetryConfig(max_attempts=2, backoff=deterministic_backoff(0)),
1379+
)
1380+
finally:
1381+
await provider.aclose()
1382+
_release_dispatch(token)
1383+
1384+
# One transient failure then success: the wire call was retried.
1385+
assert calls[0] == 2
1386+
assert response.message.content == "ok"
1387+
# Terminal-only: one LlmCompletionEvent, no LlmFailedEvent for the
1388+
# intermediate transient attempt.
1389+
assert len([e for e in events if isinstance(e, LlmCompletionEvent)]) == 1
1390+
assert [e for e in events if isinstance(e, LlmFailedEvent)] == []
1391+
1392+
1393+
async def test_call_level_retry_exhaustion_emits_one_failed_event() -> None:
1394+
calls = [0]
1395+
events, token = _collecting_dispatch()
1396+
provider = OpenAIProvider(
1397+
base_url="http://test",
1398+
model="m",
1399+
api_key="k",
1400+
transport=httpx.MockTransport(_fail_n_then_ok(calls, fail_count=99)),
1401+
)
1402+
try:
1403+
with pytest.raises(ProviderUnavailable):
1404+
await provider.complete(
1405+
[UserMessage(content="hi")],
1406+
retry=RetryConfig(max_attempts=3, backoff=deterministic_backoff(0)),
1407+
)
1408+
finally:
1409+
await provider.aclose()
1410+
_release_dispatch(token)
1411+
1412+
# Exhausted all 3 attempts, then propagated. Terminal-only: one
1413+
# LlmFailedEvent (not one per attempt), no LlmCompletionEvent.
1414+
assert calls[0] == 3
1415+
assert [e for e in events if isinstance(e, LlmCompletionEvent)] == []
1416+
assert len([e for e in events if isinstance(e, LlmFailedEvent)]) == 1
1417+
1418+
1419+
async def test_call_level_retry_skips_non_transient() -> None:
1420+
calls = [0]
1421+
events, token = _collecting_dispatch()
1422+
1423+
def _400(_req: httpx.Request) -> httpx.Response:
1424+
calls[0] += 1
1425+
return httpx.Response(400, json={"error": {"message": "bad"}})
1426+
1427+
provider = OpenAIProvider(
1428+
base_url="http://test", model="m", api_key="k", transport=httpx.MockTransport(_400)
1429+
)
1430+
try:
1431+
with pytest.raises(ProviderInvalidRequest):
1432+
await provider.complete(
1433+
[UserMessage(content="hi")],
1434+
retry=RetryConfig(max_attempts=5, backoff=deterministic_backoff(0)),
1435+
)
1436+
finally:
1437+
await provider.aclose()
1438+
_release_dispatch(token)
1439+
1440+
# provider_invalid_request is non-transient: no retry, single attempt.
1441+
assert calls[0] == 1
1442+
assert len([e for e in events if isinstance(e, LlmFailedEvent)]) == 1
1443+
1444+
1445+
async def test_call_level_retry_invokes_on_retry_per_attempt() -> None:
1446+
calls = [0]
1447+
retries: list[tuple[str, int]] = []
1448+
1449+
async def _on_retry(exc: Exception, attempt: int) -> None:
1450+
retries.append((type(exc).__name__, attempt))
1451+
1452+
provider = OpenAIProvider(
1453+
base_url="http://test",
1454+
model="m",
1455+
api_key="k",
1456+
transport=httpx.MockTransport(_fail_n_then_ok(calls, fail_count=2)),
1457+
)
1458+
try:
1459+
await provider.complete(
1460+
[UserMessage(content="hi")],
1461+
retry=RetryConfig(max_attempts=3, backoff=deterministic_backoff(0), on_retry=_on_retry),
1462+
)
1463+
finally:
1464+
await provider.aclose()
1465+
1466+
# Two transient failures then success: on_retry fires once per
1467+
# retried attempt (before each backoff), with the 0-based index.
1468+
assert calls[0] == 3
1469+
assert retries == [("ProviderUnavailable", 0), ("ProviderUnavailable", 1)]
1470+
1471+
13391472
# ---------------------------------------------------------------------------
13401473
# Proposal 0058: per-category field-mapping + pre-send + mutual exclusion
13411474
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)