11import asyncio
22import logging
3+ import math
34import random
45from dataclasses import dataclass
56from typing import Callable
@@ -15,36 +16,36 @@ class Retrier:
1516 def __init__ (
1617 self ,
1718 should_retry_on : Callable [[Exception ], bool ],
18- max_attempts : int ,
19+ max_retries : int ,
1920 min_base_delay : float = 0.1 ,
2021 max_base_delay : float = 1.0 ,
2122 ):
2223 self .should_retry_on = should_retry_on
23- self .max_attempts = max_attempts
24+ self .max_retries = max_retries
2425 self .min_base_delay = min_base_delay
2526 self .max_base_delay = max_base_delay
2627
2728 async def __call__ (self , f : Callable , * args , ** kwargs ):
28- backoffs = compute_backoffs (
29- attempts = max (self .max_attempts - 1 , 0 ),
30- min_base_delay = self .min_base_delay ,
31- max_base_delay = self .max_base_delay ,
32- )
29+ max_retries = self .max_retries
3330 attempt = 0
3431 while True :
3532 try :
3633 return await f (* args , ** kwargs )
3734 except Exception as e :
38- if attempt < len (backoffs ) and self .should_retry_on (e ):
39- delay = backoffs [attempt ]
35+ if attempt < max_retries and self .should_retry_on (e ):
36+ delay = compute_backoff (
37+ attempt ,
38+ min_base_delay = self .min_base_delay ,
39+ max_base_delay = self .max_base_delay ,
40+ )
4041 retry_after = getattr (e , "_retry_after" , None )
4142 if retry_after is not None :
4243 delay = max (delay , retry_after )
4344 logger .debug (
4445 "retrying request: error=%s backoff=%.3fs retries_remaining=%d" ,
4546 e ,
4647 delay ,
47- len ( backoffs ) - attempt - 1 ,
48+ max_retries - attempt - 1 ,
4849 )
4950 await asyncio .sleep (delay )
5051 attempt += 1
@@ -53,7 +54,7 @@ async def __call__(self, f: Callable, *args, **kwargs):
5354 "not retrying request: error=%s is_retryable=%s retries_exhausted=%s" ,
5455 e ,
5556 self .should_retry_on (e ),
56- attempt >= len ( backoffs ) ,
57+ attempt >= max_retries ,
5758 )
5859 raise e
5960
@@ -63,17 +64,17 @@ class Attempt:
6364 value : int
6465
6566
66- def compute_backoffs (
67- attempts : int ,
67+ def compute_backoff (
68+ attempt : int ,
6869 min_base_delay : float = 0.1 ,
6970 max_base_delay : float = 1.0 ,
70- ) -> list [ float ] :
71- backoffs = []
72- for n in range ( attempts ):
73- base_delay = min ( min_base_delay * 2 ** n , max_base_delay )
74- jitter = random . uniform ( 0 , base_delay )
75- backoffs . append ( base_delay + jitter )
76- return backoffs
71+ ) -> float :
72+ try :
73+ base_delay = min ( math . ldexp ( min_base_delay , attempt ), max_base_delay )
74+ except OverflowError :
75+ base_delay = max_base_delay
76+ jitter = random . uniform ( 0 , base_delay )
77+ return base_delay + jitter
7778
7879
7980def is_safe_to_retry_unary (
0 commit comments