Skip to content

Commit 48d0dd8

Browse files
SAY-5yaythomas
authored andcommitted
feat(sdk): align linear retry strategy with exponential
Signed-off-by: Sai Asish Y <say.apm35@gmail.com>
1 parent 277bd28 commit 48d0dd8

3 files changed

Lines changed: 174 additions & 54 deletions

File tree

packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import math
56
import random
67
from dataclasses import dataclass, field
78
from enum import Enum, StrEnum
@@ -589,5 +590,16 @@ def apply_jitter(self, delay: float) -> float:
589590
# Full jitter: random(0, delay)
590591
return random.random() * delay # noqa: S311
591592

593+
def finalize_delay(self, base_delay: float) -> int:
594+
"""Apply jitter, round up, and clamp to a minimum of 1 second.
595+
596+
Args:
597+
base_delay: The base delay value before jitter is applied
598+
599+
Returns:
600+
The final delay in whole seconds, at least 1
601+
"""
602+
return max(1, math.ceil(self.apply_jitter(base_delay)))
603+
592604

593605
# endregion Jitter

packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/retries.py

Lines changed: 95 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import math
65
import re
76
from dataclasses import dataclass, field
87
from typing import TYPE_CHECKING, Generic, TypeVar
@@ -71,84 +70,129 @@ def max_delay_seconds(self) -> int:
7170
return self.max_delay.to_seconds()
7271

7372

73+
@dataclass
74+
class LinearRetryStrategyConfig:
75+
max_attempts: int = 6
76+
initial_delay: Duration = field(default_factory=lambda: Duration.from_seconds(1))
77+
increment: Duration = field(default_factory=lambda: Duration.from_seconds(1))
78+
max_delay: Duration = field(default_factory=lambda: Duration.from_minutes(5))
79+
jitter_strategy: JitterStrategy = field(default=JitterStrategy.FULL)
80+
retryable_errors: list[str | re.Pattern] | None = None
81+
retryable_error_types: list[type[Exception]] | None = None
82+
83+
@property
84+
def initial_delay_seconds(self) -> int:
85+
"""Get initial delay in seconds."""
86+
return self.initial_delay.to_seconds()
87+
88+
@property
89+
def increment_seconds(self) -> int:
90+
"""Get increment in seconds."""
91+
return self.increment.to_seconds()
92+
93+
@property
94+
def max_delay_seconds(self) -> int:
95+
"""Get max delay in seconds."""
96+
return self.max_delay.to_seconds()
97+
98+
99+
def _resolve_retryable_errors(
100+
retryable_errors: list[str | re.Pattern] | None,
101+
retryable_error_types: list[type[Exception]] | None,
102+
) -> tuple[list[str | re.Pattern], list[type[Exception]]]:
103+
"""Resolve the error filters, applying the match-all default only when neither is set."""
104+
should_use_default_errors: bool = (
105+
retryable_errors is None and retryable_error_types is None
106+
)
107+
resolved_errors: list[str | re.Pattern] = (
108+
retryable_errors
109+
if retryable_errors is not None
110+
else ([_DEFAULT_RETRYABLE_ERROR_PATTERN] if should_use_default_errors else [])
111+
)
112+
resolved_error_types: list[type[Exception]] = retryable_error_types or []
113+
return resolved_errors, resolved_error_types
114+
115+
116+
def _is_error_retryable(
117+
error: Exception,
118+
retryable_errors: list[str | re.Pattern],
119+
retryable_error_types: list[type[Exception]],
120+
) -> bool:
121+
"""Return True when the error matches one of the message patterns or types."""
122+
is_retryable_error_message: bool = any(
123+
pattern.search(str(error))
124+
if isinstance(pattern, re.Pattern)
125+
else pattern in str(error)
126+
for pattern in retryable_errors
127+
)
128+
is_retryable_error_type: bool = any(
129+
isinstance(error, error_type) for error_type in retryable_error_types
130+
)
131+
return is_retryable_error_message or is_retryable_error_type
132+
133+
74134
def create_retry_strategy(
75135
config: RetryStrategyConfig | None = None,
76136
) -> Callable[[Exception, int], RetryDecision]:
77137
if config is None:
78138
config = RetryStrategyConfig()
79139

80-
# Apply default retryableErrors only if user didn't specify either filter
81-
should_use_default_errors: bool = (
82-
config.retryable_errors is None and config.retryable_error_types is None
83-
)
84-
85-
retryable_errors: list[str | re.Pattern] = (
86-
config.retryable_errors
87-
if config.retryable_errors is not None
88-
else ([_DEFAULT_RETRYABLE_ERROR_PATTERN] if should_use_default_errors else [])
140+
retryable_errors, retryable_error_types = _resolve_retryable_errors(
141+
config.retryable_errors, config.retryable_error_types
89142
)
90-
retryable_error_types: list[type[Exception]] = config.retryable_error_types or []
91143

92144
def retry_strategy(error: Exception, attempts_made: int) -> RetryDecision:
93145
# Check if we've exceeded max attempts
94146
if attempts_made >= config.max_attempts:
95147
return RetryDecision.no_retry()
96148

97-
# Check if error is retryable based on error message
98-
is_retryable_error_message: bool = any(
99-
pattern.search(str(error))
100-
if isinstance(pattern, re.Pattern)
101-
else pattern in str(error)
102-
for pattern in retryable_errors
103-
)
104-
105-
# Check if error is retryable based on error type
106-
is_retryable_error_type: bool = any(
107-
isinstance(error, error_type) for error_type in retryable_error_types
108-
)
109-
110-
if not is_retryable_error_message and not is_retryable_error_type:
149+
if not _is_error_retryable(error, retryable_errors, retryable_error_types):
111150
return RetryDecision.no_retry()
112151

113152
# Calculate delay with exponential backoff
114153
base_delay: float = min(
115154
config.initial_delay_seconds * (config.backoff_rate ** (attempts_made - 1)),
116155
config.max_delay_seconds,
117156
)
118-
# Apply jitter to get final delay
119-
delay_with_jitter: float = config.jitter_strategy.apply_jitter(base_delay)
120-
# Round up and ensure minimum of 1 second
121-
final_delay: int = max(1, math.ceil(delay_with_jitter))
157+
final_delay: int = config.jitter_strategy.finalize_delay(base_delay)
122158

123159
return RetryDecision.retry(Duration(seconds=final_delay))
124160

125161
return retry_strategy
126162

127163

128164
def create_linear_retry_strategy(
129-
max_attempts: int = 6,
130-
initial_delay: Duration | None = None,
131-
increment: Duration | None = None,
165+
config: LinearRetryStrategyConfig | None = None,
132166
) -> Callable[[Exception, int], RetryDecision]:
133-
"""Linearly increasing delay between retries: initial + increment * (attempts_made - 1).
167+
"""Linearly increasing delay between retries.
134168
135-
Mirrors the JS SDK's ``createLinearRetryStrategy``. With the defaults this
136-
yields delays of 1s, 2s, 3s, 4s, 5s. No jitter is applied and there is no
137-
upper cap on the delay; callers who need either can build their own
138-
strategy via ``create_retry_strategy``.
169+
The base delay is ``initial_delay + increment * (attempts_made - 1)``,
170+
capped at ``max_delay``, with jitter and error filtering applied the same
171+
way as :func:`create_retry_strategy`. Mirrors the JS SDK's
172+
``createLinearRetryStrategy``.
139173
"""
140-
initial: Duration = (
141-
initial_delay if initial_delay is not None else Duration.from_seconds(1)
174+
if config is None:
175+
config = LinearRetryStrategyConfig()
176+
177+
retryable_errors, retryable_error_types = _resolve_retryable_errors(
178+
config.retryable_errors, config.retryable_error_types
142179
)
143-
step: Duration = increment if increment is not None else Duration.from_seconds(1)
144180

145-
def linear_retry_strategy(_error: Exception, attempts_made: int) -> RetryDecision:
146-
if attempts_made >= max_attempts:
181+
def linear_retry_strategy(error: Exception, attempts_made: int) -> RetryDecision:
182+
if attempts_made >= config.max_attempts:
183+
return RetryDecision.no_retry()
184+
185+
if not _is_error_retryable(error, retryable_errors, retryable_error_types):
147186
return RetryDecision.no_retry()
148-
delay_seconds: int = initial.to_seconds() + step.to_seconds() * (
149-
attempts_made - 1
187+
188+
base_delay: float = min(
189+
config.initial_delay_seconds
190+
+ config.increment_seconds * (attempts_made - 1),
191+
config.max_delay_seconds,
150192
)
151-
return RetryDecision.retry(Duration(seconds=delay_seconds))
193+
final_delay: int = config.jitter_strategy.finalize_delay(base_delay)
194+
195+
return RetryDecision.retry(Duration(seconds=final_delay))
152196

153197
return linear_retry_strategy
154198

@@ -212,9 +256,12 @@ def critical(cls) -> Callable[[Exception, int], RetryDecision]:
212256
def linear(cls) -> Callable[[Exception, int], RetryDecision]:
213257
"""Linearly increasing delay between retries: 1s, 2s, 3s, 4s, 5s."""
214258
return create_linear_retry_strategy(
215-
max_attempts=6,
216-
initial_delay=Duration.from_seconds(1),
217-
increment=Duration.from_seconds(1),
259+
LinearRetryStrategyConfig(
260+
max_attempts=6,
261+
initial_delay=Duration.from_seconds(1),
262+
increment=Duration.from_seconds(1),
263+
jitter_strategy=JitterStrategy.NONE,
264+
)
218265
)
219266

220267
@classmethod

packages/aws-durable-execution-sdk-python/tests/retries_test.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aws_durable_execution_sdk_python.config import Duration
99
from aws_durable_execution_sdk_python.retries import (
1010
JitterStrategy,
11+
LinearRetryStrategyConfig,
1112
RetryDecision,
1213
RetryPresets,
1314
RetryStrategyConfig,
@@ -580,8 +581,10 @@ def test_mixed_error_types_and_patterns():
580581
# region create_linear_retry_strategy
581582

582583

583-
def test_linear_retry_strategy_uses_additive_formula():
584-
"""Default config yields delays of 1s, 2s, 3s, 4s, 5s with no jitter."""
584+
@patch("random.random")
585+
def test_linear_retry_strategy_uses_additive_formula(mock_random):
586+
"""Default config yields additive delays of 1s, 2s, 3s, 4s, 5s."""
587+
mock_random.return_value = 1.0 # FULL jitter at the upper bound keeps the base
585588
strategy = create_linear_retry_strategy()
586589

587590
delays = [
@@ -593,7 +596,7 @@ def test_linear_retry_strategy_uses_additive_formula():
593596

594597
def test_linear_retry_strategy_stops_at_max_attempts():
595598
"""No retry once attempts_made reaches max_attempts."""
596-
strategy = create_linear_retry_strategy(max_attempts=3)
599+
strategy = create_linear_retry_strategy(LinearRetryStrategyConfig(max_attempts=3))
597600

598601
assert strategy(Exception("e"), 1).should_retry is True
599602
assert strategy(Exception("e"), 2).should_retry is True
@@ -603,9 +606,12 @@ def test_linear_retry_strategy_stops_at_max_attempts():
603606
def test_linear_retry_strategy_respects_custom_initial_and_increment():
604607
"""Custom initial_delay and increment shift the additive sequence."""
605608
strategy = create_linear_retry_strategy(
606-
max_attempts=10,
607-
initial_delay=Duration.from_seconds(2),
608-
increment=Duration.from_seconds(3),
609+
LinearRetryStrategyConfig(
610+
max_attempts=10,
611+
initial_delay=Duration.from_seconds(2),
612+
increment=Duration.from_seconds(3),
613+
jitter_strategy=JitterStrategy.NONE,
614+
)
609615
)
610616

611617
delays = [
@@ -616,6 +622,61 @@ def test_linear_retry_strategy_respects_custom_initial_and_increment():
616622
assert delays == [2, 5, 8, 11]
617623

618624

625+
def test_linear_retry_strategy_caps_at_max_delay():
626+
"""The additive delay is capped at max_delay before jitter."""
627+
strategy = create_linear_retry_strategy(
628+
LinearRetryStrategyConfig(
629+
max_attempts=10,
630+
initial_delay=Duration.from_seconds(10),
631+
increment=Duration.from_seconds(10),
632+
max_delay=Duration.from_seconds(25),
633+
jitter_strategy=JitterStrategy.NONE,
634+
)
635+
)
636+
637+
# 10, 20, then capped at 25 for the third attempt (would be 30).
638+
delays = [
639+
strategy(Exception("e"), attempt).delay_seconds for attempt in range(1, 4)
640+
]
641+
assert delays == [10, 20, 25]
642+
643+
644+
@patch("random.random")
645+
def test_linear_retry_strategy_applies_jitter(mock_random):
646+
"""FULL jitter scales the additive base delay by random()."""
647+
mock_random.return_value = 0.5
648+
strategy = create_linear_retry_strategy(
649+
LinearRetryStrategyConfig(
650+
initial_delay=Duration.from_seconds(4),
651+
increment=Duration.from_seconds(4),
652+
jitter_strategy=JitterStrategy.FULL,
653+
)
654+
)
655+
656+
# base = 4 + 4*1 = 8, full jitter = 0.5 * 8 = 4
657+
assert strategy(Exception("e"), 2).delay_seconds == 4
658+
659+
660+
def test_linear_retry_strategy_filters_by_error_message():
661+
"""Only errors matching retryable_errors are retried."""
662+
strategy = create_linear_retry_strategy(
663+
LinearRetryStrategyConfig(retryable_errors=["timeout"])
664+
)
665+
666+
assert strategy(Exception("connection timeout"), 1).should_retry is True
667+
assert strategy(Exception("permission denied"), 1).should_retry is False
668+
669+
670+
def test_linear_retry_strategy_filters_by_error_type():
671+
"""Only errors matching retryable_error_types are retried."""
672+
strategy = create_linear_retry_strategy(
673+
LinearRetryStrategyConfig(retryable_error_types=[ValueError])
674+
)
675+
676+
assert strategy(ValueError("bad"), 1).should_retry is True
677+
assert strategy(KeyError("missing"), 1).should_retry is False
678+
679+
619680
# endregion
620681

621682

0 commit comments

Comments
 (0)