Skip to content

Commit 646b406

Browse files
committed
Add support for overridable limits for ClientTimeout
1 parent 43dadb6 commit 646b406

5 files changed

Lines changed: 29 additions & 7 deletions

File tree

llms/extensions/providers/anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ async def chat(self, chat, context=None):
174174
self.chat_url,
175175
headers=self.headers,
176176
data=json.dumps(anthropic_request),
177-
timeout=aiohttp.ClientTimeout(total=120),
177+
timeout=ctx.get_client_timeout(),
178178
) as response:
179179
return ctx.log_json(
180180
self.to_response(await self.response_json(response), chat, started_at, context=context)

llms/extensions/providers/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ async def chat(self, chat, context=None):
378378
gemini_chat_url,
379379
headers=self.headers,
380380
data=json.dumps(gemini_chat),
381-
timeout=aiohttp.ClientTimeout(total=120),
381+
timeout=ctx.get_client_timeout(),
382382
) as res:
383383
obj = await self.response_json(res)
384384
if context is not None:

llms/extensions/providers/nvidia.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ async def chat(self, chat, provider=None, context=None):
9696
gen_url,
9797
headers=headers,
9898
data=json.dumps(gen_request),
99-
timeout=aiohttp.ClientTimeout(total=120),
99+
timeout=ctx.get_client_timeout(),
100100
) as response:
101101
return self.to_response(await self.response_json(response), chat, started_at, context=context)
102102

llms/llms.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@
124124
}
125125
},
126126
"limits": {
127+
"client_timeout": 120,
127128
"client_max_size": 20971520
128129
},
129130
"convert": {

llms/main.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@
6262
MOCK = os.getenv("MOCK") == "1"
6363
MOCK_DIR = os.getenv("MOCK_DIR")
6464
DISABLE_EXTENSIONS = (os.getenv("LLMS_DISABLE") or "").split(",")
65+
DEFAULT_LIMITS = {
66+
"client_timeout": 120,
67+
"client_max_size": 20971520,
68+
}
6569
g_config_path = None
6670
g_config = None
6771
g_providers = None
@@ -475,7 +479,7 @@ async def download_file(url):
475479

476480
async def session_download_file(session, url, default_mimetype="application/octet-stream"):
477481
try:
478-
async with session.get(url, timeout=aiohttp.ClientTimeout(total=120)) as response:
482+
async with session.get(url, timeout=get_client_timeout()) as response:
479483
response.raise_for_status()
480484
content = await response.read()
481485
mimetype = response.headers.get("Content-Type")
@@ -1294,7 +1298,7 @@ async def chat(self, chat, context=None):
12941298
async with aiohttp.ClientSession() as session:
12951299
started_at = time.time()
12961300
async with session.post(
1297-
self.chat_url, headers=self.headers, data=json.dumps(chat), timeout=aiohttp.ClientTimeout(total=120)
1301+
self.chat_url, headers=self.headers, data=json.dumps(chat), timeout=get_client_timeout()
12981302
) as response:
12991303
chat["metadata"] = metadata
13001304
return self.to_response(await response_json(response), chat, started_at, context=context)
@@ -1361,7 +1365,7 @@ async def get_models(self):
13611365
async with aiohttp.ClientSession() as session:
13621366
_log(f"GET {self.api}/api/tags")
13631367
async with session.get(
1364-
f"{self.api}/api/tags", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
1368+
f"{self.api}/api/tags", headers=self.headers, timeout=get_client_timeout()
13651369
) as response:
13661370
data = await response_json(response)
13671371
for model in data.get("models", []):
@@ -1422,7 +1426,7 @@ async def get_models(self):
14221426
async with aiohttp.ClientSession() as session:
14231427
_log(f"GET {self.api}/models")
14241428
async with session.get(
1425-
f"{self.api}/models", headers=self.headers, timeout=aiohttp.ClientTimeout(total=120)
1429+
f"{self.api}/models", headers=self.headers, timeout=get_client_timeout()
14261430
) as response:
14271431
data = await response_json(response)
14281432
for model in data.get("data", []):
@@ -2833,6 +2837,12 @@ def check_auth(self, request: web.Request) -> Tuple[bool, Optional[Dict[str, Any
28332837
return False, None
28342838

28352839

2840+
def get_client_timeout(app=None):
2841+
app = app or g_app
2842+
timeout = app.limits.get("client_timeout", 120) if app else 120
2843+
return aiohttp.ClientTimeout(total=timeout)
2844+
2845+
28362846
class AppExtensions:
28372847
"""
28382848
APIs extensions can use to extend the app
@@ -2842,6 +2852,7 @@ def __init__(self, cli_args: argparse.Namespace, extra_args: Dict[str, Any]):
28422852
self.cli_args = cli_args
28432853
self.extra_args = extra_args
28442854
self.config = None
2855+
self.limits = DEFAULT_LIMITS
28452856
self.error_auth_required = create_error_response("Authentication required", "Unauthorized")
28462857
self.ui_extensions = []
28472858
self.chat_request_filters = []
@@ -2921,6 +2932,12 @@ def __init__(self, cli_args: argparse.Namespace, extra_args: Dict[str, Any]):
29212932

29222933
def set_config(self, config: Dict[str, Any]):
29232934
self.config = config
2935+
self.limits = self.config.get("limits", DEFAULT_LIMITS)
2936+
self.limits["client_timeout"] = self.limits.get("client_timeout", 120)
2937+
self.limits["client_max_size"] = self.limits.get("client_max_size", 20971520)
2938+
2939+
def get_client_timeout(self):
2940+
return get_client_timeout(self)
29242941

29252942
def set_allowed_directories(
29262943
self, directories: List[Annotated[str, "List of absolute paths that are allowed to be accessed."]]
@@ -3080,6 +3097,7 @@ class ExtensionContext:
30803097
def __init__(self, app: AppExtensions, path: str):
30813098
self.app = app
30823099
self.config = app.config
3100+
self.limits = app.limits
30833101
self.cli_args = app.cli_args
30843102
self.extra_args = app.extra_args
30853103
self.error_auth_required = app.error_auth_required
@@ -3098,6 +3116,9 @@ def __init__(self, app: AppExtensions, path: str):
30983116
self.oauth_states = app.oauth_states
30993117
self.disabled = False
31003118

3119+
def get_client_timeout(self):
3120+
return self.app.get_client_timeout()
3121+
31013122
def add_auth_provider(self, auth_provider: AuthProvider) -> None:
31023123
"""Add an authentication provider."""
31033124
self.app.add_auth_provider(auth_provider)

0 commit comments

Comments
 (0)