Skip to content

Commit 562e3a1

Browse files
replaced the retry backoff with tenacity
1 parent 57416a7 commit 562e3a1

File tree

4 files changed

+226
-123
lines changed

4 files changed

+226
-123
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ authors = [
77
]
88
dependencies = [
99
"requests>=2.32.3",
10+
"tenacity>=8.2.0",
1011
]
1112
requires-python = ">=3.11"
1213
readme = "README.md"

src/unstract/api_deployments/client.py

Lines changed: 126 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,20 @@
1010
import logging
1111
import ntpath
1212
import os
13-
import random
1413
import time
1514
from urllib.parse import urlparse
1615

1716
import requests
1817
from requests.exceptions import ConnectionError, JSONDecodeError, Timeout
18+
from tenacity import (
19+
RetryCallState,
20+
Retrying,
21+
retry_if_exception_type,
22+
retry_if_result,
23+
stop_after_attempt,
24+
wait_exponential_jitter,
25+
)
26+
from tenacity.wait import wait_base
1927

2028
from unstract.api_deployments.utils import UnstractUtils
2129

@@ -34,6 +42,43 @@ def error_message(self):
3442
return self.value
3543

3644

45+
class _WaitRetryAfterOrExponentialJitter(wait_base):
46+
"""Wait strategy that respects Retry-After on 429, else exponential jitter.
47+
48+
For 429 responses with a valid ``Retry-After`` header the server-requested
49+
delay is used. In every other case the strategy delegates to
50+
``wait_exponential_jitter`` (additive jitter).
51+
"""
52+
53+
def __init__(
54+
self,
55+
initial: float,
56+
max: float,
57+
exp_base: float,
58+
jitter: float,
59+
) -> None:
60+
super().__init__()
61+
self._exp_jitter = wait_exponential_jitter(
62+
initial=initial, max=max, exp_base=exp_base, jitter=jitter
63+
)
64+
65+
def __call__(self, retry_state: RetryCallState) -> float:
66+
outcome = retry_state.outcome
67+
if outcome and not outcome.failed:
68+
response = outcome.result()
69+
if (
70+
response is not None
71+
and getattr(response, "status_code", None) == 429
72+
):
73+
retry_after = response.headers.get("Retry-After")
74+
if retry_after is not None:
75+
try:
76+
return float(retry_after)
77+
except (ValueError, TypeError):
78+
pass
79+
return self._exp_jitter(retry_state)
80+
81+
3782
class APIDeploymentsClient:
3883
"""A class to invoke APIs deployed on the Unstract platform."""
3984

@@ -61,6 +106,7 @@ def __init__(
61106
initial_delay: float = 2.0,
62107
max_delay: float = 60.0,
63108
backoff_factor: float = 2.0,
109+
jitter: float = 1.0,
64110
):
65111
"""Initializes the APIClient class.
66112
@@ -72,6 +118,7 @@ def __init__(
72118
initial_delay (float): Initial delay in seconds before the first retry.
73119
max_delay (float): Maximum delay in seconds between retries.
74120
backoff_factor (float): Multiplier applied to delay for each retry.
121+
jitter (float): Maximum additive jitter in seconds added to each delay.
75122
"""
76123
if logging_level == "":
77124
logging_level = os.getenv("UNSTRACT_API_CLIENT_LOGGING_LEVEL", "INFO")
@@ -102,6 +149,7 @@ def __init__(
102149
self.initial_delay = initial_delay
103150
self.max_delay = max_delay
104151
self.backoff_factor = backoff_factor
152+
self.jitter = jitter
105153

106154
def _is_retryable_status(self, status_code: int) -> bool:
107155
"""Checks whether a status code should trigger a retry.
@@ -124,37 +172,6 @@ def __save_base_url(self, full_url: str):
124172
self.base_url = parsed_url.scheme + "://" + parsed_url.netloc
125173
self.logger.debug("Base URL: " + self.base_url)
126174

127-
def _calculate_delay(self, attempt: int) -> float:
128-
"""Calculates the delay before the next retry using exponential backoff
129-
with full jitter.
130-
131-
Args:
132-
attempt (int): The current retry attempt number (0-indexed).
133-
134-
Returns:
135-
float: The delay in seconds.
136-
"""
137-
exp_delay = min(
138-
self.initial_delay * (self.backoff_factor**attempt), self.max_delay
139-
)
140-
# Full jitter: randomize between 0 and exp_delay to avoid thundering herd
141-
return random.uniform(0, exp_delay)
142-
143-
def _get_retry_delay(self, response, attempt: int) -> float:
144-
"""Returns the delay before the next retry.
145-
146-
For 429 responses, respects the Retry-After header if present.
147-
Otherwise falls back to exponential backoff with jitter.
148-
"""
149-
if response is not None and response.status_code == 429:
150-
retry_after = response.headers.get("Retry-After")
151-
if retry_after is not None:
152-
try:
153-
return float(retry_after)
154-
except (ValueError, TypeError):
155-
pass
156-
return self._calculate_delay(attempt)
157-
158175
@staticmethod
159176
def _rewind_files(files):
160177
"""Rewinds file objects so they can be re-sent on retry."""
@@ -169,6 +186,8 @@ def _rewind_files(files):
169186
def _request_with_retry(self, method: str, url: str, **kwargs) -> requests.Response:
170187
"""Makes an HTTP request with exponential backoff retry logic.
171188
189+
Uses ``tenacity`` with additive jitter and Retry-After support.
190+
172191
Args:
173192
method (str): The HTTP method (e.g., "GET", "POST").
174193
url (str): The request URL.
@@ -181,67 +200,81 @@ def _request_with_retry(self, method: str, url: str, **kwargs) -> requests.Respo
181200
ConnectionError: If a connection error persists after all retries.
182201
Timeout: If a timeout persists after all retries.
183202
"""
184-
response = None
185-
186-
for attempt in range(self.max_retries + 1):
187-
# Rewind file objects for retry attempts
188-
if attempt > 0:
189-
files = kwargs.get("files")
190-
if files:
191-
self._rewind_files(files)
192-
193-
try:
194-
response = requests.request(method, url, **kwargs)
195-
196-
if not self._is_retryable_status(response.status_code):
197-
return response
198-
199-
if attempt < self.max_retries:
200-
delay = self._get_retry_delay(response, attempt)
201-
self.logger.warning(
202-
"Request to %s returned %d. Retrying in %.1fs "
203-
"(attempt %d/%d).",
204-
url,
205-
response.status_code,
206-
delay,
207-
attempt + 1,
208-
self.max_retries,
209-
)
210-
time.sleep(delay)
211-
else:
212-
self.logger.warning(
213-
"Request to %s returned %d. Retries exhausted (%d/%d).",
214-
url,
215-
response.status_code,
216-
self.max_retries,
217-
self.max_retries,
218-
)
219-
220-
except (ConnectionError, Timeout) as exc:
221-
response = None
222-
if attempt < self.max_retries:
223-
delay = self._get_retry_delay(None, attempt)
224-
self.logger.warning(
225-
"%s during request to %s. Retrying in %.1fs "
226-
"(attempt %d/%d).",
227-
type(exc).__name__,
228-
url,
229-
delay,
230-
attempt + 1,
231-
self.max_retries,
232-
)
233-
time.sleep(delay)
234-
else:
235-
self.logger.warning(
236-
"%s during request to %s. Retries exhausted (%d/%d).",
237-
type(exc).__name__,
238-
url,
239-
self.max_retries,
240-
self.max_retries,
241-
)
242-
raise
243-
244-
return response
203+
files = kwargs.get("files")
204+
205+
def _before_sleep(retry_state: RetryCallState):
206+
attempt = retry_state.attempt_number
207+
delay = retry_state.next_action.sleep
208+
outcome = retry_state.outcome
209+
if outcome.failed:
210+
exc = outcome.exception()
211+
self.logger.warning(
212+
"%s during request to %s. Retrying in %.1fs "
213+
"(attempt %d/%d).",
214+
type(exc).__name__,
215+
url,
216+
delay,
217+
attempt,
218+
self.max_retries,
219+
)
220+
else:
221+
response = outcome.result()
222+
self.logger.warning(
223+
"Request to %s returned %d. Retrying in %.1fs "
224+
"(attempt %d/%d).",
225+
url,
226+
response.status_code,
227+
delay,
228+
attempt,
229+
self.max_retries,
230+
)
231+
# Rewind file objects before next attempt
232+
if files:
233+
self._rewind_files(files)
234+
235+
def _retry_error_callback(retry_state: RetryCallState):
236+
outcome = retry_state.outcome
237+
if outcome.failed:
238+
exc = outcome.exception()
239+
self.logger.warning(
240+
"%s during request to %s. Retries exhausted (%d/%d).",
241+
type(exc).__name__,
242+
url,
243+
self.max_retries,
244+
self.max_retries,
245+
)
246+
raise exc
247+
response = outcome.result()
248+
self.logger.warning(
249+
"Request to %s returned %d. Retries exhausted (%d/%d).",
250+
url,
251+
response.status_code,
252+
self.max_retries,
253+
self.max_retries,
254+
)
255+
return response
256+
257+
retrier = Retrying(
258+
stop=stop_after_attempt(self.max_retries + 1),
259+
wait=_WaitRetryAfterOrExponentialJitter(
260+
initial=self.initial_delay,
261+
max=self.max_delay,
262+
exp_base=self.backoff_factor,
263+
jitter=self.jitter,
264+
),
265+
retry=(
266+
retry_if_result(
267+
lambda r: self._is_retryable_status(r.status_code)
268+
)
269+
| retry_if_exception_type((ConnectionError, Timeout))
270+
),
271+
before_sleep=_before_sleep,
272+
retry_error_callback=_retry_error_callback,
273+
sleep=time.sleep,
274+
reraise=False,
275+
)
276+
277+
return retrier(requests.request, method, url, **kwargs)
245278

246279
def structure_file(self, file_paths: list[str]) -> dict:
247280
"""Invokes the API deployed on the Unstract platform.

0 commit comments

Comments
 (0)