77import json
88import logging
99import os
10- import urllib .request
11- from asyncio .tasks import Task
1210from typing import Any , Dict , List , Optional , Union
1311
1412import aiohttp
1715import requests
1816import tqdm
1917import tqdm .notebook as tqdm_notebook
18+ import time
2019
2120from nucleus .url_utils import sanitize_string_args
2221
105104)
106105
107106
107+ class RetryStrategy :
108+ statuses = {503 , 504 }
109+ sleep_times = [1 , 3 , 9 ]
110+
111+
108112class NucleusClient :
109113 """
110114 Nucleus client.
@@ -511,28 +515,41 @@ async def _make_files_request(
511515 content_type = file [1 ][2 ],
512516 )
513517
514- async with session .post (
515- endpoint ,
516- data = form ,
517- auth = aiohttp .BasicAuth (self .api_key , "" ),
518- timeout = DEFAULT_NETWORK_TIMEOUT_SEC ,
519- ) as response :
520- logger .info ("API request has response code %s" , response .status )
521-
522- try :
523- data = await response .json ()
524- except aiohttp .client_exceptions .ContentTypeError :
525- # In case of 404, the server returns text
526- data = await response .text ()
527-
528- if not response .ok :
529- self .handle_bad_response (
530- endpoint ,
531- session .post ,
532- aiohttp_response = (response .status , response .reason , data ),
518+ for sleep_time in RetryStrategy .sleep_times + ["" ]:
519+ async with session .post (
520+ endpoint ,
521+ data = form ,
522+ auth = aiohttp .BasicAuth (self .api_key , "" ),
523+ timeout = DEFAULT_NETWORK_TIMEOUT_SEC ,
524+ ) as response :
525+ logger .info (
526+ "API request has response code %s" , response .status
533527 )
534528
535- return data
529+ try :
530+ data = await response .json ()
531+ except aiohttp .client_exceptions .ContentTypeError :
532+ # In case of 404, the server returns text
533+ data = await response .text ()
534+ if (
535+ response .status in RetryStrategy .statuses
536+ and sleep_time != ""
537+ ):
538+ time .sleep (sleep_time )
539+ continue
540+
541+ if not response .ok :
542+ self .handle_bad_response (
543+ endpoint ,
544+ session .post ,
545+ aiohttp_response = (
546+ response .status ,
547+ response .reason ,
548+ data ,
549+ ),
550+ )
551+
552+ return data
536553
537554 def _process_append_requests (
538555 self ,
@@ -1191,14 +1208,20 @@ def make_request(
11911208
11921209 logger .info ("Posting to %s" , endpoint )
11931210
1194- response = requests_command (
1195- endpoint ,
1196- json = payload ,
1197- headers = {"Content-Type" : "application/json" },
1198- auth = (self .api_key , "" ),
1199- timeout = DEFAULT_NETWORK_TIMEOUT_SEC ,
1200- )
1201- logger .info ("API request has response code %s" , response .status_code )
1211+ for retry_wait_time in RetryStrategy .sleep_times :
1212+ response = requests_command (
1213+ endpoint ,
1214+ json = payload ,
1215+ headers = {"Content-Type" : "application/json" },
1216+ auth = (self .api_key , "" ),
1217+ timeout = DEFAULT_NETWORK_TIMEOUT_SEC ,
1218+ )
1219+ logger .info (
1220+ "API request has response code %s" , response .status_code
1221+ )
1222+ if response .status_code not in RetryStrategy .statuses :
1223+ break
1224+ time .sleep (retry_wait_time )
12021225
12031226 if not response .ok :
12041227 self .handle_bad_response (endpoint , requests_command , response )
@@ -1214,4 +1237,4 @@ def handle_bad_response(
12141237 ):
12151238 raise NucleusAPIError (
12161239 endpoint , requests_command , requests_response , aiohttp_response
1217- )
1240+ )
0 commit comments