Skip to content

Commit 89ba250

Browse files
authored
fix: uncapped and eager backoff calculation (#52)
1 parent bb1043b commit 89ba250

6 files changed

Lines changed: 75 additions & 189 deletions

File tree

src/s2_sdk/_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __init__(
125125
self._basin_clients: dict[str, HttpClient] = {}
126126
self._retrier = Retrier(
127127
should_retry_on=http_retry_on,
128-
max_attempts=retry.max_attempts,
128+
max_retries=retry._max_retries(),
129129
min_base_delay=retry.min_base_delay.total_seconds(),
130130
max_base_delay=retry.max_base_delay.total_seconds(),
131131
)
@@ -568,7 +568,7 @@ def __init__(
568568
self._compression = compression
569569
self._retrier = Retrier(
570570
should_retry_on=http_retry_on,
571-
max_attempts=retry.max_attempts,
571+
max_retries=retry._max_retries(),
572572
min_base_delay=retry.min_base_delay.total_seconds(),
573573
max_base_delay=retry.max_base_delay.total_seconds(),
574574
)
@@ -806,15 +806,15 @@ def __init__(
806806
self._encryption_key = encryption_key
807807
self._retrier = Retrier(
808808
should_retry_on=http_retry_on,
809-
max_attempts=retry.max_attempts,
809+
max_retries=retry._max_retries(),
810810
min_base_delay=retry.min_base_delay.total_seconds(),
811811
max_base_delay=retry.max_base_delay.total_seconds(),
812812
)
813813
self._append_retrier = Retrier(
814814
should_retry_on=lambda e: is_safe_to_retry_unary(
815815
e, retry.append_retry_policy
816816
),
817-
max_attempts=retry.max_attempts,
817+
max_retries=retry._max_retries(),
818818
min_base_delay=retry.min_base_delay.total_seconds(),
819819
max_base_delay=retry.max_base_delay.total_seconds(),
820820
)

src/s2_sdk/_retrier.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
import math
34
import random
45
from dataclasses import dataclass
56
from typing import Callable
@@ -15,36 +16,36 @@ class Retrier:
1516
def __init__(
1617
self,
1718
should_retry_on: Callable[[Exception], bool],
18-
max_attempts: int,
19+
max_retries: int,
1920
min_base_delay: float = 0.1,
2021
max_base_delay: float = 1.0,
2122
):
2223
self.should_retry_on = should_retry_on
23-
self.max_attempts = max_attempts
24+
self.max_retries = max_retries
2425
self.min_base_delay = min_base_delay
2526
self.max_base_delay = max_base_delay
2627

2728
async def __call__(self, f: Callable, *args, **kwargs):
28-
backoffs = compute_backoffs(
29-
attempts=max(self.max_attempts - 1, 0),
30-
min_base_delay=self.min_base_delay,
31-
max_base_delay=self.max_base_delay,
32-
)
29+
max_retries = self.max_retries
3330
attempt = 0
3431
while True:
3532
try:
3633
return await f(*args, **kwargs)
3734
except Exception as e:
38-
if attempt < len(backoffs) and self.should_retry_on(e):
39-
delay = backoffs[attempt]
35+
if attempt < max_retries and self.should_retry_on(e):
36+
delay = compute_backoff(
37+
attempt,
38+
min_base_delay=self.min_base_delay,
39+
max_base_delay=self.max_base_delay,
40+
)
4041
retry_after = getattr(e, "_retry_after", None)
4142
if retry_after is not None:
4243
delay = max(delay, retry_after)
4344
logger.debug(
4445
"retrying request: error=%s backoff=%.3fs retries_remaining=%d",
4546
e,
4647
delay,
47-
len(backoffs) - attempt - 1,
48+
max_retries - attempt - 1,
4849
)
4950
await asyncio.sleep(delay)
5051
attempt += 1
@@ -53,7 +54,7 @@ async def __call__(self, f: Callable, *args, **kwargs):
5354
"not retrying request: error=%s is_retryable=%s retries_exhausted=%s",
5455
e,
5556
self.should_retry_on(e),
56-
attempt >= len(backoffs),
57+
attempt >= max_retries,
5758
)
5859
raise e
5960

@@ -63,17 +64,17 @@ class Attempt:
6364
value: int
6465

6566

66-
def compute_backoffs(
67-
attempts: int,
67+
def compute_backoff(
68+
attempt: int,
6869
min_base_delay: float = 0.1,
6970
max_base_delay: float = 1.0,
70-
) -> list[float]:
71-
backoffs = []
72-
for n in range(attempts):
73-
base_delay = min(min_base_delay * 2**n, max_base_delay)
74-
jitter = random.uniform(0, base_delay)
75-
backoffs.append(base_delay + jitter)
76-
return backoffs
71+
) -> float:
72+
try:
73+
base_delay = min(math.ldexp(min_base_delay, attempt), max_base_delay)
74+
except OverflowError:
75+
base_delay = max_base_delay
76+
jitter = random.uniform(0, base_delay)
77+
return base_delay + jitter
7778

7879

7980
def is_safe_to_retry_unary(

src/s2_sdk/_s2s/_append_session.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from s2_sdk._exceptions import ReadTimeoutError, S2ClientError
1010
from s2_sdk._frame_signal import FrameSignal
1111
from s2_sdk._mappers import append_ack_from_proto, append_input_to_proto
12-
from s2_sdk._retrier import Attempt, compute_backoffs, is_safe_to_retry_session
12+
from s2_sdk._retrier import Attempt, compute_backoff, is_safe_to_retry_session
1313
from s2_sdk._s2s import _stream_records_path
1414
from s2_sdk._s2s._protocol import (
1515
Message,
@@ -64,11 +64,9 @@ async def pipe_inputs():
6464

6565
async def retrying_inner():
6666
inflight_inputs: deque[_InflightInput] = deque()
67-
backoffs = compute_backoffs(
68-
retry._max_retries(),
69-
min_base_delay=retry.min_base_delay.total_seconds(),
70-
max_base_delay=retry.max_base_delay.total_seconds(),
71-
)
67+
max_retries = retry._max_retries()
68+
min_base_delay = retry.min_base_delay.total_seconds()
69+
max_base_delay = retry.max_base_delay.total_seconds()
7270
attempt = Attempt(0)
7371
try:
7472
while True:
@@ -92,26 +90,30 @@ async def retrying_inner():
9290
return
9391
except Exception as e:
9492
has_inflight = len(inflight_inputs) > 0
95-
if attempt.value < len(backoffs) and is_safe_to_retry_session(
93+
if attempt.value < max_retries and is_safe_to_retry_session(
9694
e,
9795
retry.append_retry_policy,
9896
has_inflight,
9997
frame_signal,
10098
):
101-
backoff = backoffs[attempt.value]
99+
backoff = compute_backoff(
100+
attempt.value,
101+
min_base_delay=min_base_delay,
102+
max_base_delay=max_base_delay,
103+
)
102104
logger.debug(
103105
"retrying append session: error=%s backoff=%.3fs retries_remaining=%d",
104106
e,
105107
backoff,
106-
len(backoffs) - attempt.value - 1,
108+
max_retries - attempt.value - 1,
107109
)
108110
await asyncio.sleep(backoff)
109111
attempt.value += 1
110112
else:
111113
logger.debug(
112114
"not retrying append session: error=%s retries_exhausted=%s",
113115
e,
114-
attempt.value >= len(backoffs),
116+
attempt.value >= max_retries,
115117
)
116118
raise
117119
finally:

src/s2_sdk/_s2s/_read_session.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from s2_sdk._client import HttpClient
99
from s2_sdk._exceptions import ReadTimeoutError
1010
from s2_sdk._mappers import read_batch_from_proto, read_limit_params, read_start_params
11-
from s2_sdk._retrier import Attempt, compute_backoffs, http_retry_on
11+
from s2_sdk._retrier import Attempt, compute_backoff, http_retry_on
1212
from s2_sdk._s2s import _stream_records_path
1313
from s2_sdk._s2s._protocol import parse_error_info, read_messages
1414
from s2_sdk._types import (
@@ -40,11 +40,9 @@ async def run_read_session(
4040
encryption_key: str | None = None,
4141
) -> AsyncIterable[ReadBatch]:
4242
params = _build_read_params(start, limit, until_timestamp, clamp_to_tail, wait)
43-
backoffs = compute_backoffs(
44-
retry._max_retries(),
45-
min_base_delay=retry.min_base_delay.total_seconds(),
46-
max_base_delay=retry.max_base_delay.total_seconds(),
47-
)
43+
max_retries = retry._max_retries()
44+
min_base_delay = retry.min_base_delay.total_seconds()
45+
max_base_delay = retry.max_base_delay.total_seconds()
4846
attempt = Attempt(0)
4947

5048
remaining_count = limit.count if limit and limit.count is not None else None
@@ -122,13 +120,17 @@ async def run_read_session(
122120

123121
return
124122
except Exception as e:
125-
if attempt.value < len(backoffs) and http_retry_on(e):
126-
backoff = backoffs[attempt.value]
123+
if attempt.value < max_retries and http_retry_on(e):
124+
backoff = compute_backoff(
125+
attempt.value,
126+
min_base_delay=min_base_delay,
127+
max_base_delay=max_base_delay,
128+
)
127129
logger.debug(
128130
"retrying read session: error=%s backoff=%.3fs retries_remaining=%d",
129131
e,
130132
backoff,
131-
len(backoffs) - attempt.value - 1,
133+
max_retries - attempt.value - 1,
132134
)
133135
await asyncio.sleep(backoff)
134136
attempt.value += 1
@@ -137,7 +139,7 @@ async def run_read_session(
137139
"not retrying read session: error=%s is_retryable=%s retries_exhausted=%s",
138140
e,
139141
http_retry_on(e),
140-
attempt.value >= len(backoffs),
142+
attempt.value >= max_retries,
141143
)
142144
raise e
143145

tests/test_retrier.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import sys
2+
3+
import pytest
4+
5+
from s2_sdk._retrier import compute_backoff
6+
7+
8+
class TestComputeBackoff:
9+
@pytest.mark.parametrize(
10+
("attempt", "expected_min", "expected_max"),
11+
[
12+
(0, 0.1, 0.2),
13+
(1, 0.2, 0.4),
14+
(2, 0.4, 0.8),
15+
(3, 0.8, 1.6),
16+
(4, 1.0, 2.0),
17+
(5, 1.0, 2.0),
18+
],
19+
)
20+
def test_backoff_range(self, attempt, expected_min, expected_max):
21+
backoff = compute_backoff(attempt, min_base_delay=0.1, max_base_delay=1.0)
22+
assert expected_min <= backoff <= expected_max
23+
24+
def test_backoff_caps_for_max_int_attempt(self):
25+
backoff = compute_backoff(sys.maxsize, min_base_delay=0.1, max_base_delay=1.0)
26+
assert 1.0 <= backoff <= 2.0

0 commit comments

Comments
 (0)