Skip to content

Commit aec1a71

Browse files
committed
fix(provider): 修复 OAuth 流程边界问题
1 parent abea526 commit aec1a71

File tree

3 files changed

+41
-17
lines changed

3 files changed

+41
-17
lines changed

astrbot/core/provider/oauth/openai_oauth.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ def parse_authorization_input(raw: str) -> tuple[str, str]:
5353
value = (raw or "").strip()
5454
if not value:
5555
raise ValueError("empty input")
56-
if "#" in value:
57-
code, state = value.split("#", 1)
58-
return code.strip(), state.strip()
5956
if "code=" in value:
6057
parsed = urlparse(value)
6158
if parsed.query:
6259
query = parse_qs(parsed.query)
6360
return query.get("code", [""])[0].strip(), query.get("state", [""])[0].strip()
6461
query = parse_qs(value)
6562
return query.get("code", [""])[0].strip(), query.get("state", [""])[0].strip()
63+
if "#" in value:
64+
code, state = value.split("#", 1)
65+
return code.strip(), state.strip()
6666
return value, ""
6767

6868

astrbot/core/provider/sources/openai_oauth_source.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,6 @@ def __init__(self, provider_config, provider_settings) -> None:
3535
patched_config.get("api_base")
3636
or "https://chatgpt.com/backend-api/codex"
3737
).rstrip("/")
38-
self.http_client = httpx.AsyncClient(
39-
proxy=patched_config.get("proxy") or None,
40-
timeout=self.timeout,
41-
follow_redirects=True,
42-
)
4338

4439
async def get_models(self):
4540
logger.info(
@@ -69,12 +64,17 @@ async def _request_backend(self, payload: dict[str, Any]) -> dict[str, Any]:
6964
for key, value in custom_headers.items():
7065
headers[str(key)] = str(value)
7166

72-
response = await self.http_client.post(
73-
f"{self.base_url}/responses",
74-
headers=headers,
75-
json=payload,
76-
)
77-
raw_text = await response.aread()
67+
async with httpx.AsyncClient(
68+
proxy=self.provider_config.get("proxy") or None,
69+
timeout=self.timeout,
70+
follow_redirects=True,
71+
) as client:
72+
response = await client.post(
73+
f"{self.base_url}/responses",
74+
headers=headers,
75+
json=payload,
76+
)
77+
raw_text = await response.aread()
7878
text = raw_text.decode("utf-8", errors="replace")
7979
if response.status_code < 200 or response.status_code >= 300:
8080
raise Exception(self._format_backend_error(response.status_code, text))

astrbot/dashboard/routes/config.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import copy
33
import inspect
44
import os
5+
import time
56
import traceback
67
from pathlib import Path
78
from typing import Any
@@ -45,6 +46,7 @@
4546
)
4647

4748
MAX_FILE_BYTES = 500 * 1024 * 1024
49+
OPENAI_OAUTH_FLOW_TTL_SECONDS = 10 * 60
4850

4951

5052
def try_cast(value: Any, type_: str):
@@ -350,7 +352,7 @@ def __init__(
350352
self._logo_token_cache = {} # 缓存logo token,避免重复注册
351353
self.acm = core_lifecycle.astrbot_config_mgr
352354
self.ucr = core_lifecycle.umop_config_router
353-
self._provider_source_oauth_flows: dict[str, dict[str, str]] = {}
355+
self._provider_source_oauth_flows: dict[str, dict[str, Any]] = {}
354356
self.routes = {
355357
"/config/abconf/new": ("POST", self.create_abconf),
356358
"/config/abconf": ("GET", self.get_abconf),
@@ -427,6 +429,25 @@ def _is_openai_oauth_supported_source(self, provider_source: dict) -> bool:
427429
and provider_source.get("type") == "openai_oauth_chat_completion"
428430
)
429431

432+
def _cleanup_expired_provider_source_oauth_flows(self) -> None:
433+
now = time.time()
434+
expired_source_ids = [
435+
source_id
436+
for source_id, flow in self._provider_source_oauth_flows.items()
437+
if now - float(flow.get("created_at") or 0) > OPENAI_OAUTH_FLOW_TTL_SECONDS
438+
]
439+
for source_id in expired_source_ids:
440+
self._provider_source_oauth_flows.pop(source_id, None)
441+
442+
def _create_provider_source_oauth_flow(self) -> dict[str, Any]:
443+
flow = create_pkce_flow()
444+
flow["created_at"] = time.time()
445+
return flow
446+
447+
def _get_provider_source_oauth_flow(self, source_id: str) -> dict[str, Any] | None:
448+
self._cleanup_expired_provider_source_oauth_flows()
449+
return self._provider_source_oauth_flows.get(source_id)
450+
430451
async def _reload_provider_source_providers(self, source_id: str) -> list[str]:
431452
prov_mgr = self.core_lifecycle.provider_manager
432453
reload_errors = []
@@ -474,7 +495,8 @@ async def start_provider_source_openai_oauth(self):
474495
_, _, provider_source = self._find_provider_source(source_id)
475496
if not self._is_openai_oauth_supported_source(provider_source):
476497
return Response().error("当前 provider source 不支持 OpenAI OAuth").__dict__
477-
flow = create_pkce_flow()
498+
self._cleanup_expired_provider_source_oauth_flows()
499+
flow = self._create_provider_source_oauth_flow()
478500
self._provider_source_oauth_flows[source_id] = flow
479501
return Response().ok(
480502
data={
@@ -489,9 +511,11 @@ async def complete_provider_source_openai_oauth(self):
489511
auth_input = post_data.get("input") or ""
490512
if not source_id:
491513
return Response().error("缺少 source_id").__dict__
492-
flow = self._provider_source_oauth_flows.get(source_id)
514+
flow = self._get_provider_source_oauth_flow(source_id)
493515
try:
494516
_, _, provider_source = self._find_provider_source(source_id)
517+
if not self._is_openai_oauth_supported_source(provider_source):
518+
return Response().error("当前 provider source 不支持 OpenAI OAuth").__dict__
495519
token = parse_oauth_credential_json(auth_input)
496520
if token is None:
497521
if not flow:

0 commit comments

Comments
 (0)