Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/gumloop/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import httpx

from gumloop._http import DEFAULT_MAX_RETRIES
from gumloop._http import AsyncHttpClient
from gumloop._http import HttpClient
from gumloop.oauth import OAuth
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
stream_base_url: str | None = None,
timeout: float = DEFAULT_TIMEOUT,
stream_timeout: float | None = DEFAULT_STREAM_TIMEOUT,
max_retries: int = DEFAULT_MAX_RETRIES,
) -> None:
self.api_key = api_key
self.access_token = access_token or api_key or os.environ.get("GUMLOOP_ACCESS_TOKEN")
Expand All @@ -67,6 +69,7 @@ def __init__(
self.stream_base_url = (stream_base_url or _derive_stream_base_url(self.base_url)).rstrip("/")
self.timeout = timeout
self.stream_timeout = stream_timeout
self.max_retries = max_retries

self._http = HttpClient(
base_url=self.base_url,
Expand All @@ -75,6 +78,7 @@ def __init__(
user_id=self.user_id,
timeout=self.timeout,
stream_timeout=self.stream_timeout,
max_retries=self.max_retries,
)

self.agents = Agents(self._http)
Expand Down Expand Up @@ -110,6 +114,7 @@ def __init__(
stream_base_url: str | None = None,
timeout: float = DEFAULT_TIMEOUT,
stream_timeout: float | None = DEFAULT_STREAM_TIMEOUT,
max_retries: int = DEFAULT_MAX_RETRIES,
) -> None:
self.api_key = api_key
self.access_token = access_token or api_key or os.environ.get("GUMLOOP_ACCESS_TOKEN")
Expand All @@ -119,6 +124,7 @@ def __init__(
self.stream_base_url = (stream_base_url or _derive_stream_base_url(self.base_url)).rstrip("/")
self.timeout = timeout
self.stream_timeout = stream_timeout
self.max_retries = max_retries

self._http = AsyncHttpClient(
base_url=self.base_url,
Expand All @@ -127,6 +133,7 @@ def __init__(
user_id=self.user_id,
timeout=self.timeout,
stream_timeout=self.stream_timeout,
max_retries=self.max_retries,
)

self.agents = AsyncAgents(self._http)
Expand Down
219 changes: 147 additions & 72 deletions src/gumloop/_http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import asyncio
import logging
import random
import time
from collections.abc import AsyncIterator
from collections.abc import Iterator
from collections.abc import Mapping
Expand All @@ -13,14 +16,59 @@
from pydantic import BaseModel
from pydantic import ValidationError

from gumloop.errors import APIStatusError
from gumloop.errors import AuthenticationError
from gumloop.errors import RateLimitError
from gumloop.errors import ServerError
from gumloop.errors import to_api_error
from gumloop.types import StreamEvent

logger = logging.getLogger(__name__)

_DONE_SENTINEL = "[DONE]"
_T = TypeVar("_T", bound=BaseModel)
DEFAULT_MAX_RETRIES = 2
_RETRY_BASE_DELAY = 0.5
_RETRY_MAX_DELAY = 60.0
# Same retry rules as _request(): idempotent methods retry on both 429 and 5xx;
# POST/PATCH only retry on 429 (server guarantees the request was not processed).
_IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "DELETE", "OPTIONS", "PUT"})


def _should_retry(exc: APIStatusError, method: str) -> bool:
if not isinstance(exc, (RateLimitError, ServerError)):
return False
# POST/PATCH are non-idempotent: only retry on 429 where the server
# guarantees the request was not processed.
if method.upper() not in _IDEMPOTENT_METHODS and isinstance(exc, ServerError):
return False
return True


def _parse_retry_after(response: httpx.Response) -> float | None:
import datetime
import email.utils

raw = response.headers.get("retry-after")
if raw is None:
return None
try:
return float(raw)
except ValueError:
pass
try:
dt = email.utils.parsedate_to_datetime(raw)
delta = (dt - datetime.datetime.now(tz=datetime.timezone.utc)).total_seconds()
return max(delta, 0.0)
except Exception:
return None


def _retry_delay(attempt: int, retry_after: float | None) -> float:
if retry_after is not None:
return retry_after
cap = min(_RETRY_BASE_DELAY * (2**attempt), _RETRY_MAX_DELAY)
return random.uniform(0, cap)


def _auth_headers(access_token: str | None, user_id: str | None) -> dict[str, str]:
Expand Down Expand Up @@ -71,11 +119,13 @@ def __init__(
user_id: str | None,
timeout: float,
stream_timeout: float | None,
max_retries: int = DEFAULT_MAX_RETRIES,
) -> None:
self.access_token = access_token
self.user_id = user_id
self._stream_base_url = stream_base_url.rstrip("/")
self._stream_timeout = stream_timeout
self._max_retries = max_retries
self._client = httpx.Client(base_url=base_url.rstrip("/"), timeout=timeout)

def close(self) -> None:
Expand Down Expand Up @@ -137,20 +187,28 @@ def stream(
json: Any = None,
params: Mapping[str, Any] | None = None,
) -> Iterator[StreamEvent]:
# Retry only covers the initial connection handshake — once the server
# starts sending events we are committed to this response and cannot
# restart without re-delivering events to the caller.
headers = {**_auth_headers(self.access_token, self.user_id), "Accept": "text/event-stream"}
with self._client.stream(
method,
f"{self._stream_base_url}/{path.lstrip('/')}",
headers=headers,
timeout=self._stream_timeout,
json=json,
params=_omit_none_params(params),
) as response:
if response.status_code >= 400:
response.read()
raise to_api_error(response)
for event in EventSource(response).iter_sse():
yield _decode_sse(event)
url = f"{self._stream_base_url}/{path.lstrip('/')}"
for attempt in range(self._max_retries + 1):
with self._client.stream(
method, url, headers=headers, timeout=self._stream_timeout,
json=json, params=_omit_none_params(params),
) as response:
if response.status_code >= 400:
response.read()
exc = to_api_error(response)
if attempt < self._max_retries and _should_retry(exc, method):
delay = _retry_delay(attempt, _parse_retry_after(response))
logger.debug("retrying stream (attempt %d, delay %.2fs)", attempt + 1, delay)
time.sleep(delay)
continue
raise exc
for event in EventSource(response).iter_sse():
yield _decode_sse(event)
return

def stream_typed(
self,
Expand All @@ -164,29 +222,34 @@ def stream_typed(
# Skips the StreamEvent envelope and honors OpenRouter's `data: [DONE]`
# terminator. Unparseable events (keep-alives, comments) are skipped.
headers = {**_auth_headers(self.access_token, self.user_id), "Accept": "text/event-stream"}
with self._client.stream(
method,
f"{self._stream_base_url}/{path.lstrip('/')}",
headers=headers,
timeout=self._stream_timeout,
json=json,
params=_omit_none_params(params),
) as response:
if response.status_code >= 400:
response.read()
raise to_api_error(response)
for event in EventSource(response).iter_sse():
if event.data == _DONE_SENTINEL:
return
if not event.data:
continue
try:
yield response_model.model_validate_json(event.data)
except ValidationError:
# Server-side mid-stream error frames or schema-drift events
# land here.
logger.debug("dropped non-%s SSE: %s", response_model.__name__, event.data)
continue
url = f"{self._stream_base_url}/{path.lstrip('/')}"
for attempt in range(self._max_retries + 1):
with self._client.stream(
method, url, headers=headers, timeout=self._stream_timeout,
json=json, params=_omit_none_params(params),
) as response:
if response.status_code >= 400:
response.read()
exc = to_api_error(response)
if attempt < self._max_retries and _should_retry(exc, method):
delay = _retry_delay(attempt, _parse_retry_after(response))
logger.debug("retrying stream_typed (attempt %d, delay %.2fs)", attempt + 1, delay)
time.sleep(delay)
continue
raise exc
for event in EventSource(response).iter_sse():
if event.data == _DONE_SENTINEL:
return
if not event.data:
continue
try:
yield response_model.model_validate_json(event.data)
except ValidationError:
# Server-side mid-stream error frames or schema-drift events
# land here.
logger.debug("dropped non-%s SSE: %s", response_model.__name__, event.data)
continue
return

def _request(self, method: str, path: str, **kwargs: Any) -> Any:
# Headers are rebuilt per request so ``access_token`` / ``user_id``
Expand All @@ -212,11 +275,13 @@ def __init__(
user_id: str | None,
timeout: float,
stream_timeout: float | None,
max_retries: int = DEFAULT_MAX_RETRIES,
) -> None:
self.access_token = access_token
self.user_id = user_id
self._stream_base_url = stream_base_url.rstrip("/")
self._stream_timeout = stream_timeout
self._max_retries = max_retries
self._client = httpx.AsyncClient(base_url=base_url.rstrip("/"), timeout=timeout)

async def aclose(self) -> None:
Expand Down Expand Up @@ -276,19 +341,24 @@ async def stream(
params: Mapping[str, Any] | None = None,
) -> AsyncIterator[StreamEvent]:
headers = {**_auth_headers(self.access_token, self.user_id), "Accept": "text/event-stream"}
async with self._client.stream(
method,
f"{self._stream_base_url}/{path.lstrip('/')}",
headers=headers,
timeout=self._stream_timeout,
json=json,
params=_omit_none_params(params),
) as response:
if response.status_code >= 400:
await response.aread()
raise to_api_error(response)
async for event in EventSource(response).aiter_sse():
yield _decode_sse(event)
url = f"{self._stream_base_url}/{path.lstrip('/')}"
for attempt in range(self._max_retries + 1):
async with self._client.stream(
method, url, headers=headers, timeout=self._stream_timeout,
json=json, params=_omit_none_params(params),
) as response:
if response.status_code >= 400:
await response.aread()
exc = to_api_error(response)
if attempt < self._max_retries and _should_retry(exc, method):
delay = _retry_delay(attempt, _parse_retry_after(response))
logger.debug("retrying stream (attempt %d, delay %.2fs)", attempt + 1, delay)
await asyncio.sleep(delay)
continue
raise exc
async for event in EventSource(response).aiter_sse():
yield _decode_sse(event)
return

async def stream_typed(
self,
Expand All @@ -300,29 +370,34 @@ async def stream_typed(
params: Mapping[str, Any] | None = None,
) -> AsyncIterator[_T]:
headers = {**_auth_headers(self.access_token, self.user_id), "Accept": "text/event-stream"}
async with self._client.stream(
method,
f"{self._stream_base_url}/{path.lstrip('/')}",
headers=headers,
timeout=self._stream_timeout,
json=json,
params=_omit_none_params(params),
) as response:
if response.status_code >= 400:
await response.aread()
raise to_api_error(response)
async for event in EventSource(response).aiter_sse():
if event.data == _DONE_SENTINEL:
return
if not event.data:
continue
try:
yield response_model.model_validate_json(event.data)
except ValidationError:
# Server-side mid-stream error frames or schema-drift events
# land here.
logger.debug("dropped non-%s SSE: %s", response_model.__name__, event.data)
continue
url = f"{self._stream_base_url}/{path.lstrip('/')}"
for attempt in range(self._max_retries + 1):
async with self._client.stream(
method, url, headers=headers, timeout=self._stream_timeout,
json=json, params=_omit_none_params(params),
) as response:
if response.status_code >= 400:
await response.aread()
exc = to_api_error(response)
if attempt < self._max_retries and _should_retry(exc, method):
delay = _retry_delay(attempt, _parse_retry_after(response))
logger.debug("retrying stream_typed (attempt %d, delay %.2fs)", attempt + 1, delay)
await asyncio.sleep(delay)
continue
raise exc
async for event in EventSource(response).aiter_sse():
if event.data == _DONE_SENTINEL:
return
if not event.data:
continue
try:
yield response_model.model_validate_json(event.data)
except ValidationError:
# Server-side mid-stream error frames or schema-drift events
# land here.
logger.debug("dropped non-%s SSE: %s", response_model.__name__, event.data)
continue
return

async def _request(self, method: str, path: str, **kwargs: Any) -> Any:
headers = _auth_headers(self.access_token, self.user_id)
Expand Down
Loading