-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path_retry.py
More file actions
134 lines (112 loc) · 4.65 KB
/
_retry.py
File metadata and controls
134 lines (112 loc) · 4.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Retry logic with exponential backoff and jitter.
Provides both sync (with_retry) and async (async_with_retry) variants.
"""
from __future__ import annotations
import asyncio
import random
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import TypeVar
import grpc
import grpc.aio
T = TypeVar("T")
@dataclass(frozen=True, slots=True)
class RetryConfig:
"""Configuration for retry behavior.
Attributes:
max_attempts: Maximum number of attempts (including the first).
initial_backoff: Initial backoff duration in seconds.
max_backoff: Maximum backoff duration in seconds.
multiplier: Backoff multiplier between attempts.
retryable_codes: gRPC status codes that trigger a retry.
total_timeout: Overall wall-clock budget in seconds shared across all
attempts. When set, backoff sleeps are clipped to the remaining
budget and no further attempt is made once the budget is exhausted.
None means no global limit (original behavior).
"""
max_attempts: int = 3
initial_backoff: float = 0.1 # seconds
max_backoff: float = 5.0 # seconds
multiplier: float = 2.0
retryable_codes: tuple[grpc.StatusCode, ...] = field(
default=(
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.DEADLINE_EXCEEDED,
grpc.StatusCode.RESOURCE_EXHAUSTED,
)
)
total_timeout: float | None = None
def write_safe_config(base: RetryConfig | None) -> RetryConfig | None:
"""Return a retry config safe for non-idempotent writes.
Strips DEADLINE_EXCEEDED from retryable codes — a timeout does not guarantee
the server hasn't already applied the write, so retrying risks double-apply.
Returns None when base is None (retry disabled).
"""
if base is None:
return None
safe_codes = tuple(c for c in base.retryable_codes if c != grpc.StatusCode.DEADLINE_EXCEEDED)
if not safe_codes:
return None
return RetryConfig(
max_attempts=base.max_attempts,
initial_backoff=base.initial_backoff,
max_backoff=base.max_backoff,
multiplier=base.multiplier,
retryable_codes=safe_codes,
total_timeout=base.total_timeout,
)
def with_retry(config: RetryConfig | None, fn: Callable[[], T]) -> T:
"""Execute fn with retry on transient gRPC errors (sync)."""
if config is None:
return fn()
deadline = time.monotonic() + config.total_timeout if config.total_timeout is not None else None
last_err: Exception | None = None
backoff = config.initial_backoff
for attempt in range(config.max_attempts):
if deadline is not None and time.monotonic() >= deadline:
break
try:
return fn()
except grpc.RpcError as e:
code = e.code()
if code not in config.retryable_codes or attempt == config.max_attempts - 1:
raise
last_err = e
jitter = random.uniform(0.5, 1.5)
sleep_time = backoff * jitter
if deadline is not None:
remaining = deadline - time.monotonic()
if remaining <= 0:
raise
sleep_time = min(sleep_time, remaining)
time.sleep(sleep_time)
backoff = min(backoff * config.multiplier, config.max_backoff)
raise last_err # type: ignore[misc] # pragma: no cover
async def async_with_retry(config: RetryConfig | None, fn: Callable[[], Awaitable[T]]) -> T:
"""Execute fn with retry on transient gRPC errors (async)."""
if config is None:
return await fn()
deadline = time.monotonic() + config.total_timeout if config.total_timeout is not None else None
last_err: Exception | None = None
backoff = config.initial_backoff
for attempt in range(config.max_attempts):
if deadline is not None and time.monotonic() >= deadline:
break
try:
return await fn()
except grpc.aio.AioRpcError as e:
code = e.code()
if code not in config.retryable_codes or attempt == config.max_attempts - 1:
raise
last_err = e
jitter = random.uniform(0.5, 1.5)
sleep_time = backoff * jitter
if deadline is not None:
remaining = deadline - time.monotonic()
if remaining <= 0:
raise
sleep_time = min(sleep_time, remaining)
await asyncio.sleep(sleep_time)
backoff = min(backoff * config.multiplier, config.max_backoff)
raise last_err # type: ignore[misc] # pragma: no cover