22import copy
33import inspect
44import os
5+ import time
56import traceback
67from pathlib import Path
78from typing import Any
4546)
4647
4748MAX_FILE_BYTES = 500 * 1024 * 1024
49+ OPENAI_OAUTH_FLOW_TTL_SECONDS = 10 * 60
4850
4951
5052def try_cast (value : Any , type_ : str ):
@@ -350,7 +352,7 @@ def __init__(
350352 self ._logo_token_cache = {} # 缓存logo token,避免重复注册
351353 self .acm = core_lifecycle .astrbot_config_mgr
352354 self .ucr = core_lifecycle .umop_config_router
353- self ._provider_source_oauth_flows : dict [str , dict [str , str ]] = {}
355+ self ._provider_source_oauth_flows : dict [str , dict [str , Any ]] = {}
354356 self .routes = {
355357 "/config/abconf/new" : ("POST" , self .create_abconf ),
356358 "/config/abconf" : ("GET" , self .get_abconf ),
@@ -427,6 +429,25 @@ def _is_openai_oauth_supported_source(self, provider_source: dict) -> bool:
427429 and provider_source .get ("type" ) == "openai_oauth_chat_completion"
428430 )
429431
432+ def _cleanup_expired_provider_source_oauth_flows (self ) -> None :
433+ now = time .time ()
434+ expired_source_ids = [
435+ source_id
436+ for source_id , flow in self ._provider_source_oauth_flows .items ()
437+ if now - float (flow .get ("created_at" ) or 0 ) > OPENAI_OAUTH_FLOW_TTL_SECONDS
438+ ]
439+ for source_id in expired_source_ids :
440+ self ._provider_source_oauth_flows .pop (source_id , None )
441+
442+ def _create_provider_source_oauth_flow (self ) -> dict [str , Any ]:
443+ flow = create_pkce_flow ()
444+ flow ["created_at" ] = time .time ()
445+ return flow
446+
447+ def _get_provider_source_oauth_flow (self , source_id : str ) -> dict [str , Any ] | None :
448+ self ._cleanup_expired_provider_source_oauth_flows ()
449+ return self ._provider_source_oauth_flows .get (source_id )
450+
430451 async def _reload_provider_source_providers (self , source_id : str ) -> list [str ]:
431452 prov_mgr = self .core_lifecycle .provider_manager
432453 reload_errors = []
@@ -474,7 +495,8 @@ async def start_provider_source_openai_oauth(self):
474495 _ , _ , provider_source = self ._find_provider_source (source_id )
475496 if not self ._is_openai_oauth_supported_source (provider_source ):
476497 return Response ().error ("当前 provider source 不支持 OpenAI OAuth" ).__dict__
477- flow = create_pkce_flow ()
498+ self ._cleanup_expired_provider_source_oauth_flows ()
499+ flow = self ._create_provider_source_oauth_flow ()
478500 self ._provider_source_oauth_flows [source_id ] = flow
479501 return Response ().ok (
480502 data = {
@@ -489,9 +511,11 @@ async def complete_provider_source_openai_oauth(self):
489511 auth_input = post_data .get ("input" ) or ""
490512 if not source_id :
491513 return Response ().error ("缺少 source_id" ).__dict__
492- flow = self ._provider_source_oauth_flows . get (source_id )
514+ flow = self ._get_provider_source_oauth_flow (source_id )
493515 try :
494516 _ , _ , provider_source = self ._find_provider_source (source_id )
517+ if not self ._is_openai_oauth_supported_source (provider_source ):
518+ return Response ().error ("当前 provider source 不支持 OpenAI OAuth" ).__dict__
495519 token = parse_oauth_credential_json (auth_input )
496520 if token is None :
497521 if not flow :
0 commit comments