44import json
55import logging
66import os
7- from typing import Optional , cast
7+ from typing import Any , Optional , cast
88
99import httpx
1010from diskcache import Cache
11+ from tenacity import (
12+ retry ,
13+ retry_if_exception ,
14+ stop_after_attempt ,
15+ wait_exponential ,
16+ before_sleep_log ,
17+ RetryError ,
18+ )
1119
1220from lightspeed_evaluation .core .api .streaming_parser import parse_streaming_response
1321from lightspeed_evaluation .core .constants import (
1927logger = logging .getLogger (__name__ )
2028
2129
30+ def _is_too_many_requests_error (exception : BaseException ) -> bool :
31+ """Check if exception is a 429 error."""
32+ return (
33+ isinstance (exception , httpx .HTTPStatusError )
34+ and exception .response .status_code == 429
35+ )
36+
37+
2238class APIClient :
2339 """API client for actual data generation."""
2440
@@ -28,10 +44,6 @@ def __init__(
2844 ):
2945 """Initialize the client with configuration."""
3046 self .config = config
31- self .api_base = config .api_base
32- self .version = config .version
33- self .endpoint_type = config .endpoint_type
34- self .timeout = config .timeout
3547
3648 self .client : Optional [httpx .Client ] = None
3749
@@ -43,11 +55,27 @@ def __init__(
4355 self ._validate_endpoint_type ()
4456 self ._setup_client ()
4557
58+ # Wrap methods with retry decorator for handling 429 Too Many Requests errors
59+ retry_decorator = self ._create_retry_decorator ()
60+ self ._standard_query_with_retry = retry_decorator (self ._standard_query )
61+ self ._streaming_query_with_retry = retry_decorator (self ._streaming_query )
62+
63+ def _create_retry_decorator (self ) -> Any :
64+ return retry (
65+ retry = retry_if_exception (_is_too_many_requests_error ),
66+ stop = stop_after_attempt (
67+ self .config .num_retries + 1
68+ ), # +1 to account for the initial attempt
69+ wait = wait_exponential (multiplier = 1 , min = 4 , max = 60 ), # multiplier * 2^x
70+ before_sleep = before_sleep_log (logger , logging .WARNING ),
71+ reraise = False , # If all retry attempts are exhausted, RetryError is raised
72+ )
73+
4674 def _validate_endpoint_type (self ) -> None :
4775 """Validate endpoint type is supported."""
48- if self .endpoint_type not in SUPPORTED_ENDPOINT_TYPES :
76+ if self .config . endpoint_type not in SUPPORTED_ENDPOINT_TYPES :
4977 raise APIError (
50- f"Unsupported endpoint type: { self .endpoint_type } . "
78+ f"Unsupported endpoint type: { self .config . endpoint_type } . "
5179 f"Must be one of { SUPPORTED_ENDPOINT_TYPES } "
5280 )
5381
@@ -57,7 +85,9 @@ def _setup_client(self) -> None:
5785 # Enable verify, currently for eval it is set to False
5886 verify = False
5987 self .client = httpx .Client (
60- base_url = self .api_base , verify = verify , timeout = self .timeout
88+ base_url = self .config .api_base ,
89+ verify = verify ,
90+ timeout = self .config .timeout ,
6191 )
6292 self .client .headers .update ({"Content-Type" : "application/json" })
6393
@@ -88,22 +118,28 @@ def query(
88118 if not self .client :
89119 raise APIError ("API client not initialized" )
90120
91- api_request = self ._prepare_request (query , conversation_id , attachments )
92- if self .config .cache_enabled :
93- cached_response = self ._get_cached_response (api_request )
94- if cached_response is not None :
95- logger .debug ("Returning cached response for query: '%s'" , query )
96- return cached_response
97-
98- if self .endpoint_type == "streaming" :
99- response = self ._streaming_query (api_request )
100- else :
101- response = self ._standard_query (api_request )
102-
103- if self .config .cache_enabled :
104- self ._add_response_to_cache (api_request , response )
105-
106- return response
121+ try :
122+ api_request = self ._prepare_request (query , conversation_id , attachments )
123+ if self .config .cache_enabled :
124+ cached_response = self ._get_cached_response (api_request )
125+ if cached_response is not None :
126+ logger .debug ("Returning cached response for query: '%s'" , query )
127+ return cached_response
128+
129+ if self .config .endpoint_type == "streaming" :
130+ response = self ._streaming_query_with_retry (api_request )
131+ else :
132+ response = self ._standard_query_with_retry (api_request )
133+
134+ if self .config .cache_enabled :
135+ self ._add_response_to_cache (api_request , response )
136+
137+ return response
138+ except RetryError as e :
139+ raise APIError (
140+ f"Maximum retry attempts ({ self .config .num_retries } ) reached "
141+ "due to persistent rate limiting (HTTP 429)."
142+ ) from e
107143
108144 def _prepare_request (
109145 self ,
@@ -123,12 +159,12 @@ def _prepare_request(
123159 )
124160
125161 def _standard_query (self , api_request : APIRequest ) -> APIResponse :
126- """Query the API using non-streaming endpoint."""
162+ """Query the API using non-streaming endpoint with retry on 429 ."""
127163 if not self .client :
128164 raise APIError ("HTTP client not initialized" )
129165 try :
130166 response = self .client .post (
131- f"/{ self .version } /query" ,
167+ f"/{ self .config . version } /query" ,
132168 json = api_request .model_dump (exclude_none = True ),
133169 )
134170 response .raise_for_status ()
@@ -165,8 +201,11 @@ def _standard_query(self, api_request: APIRequest) -> APIResponse:
165201 return APIResponse .from_raw_response (response_data )
166202
167203 except httpx .TimeoutException as e :
168- raise self ._handle_timeout_error ("standard" , self .timeout ) from e
204+ raise self ._handle_timeout_error ("standard" , self .config . timeout ) from e
169205 except httpx .HTTPStatusError as e :
206+ # Re-raise 429 errors without conversion to allow retry decorator to handle them
207+ if e .response .status_code == 429 :
208+ raise
170209 raise self ._handle_http_error (e ) from e
171210 except ValueError as e :
172211 raise self ._handle_validation_error (e ) from e
@@ -182,17 +221,20 @@ def _streaming_query(self, api_request: APIRequest) -> APIResponse:
182221 try :
183222 with self .client .stream (
184223 "POST" ,
185- f"/{ self .version } /streaming_query" ,
224+ f"/{ self .config . version } /streaming_query" ,
186225 json = api_request .model_dump (exclude_none = True ),
187226 ) as response :
188227 self ._handle_response_errors (response )
189228 raw_data = parse_streaming_response (response )
190229 return APIResponse .from_raw_response (raw_data )
191230
192231 except httpx .TimeoutException as e :
193- raise self ._handle_timeout_error ("streaming" , self .timeout ) from e
232+ raise self ._handle_timeout_error ("streaming" , self .config . timeout ) from e
194233 except httpx .HTTPStatusError as e :
195- raise APIError (str (e )) from e
234+ # Re-raise 429 errors without conversion to allow retry decorator to handle them
235+ if e .response .status_code == 429 :
236+ raise
237+ raise self ._handle_http_error (e ) from e
196238 except ValueError as e :
197239 raise self ._handle_validation_error (e ) from e
198240 except APIError :
0 commit comments