|
21 | 21 | from pydantic import ValidationError |
22 | 22 |
|
23 | 23 | from openarmature.graph.events import LlmCompletionEvent, LlmFailedEvent, NodeEvent |
| 24 | +from openarmature.graph.middleware import RetryConfig, deterministic_backoff |
24 | 25 | from openarmature.graph.observer import ObserverEvent |
25 | 26 | from openarmature.llm import ( |
26 | 27 | PROVIDER_AUTHENTICATION, |
@@ -1336,6 +1337,138 @@ def _503(_req: httpx.Request) -> httpx.Response: |
1336 | 1337 | assert failed_events[0].error_type == "ProviderUnavailable" |
1337 | 1338 |
|
1338 | 1339 |
|
| 1340 | +# --------------------------------------------------------------------------- |
| 1341 | +# Call-level retry (proposal 0050) |
| 1342 | +# --------------------------------------------------------------------------- |
| 1343 | + |
| 1344 | + |
| 1345 | +def _ok_chat_completion() -> dict[str, object]: |
| 1346 | + return { |
| 1347 | + "id": "x", |
| 1348 | + "object": "chat.completion", |
| 1349 | + "created": 0, |
| 1350 | + "model": "m", |
| 1351 | + "choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"}], |
| 1352 | + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, |
| 1353 | + } |
| 1354 | + |
| 1355 | + |
| 1356 | +def _fail_n_then_ok(calls: list[int], fail_count: int) -> Callable[[httpx.Request], httpx.Response]: |
| 1357 | + def handler(_req: httpx.Request) -> httpx.Response: |
| 1358 | + calls[0] += 1 |
| 1359 | + if calls[0] <= fail_count: |
| 1360 | + return httpx.Response(503, json={"error": {"message": "down"}}) |
| 1361 | + return httpx.Response(200, json=_ok_chat_completion()) |
| 1362 | + |
| 1363 | + return handler |
| 1364 | + |
| 1365 | + |
| 1366 | +async def test_call_level_retry_succeeds_after_transient() -> None: |
| 1367 | + calls = [0] |
| 1368 | + events, token = _collecting_dispatch() |
| 1369 | + provider = OpenAIProvider( |
| 1370 | + base_url="http://test", |
| 1371 | + model="m", |
| 1372 | + api_key="k", |
| 1373 | + transport=httpx.MockTransport(_fail_n_then_ok(calls, fail_count=1)), |
| 1374 | + ) |
| 1375 | + try: |
| 1376 | + response = await provider.complete( |
| 1377 | + [UserMessage(content="hi")], |
| 1378 | + retry=RetryConfig(max_attempts=2, backoff=deterministic_backoff(0)), |
| 1379 | + ) |
| 1380 | + finally: |
| 1381 | + await provider.aclose() |
| 1382 | + _release_dispatch(token) |
| 1383 | + |
| 1384 | + # One transient failure then success: the wire call was retried. |
| 1385 | + assert calls[0] == 2 |
| 1386 | + assert response.message.content == "ok" |
| 1387 | + # Terminal-only: one LlmCompletionEvent, no LlmFailedEvent for the |
| 1388 | + # intermediate transient attempt. |
| 1389 | + assert len([e for e in events if isinstance(e, LlmCompletionEvent)]) == 1 |
| 1390 | + assert [e for e in events if isinstance(e, LlmFailedEvent)] == [] |
| 1391 | + |
| 1392 | + |
| 1393 | +async def test_call_level_retry_exhaustion_emits_one_failed_event() -> None: |
| 1394 | + calls = [0] |
| 1395 | + events, token = _collecting_dispatch() |
| 1396 | + provider = OpenAIProvider( |
| 1397 | + base_url="http://test", |
| 1398 | + model="m", |
| 1399 | + api_key="k", |
| 1400 | + transport=httpx.MockTransport(_fail_n_then_ok(calls, fail_count=99)), |
| 1401 | + ) |
| 1402 | + try: |
| 1403 | + with pytest.raises(ProviderUnavailable): |
| 1404 | + await provider.complete( |
| 1405 | + [UserMessage(content="hi")], |
| 1406 | + retry=RetryConfig(max_attempts=3, backoff=deterministic_backoff(0)), |
| 1407 | + ) |
| 1408 | + finally: |
| 1409 | + await provider.aclose() |
| 1410 | + _release_dispatch(token) |
| 1411 | + |
| 1412 | + # Exhausted all 3 attempts, then propagated. Terminal-only: one |
| 1413 | + # LlmFailedEvent (not one per attempt), no LlmCompletionEvent. |
| 1414 | + assert calls[0] == 3 |
| 1415 | + assert [e for e in events if isinstance(e, LlmCompletionEvent)] == [] |
| 1416 | + assert len([e for e in events if isinstance(e, LlmFailedEvent)]) == 1 |
| 1417 | + |
| 1418 | + |
| 1419 | +async def test_call_level_retry_skips_non_transient() -> None: |
| 1420 | + calls = [0] |
| 1421 | + events, token = _collecting_dispatch() |
| 1422 | + |
| 1423 | + def _400(_req: httpx.Request) -> httpx.Response: |
| 1424 | + calls[0] += 1 |
| 1425 | + return httpx.Response(400, json={"error": {"message": "bad"}}) |
| 1426 | + |
| 1427 | + provider = OpenAIProvider( |
| 1428 | + base_url="http://test", model="m", api_key="k", transport=httpx.MockTransport(_400) |
| 1429 | + ) |
| 1430 | + try: |
| 1431 | + with pytest.raises(ProviderInvalidRequest): |
| 1432 | + await provider.complete( |
| 1433 | + [UserMessage(content="hi")], |
| 1434 | + retry=RetryConfig(max_attempts=5, backoff=deterministic_backoff(0)), |
| 1435 | + ) |
| 1436 | + finally: |
| 1437 | + await provider.aclose() |
| 1438 | + _release_dispatch(token) |
| 1439 | + |
| 1440 | + # provider_invalid_request is non-transient: no retry, single attempt. |
| 1441 | + assert calls[0] == 1 |
| 1442 | + assert len([e for e in events if isinstance(e, LlmFailedEvent)]) == 1 |
| 1443 | + |
| 1444 | + |
| 1445 | +async def test_call_level_retry_invokes_on_retry_per_attempt() -> None: |
| 1446 | + calls = [0] |
| 1447 | + retries: list[tuple[str, int]] = [] |
| 1448 | + |
| 1449 | + async def _on_retry(exc: Exception, attempt: int) -> None: |
| 1450 | + retries.append((type(exc).__name__, attempt)) |
| 1451 | + |
| 1452 | + provider = OpenAIProvider( |
| 1453 | + base_url="http://test", |
| 1454 | + model="m", |
| 1455 | + api_key="k", |
| 1456 | + transport=httpx.MockTransport(_fail_n_then_ok(calls, fail_count=2)), |
| 1457 | + ) |
| 1458 | + try: |
| 1459 | + await provider.complete( |
| 1460 | + [UserMessage(content="hi")], |
| 1461 | + retry=RetryConfig(max_attempts=3, backoff=deterministic_backoff(0), on_retry=_on_retry), |
| 1462 | + ) |
| 1463 | + finally: |
| 1464 | + await provider.aclose() |
| 1465 | + |
| 1466 | + # Two transient failures then success: on_retry fires once per |
| 1467 | + # retried attempt (before each backoff), with the 0-based index. |
| 1468 | + assert calls[0] == 3 |
| 1469 | + assert retries == [("ProviderUnavailable", 0), ("ProviderUnavailable", 1)] |
| 1470 | + |
| 1471 | + |
1339 | 1472 | # --------------------------------------------------------------------------- |
1340 | 1473 | # Proposal 0058: per-category field-mapping + pre-send + mutual exclusion |
1341 | 1474 | # --------------------------------------------------------------------------- |
|
0 commit comments