Skip to content

Commit 99037a8

Browse files
committed
fix(ci): 修复格式检查与告警
1 parent aec1a71 commit 99037a8

File tree

4 files changed

+202
-72
lines changed

4 files changed

+202
-72
lines changed

astrbot/core/provider/manager.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,12 @@ def get_merged_provider_config(self, provider_config: dict) -> dict:
496496
merged_config = {**provider_source, **pc}
497497
# 保持 id 为 provider 的 id,而不是 source 的 id
498498
merged_config["id"] = pc["id"]
499-
merged_config["type"] = provider_source.get("type", merged_config.get("type"))
500-
merged_config["provider"] = provider_source.get("provider", merged_config.get("provider"))
499+
merged_config["type"] = provider_source.get(
500+
"type", merged_config.get("type")
501+
)
502+
merged_config["provider"] = provider_source.get(
503+
"provider", merged_config.get("provider")
504+
)
501505
merged_config["provider_type"] = provider_source.get(
502506
"provider_type", merged_config.get("provider_type")
503507
)
@@ -506,7 +510,9 @@ def get_merged_provider_config(self, provider_config: dict) -> dict:
506510
and merged_config.get("type") == "openai_oauth_chat_completion"
507511
and merged_config.get("auth_mode") == "openai_oauth"
508512
):
509-
access_token = (merged_config.get("oauth_access_token") or "").strip()
513+
access_token = (
514+
merged_config.get("oauth_access_token") or ""
515+
).strip()
510516
if access_token:
511517
merged_config["key"] = [access_token]
512518
pc = merged_config
@@ -527,7 +533,7 @@ def _resolve_env_key_list(self, provider_config: dict) -> dict:
527533
if env_val is None:
528534
provider_id = provider_config.get("id")
529535
logger.warning(
530-
f"Provider {provider_id} 配置项 key[{idx}] 使用环境变量 {env_key} 但未设置。",
536+
f"Provider {provider_id} 配置项 key[{idx}] 使用的环境变量未设置。",
531537
)
532538
resolved_keys.append("")
533539
else:

astrbot/core/provider/oauth/openai_oauth.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
def create_pkce_flow() -> dict[str, str]:
2121
state = secrets.token_hex(16)
2222
verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode().rstrip("=")
23-
challenge = base64.urlsafe_b64encode(
24-
hashlib.sha256(verifier.encode()).digest()
25-
).decode().rstrip("=")
23+
challenge = (
24+
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest())
25+
.decode()
26+
.rstrip("=")
27+
)
2628
return {
2729
"state": state,
2830
"verifier": verifier,
@@ -57,7 +59,9 @@ def parse_authorization_input(raw: str) -> tuple[str, str]:
5759
parsed = urlparse(value)
5860
if parsed.query:
5961
query = parse_qs(parsed.query)
60-
return query.get("code", [""])[0].strip(), query.get("state", [""])[0].strip()
62+
return query.get("code", [""])[0].strip(), query.get("state", [""])[
63+
0
64+
].strip()
6165
query = parse_qs(value)
6266
return query.get("code", [""])[0].strip(), query.get("state", [""])[0].strip()
6367
if "#" in value:
@@ -83,7 +87,9 @@ def parse_oauth_credential_json(raw: str) -> dict[str, Any] | None:
8387
expires_at = _normalize_expires_at(
8488
data.get("expired") or data.get("expires_at") or data.get("expires"),
8589
)
86-
account_id = str(data.get("account_id") or "").strip() or extract_account_id_from_jwt(access_token)
90+
account_id = str(
91+
data.get("account_id") or ""
92+
).strip() or extract_account_id_from_jwt(access_token)
8793
email = str(data.get("email") or "").strip() or extract_email_from_jwt(access_token)
8894
return {
8995
"access_token": access_token,
@@ -122,8 +128,12 @@ async def refresh_access_token(
122128
return await _request_token(payload, proxy_url)
123129

124130

125-
async def _request_token(payload: dict[str, str], proxy_url: str = "") -> dict[str, Any]:
126-
async with httpx.AsyncClient(proxy=proxy_url or None, timeout=OPENAI_OAUTH_TIMEOUT) as client:
131+
async def _request_token(
132+
payload: dict[str, str], proxy_url: str = ""
133+
) -> dict[str, Any]:
134+
async with httpx.AsyncClient(
135+
proxy=proxy_url or None, timeout=OPENAI_OAUTH_TIMEOUT
136+
) as client:
127137
response = await client.post(
128138
OPENAI_OAUTH_TOKEN_URL,
129139
data=payload,
@@ -134,7 +144,9 @@ async def _request_token(payload: dict[str, str], proxy_url: str = "") -> dict[s
134144
)
135145
data = response.json()
136146
if response.status_code < 200 or response.status_code >= 300:
137-
raise ValueError(f"oauth token request failed: status={response.status_code}, body={data}")
147+
raise ValueError(
148+
f"oauth token request failed: status={response.status_code}, body={data}"
149+
)
138150
access_token = (data.get("access_token") or "").strip()
139151
refresh_token = (data.get("refresh_token") or "").strip()
140152
expires_in = int(data.get("expires_in") or 0)

astrbot/core/provider/sources/openai_oauth_source.py

Lines changed: 108 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,16 @@ def __init__(self, provider_config, provider_settings) -> None:
2525
super().__init__(patched_config, provider_settings)
2626
self.provider_config = patched_config
2727
self.api_keys = [access_token] if access_token else self.api_keys
28-
self.chosen_api_key = access_token or (self.api_keys[0] if self.api_keys else "")
28+
self.chosen_api_key = access_token or (
29+
self.api_keys[0] if self.api_keys else ""
30+
)
2931
self.account_id = (
3032
patched_config.get("oauth_account_id")
3133
or patched_config.get("account_id")
3234
or ""
3335
).strip()
3436
self.base_url = (
35-
patched_config.get("api_base")
36-
or "https://chatgpt.com/backend-api/codex"
37+
patched_config.get("api_base") or "https://chatgpt.com/backend-api/codex"
3738
).rstrip("/")
3839

3940
async def get_models(self):
@@ -44,12 +45,18 @@ async def get_models(self):
4445
return []
4546

4647
async def _request_backend(self, payload: dict[str, Any]) -> dict[str, Any]:
47-
access_token = (self.provider_config.get("oauth_access_token") or self.chosen_api_key or "").strip()
48-
account_id = (self.provider_config.get("oauth_account_id") or self.account_id or "").strip()
48+
access_token = (
49+
self.provider_config.get("oauth_access_token") or self.chosen_api_key or ""
50+
).strip()
51+
account_id = (
52+
self.provider_config.get("oauth_account_id") or self.account_id or ""
53+
).strip()
4954
if not access_token:
5055
raise Exception("当前 OAuth Source 尚未绑定 access token")
5156
if not account_id:
52-
raise Exception("当前 OAuth Source 缺少 chatgpt_account_id,请重新绑定或导入完整 JSON 凭据")
57+
raise Exception(
58+
"当前 OAuth Source 缺少 chatgpt_account_id,请重新绑定或导入完整 JSON 凭据"
59+
)
5360

5461
headers = {
5562
"Authorization": f"Bearer {access_token}",
@@ -88,7 +95,9 @@ def _format_backend_error(self, status_code: int, text: str) -> str:
8895
data = json.loads(stripped)
8996
return f"Codex backend request failed: status={status_code}, body={data}"
9097
except Exception:
91-
return f"Codex backend request failed: status={status_code}, body={stripped}"
98+
return (
99+
f"Codex backend request failed: status={status_code}, body={stripped}"
100+
)
92101

93102
def _parse_backend_response(self, text: str) -> dict[str, Any]:
94103
completed_response: dict[str, Any] | None = None
@@ -123,10 +132,14 @@ def _parse_backend_response(self, text: str) -> dict[str, Any]:
123132
if stripped.startswith("{"):
124133
data = json.loads(stripped)
125134
if isinstance(data, dict):
126-
if data.get("type") == "response.completed" and isinstance(data.get("response"), dict):
135+
if data.get("type") == "response.completed" and isinstance(
136+
data.get("response"), dict
137+
):
127138
return data["response"]
128139
return data
129-
raise Exception("Codex backend response did not contain response.completed event")
140+
raise Exception(
141+
"Codex backend response did not contain response.completed event"
142+
)
130143

131144
def _convert_message_content(self, raw_content: Any) -> str | list[dict[str, Any]]:
132145
if isinstance(raw_content, str):
@@ -235,7 +248,9 @@ def _convert_messages_to_backend_input(
235248
if not name or not call_id:
236249
continue
237250
if not isinstance(arguments, str):
238-
arguments = json.dumps(arguments, ensure_ascii=False, default=str)
251+
arguments = json.dumps(
252+
arguments, ensure_ascii=False, default=str
253+
)
239254
response_items.append(
240255
{
241256
"type": "function_call",
@@ -244,7 +259,9 @@ def _convert_messages_to_backend_input(
244259
"arguments": arguments,
245260
}
246261
)
247-
return "\n\n".join(part for part in instructions_parts if part).strip(), response_items
262+
return "\n\n".join(
263+
part for part in instructions_parts if part
264+
).strip(), response_items
248265

249266
def _extract_response_usage(self, usage: Any) -> TokenUsage | None:
250267
if usage is None:
@@ -265,7 +282,9 @@ def _extract_response_usage(self, usage: Any) -> TokenUsage | None:
265282
output=output_tokens,
266283
)
267284

268-
def _convert_tools_to_backend_format(self, tool_list: list[dict[str, Any]]) -> list[dict[str, Any]]:
285+
def _convert_tools_to_backend_format(
286+
self, tool_list: list[dict[str, Any]]
287+
) -> list[dict[str, Any]]:
269288
backend_tools: list[dict[str, Any]] = []
270289
for tool in tool_list:
271290
if not isinstance(tool, dict):
@@ -283,7 +302,8 @@ def _convert_tools_to_backend_format(self, tool_list: list[dict[str, Any]]) -> l
283302
"type": "function",
284303
"name": name,
285304
"description": str(function.get("description") or "").strip(),
286-
"parameters": function.get("parameters") or {"type": "object", "properties": {}},
305+
"parameters": function.get("parameters")
306+
or {"type": "object", "properties": {}},
287307
}
288308
backend_tools.append(backend_tool)
289309
return backend_tools
@@ -298,43 +318,95 @@ async def _parse_responses_completion(self, response: Any, tools) -> LLMResponse
298318
if output_text:
299319
llm_response.result_chain = MessageChain().message(output_text)
300320

301-
output_items = list(response.get("output", []) if isinstance(response, dict) else getattr(response, "output", []) or [])
321+
output_items = list(
322+
response.get("output", [])
323+
if isinstance(response, dict)
324+
else getattr(response, "output", []) or []
325+
)
302326
reasoning_parts: list[str] = []
303327
tool_args: list[dict[str, Any]] = []
304328
tool_names: list[str] = []
305329
tool_ids: list[str] = []
306330

307331
for item in output_items:
308-
item_type = item.get("type") if isinstance(item, dict) else getattr(item, "type", None)
332+
item_type = (
333+
item.get("type")
334+
if isinstance(item, dict)
335+
else getattr(item, "type", None)
336+
)
309337
if item_type == "reasoning":
310-
summaries = item.get("summary", []) if isinstance(item, dict) else getattr(item, "summary", []) or []
338+
summaries = (
339+
item.get("summary", [])
340+
if isinstance(item, dict)
341+
else getattr(item, "summary", []) or []
342+
)
311343
for summary in summaries:
312-
text = summary.get("text") if isinstance(summary, dict) else getattr(summary, "text", None)
344+
text = (
345+
summary.get("text")
346+
if isinstance(summary, dict)
347+
else getattr(summary, "text", None)
348+
)
313349
if text:
314350
reasoning_parts.append(str(text))
315351
elif item_type == "function_call" and tools is not None:
316-
arguments = item.get("arguments", "{}") if isinstance(item, dict) else getattr(item, "arguments", "{}")
352+
arguments = (
353+
item.get("arguments", "{}")
354+
if isinstance(item, dict)
355+
else getattr(item, "arguments", "{}")
356+
)
317357
try:
318-
parsed_args = json.loads(arguments) if isinstance(arguments, str) else arguments
358+
parsed_args = (
359+
json.loads(arguments)
360+
if isinstance(arguments, str)
361+
else arguments
362+
)
319363
except Exception:
320364
parsed_args = {}
321365
tool_args.append(parsed_args if isinstance(parsed_args, dict) else {})
322-
tool_names.append(str(item.get("name", "") if isinstance(item, dict) else getattr(item, "name", "") or ""))
323-
tool_ids.append(str(item.get("call_id", "") if isinstance(item, dict) else getattr(item, "call_id", "") or ""))
366+
tool_names.append(
367+
str(
368+
item.get("name", "")
369+
if isinstance(item, dict)
370+
else getattr(item, "name", "") or ""
371+
)
372+
)
373+
tool_ids.append(
374+
str(
375+
item.get("call_id", "")
376+
if isinstance(item, dict)
377+
else getattr(item, "call_id", "") or ""
378+
)
379+
)
324380
elif item_type == "message" and not output_text:
325-
content_items = item.get("content", []) if isinstance(item, dict) else getattr(item, "content", []) or []
381+
content_items = (
382+
item.get("content", [])
383+
if isinstance(item, dict)
384+
else getattr(item, "content", []) or []
385+
)
326386
item_text_parts: list[str] = []
327387
for content in content_items:
328-
ctype = content.get("type") if isinstance(content, dict) else getattr(content, "type", None)
388+
ctype = (
389+
content.get("type")
390+
if isinstance(content, dict)
391+
else getattr(content, "type", None)
392+
)
329393
if ctype in {"output_text", "text"}:
330-
text = content.get("text") if isinstance(content, dict) else getattr(content, "text", None)
394+
text = (
395+
content.get("text")
396+
if isinstance(content, dict)
397+
else getattr(content, "text", None)
398+
)
331399
if text:
332400
item_text_parts.append(str(text))
333401
if item_text_parts:
334-
llm_response.result_chain = MessageChain().message("".join(item_text_parts).strip())
402+
llm_response.result_chain = MessageChain().message(
403+
"".join(item_text_parts).strip()
404+
)
335405

336406
if reasoning_parts:
337-
llm_response.reasoning_content = "\n".join(part for part in reasoning_parts if part)
407+
llm_response.reasoning_content = "\n".join(
408+
part for part in reasoning_parts if part
409+
)
338410

339411
if tool_args:
340412
llm_response.role = "tool"
@@ -346,10 +418,18 @@ async def _parse_responses_completion(self, response: Any, tools) -> LLMResponse
346418
raise Exception(f"账号态 responses 响应无法解析:{response}。")
347419

348420
llm_response.raw_completion = response
349-
response_id = response.get("id") if isinstance(response, dict) else getattr(response, "id", None)
421+
response_id = (
422+
response.get("id")
423+
if isinstance(response, dict)
424+
else getattr(response, "id", None)
425+
)
350426
if response_id:
351427
llm_response.id = response_id
352-
usage = self._extract_response_usage(response.get("usage") if isinstance(response, dict) else getattr(response, "usage", None))
428+
usage = self._extract_response_usage(
429+
response.get("usage")
430+
if isinstance(response, dict)
431+
else getattr(response, "usage", None)
432+
)
353433
if usage is not None:
354434
llm_response.usage = usage
355435
return llm_response

0 commit comments

Comments
 (0)