|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -import math |
6 | 5 | import re |
7 | 6 | from dataclasses import dataclass, field |
8 | 7 | from typing import TYPE_CHECKING, Generic, TypeVar |
@@ -71,84 +70,129 @@ def max_delay_seconds(self) -> int: |
71 | 70 | return self.max_delay.to_seconds() |
72 | 71 |
|
73 | 72 |
|
| 73 | +@dataclass |
| 74 | +class LinearRetryStrategyConfig: |
| 75 | + max_attempts: int = 6 |
| 76 | + initial_delay: Duration = field(default_factory=lambda: Duration.from_seconds(1)) |
| 77 | + increment: Duration = field(default_factory=lambda: Duration.from_seconds(1)) |
| 78 | + max_delay: Duration = field(default_factory=lambda: Duration.from_minutes(5)) |
| 79 | + jitter_strategy: JitterStrategy = field(default=JitterStrategy.FULL) |
| 80 | + retryable_errors: list[str | re.Pattern] | None = None |
| 81 | + retryable_error_types: list[type[Exception]] | None = None |
| 82 | + |
| 83 | + @property |
| 84 | + def initial_delay_seconds(self) -> int: |
| 85 | + """Get initial delay in seconds.""" |
| 86 | + return self.initial_delay.to_seconds() |
| 87 | + |
| 88 | + @property |
| 89 | + def increment_seconds(self) -> int: |
| 90 | + """Get increment in seconds.""" |
| 91 | + return self.increment.to_seconds() |
| 92 | + |
| 93 | + @property |
| 94 | + def max_delay_seconds(self) -> int: |
| 95 | + """Get max delay in seconds.""" |
| 96 | + return self.max_delay.to_seconds() |
| 97 | + |
| 98 | + |
| 99 | +def _resolve_retryable_errors( |
| 100 | + retryable_errors: list[str | re.Pattern] | None, |
| 101 | + retryable_error_types: list[type[Exception]] | None, |
| 102 | +) -> tuple[list[str | re.Pattern], list[type[Exception]]]: |
| 103 | + """Resolve the error filters, applying the match-all default only when neither is set.""" |
| 104 | + should_use_default_errors: bool = ( |
| 105 | + retryable_errors is None and retryable_error_types is None |
| 106 | + ) |
| 107 | + resolved_errors: list[str | re.Pattern] = ( |
| 108 | + retryable_errors |
| 109 | + if retryable_errors is not None |
| 110 | + else ([_DEFAULT_RETRYABLE_ERROR_PATTERN] if should_use_default_errors else []) |
| 111 | + ) |
| 112 | + resolved_error_types: list[type[Exception]] = retryable_error_types or [] |
| 113 | + return resolved_errors, resolved_error_types |
| 114 | + |
| 115 | + |
| 116 | +def _is_error_retryable( |
| 117 | + error: Exception, |
| 118 | + retryable_errors: list[str | re.Pattern], |
| 119 | + retryable_error_types: list[type[Exception]], |
| 120 | +) -> bool: |
| 121 | + """Return True when the error matches one of the message patterns or types.""" |
| 122 | + is_retryable_error_message: bool = any( |
| 123 | + pattern.search(str(error)) |
| 124 | + if isinstance(pattern, re.Pattern) |
| 125 | + else pattern in str(error) |
| 126 | + for pattern in retryable_errors |
| 127 | + ) |
| 128 | + is_retryable_error_type: bool = any( |
| 129 | + isinstance(error, error_type) for error_type in retryable_error_types |
| 130 | + ) |
| 131 | + return is_retryable_error_message or is_retryable_error_type |
| 132 | + |
| 133 | + |
74 | 134 | def create_retry_strategy( |
75 | 135 | config: RetryStrategyConfig | None = None, |
76 | 136 | ) -> Callable[[Exception, int], RetryDecision]: |
77 | 137 | if config is None: |
78 | 138 | config = RetryStrategyConfig() |
79 | 139 |
|
80 | | - # Apply default retryableErrors only if user didn't specify either filter |
81 | | - should_use_default_errors: bool = ( |
82 | | - config.retryable_errors is None and config.retryable_error_types is None |
83 | | - ) |
84 | | - |
85 | | - retryable_errors: list[str | re.Pattern] = ( |
86 | | - config.retryable_errors |
87 | | - if config.retryable_errors is not None |
88 | | - else ([_DEFAULT_RETRYABLE_ERROR_PATTERN] if should_use_default_errors else []) |
| 140 | + retryable_errors, retryable_error_types = _resolve_retryable_errors( |
| 141 | + config.retryable_errors, config.retryable_error_types |
89 | 142 | ) |
90 | | - retryable_error_types: list[type[Exception]] = config.retryable_error_types or [] |
91 | 143 |
|
92 | 144 | def retry_strategy(error: Exception, attempts_made: int) -> RetryDecision: |
93 | 145 | # Check if we've exceeded max attempts |
94 | 146 | if attempts_made >= config.max_attempts: |
95 | 147 | return RetryDecision.no_retry() |
96 | 148 |
|
97 | | - # Check if error is retryable based on error message |
98 | | - is_retryable_error_message: bool = any( |
99 | | - pattern.search(str(error)) |
100 | | - if isinstance(pattern, re.Pattern) |
101 | | - else pattern in str(error) |
102 | | - for pattern in retryable_errors |
103 | | - ) |
104 | | - |
105 | | - # Check if error is retryable based on error type |
106 | | - is_retryable_error_type: bool = any( |
107 | | - isinstance(error, error_type) for error_type in retryable_error_types |
108 | | - ) |
109 | | - |
110 | | - if not is_retryable_error_message and not is_retryable_error_type: |
| 149 | + if not _is_error_retryable(error, retryable_errors, retryable_error_types): |
111 | 150 | return RetryDecision.no_retry() |
112 | 151 |
|
113 | 152 | # Calculate delay with exponential backoff |
114 | 153 | base_delay: float = min( |
115 | 154 | config.initial_delay_seconds * (config.backoff_rate ** (attempts_made - 1)), |
116 | 155 | config.max_delay_seconds, |
117 | 156 | ) |
118 | | - # Apply jitter to get final delay |
119 | | - delay_with_jitter: float = config.jitter_strategy.apply_jitter(base_delay) |
120 | | - # Round up and ensure minimum of 1 second |
121 | | - final_delay: int = max(1, math.ceil(delay_with_jitter)) |
| 157 | + final_delay: int = config.jitter_strategy.finalize_delay(base_delay) |
122 | 158 |
|
123 | 159 | return RetryDecision.retry(Duration(seconds=final_delay)) |
124 | 160 |
|
125 | 161 | return retry_strategy |
126 | 162 |
|
127 | 163 |
|
128 | 164 | def create_linear_retry_strategy( |
129 | | - max_attempts: int = 6, |
130 | | - initial_delay: Duration | None = None, |
131 | | - increment: Duration | None = None, |
| 165 | + config: LinearRetryStrategyConfig | None = None, |
132 | 166 | ) -> Callable[[Exception, int], RetryDecision]: |
133 | | - """Linearly increasing delay between retries: initial + increment * (attempts_made - 1). |
| 167 | + """Linearly increasing delay between retries. |
134 | 168 |
|
135 | | - Mirrors the JS SDK's ``createLinearRetryStrategy``. With the defaults this |
136 | | - yields delays of 1s, 2s, 3s, 4s, 5s. No jitter is applied and there is no |
137 | | - upper cap on the delay; callers who need either can build their own |
138 | | - strategy via ``create_retry_strategy``. |
| 169 | + The base delay is ``initial_delay + increment * (attempts_made - 1)``, |
| 170 | + capped at ``max_delay``, with jitter and error filtering applied the same |
| 171 | + way as :func:`create_retry_strategy`. Mirrors the JS SDK's |
| 172 | + ``createLinearRetryStrategy``. |
139 | 173 | """ |
140 | | - initial: Duration = ( |
141 | | - initial_delay if initial_delay is not None else Duration.from_seconds(1) |
| 174 | + if config is None: |
| 175 | + config = LinearRetryStrategyConfig() |
| 176 | + |
| 177 | + retryable_errors, retryable_error_types = _resolve_retryable_errors( |
| 178 | + config.retryable_errors, config.retryable_error_types |
142 | 179 | ) |
143 | | - step: Duration = increment if increment is not None else Duration.from_seconds(1) |
144 | 180 |
|
145 | | - def linear_retry_strategy(_error: Exception, attempts_made: int) -> RetryDecision: |
146 | | - if attempts_made >= max_attempts: |
| 181 | + def linear_retry_strategy(error: Exception, attempts_made: int) -> RetryDecision: |
| 182 | + if attempts_made >= config.max_attempts: |
| 183 | + return RetryDecision.no_retry() |
| 184 | + |
| 185 | + if not _is_error_retryable(error, retryable_errors, retryable_error_types): |
147 | 186 | return RetryDecision.no_retry() |
148 | | - delay_seconds: int = initial.to_seconds() + step.to_seconds() * ( |
149 | | - attempts_made - 1 |
| 187 | + |
| 188 | + base_delay: float = min( |
| 189 | + config.initial_delay_seconds |
| 190 | + + config.increment_seconds * (attempts_made - 1), |
| 191 | + config.max_delay_seconds, |
150 | 192 | ) |
151 | | - return RetryDecision.retry(Duration(seconds=delay_seconds)) |
| 193 | + final_delay: int = config.jitter_strategy.finalize_delay(base_delay) |
| 194 | + |
| 195 | + return RetryDecision.retry(Duration(seconds=final_delay)) |
152 | 196 |
|
153 | 197 | return linear_retry_strategy |
154 | 198 |
|
@@ -212,9 +256,12 @@ def critical(cls) -> Callable[[Exception, int], RetryDecision]: |
212 | 256 | def linear(cls) -> Callable[[Exception, int], RetryDecision]: |
213 | 257 | """Linearly increasing delay between retries: 1s, 2s, 3s, 4s, 5s.""" |
214 | 258 | return create_linear_retry_strategy( |
215 | | - max_attempts=6, |
216 | | - initial_delay=Duration.from_seconds(1), |
217 | | - increment=Duration.from_seconds(1), |
| 259 | + LinearRetryStrategyConfig( |
| 260 | + max_attempts=6, |
| 261 | + initial_delay=Duration.from_seconds(1), |
| 262 | + increment=Duration.from_seconds(1), |
| 263 | + jitter_strategy=JitterStrategy.NONE, |
| 264 | + ) |
218 | 265 | ) |
219 | 266 |
|
220 | 267 | @classmethod |
|
0 commit comments