Skip to content

Commit abea526

Browse files
committed
feat(provider): 增加 ChatGPT/Codex OAuth 提供商源
1 parent 55ed028 commit abea526

File tree

10 files changed

+1207
-3
lines changed

10 files changed

+1207
-3
lines changed

astrbot/core/config/default.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,25 @@ class ChatProviderTemplate(TypedDict):
11341134
"proxy": "",
11351135
"custom_headers": {},
11361136
},
1137+
"ChatGPT/Codex OAuth": {
1138+
"id": "openai_oauth",
1139+
"provider": "openai",
1140+
"type": "openai_oauth_chat_completion",
1141+
"provider_type": "chat_completion",
1142+
"enable": True,
1143+
"key": [],
1144+
"api_base": "https://chatgpt.com/backend-api/codex",
1145+
"timeout": 120,
1146+
"proxy": "",
1147+
"custom_headers": {},
1148+
"auth_mode": "openai_oauth",
1149+
"oauth_provider": "openai",
1150+
"oauth_access_token": "",
1151+
"oauth_refresh_token": "",
1152+
"oauth_expires_at": "",
1153+
"oauth_account_email": "",
1154+
"oauth_account_id": "",
1155+
},
11371156
"Google Gemini": {
11381157
"id": "google_gemini",
11391158
"provider": "google",
@@ -1863,6 +1882,41 @@ class ChatProviderTemplate(TypedDict):
18631882
"items": {},
18641883
"hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。值必须为字符串。",
18651884
},
1885+
"auth_mode": {
1886+
"description": "认证方式",
1887+
"type": "string",
1888+
"invisible": True,
1889+
},
1890+
"oauth_provider": {
1891+
"description": "OAuth 提供方",
1892+
"type": "string",
1893+
"invisible": True,
1894+
},
1895+
"oauth_access_token": {
1896+
"description": "OAuth Access Token",
1897+
"type": "string",
1898+
"invisible": True,
1899+
},
1900+
"oauth_refresh_token": {
1901+
"description": "OAuth Refresh Token",
1902+
"type": "string",
1903+
"invisible": True,
1904+
},
1905+
"oauth_expires_at": {
1906+
"description": "OAuth 过期时间",
1907+
"type": "string",
1908+
"invisible": True,
1909+
},
1910+
"oauth_account_email": {
1911+
"description": "OAuth 账号邮箱",
1912+
"type": "string",
1913+
"invisible": True,
1914+
},
1915+
"oauth_account_id": {
1916+
"description": "OAuth 账号 ID",
1917+
"type": "string",
1918+
"invisible": True,
1919+
},
18661920
"ollama_disable_thinking": {
18671921
"description": "关闭思考模式",
18681922
"type": "bool",

astrbot/core/provider/manager.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,10 @@ def dynamic_import_provider(self, type: str) -> None:
357357
from .sources.openai_source import (
358358
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
359359
)
360+
case "openai_oauth_chat_completion":
361+
from .sources.openai_oauth_source import (
362+
ProviderOpenAIOAuth as ProviderOpenAIOAuth,
363+
)
360364
case "zhipu_chat_completion":
361365
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
362366
case "groq_chat_completion":
@@ -488,10 +492,23 @@ def get_merged_provider_config(self, provider_config: dict) -> dict:
488492
break
489493

490494
if provider_source:
491-
# 合并配置,provider 的配置优先级更高
495+
# 合并配置,provider 的业务字段优先,但 provider 类型应跟随 source。
492496
merged_config = {**provider_source, **pc}
493497
# 保持 id 为 provider 的 id,而不是 source 的 id
494498
merged_config["id"] = pc["id"]
499+
merged_config["type"] = provider_source.get("type", merged_config.get("type"))
500+
merged_config["provider"] = provider_source.get("provider", merged_config.get("provider"))
501+
merged_config["provider_type"] = provider_source.get(
502+
"provider_type", merged_config.get("provider_type")
503+
)
504+
if (
505+
merged_config.get("provider") == "openai"
506+
and merged_config.get("type") == "openai_oauth_chat_completion"
507+
and merged_config.get("auth_mode") == "openai_oauth"
508+
):
509+
access_token = (merged_config.get("oauth_access_token") or "").strip()
510+
if access_token:
511+
merged_config["key"] = [access_token]
495512
pc = merged_config
496513
return pc
497514

astrbot/core/provider/oauth/__init__.py

Whitespace-only changes.
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import base64
2+
import hashlib
3+
import json
4+
import secrets
5+
from datetime import UTC, datetime, timedelta
6+
from typing import Any
7+
from urllib.parse import parse_qs, urlencode, urlparse
8+
9+
import httpx
10+
11+
OPENAI_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
12+
OPENAI_OAUTH_AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize"
13+
OPENAI_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
14+
OPENAI_OAUTH_REDIRECT_URI = "http://localhost:1455/auth/callback"
15+
OPENAI_OAUTH_SCOPE = "openid profile email offline_access"
16+
OPENAI_OAUTH_TIMEOUT = 20.0
17+
OPENAI_OAUTH_ACCOUNT_CLAIM_PATH = "https://api.openai.com/auth"
18+
19+
20+
def create_pkce_flow() -> dict[str, str]:
21+
state = secrets.token_hex(16)
22+
verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode().rstrip("=")
23+
challenge = base64.urlsafe_b64encode(
24+
hashlib.sha256(verifier.encode()).digest()
25+
).decode().rstrip("=")
26+
return {
27+
"state": state,
28+
"verifier": verifier,
29+
"challenge": challenge,
30+
"authorize_url": build_authorize_url(state, challenge),
31+
}
32+
33+
34+
def build_authorize_url(state: str, challenge: str) -> str:
35+
query = urlencode(
36+
{
37+
"response_type": "code",
38+
"client_id": OPENAI_OAUTH_CLIENT_ID,
39+
"redirect_uri": OPENAI_OAUTH_REDIRECT_URI,
40+
"scope": OPENAI_OAUTH_SCOPE,
41+
"code_challenge": challenge,
42+
"code_challenge_method": "S256",
43+
"state": state,
44+
"id_token_add_organizations": "true",
45+
"codex_cli_simplified_flow": "true",
46+
"originator": "codex_cli_rs",
47+
}
48+
)
49+
return f"{OPENAI_OAUTH_AUTHORIZE_URL}?{query}"
50+
51+
52+
def parse_authorization_input(raw: str) -> tuple[str, str]:
53+
value = (raw or "").strip()
54+
if not value:
55+
raise ValueError("empty input")
56+
if "#" in value:
57+
code, state = value.split("#", 1)
58+
return code.strip(), state.strip()
59+
if "code=" in value:
60+
parsed = urlparse(value)
61+
if parsed.query:
62+
query = parse_qs(parsed.query)
63+
return query.get("code", [""])[0].strip(), query.get("state", [""])[0].strip()
64+
query = parse_qs(value)
65+
return query.get("code", [""])[0].strip(), query.get("state", [""])[0].strip()
66+
return value, ""
67+
68+
69+
def parse_oauth_credential_json(raw: str) -> dict[str, Any] | None:
70+
value = (raw or "").strip()
71+
if not value.startswith("{"):
72+
return None
73+
try:
74+
data = json.loads(value)
75+
except Exception as exc:
76+
raise ValueError(f"OAuth JSON 凭据解析失败: {exc}") from exc
77+
if not isinstance(data, dict):
78+
raise ValueError("OAuth JSON 凭据必须是对象")
79+
access_token = str(data.get("access_token") or "").strip()
80+
if not access_token:
81+
raise ValueError("OAuth JSON 凭据缺少 access_token")
82+
refresh_token = str(data.get("refresh_token") or "").strip()
83+
expires_at = _normalize_expires_at(
84+
data.get("expired") or data.get("expires_at") or data.get("expires"),
85+
)
86+
account_id = str(data.get("account_id") or "").strip() or extract_account_id_from_jwt(access_token)
87+
email = str(data.get("email") or "").strip() or extract_email_from_jwt(access_token)
88+
return {
89+
"access_token": access_token,
90+
"refresh_token": refresh_token,
91+
"expires_at": expires_at,
92+
"email": email,
93+
"account_id": account_id,
94+
"raw": data,
95+
}
96+
97+
98+
async def exchange_authorization_code(
99+
code: str,
100+
verifier: str,
101+
proxy_url: str = "",
102+
) -> dict[str, Any]:
103+
payload = {
104+
"grant_type": "authorization_code",
105+
"client_id": OPENAI_OAUTH_CLIENT_ID,
106+
"code": code.strip(),
107+
"code_verifier": verifier.strip(),
108+
"redirect_uri": OPENAI_OAUTH_REDIRECT_URI,
109+
}
110+
return await _request_token(payload, proxy_url)
111+
112+
113+
async def refresh_access_token(
114+
refresh_token: str,
115+
proxy_url: str = "",
116+
) -> dict[str, Any]:
117+
payload = {
118+
"grant_type": "refresh_token",
119+
"client_id": OPENAI_OAUTH_CLIENT_ID,
120+
"refresh_token": refresh_token.strip(),
121+
}
122+
return await _request_token(payload, proxy_url)
123+
124+
125+
async def _request_token(payload: dict[str, str], proxy_url: str = "") -> dict[str, Any]:
126+
async with httpx.AsyncClient(proxy=proxy_url or None, timeout=OPENAI_OAUTH_TIMEOUT) as client:
127+
response = await client.post(
128+
OPENAI_OAUTH_TOKEN_URL,
129+
data=payload,
130+
headers={
131+
"Accept": "application/json",
132+
"Content-Type": "application/x-www-form-urlencoded",
133+
},
134+
)
135+
data = response.json()
136+
if response.status_code < 200 or response.status_code >= 300:
137+
raise ValueError(f"oauth token request failed: status={response.status_code}, body={data}")
138+
access_token = (data.get("access_token") or "").strip()
139+
refresh_token = (data.get("refresh_token") or "").strip()
140+
expires_in = int(data.get("expires_in") or 0)
141+
if not access_token or not refresh_token or expires_in <= 0:
142+
raise ValueError("oauth token response missing required fields")
143+
expires_at = datetime.now(UTC) + timedelta(seconds=expires_in)
144+
return {
145+
"access_token": access_token,
146+
"refresh_token": refresh_token,
147+
"expires_at": expires_at.isoformat(),
148+
"email": extract_email_from_jwt(access_token),
149+
"account_id": extract_account_id_from_jwt(access_token),
150+
"raw": data,
151+
}
152+
153+
154+
def extract_email_from_jwt(token: str) -> str:
155+
claims = decode_jwt_claims(token)
156+
email = claims.get("email")
157+
return email.strip() if isinstance(email, str) else ""
158+
159+
160+
def extract_account_id_from_jwt(token: str) -> str:
161+
claims = decode_jwt_claims(token)
162+
raw = claims.get(OPENAI_OAUTH_ACCOUNT_CLAIM_PATH)
163+
if not isinstance(raw, dict):
164+
return ""
165+
account_id = raw.get("chatgpt_account_id")
166+
return account_id.strip() if isinstance(account_id, str) else ""
167+
168+
169+
def decode_jwt_claims(token: str) -> dict[str, Any]:
170+
parts = token.split(".")
171+
if len(parts) < 2:
172+
return {}
173+
payload = parts[1]
174+
padding = "=" * (-len(payload) % 4)
175+
try:
176+
decoded = base64.urlsafe_b64decode(payload + padding)
177+
obj = json.loads(decoded.decode())
178+
return obj if isinstance(obj, dict) else {}
179+
except Exception:
180+
return {}
181+
182+
183+
def _normalize_expires_at(value: Any) -> str:
184+
if value is None:
185+
return ""
186+
if isinstance(value, (int, float)):
187+
try:
188+
return datetime.fromtimestamp(float(value), UTC).isoformat()
189+
except Exception:
190+
return ""
191+
if isinstance(value, str):
192+
stripped = value.strip()
193+
if not stripped:
194+
return ""
195+
try:
196+
if stripped.endswith("Z"):
197+
stripped = stripped[:-1] + "+00:00"
198+
return datetime.fromisoformat(stripped).isoformat()
199+
except Exception:
200+
return value.strip()
201+
return ""

0 commit comments

Comments
 (0)