-
Notifications
You must be signed in to change notification settings - Fork 5
add structured error types and retry logic for 429/5xx #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,9 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import logging | ||
| import random | ||
| import time | ||
| from collections.abc import AsyncIterator | ||
| from collections.abc import Iterator | ||
| from collections.abc import Mapping | ||
|
|
@@ -13,7 +16,10 @@ | |
| from pydantic import BaseModel | ||
| from pydantic import ValidationError | ||
|
|
||
| from gumloop.errors import APIStatusError | ||
| from gumloop.errors import AuthenticationError | ||
| from gumloop.errors import RateLimitError | ||
| from gumloop.errors import ServerError | ||
| from gumloop.errors import to_api_error | ||
| from gumloop.types import StreamEvent | ||
|
|
||
|
|
@@ -22,6 +28,11 @@ | |
| _DONE_SENTINEL = "[DONE]" | ||
| _T = TypeVar("_T", bound=BaseModel) | ||
|
|
||
| DEFAULT_MAX_RETRIES = 2 | ||
| # Base delay in seconds for exponential backoff; actual delay is base * 2^attempt + jitter. | ||
| _RETRY_BASE_DELAY = 0.5 | ||
| _RETRY_MAX_DELAY = 60.0 | ||
|
|
||
|
|
||
| def _auth_headers(access_token: str | None, user_id: str | None) -> dict[str, str]: | ||
| if not access_token: | ||
|
|
@@ -41,6 +52,55 @@ def _omit_none_params(params: Mapping[str, Any] | None) -> dict[str, Any] | None | |
| return {k: v for k, v in params.items() if v is not None} | ||
|
|
||
|
|
||
| _IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "DELETE", "OPTIONS", "PUT"}) | ||
|
|
||
|
|
||
| def _should_retry(exc: APIStatusError, method: str) -> bool: | ||
| # Never retry client errors. | ||
| if not isinstance(exc, (RateLimitError, ServerError)): | ||
| return False | ||
| # POST/PATCH are non-idempotent: a 5xx may arrive after the server already | ||
| # committed the write, so retrying would duplicate the mutation. Only retry | ||
| # them on 429 (rate-limit), where the server explicitly guarantees the | ||
| # request was not processed. | ||
| if method.upper() not in _IDEMPOTENT_METHODS and isinstance(exc, ServerError): | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| def _retry_delay(attempt: int, retry_after: float | None) -> float: | ||
| """Return how many seconds to sleep before the next attempt. | ||
|
|
||
| Honours a ``Retry-After`` header when present; otherwise uses exponential | ||
| backoff with full jitter so concurrent clients don't thunderherd. | ||
| """ | ||
| if retry_after is not None: | ||
| return retry_after | ||
| cap = min(_RETRY_BASE_DELAY * (2**attempt), _RETRY_MAX_DELAY) | ||
| return random.uniform(0, cap) | ||
|
|
||
|
|
||
| def _parse_retry_after(response: httpx.Response) -> float | None: | ||
| import email.utils | ||
|
|
||
| raw = response.headers.get("retry-after") | ||
| if raw is None: | ||
| return None | ||
| try: | ||
| return float(raw) | ||
| except ValueError: | ||
| pass | ||
| # RFC 7231 also allows an HTTP-date: "Retry-After: Wed, 21 Oct 2015 07:28:00 GMT" | ||
| try: | ||
| dt = email.utils.parsedate_to_datetime(raw) | ||
| import datetime | ||
|
|
||
| delta = (dt - datetime.datetime.now(tz=datetime.timezone.utc)).total_seconds() | ||
| return max(delta, 0.0) | ||
| except Exception: | ||
| return None | ||
|
|
||
|
|
||
| def _decode_sse(event: ServerSentEvent) -> StreamEvent: | ||
| try: | ||
| decoded: Any = event.json() if event.data else {} | ||
|
|
@@ -71,11 +131,13 @@ def __init__( | |
| user_id: str | None, | ||
| timeout: float, | ||
| stream_timeout: float | None, | ||
| max_retries: int = DEFAULT_MAX_RETRIES, | ||
| ) -> None: | ||
| self.access_token = access_token | ||
| self.user_id = user_id | ||
| self._stream_base_url = stream_base_url.rstrip("/") | ||
| self._stream_timeout = stream_timeout | ||
| self._max_retries = max_retries | ||
| self._client = httpx.Client(base_url=base_url.rstrip("/"), timeout=timeout) | ||
|
|
||
| def close(self) -> None: | ||
|
|
@@ -119,15 +181,18 @@ def post_to_stream_host(self, path: str, *, json: Any = None) -> Any: | |
| # the api host has no handler for them. | ||
| headers = _auth_headers(self.access_token, self.user_id) | ||
| headers["Content-Type"] = "application/json" | ||
| response = self._client.post( | ||
| f"{self._stream_base_url}/{path.lstrip('/')}", | ||
| headers=headers, | ||
| timeout=self._stream_timeout, | ||
| json=json, | ||
| ) | ||
| if response.status_code >= 400: | ||
| raise to_api_error(response) | ||
| return response.json() if response.content else None | ||
| url = f"{self._stream_base_url}/{path.lstrip('/')}" | ||
| for attempt in range(self._max_retries + 1): | ||
| response = self._client.post(url, headers=headers, timeout=self._stream_timeout, json=json) | ||
| if response.status_code < 400: | ||
| return response.json() if response.content else None | ||
| exc = to_api_error(response) | ||
| if attempt < self._max_retries and _should_retry(exc, "POST"): | ||
| delay = _retry_delay(attempt, _parse_retry_after(response)) | ||
| logger.debug("retrying stream-host request (attempt %d, delay %.2fs)", attempt + 1, delay) | ||
| time.sleep(delay) | ||
| continue | ||
| raise exc | ||
|
|
||
| def stream( | ||
| self, | ||
|
|
@@ -184,7 +249,7 @@ def stream_typed( | |
| yield response_model.model_validate_json(event.data) | ||
| except ValidationError: | ||
| # Server-side mid-stream error frames or schema-drift events | ||
| # land here. | ||
| # land here. | ||
| logger.debug("dropped non-%s SSE: %s", response_model.__name__, event.data) | ||
| continue | ||
|
|
||
|
Comment on lines
197
to
255
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Both |
||
|
|
@@ -194,10 +259,17 @@ def _request(self, method: str, path: str, **kwargs: Any) -> Any: | |
| headers = _auth_headers(self.access_token, self.user_id) | ||
| if not kwargs.get("files"): | ||
| headers["Content-Type"] = "application/json" | ||
| response = self._client.request(method, path, headers=headers, **kwargs) | ||
| if response.status_code >= 400: | ||
| raise to_api_error(response) | ||
| return response.json() if response.content else None | ||
| for attempt in range(self._max_retries + 1): | ||
| response = self._client.request(method, path, headers=headers, **kwargs) | ||
| if response.status_code < 400: | ||
| return response.json() if response.content else None | ||
| exc = to_api_error(response) | ||
| if attempt < self._max_retries and _should_retry(exc, method): | ||
| delay = _retry_delay(attempt, _parse_retry_after(response)) | ||
| logger.debug("retrying %s %s (attempt %d, delay %.2fs)", method, path, attempt + 1, delay) | ||
| time.sleep(delay) | ||
| continue | ||
| raise exc | ||
|
Comment on lines
+262
to
+272
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
|
|
||
| class AsyncHttpClient: | ||
|
|
@@ -212,11 +284,13 @@ def __init__( | |
| user_id: str | None, | ||
| timeout: float, | ||
| stream_timeout: float | None, | ||
| max_retries: int = DEFAULT_MAX_RETRIES, | ||
| ) -> None: | ||
| self.access_token = access_token | ||
| self.user_id = user_id | ||
| self._stream_base_url = stream_base_url.rstrip("/") | ||
| self._stream_timeout = stream_timeout | ||
| self._max_retries = max_retries | ||
| self._client = httpx.AsyncClient(base_url=base_url.rstrip("/"), timeout=timeout) | ||
|
|
||
| async def aclose(self) -> None: | ||
|
|
@@ -257,15 +331,18 @@ async def delete(self, path: str) -> Any: | |
| async def post_to_stream_host(self, path: str, *, json: Any = None) -> Any: | ||
| headers = _auth_headers(self.access_token, self.user_id) | ||
| headers["Content-Type"] = "application/json" | ||
| response = await self._client.post( | ||
| f"{self._stream_base_url}/{path.lstrip('/')}", | ||
| headers=headers, | ||
| timeout=self._stream_timeout, | ||
| json=json, | ||
| ) | ||
| if response.status_code >= 400: | ||
| raise to_api_error(response) | ||
| return response.json() if response.content else None | ||
| url = f"{self._stream_base_url}/{path.lstrip('/')}" | ||
| for attempt in range(self._max_retries + 1): | ||
| response = await self._client.post(url, headers=headers, timeout=self._stream_timeout, json=json) | ||
| if response.status_code < 400: | ||
| return response.json() if response.content else None | ||
| exc = to_api_error(response) | ||
| if attempt < self._max_retries and _should_retry(exc, "POST"): | ||
| delay = _retry_delay(attempt, _parse_retry_after(response)) | ||
| logger.debug("retrying stream-host request (attempt %d, delay %.2fs)", attempt + 1, delay) | ||
| await asyncio.sleep(delay) | ||
| continue | ||
| raise exc | ||
|
|
||
| async def stream( | ||
| self, | ||
|
|
@@ -320,15 +397,22 @@ async def stream_typed( | |
| yield response_model.model_validate_json(event.data) | ||
| except ValidationError: | ||
| # Server-side mid-stream error frames or schema-drift events | ||
| # land here. | ||
| # land here. | ||
| logger.debug("dropped non-%s SSE: %s", response_model.__name__, event.data) | ||
| continue | ||
|
|
||
| async def _request(self, method: str, path: str, **kwargs: Any) -> Any: | ||
| headers = _auth_headers(self.access_token, self.user_id) | ||
| if not kwargs.get("files"): | ||
| headers["Content-Type"] = "application/json" | ||
| response = await self._client.request(method, path, headers=headers, **kwargs) | ||
| if response.status_code >= 400: | ||
| raise to_api_error(response) | ||
| return response.json() if response.content else None | ||
| for attempt in range(self._max_retries + 1): | ||
| response = await self._client.request(method, path, headers=headers, **kwargs) | ||
| if response.status_code < 400: | ||
| return response.json() if response.content else None | ||
| exc = to_api_error(response) | ||
| if attempt < self._max_retries and _should_retry(exc, method): | ||
| delay = _retry_delay(attempt, _parse_retry_after(response)) | ||
| logger.debug("retrying %s %s (attempt %d, delay %.2fs)", method, path, attempt + 1, delay) | ||
| await asyncio.sleep(delay) | ||
| continue | ||
| raise exc | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Retry-AfterHTTP-date format is silently droppedRFC 7231 allows
Retry-Afterto be either a delay-in-seconds (Retry-After: 30) or an HTTP-date (Retry-After: Wed, 21 Oct 2015 07:28:00 GMT). When the header contains a date string,float(raw)raisesValueError, the function returnsNone, and the code falls back to exponential backoff — which forattempt=0is at most 0.5 s. If the server sent a date meaning "wait 30 s", the client will retry far too soon, burn through all retries, and still surface aRateLimitErrorto the caller while likely worsening the rate-limit situation on the server side.