Skip to content

Commit 89528bd

Browse files
committed
feat(sdk): align linear retry strategy with exponential
Signed-off-by: Sai Asish Y <say.apm35@gmail.com>
1 parent 277bd28 commit 89528bd

3 files changed

Lines changed: 260 additions & 53 deletions

File tree

packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/retries.py

Lines changed: 101 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -71,84 +71,135 @@ def max_delay_seconds(self) -> int:
7171
return self.max_delay.to_seconds()
7272

7373

74+
@dataclass
75+
class LinearRetryStrategyConfig:
76+
max_attempts: int = 6
77+
initial_delay: Duration = field(default_factory=lambda: Duration.from_seconds(1))
78+
increment: Duration = field(default_factory=lambda: Duration.from_seconds(1))
79+
max_delay: Duration = field(default_factory=lambda: Duration.from_minutes(5))
80+
jitter_strategy: JitterStrategy = field(default=JitterStrategy.FULL)
81+
retryable_errors: list[str | re.Pattern] | None = None
82+
retryable_error_types: list[type[Exception]] | None = None
83+
84+
@property
85+
def initial_delay_seconds(self) -> int:
86+
"""Get initial delay in seconds."""
87+
return self.initial_delay.to_seconds()
88+
89+
@property
90+
def increment_seconds(self) -> int:
91+
"""Get increment in seconds."""
92+
return self.increment.to_seconds()
93+
94+
@property
95+
def max_delay_seconds(self) -> int:
96+
"""Get max delay in seconds."""
97+
return self.max_delay.to_seconds()
98+
99+
100+
def _resolve_retryable_errors(
101+
retryable_errors: list[str | re.Pattern] | None,
102+
retryable_error_types: list[type[Exception]] | None,
103+
) -> tuple[list[str | re.Pattern], list[type[Exception]]]:
104+
"""Resolve the error filters, applying the match-all default only when neither is set."""
105+
should_use_default_errors: bool = (
106+
retryable_errors is None and retryable_error_types is None
107+
)
108+
resolved_errors: list[str | re.Pattern] = (
109+
retryable_errors
110+
if retryable_errors is not None
111+
else ([_DEFAULT_RETRYABLE_ERROR_PATTERN] if should_use_default_errors else [])
112+
)
113+
resolved_error_types: list[type[Exception]] = retryable_error_types or []
114+
return resolved_errors, resolved_error_types
115+
116+
117+
def _is_error_retryable(
118+
error: Exception,
119+
retryable_errors: list[str | re.Pattern],
120+
retryable_error_types: list[type[Exception]],
121+
) -> bool:
122+
"""Return True when the error matches one of the message patterns or types."""
123+
is_retryable_error_message: bool = any(
124+
pattern.search(str(error))
125+
if isinstance(pattern, re.Pattern)
126+
else pattern in str(error)
127+
for pattern in retryable_errors
128+
)
129+
is_retryable_error_type: bool = any(
130+
isinstance(error, error_type) for error_type in retryable_error_types
131+
)
132+
return is_retryable_error_message or is_retryable_error_type
133+
134+
135+
def _finalize_delay_seconds(base_delay: float, jitter_strategy: JitterStrategy) -> int:
136+
"""Apply jitter, round up, and clamp to a minimum of 1 second."""
137+
delay_with_jitter: float = jitter_strategy.apply_jitter(base_delay)
138+
return max(1, math.ceil(delay_with_jitter))
139+
140+
74141
def create_retry_strategy(
75142
config: RetryStrategyConfig | None = None,
76143
) -> Callable[[Exception, int], RetryDecision]:
77144
if config is None:
78145
config = RetryStrategyConfig()
79146

80-
# Apply default retryableErrors only if user didn't specify either filter
81-
should_use_default_errors: bool = (
82-
config.retryable_errors is None and config.retryable_error_types is None
83-
)
84-
85-
retryable_errors: list[str | re.Pattern] = (
86-
config.retryable_errors
87-
if config.retryable_errors is not None
88-
else ([_DEFAULT_RETRYABLE_ERROR_PATTERN] if should_use_default_errors else [])
147+
retryable_errors, retryable_error_types = _resolve_retryable_errors(
148+
config.retryable_errors, config.retryable_error_types
89149
)
90-
retryable_error_types: list[type[Exception]] = config.retryable_error_types or []
91150

92151
def retry_strategy(error: Exception, attempts_made: int) -> RetryDecision:
93152
# Check if we've exceeded max attempts
94153
if attempts_made >= config.max_attempts:
95154
return RetryDecision.no_retry()
96155

97-
# Check if error is retryable based on error message
98-
is_retryable_error_message: bool = any(
99-
pattern.search(str(error))
100-
if isinstance(pattern, re.Pattern)
101-
else pattern in str(error)
102-
for pattern in retryable_errors
103-
)
104-
105-
# Check if error is retryable based on error type
106-
is_retryable_error_type: bool = any(
107-
isinstance(error, error_type) for error_type in retryable_error_types
108-
)
109-
110-
if not is_retryable_error_message and not is_retryable_error_type:
156+
if not _is_error_retryable(error, retryable_errors, retryable_error_types):
111157
return RetryDecision.no_retry()
112158

113159
# Calculate delay with exponential backoff
114160
base_delay: float = min(
115161
config.initial_delay_seconds * (config.backoff_rate ** (attempts_made - 1)),
116162
config.max_delay_seconds,
117163
)
118-
# Apply jitter to get final delay
119-
delay_with_jitter: float = config.jitter_strategy.apply_jitter(base_delay)
120-
# Round up and ensure minimum of 1 second
121-
final_delay: int = max(1, math.ceil(delay_with_jitter))
164+
final_delay: int = _finalize_delay_seconds(base_delay, config.jitter_strategy)
122165

123166
return RetryDecision.retry(Duration(seconds=final_delay))
124167

125168
return retry_strategy
126169

127170

128171
def create_linear_retry_strategy(
129-
max_attempts: int = 6,
130-
initial_delay: Duration | None = None,
131-
increment: Duration | None = None,
172+
config: LinearRetryStrategyConfig | None = None,
132173
) -> Callable[[Exception, int], RetryDecision]:
133-
"""Linearly increasing delay between retries: initial + increment * (attempts_made - 1).
174+
"""Linearly increasing delay between retries.
134175
135-
Mirrors the JS SDK's ``createLinearRetryStrategy``. With the defaults this
136-
yields delays of 1s, 2s, 3s, 4s, 5s. No jitter is applied and there is no
137-
upper cap on the delay; callers who need either can build their own
138-
strategy via ``create_retry_strategy``.
176+
The base delay is ``initial_delay + increment * (attempts_made - 1)``,
177+
capped at ``max_delay``, with jitter and error filtering applied the same
178+
way as :func:`create_retry_strategy`. Mirrors the JS SDK's
179+
``createLinearRetryStrategy``.
139180
"""
140-
initial: Duration = (
141-
initial_delay if initial_delay is not None else Duration.from_seconds(1)
181+
if config is None:
182+
config = LinearRetryStrategyConfig()
183+
184+
retryable_errors, retryable_error_types = _resolve_retryable_errors(
185+
config.retryable_errors, config.retryable_error_types
142186
)
143-
step: Duration = increment if increment is not None else Duration.from_seconds(1)
144187

145-
def linear_retry_strategy(_error: Exception, attempts_made: int) -> RetryDecision:
146-
if attempts_made >= max_attempts:
188+
def linear_retry_strategy(error: Exception, attempts_made: int) -> RetryDecision:
189+
if attempts_made >= config.max_attempts:
190+
return RetryDecision.no_retry()
191+
192+
if not _is_error_retryable(error, retryable_errors, retryable_error_types):
147193
return RetryDecision.no_retry()
148-
delay_seconds: int = initial.to_seconds() + step.to_seconds() * (
149-
attempts_made - 1
194+
195+
base_delay: float = min(
196+
config.initial_delay_seconds
197+
+ config.increment_seconds * (attempts_made - 1),
198+
config.max_delay_seconds,
150199
)
151-
return RetryDecision.retry(Duration(seconds=delay_seconds))
200+
final_delay: int = _finalize_delay_seconds(base_delay, config.jitter_strategy)
201+
202+
return RetryDecision.retry(Duration(seconds=final_delay))
152203

153204
return linear_retry_strategy
154205

@@ -212,9 +263,12 @@ def critical(cls) -> Callable[[Exception, int], RetryDecision]:
212263
def linear(cls) -> Callable[[Exception, int], RetryDecision]:
213264
"""Linearly increasing delay between retries: 1s, 2s, 3s, 4s, 5s."""
214265
return create_linear_retry_strategy(
215-
max_attempts=6,
216-
initial_delay=Duration.from_seconds(1),
217-
increment=Duration.from_seconds(1),
266+
LinearRetryStrategyConfig(
267+
max_attempts=6,
268+
initial_delay=Duration.from_seconds(1),
269+
increment=Duration.from_seconds(1),
270+
jitter_strategy=JitterStrategy.NONE,
271+
)
218272
)
219273

220274
@classmethod

packages/aws-durable-execution-sdk-python/tests/retries_test.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aws_durable_execution_sdk_python.config import Duration
99
from aws_durable_execution_sdk_python.retries import (
1010
JitterStrategy,
11+
LinearRetryStrategyConfig,
1112
RetryDecision,
1213
RetryPresets,
1314
RetryStrategyConfig,
@@ -580,8 +581,10 @@ def test_mixed_error_types_and_patterns():
580581
# region create_linear_retry_strategy
581582

582583

583-
def test_linear_retry_strategy_uses_additive_formula():
584-
"""Default config yields delays of 1s, 2s, 3s, 4s, 5s with no jitter."""
584+
@patch("random.random")
585+
def test_linear_retry_strategy_uses_additive_formula(mock_random):
586+
"""Default config yields additive delays of 1s, 2s, 3s, 4s, 5s."""
587+
mock_random.return_value = 1.0 # FULL jitter at the upper bound keeps the base
585588
strategy = create_linear_retry_strategy()
586589

587590
delays = [
@@ -593,7 +596,7 @@ def test_linear_retry_strategy_uses_additive_formula():
593596

594597
def test_linear_retry_strategy_stops_at_max_attempts():
595598
"""No retry once attempts_made reaches max_attempts."""
596-
strategy = create_linear_retry_strategy(max_attempts=3)
599+
strategy = create_linear_retry_strategy(LinearRetryStrategyConfig(max_attempts=3))
597600

598601
assert strategy(Exception("e"), 1).should_retry is True
599602
assert strategy(Exception("e"), 2).should_retry is True
@@ -603,9 +606,12 @@ def test_linear_retry_strategy_stops_at_max_attempts():
603606
def test_linear_retry_strategy_respects_custom_initial_and_increment():
604607
"""Custom initial_delay and increment shift the additive sequence."""
605608
strategy = create_linear_retry_strategy(
606-
max_attempts=10,
607-
initial_delay=Duration.from_seconds(2),
608-
increment=Duration.from_seconds(3),
609+
LinearRetryStrategyConfig(
610+
max_attempts=10,
611+
initial_delay=Duration.from_seconds(2),
612+
increment=Duration.from_seconds(3),
613+
jitter_strategy=JitterStrategy.NONE,
614+
)
609615
)
610616

611617
delays = [
@@ -616,6 +622,61 @@ def test_linear_retry_strategy_respects_custom_initial_and_increment():
616622
assert delays == [2, 5, 8, 11]
617623

618624

625+
def test_linear_retry_strategy_caps_at_max_delay():
626+
"""The additive delay is capped at max_delay before jitter."""
627+
strategy = create_linear_retry_strategy(
628+
LinearRetryStrategyConfig(
629+
max_attempts=10,
630+
initial_delay=Duration.from_seconds(10),
631+
increment=Duration.from_seconds(10),
632+
max_delay=Duration.from_seconds(25),
633+
jitter_strategy=JitterStrategy.NONE,
634+
)
635+
)
636+
637+
# 10, 20, then capped at 25 for the third attempt (would be 30).
638+
delays = [
639+
strategy(Exception("e"), attempt).delay_seconds for attempt in range(1, 4)
640+
]
641+
assert delays == [10, 20, 25]
642+
643+
644+
@patch("random.random")
645+
def test_linear_retry_strategy_applies_jitter(mock_random):
646+
"""FULL jitter scales the additive base delay by random()."""
647+
mock_random.return_value = 0.5
648+
strategy = create_linear_retry_strategy(
649+
LinearRetryStrategyConfig(
650+
initial_delay=Duration.from_seconds(4),
651+
increment=Duration.from_seconds(4),
652+
jitter_strategy=JitterStrategy.FULL,
653+
)
654+
)
655+
656+
# base = 4 + 4*1 = 8, full jitter = 0.5 * 8 = 4
657+
assert strategy(Exception("e"), 2).delay_seconds == 4
658+
659+
660+
def test_linear_retry_strategy_filters_by_error_message():
661+
"""Only errors matching retryable_errors are retried."""
662+
strategy = create_linear_retry_strategy(
663+
LinearRetryStrategyConfig(retryable_errors=["timeout"])
664+
)
665+
666+
assert strategy(Exception("connection timeout"), 1).should_retry is True
667+
assert strategy(Exception("permission denied"), 1).should_retry is False
668+
669+
670+
def test_linear_retry_strategy_filters_by_error_type():
671+
"""Only errors matching retryable_error_types are retried."""
672+
strategy = create_linear_retry_strategy(
673+
LinearRetryStrategyConfig(retryable_error_types=[ValueError])
674+
)
675+
676+
assert strategy(ValueError("bad"), 1).should_retry is True
677+
assert strategy(KeyError("missing"), 1).should_retry is False
678+
679+
619680
# endregion
620681

621682

packages/aws-durable-execution-sdk-python/uv.lock

Lines changed: 92 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)