Skip to content

Commit def2fbf

Browse files
committed
适应企业版和标准版geminicli
1 parent 2acbad1 commit def2fbf

5 files changed

Lines changed: 81 additions & 125 deletions

File tree

src/auth.py

Lines changed: 59 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Any, Dict, Optional
1313
from urllib.parse import parse_qs, urlparse
1414

15-
from config import get_config_value, get_antigravity_api_url, get_code_assist_endpoint
15+
from config import get_config_value, get_antigravity_api_url
1616
from log import log
1717

1818
from .google_oauth_api import (
@@ -33,7 +33,6 @@
3333
CLIENT_ID,
3434
CLIENT_SECRET,
3535
SCOPES,
36-
GEMINICLI_USER_AGENT,
3736
TOKEN_URL,
3837
)
3938

@@ -691,66 +690,46 @@ async def asyncio_complete_auth_flow(
691690

692691
# 如果需要自动检测项目ID且没有提供项目ID(标准模式)
693692
if flow_data.get("auto_project_detection", False) and not project_id:
694-
log.info("标准模式:从API获取project_id...")
695-
# 使用API获取project_id(使用标准模式的User-Agent)
696-
code_assist_url = await get_code_assist_endpoint()
697-
project_id, subscription_tier = await fetch_project_id_and_tier(
698-
credentials.access_token,
699-
GEMINICLI_USER_AGENT,
700-
code_assist_url
701-
)
702-
if project_id:
703-
flow_data["project_id"] = project_id
704-
log.info(f"成功从API获取project_id: {project_id}")
705-
# 自动启用必需的API服务
706-
log.info("正在自动启用必需的API服务...")
707-
await enable_required_apis(credentials, project_id)
708-
else:
709-
log.warning("无法从API获取project_id,回退到项目列表获取方式")
710-
# 回退到原来的项目列表获取方式
711-
user_projects = await get_user_projects(credentials)
712-
713-
if user_projects:
714-
# 如果只有一个项目,自动使用
715-
if len(user_projects) == 1:
716-
# Google API returns projectId in camelCase
717-
project_id = user_projects[0].get("projectId")
718-
if project_id:
719-
flow_data["project_id"] = project_id
720-
log.info(f"自动选择唯一项目: {project_id}")
721-
# 自动启用必需的API服务
722-
log.info("正在自动启用必需的API服务...")
723-
await enable_required_apis(credentials, project_id)
724-
# 如果有多个项目,尝试选择默认项目
725-
else:
726-
project_id = await select_default_project(user_projects)
727-
if project_id:
728-
flow_data["project_id"] = project_id
729-
log.info(f"自动选择默认项目: {project_id}")
730-
# 自动启用必需的API服务
731-
log.info("正在自动启用必需的API服务...")
732-
await enable_required_apis(credentials, project_id)
733-
else:
734-
# 返回项目列表让用户选择
735-
return {
736-
"success": False,
737-
"error": "请从以下项目中选择一个",
738-
"requires_project_selection": True,
739-
"available_projects": [
740-
{
741-
# Google API returns projectId in camelCase
742-
"project_id": p.get("projectId"),
743-
"name": p.get("displayName") or p.get("projectId"),
744-
"projectNumber": p.get("projectNumber"),
745-
}
746-
for p in user_projects
747-
],
748-
}
693+
log.info("标准模式:通过项目列表获取project_id...")
694+
user_projects = await get_user_projects(credentials)
695+
696+
if user_projects:
697+
# 如果只有一个项目,自动使用
698+
if len(user_projects) == 1:
699+
project_id = user_projects[0].get("projectId")
700+
if project_id:
701+
flow_data["project_id"] = project_id
702+
log.info(f"自动选择唯一项目: {project_id}")
703+
log.info("正在自动启用必需的API服务...")
704+
await enable_required_apis(credentials, project_id)
705+
# 如果有多个项目,尝试选择默认项目
749706
else:
750-
# 如果无法获取项目列表,使用默认project_id
751-
project_id = DEFAULT_PROJECT_ID
752-
flow_data["project_id"] = project_id
753-
log.warning(f"无法获取项目列表,使用默认project_id: {project_id}")
707+
project_id = await select_default_project(user_projects)
708+
if project_id:
709+
flow_data["project_id"] = project_id
710+
log.info(f"自动选择默认项目: {project_id}")
711+
log.info("正在自动启用必需的API服务...")
712+
await enable_required_apis(credentials, project_id)
713+
else:
714+
# 返回项目列表让用户选择
715+
return {
716+
"success": False,
717+
"error": "请从以下项目中选择一个",
718+
"requires_project_selection": True,
719+
"available_projects": [
720+
{
721+
"project_id": p.get("projectId"),
722+
"name": p.get("displayName") or p.get("projectId"),
723+
"projectNumber": p.get("projectNumber"),
724+
}
725+
for p in user_projects
726+
],
727+
}
728+
else:
729+
# 如果无法获取项目列表,使用默认project_id
730+
project_id = DEFAULT_PROJECT_ID
731+
flow_data["project_id"] = project_id
732+
log.warning(f"无法获取项目列表,使用默认project_id: {project_id}")
754733
elif project_id:
755734
# 如果已经有项目ID(手动提供或环境检测),也尝试启用API服务
756735
log.info("正在为已提供的项目ID自动启用必需的API服务...")
@@ -868,45 +847,28 @@ async def complete_auth_flow_from_callback_url(
868847
subscription_tier = None
869848

870849
if not project_id:
871-
# 尝试使用fetch_project_id_and_tier自动获取项目ID
850+
# 通过项目列表获取项目ID
872851
try:
873-
log.info("标准模式:从API获取project_id...")
874-
code_assist_url = await get_code_assist_endpoint()
875-
detected_project_id, subscription_tier = await fetch_project_id_and_tier(
876-
credentials.access_token,
877-
GEMINICLI_USER_AGENT,
878-
code_assist_url
879-
)
880-
if detected_project_id:
881-
auto_detected = True
882-
log.info(f"成功从API获取project_id: {detected_project_id}, tier: {subscription_tier}")
883-
else:
884-
log.warning("无法从API获取project_id,回退到项目列表获取方式")
885-
# 回退到原来的项目列表获取方式
886-
projects = await get_user_projects(credentials)
887-
if projects:
888-
if len(projects) == 1:
889-
# 只有一个项目,自动使用
890-
# Google API returns projectId in camelCase
891-
detected_project_id = projects[0]["projectId"]
892-
auto_detected = True
893-
log.info(f"自动检测到唯一项目ID: {detected_project_id}")
894-
else:
895-
# 多个项目,自动选择第一个
896-
# Google API returns projectId in camelCase
897-
detected_project_id = projects[0]["projectId"]
898-
auto_detected = True
899-
log.info(
900-
f"检测到{len(projects)}个项目,自动选择第一个: {detected_project_id}"
901-
)
902-
log.debug(f"其他可用项目: {[p['projectId'] for p in projects[1:]]}")
852+
log.info("标准模式:通过项目列表获取project_id...")
853+
projects = await get_user_projects(credentials)
854+
if projects:
855+
if len(projects) == 1:
856+
detected_project_id = projects[0]["projectId"]
857+
auto_detected = True
858+
log.info(f"自动检测到唯一项目ID: {detected_project_id}")
903859
else:
904-
# 没有项目访问权限,使用默认project_id
905-
detected_project_id = DEFAULT_PROJECT_ID
906-
auto_detected = False
907-
log.warning(f"未检测到可访问项目,使用默认project_id: {detected_project_id}")
860+
detected_project_id = projects[0]["projectId"]
861+
auto_detected = True
862+
log.info(
863+
f"检测到{len(projects)}个项目,自动选择第一个: {detected_project_id}"
864+
)
865+
log.debug(f"其他可用项目: {[p['projectId'] for p in projects[1:]]}")
866+
else:
867+
detected_project_id = DEFAULT_PROJECT_ID
868+
auto_detected = False
869+
log.warning(f"未检测到可访问项目,使用默认project_id: {detected_project_id}")
908870
except Exception as e:
909-
log.warning(f"自动检测项目ID失败: {e},使用默认project_id")
871+
log.warning(f"获取项目列表失败: {e},使用默认project_id")
910872
detected_project_id = DEFAULT_PROJECT_ID
911873
auto_detected = False
912874
else:

src/panel/creds.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from src.storage_adapter import get_storage_adapter
2323
from src.utils import verify_panel_token, GEMINICLI_USER_AGENT, ANTIGRAVITY_USER_AGENT
2424
from src.api.antigravity import fetch_quota_info
25-
from src.google_oauth_api import Credentials, fetch_project_id_and_tier
25+
from src.google_oauth_api import Credentials, fetch_project_id_and_tier, get_user_projects, select_default_project, enable_required_apis
2626
from config import get_code_assist_endpoint, get_antigravity_api_url
2727
from .utils import validate_mode
2828

@@ -557,29 +557,35 @@ async def verify_credential_project_common(filename: str, mode: str = "geminicli
557557
credential_data = credentials.to_dict()
558558
await storage_adapter.store_credential(filename, credential_data, mode=mode)
559559

560-
# 获取API端点和对应的User-Agent
560+
# 重新获取project id(仅 antigravity 模式请求积分)
561561
if mode == "antigravity":
562562
api_base_url = await get_antigravity_api_url()
563563
user_agent = ANTIGRAVITY_USER_AGENT
564-
else:
565-
api_base_url = await get_code_assist_endpoint()
566-
user_agent = GEMINICLI_USER_AGENT
567-
568-
# 重新获取project id(仅 antigravity 模式请求积分)
569-
if mode == "antigravity":
570564
project_id, subscription_tier, credit_amount = await fetch_project_id_and_tier(
571565
access_token=credentials.access_token,
572566
user_agent=user_agent,
573567
api_base_url=api_base_url,
574568
include_credits=True,
575569
)
576570
else:
577-
project_id, subscription_tier = await fetch_project_id_and_tier(
578-
access_token=credentials.access_token,
579-
user_agent=user_agent,
580-
api_base_url=api_base_url,
581-
)
571+
# geminicli 模式:通过项目列表获取 project_id
582572
credit_amount = None
573+
subscription_tier = None
574+
user_projects = await get_user_projects(credentials)
575+
if user_projects:
576+
if len(user_projects) == 1:
577+
project_id = user_projects[0].get("projectId")
578+
else:
579+
project_id = await select_default_project(user_projects)
580+
else:
581+
project_id = None
582+
583+
if project_id:
584+
log.info(f"正在为项目 {project_id} 启用必需的API服务...")
585+
try:
586+
await enable_required_apis(credentials, project_id)
587+
except Exception as e:
588+
log.warning(f"启用API服务失败: {e}")
583589

584590
if project_id:
585591
credential_data["project_id"] = project_id

src/storage/mongodb_manager.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ async def get_next_available_credential(
518518
# Redis 快速路径:根据模型名派生过滤标志,直接在 Redis 分桶中筛选
519519
if self._redis_enabled:
520520
model_lower = model_name.lower() if model_name else ""
521-
exclude_free = mode == "geminicli" and "pro" in model_lower
521+
exclude_free = False
522522
preview_only = mode == "geminicli" and "preview" in model_lower
523523
result = await self._get_next_available_from_redis(
524524
mode, model_name, exclude_free_tier=exclude_free, preview_only=preview_only
@@ -536,10 +536,6 @@ async def get_next_available_credential(
536536
# 构建普通查询(避免 $sample 聚合导致全集合扫描)
537537
match_query: Dict[str, Any] = {"disabled": False}
538538

539-
# pro 模型只允许非 free tier 凭证
540-
if mode == "geminicli" and model_name and "pro" in model_name.lower():
541-
match_query["tier"] = {"$ne": "free"}
542-
543539
# preview 模型只允许 preview=True 的凭证
544540
if mode == "geminicli" and model_name and "preview" in model_name.lower():
545541
match_query["preview"] = True

src/storage/psql_manager.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,14 +254,10 @@ async def get_next_available_credential(
254254

255255
async with self._pool.acquire() as conn:
256256
if mode == "geminicli":
257-
tier_clause = ""
258-
if model_name and "pro" in model_name.lower():
259-
tier_clause = "AND (tier IS NULL OR tier != 'free')"
260-
261257
rows = await conn.fetch(f"""
262258
SELECT filename, credential_data, model_cooldowns, preview
263259
FROM {table_name}
264-
WHERE disabled = 0 {tier_clause}
260+
WHERE disabled = 0
265261
ORDER BY RANDOM()
266262
""")
267263

src/storage/sqlite_manager.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -393,14 +393,10 @@ async def get_next_available_credential(
393393
current_time = time.time()
394394

395395
if mode == "geminicli":
396-
tier_clause = ""
397-
if model_name and "pro" in model_name.lower():
398-
tier_clause = "AND (tier IS NULL OR tier != 'free')"
399-
400396
async with db.execute(f"""
401397
SELECT filename, credential_data, model_cooldowns, preview
402398
FROM {table_name}
403-
WHERE disabled = 0 {tier_clause}
399+
WHERE disabled = 0
404400
ORDER BY RANDOM()
405401
""") as cursor:
406402
rows = await cursor.fetchall()

0 commit comments

Comments
 (0)