diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 0fd62b5991..409d703cf7 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -18,7 +18,7 @@ Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to a A clear and concise description of what the bug is. **CLI Type** -What type of CLI account do you use? (gemini-cli, gemini, codex, claude code or openai-compatibility) +What type of CLI account do you use? (gemini, codex, claude code or openai-compatibility) **Model Name** What model are you using? (example: gemini-2.5-pro, claude-sonnet-4-20250514, gpt-5, etc.) diff --git a/.gitignore b/.gitignore index 9824a36d8d..728fa95950 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ static/* # Authentication data auths/* +/auths !auths/.gitkeep # Documentation @@ -38,6 +39,7 @@ GEMINI.md .worktrees/ .codex/* .claude/* +.claude .gemini/* .serena/* .agent/* diff --git a/README.md b/README.md index 8410b6b65c..7dac54eef1 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ PackyCode provides special discounts for our software users: register using AICodeMirror -Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for CLIProxyAPI users: register via this link to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off! +Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for CLIProxyAPI users: register via this link to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off! BmoPlus @@ -66,7 +66,6 @@ PackyCode provides special discounts for our software users: register using [!NOTE] > If you developed a project based on CLIProxyAPI, please open a PR to add it to this list. diff --git a/README_CN.md b/README_CN.md index 071366e931..78bb365e75 100644 --- a/README_CN.md +++ b/README_CN.md @@ -24,7 +24,7 @@ PackyCode 为本软件用户提供了特别优惠:使用AICodeMirror -感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过此链接注册的用户,可享受首充8折,企业客户最高可享 7.5 折! +感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过此链接注册的用户,可享受首充8折,企业客户最高可享 7.5 折! BmoPlus @@ -67,7 +67,6 @@ PackyCode 为本软件用户提供了特别优惠:使用AICodeMirror -AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:こちらのリンクから登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます! +AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:こちらのリンクから登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます! BmoPlus @@ -66,7 +66,6 @@ PackyCodeは当ソフトウェアのユーザーに特別割引を提供して - シンプルなCLI認証フロー(Gemini、OpenAI、Claude、Grok) - Generative Language APIキーのサポート - AI Studioビルドのマルチアカウント負荷分散 -- Gemini CLIのマルチアカウント負荷分散 - Claude Codeのマルチアカウント負荷分散 - OpenAI Codexのマルチアカウント負荷分散 - Grok Buildのマルチアカウント負荷分散 @@ -153,7 +152,7 @@ PowerShellスクリプトで実装されたWindowsトレイアプリケーショ ### [霖君](https://github.com/wangdabaoqq/LinJun) -霖君はAIプログラミングアシスタントを管理するクロスプラットフォームデスクトップアプリケーションで、macOS、Windows、Linuxシステムに対応。Claude Code、Gemini CLI、OpenAI Codexなどのコーディングツールを統合管理し、ローカルプロキシによるマルチアカウントクォータ追跡とワンクリック設定が可能 +霖君はAIプログラミングアシスタントを管理するクロスプラットフォームデスクトップアプリケーションで、macOS、Windows、Linuxシステムに対応。Claude Code、Gemini、OpenAI Codexなどのコーディングツールを統合管理し、ローカルプロキシによるマルチアカウントクォータ追跡とワンクリック設定が可能 ### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard) @@ -185,11 +184,11 @@ AIコーディングアシスタント向けのマルチエージェントオー ### [Tunnel Agent](https://github.com/Villoh/tunnel-agent) -CLIProxyAPIとPerplexity WebUI Scraperをひとつのインターフェースで管理するWindowsデスクトップUI。QuotioとVibeProxyにインスパイアされ、OAuthプロバイダー(Claude、Gemini CLI、Codex、Kimi、Antigravity)、カスタムAPIキー、Perplexityセッションアカウントを接続し、任意のコーディングエージェントをローカルエンドポイントに向けることができます。 +CLIProxyAPIとPerplexity WebUI Scraperをひとつのインターフェースで管理するWindowsデスクトップUI。QuotioとVibeProxyにインスパイアされ、OAuthプロバイダー(Claude、Gemini、Codex、Kimi、Antigravity)、カスタムAPIキー、Perplexityセッションアカウントを接続し、任意のコーディングエージェントをローカルエンドポイントに向けることができます。 ### [Quotio Desktop](https://github.com/xiaocoss/quotio-desktop) -Quotio のクロスプラットフォーム(Tauri)移植版(Windows / macOS / Linux 対応)。CLIProxyAPI 経由で複数の AI アカウント(Codex、Claude Code、GitHub Copilot、Gemini CLI、Antigravity、Kiro、Cursor、Trae、GLM)のプールを管理し、アカウントごとの 5 時間 / 週間クォータバー、Codex のリセットクレジットとワンクリックリセット、スマートスケジューリング、使用統計、Codex マルチインスタンスに対応。API キー不要。 +Quotio のクロスプラットフォーム(Tauri)移植版(Windows / macOS / Linux 対応)。CLIProxyAPI 経由で複数の AI アカウント(Codex、Claude Code、GitHub Copilot、Gemini、Antigravity、Kiro、Cursor、Trae、GLM)のプールを管理し、アカウントごとの 5 時間 / 週間クォータバー、Codex のリセットクレジットとワンクリックリセット、スマートスケジューリング、使用統計、Codex マルチインスタンスに対応。API キー不要。 > [!NOTE] > CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。 diff --git a/cmd/server/main.go b/cmd/server/main.go index a56ba3072a..dde0678c79 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -70,7 +70,6 @@ func main() { fmt.Printf("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s\n", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate) // Command-line flags to control the application's behavior. - var login bool var codexLogin bool var codexDeviceLogin bool var claudeLogin bool @@ -79,7 +78,6 @@ func main() { var antigravityLogin bool var kimiLogin bool var xaiLogin bool - var projectID string var vertexImport string var vertexImportPrefix string var configPath string @@ -91,7 +89,6 @@ func main() { var localModel bool // Define command-line flags for different operation modes. - flag.BoolVar(&login, "login", false, "Login Google Account") flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth") flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow") flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth") @@ -100,7 +97,6 @@ func main() { flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth") flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth") flag.BoolVar(&xaiLogin, "xai-login", false, "Login to xAI using OAuth") - flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)") @@ -531,7 +527,7 @@ func main() { CallbackPort: oauthCallbackPort, } - commandMode := vertexImport != "" || login || antigravityLogin || codexLogin || codexDeviceLogin || claudeLogin || kimiLogin || xaiLogin + commandMode := vertexImport != "" || antigravityLogin || codexLogin || codexDeviceLogin || claudeLogin || kimiLogin || xaiLogin cloudConfigMissing := isCloudDeploy && !configFileExists homeMode := configLoadedFromHome || (cfg != nil && cfg.Home.Enabled) if shouldStartExampleAPIKeyWarningServer(cfg, commandMode, tuiMode, standalone, cloudConfigMissing, homeMode) { @@ -569,9 +565,6 @@ func main() { if vertexImport != "" { // Handle Vertex service account import cmd.DoVertexImport(cfg, vertexImport, vertexImportPrefix) - } else if login { - // Handle Google/Gemini login - cmd.DoLogin(cfg, projectID, options) } else if antigravityLogin { // Handle Antigravity login cmd.DoAntigravityLogin(cfg, options) diff --git a/config.example.yaml b/config.example.yaml index 101f23916f..c480b2f531 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -181,10 +181,6 @@ codex: # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: true -# When true, enable Gemini CLI internal endpoints (/v1internal:*). -# Default is false for safety. -enable-gemini-cli-endpoint: false - # When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts. nonstream-keepalive-interval: 0 # Streaming behavior (SSE keep-alives + safe bootstrap retries). @@ -260,6 +256,7 @@ nonstream-keepalive-interval: 0 # - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219) # - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking) # - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022) +# rebuild-mid-system-message: false # optional: default is false; when true, move messages with role "system" into the top-level Claude system field # cloak: # optional: request cloaking for non-Claude-Code clients # mode: "auto" # "auto" (default): cloak only when client is not Claude Code # # "always": always apply cloaking @@ -352,17 +349,24 @@ nonstream-keepalive-interval: 0 # Global OAuth model name aliases (per channel) # These aliases rename model IDs for both model listing and request routing. -# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi, xai. +# Supported channels: vertex, aistudio, antigravity, claude, codex, kimi, xai. # NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, or vertex-api-key. # NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping # client-visible names can become ambiguous across providers. For strict backend pinning, use # unique aliases/prefixes or avoid overlapping names. # You can repeat the same name with different aliases to expose multiple client model names. +# Per-auth OAuth aliases can also be stored in an OAuth auth JSON file as "model-aliases". +# They apply only to that selected auth and take precedence over global aliases for the same client-visible alias. +# Example auth JSON: +# { +# "type": "codex", +# "email": "user@example.com", +# "model-aliases": [ +# {"name": "gpt-5.3-codex-spark", "alias": "gpt-5.5"}, +# {"name": "gpt-5.3-codex-spark", "alias": "gpt-5.4"} +# ] +# } # oauth-model-alias: -# gemini-cli: -# - name: "gemini-2.5-pro" # original model name under this channel -# alias: "g2.5p" # client-visible alias -# fork: true # when true, keep original and also add the alias as an extra model (default: false) # vertex: # - name: "gemini-2.5-pro" # alias: "g2.5p" @@ -390,11 +394,6 @@ nonstream-keepalive-interval: 0 # OAuth provider excluded models # oauth-excluded-models: -# gemini-cli: -# - "gemini-2.5-pro" # exclude specific models (exact match) -# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro) -# - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview) -# - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite) # vertex: # - "gemini-3-pro-preview" # aistudio: diff --git a/go.mod b/go.mod index 3418dbadd5..c83d19ce95 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 - github.com/jackc/pgx/v5 v5.7.6 + github.com/jackc/pgx/v5 v5.9.2 github.com/joho/godotenv v1.5.1 github.com/klauspost/compress v1.17.4 github.com/minio/minio-go/v7 v7.0.66 diff --git a/go.sum b/go.sum index 5f0a03fbef..d9f1ac7f8a 100644 --- a/go.sum +++ b/go.sum @@ -104,6 +104,8 @@ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7Ulw github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw= +github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= diff --git a/internal/api/handlers/management/api_key_usage.go b/internal/api/handlers/management/api_key_usage.go index dbe6fbd998..88ee8b326a 100644 --- a/internal/api/handlers/management/api_key_usage.go +++ b/internal/api/handlers/management/api_key_usage.go @@ -40,6 +40,19 @@ func mergeRecentRequestBuckets(dst, src []coreauth.RecentRequestBucket) []coreau return dst } +func apiKeyUsageProviderKey(auth *coreauth.Auth) string { + provider := strings.ToLower(strings.TrimSpace(auth.Provider)) + if auth.Attributes != nil { + if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" { + provider = strings.ToLower(compatName) + } + } + if provider == "" { + return "unknown" + } + return provider +} + // GetAPIKeyUsage returns recent request buckets for all in-memory api_key auths, // grouped by provider and keyed by "base_url|api_key". func (h *Handler) GetAPIKeyUsage(c *gin.Context) { @@ -78,10 +91,7 @@ func (h *Handler) GetAPIKeyUsage(c *gin.Context) { } } compositeKey := baseURL + "|" + apiKey - provider := strings.ToLower(strings.TrimSpace(auth.Provider)) - if provider == "" { - provider = "unknown" - } + provider := apiKeyUsageProviderKey(auth) recent := auth.RecentRequestsSnapshot(now) providerBucket, ok := out[provider] diff --git a/internal/api/handlers/management/api_key_usage_test.go b/internal/api/handlers/management/api_key_usage_test.go index 70d9b11e92..c933e74e67 100644 --- a/internal/api/handlers/management/api_key_usage_test.go +++ b/internal/api/handlers/management/api_key_usage_test.go @@ -92,3 +92,51 @@ func TestGetAPIKeyUsage_GroupsByProviderAndAPIKey(t *testing.T) { t.Fatalf("claude totals = %d/%d, want 1/0", claudeSuccess, claudeFailed) } } + +func TestGetAPIKeyUsage_GroupsOpenAICompatibleByCompatName(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + + manager := coreauth.NewManager(nil, nil, nil) + if _, err := manager.Register(context.Background(), &coreauth.Auth{ + ID: "vast-auth", + Provider: "openai-compatible-vast", + Attributes: map[string]string{ + "api_key": "vast-key", + "base_url": "https://www.vastnum.com/v1", + "compat_name": "VAST", + }, + }); err != nil { + t.Fatalf("register vast auth: %v", err) + } + + manager.MarkResult(context.Background(), coreauth.Result{AuthID: "vast-auth", Provider: "openai-compatible-vast", Model: "gpt-5", Success: true}) + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager) + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodGet, "/v0/management/api-key-usage", nil) + ginCtx.Request = req + h.GetAPIKeyUsage(ginCtx) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var payload map[string]map[string]apiKeyUsageEntry + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode payload: %v", err) + } + + if _, exists := payload["openai-compatible-vast"]; exists { + t.Fatalf("unexpected namespaced provider bucket in payload: %#v", payload) + } + vastBucket, exists := payload["vast"] + if !exists { + t.Fatalf("missing compat provider bucket in payload: %#v", payload) + } + vastEntry := vastBucket["https://www.vastnum.com/v1|vast-key"] + if vastEntry.Success != 1 || vastEntry.Failed != 0 { + t.Fatalf("vast totals = %d/%d, want 1/0", vastEntry.Success, vastEntry.Failed) + } +} diff --git a/internal/api/handlers/management/api_tools.go b/internal/api/handlers/management/api_tools.go index f10850701a..334099c423 100644 --- a/internal/api/handlers/management/api_tools.go +++ b/internal/api/handlers/management/api_tools.go @@ -12,27 +12,13 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/geminicli" coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v7/sdk/proxyutil" log "github.com/sirupsen/logrus" - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" ) const defaultAPICallTimeout = 60 * time.Second -const ( - geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - geminiOAuthClientSecret = "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" -) - -var geminiOAuthScopes = []string{ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", -} - const ( antigravityOAuthClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" antigravityOAuthClientSecret = "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" @@ -240,11 +226,6 @@ func tokenValueForAuth(auth *coreauth.Auth) string { return v } } - if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil { - if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" { - return v - } - } return "" } @@ -253,12 +234,7 @@ func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) return "", nil } - provider := strings.ToLower(strings.TrimSpace(auth.Provider)) - if provider == "gemini-cli" { - token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth) - return token, errToken - } - if provider == "antigravity" { + if strings.EqualFold(strings.TrimSpace(auth.Provider), "antigravity") { token, errToken := h.refreshAntigravityOAuthAccessToken(ctx, auth) return token, errToken } @@ -266,76 +242,6 @@ func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) return tokenValueForAuth(auth), nil } -func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) { - if ctx == nil { - ctx = context.Background() - } - if auth == nil { - return "", nil - } - - metadata, updater := geminiOAuthMetadata(auth) - if len(metadata) == 0 { - return "", fmt.Errorf("gemini oauth metadata missing") - } - - base := make(map[string]any) - if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil { - base = cloneMap(tokenRaw) - } - - var token oauth2.Token - if len(base) > 0 { - if raw, errMarshal := json.Marshal(base); errMarshal == nil { - _ = json.Unmarshal(raw, &token) - } - } - - if token.AccessToken == "" { - token.AccessToken = stringValue(metadata, "access_token") - } - if token.RefreshToken == "" { - token.RefreshToken = stringValue(metadata, "refresh_token") - } - if token.TokenType == "" { - token.TokenType = stringValue(metadata, "token_type") - } - if token.Expiry.IsZero() { - if expiry := stringValue(metadata, "expiry"); expiry != "" { - if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil { - token.Expiry = ts - } - } - } - - conf := &oauth2.Config{ - ClientID: geminiOAuthClientID, - ClientSecret: geminiOAuthClientSecret, - Scopes: geminiOAuthScopes, - Endpoint: google.Endpoint, - } - - ctxToken := ctx - httpClient := &http.Client{ - Timeout: defaultAPICallTimeout, - Transport: h.apiCallTransport(auth), - } - ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient) - - src := conf.TokenSource(ctxToken, &token) - currentToken, errToken := src.Token() - if errToken != nil { - return "", errToken - } - - merged := buildOAuthTokenMap(base, currentToken) - fields := buildOAuthTokenFields(currentToken, merged) - if updater != nil { - updater(fields) - } - return strings.TrimSpace(currentToken.AccessToken), nil -} - func (h *Handler) refreshAntigravityOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) { if ctx == nil { ctx = context.Background() @@ -491,24 +397,6 @@ func int64Value(raw any) int64 { return 0 } -func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) { - if auth == nil { - return nil, nil - } - if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil { - snapshot := shared.MetadataSnapshot() - return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) } - } - return auth.Metadata, func(fields map[string]any) { - if auth.Metadata == nil { - auth.Metadata = make(map[string]any) - } - for k, v := range fields { - auth.Metadata[k] = v - } - } -} - func stringValue(metadata map[string]any, key string) string { if len(metadata) == 0 || key == "" { return "" @@ -519,56 +407,6 @@ func stringValue(metadata map[string]any, key string) string { return "" } -func cloneMap(in map[string]any) map[string]any { - if len(in) == 0 { - return nil - } - out := make(map[string]any, len(in)) - for k, v := range in { - out[k] = v - } - return out -} - -func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any { - merged := cloneMap(base) - if merged == nil { - merged = make(map[string]any) - } - if tok == nil { - return merged - } - if raw, errMarshal := json.Marshal(tok); errMarshal == nil { - var tokenMap map[string]any - if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil { - for k, v := range tokenMap { - merged[k] = v - } - } - } - return merged -} - -func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any { - fields := make(map[string]any, 5) - if tok != nil && tok.AccessToken != "" { - fields["access_token"] = tok.AccessToken - } - if tok != nil && tok.TokenType != "" { - fields["token_type"] = tok.TokenType - } - if tok != nil && tok.RefreshToken != "" { - fields["refresh_token"] = tok.RefreshToken - } - if tok != nil && !tok.Expiry.IsZero() { - fields["expiry"] = tok.Expiry.Format(time.RFC3339) - } - if len(merged) > 0 { - fields["token"] = cloneMap(merged) - } - return fields -} - func tokenValueFromMetadata(metadata map[string]any) string { if len(metadata) == 0 { return "" diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 162f1fa8ea..a960b58616 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -25,12 +25,11 @@ import ( "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/antigravity" "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/claude" "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex" - geminiAuth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini" "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/kimi" xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost" "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" "github.com/router-for-me/CLIProxyAPI/v7/internal/util" "github.com/router-for-me/CLIProxyAPI/v7/internal/watcher/synthesizer" @@ -39,18 +38,13 @@ import ( "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" ) var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} const ( anthropicCallbackPort = 54545 - geminiCallbackPort = 8085 codexCallbackPort = 1455 - geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com" - geminiCLIVersion = "v1internal" ) type callbackForwarder struct { @@ -70,6 +64,7 @@ var ( callbackForwarders = make(map[int]*callbackForwarder) errAuthFileMustBeJSON = errors.New("auth file must be .json") errAuthFileNotFound = errors.New("auth file not found") + errPluginVirtualAuth = errors.New("plugin virtual auth cannot be modified directly; edit or delete the source auth file") newCodexOAuthService = func(cfg *config.Config) codexOAuthService { return codex.NewCodexAuth(cfg) } ) @@ -621,9 +616,6 @@ func authProjectID(auth *coreauth.Auth) string { if projectID := strings.TrimSpace(auth.Attributes["project_id"]); projectID != "" { return projectID } - if projectID := strings.TrimSpace(auth.Attributes["gemini_virtual_project"]); projectID != "" { - return projectID - } } return "" } @@ -1041,6 +1033,9 @@ func (h *Handler) deleteAuthFileByName(ctx context.Context, name string) (string targetPath := filepath.Join(h.cfg.AuthDir, filepath.Base(name)) targetID := "" if targetAuth := h.findAuthForDelete(name); targetAuth != nil { + if !isPluginVirtualSourceDelete(name, targetAuth) { + return filepath.Base(name), http.StatusConflict, errPluginVirtualAuth + } targetID = strings.TrimSpace(targetAuth.ID) if path := strings.TrimSpace(authAttribute(targetAuth, "path")); path != "" { targetPath = path @@ -1060,14 +1055,24 @@ func (h *Handler) deleteAuthFileByName(ctx context.Context, name string) (string if errDeleteRecord := h.deleteTokenRecord(ctx, targetPath); errDeleteRecord != nil { return filepath.Base(name), http.StatusInternalServerError, errDeleteRecord } - if targetID != "" { - h.removeAuth(ctx, targetID) - } else { - h.removeAuth(ctx, targetPath) - } + h.removeAuthsForPath(ctx, targetPath, targetID) return filepath.Base(name), http.StatusOK, nil } +func isPluginVirtualSourceDelete(name string, auth *coreauth.Auth) bool { + if !coreauth.IsPluginVirtualAuth(auth) { + return true + } + sourcePath := strings.TrimSpace(authAttribute(auth, coreauth.AttributeVirtualSource)) + if sourcePath == "" { + sourcePath = strings.TrimSpace(authAttribute(auth, "path")) + } + if sourcePath == "" { + return false + } + return strings.EqualFold(filepath.Base(strings.TrimSpace(name)), filepath.Base(sourcePath)) +} + func (h *Handler) findAuthForDelete(name string) *coreauth.Auth { if h == nil || h.authManager == nil { return nil @@ -1275,6 +1280,10 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"}) return } + if coreauth.IsPluginVirtualAuth(targetAuth) { + c.JSON(http.StatusConflict, gin.H{"error": errPluginVirtualAuth.Error()}) + return + } if coreauth.IsConfigAPIKeyAuth(targetAuth) { h.mu.Lock() @@ -1378,6 +1387,10 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"}) return } + if coreauth.IsPluginVirtualAuth(targetAuth) { + c.JSON(http.StatusConflict, gin.H{"error": errPluginVirtualAuth.Error()}) + return + } changed := false touchedRoots := make(map[string]struct{}, len(req)) @@ -1707,6 +1720,53 @@ func (h *Handler) removeAuth(ctx context.Context, id string) { h.authManager.Remove(ctx, authID) } +func (h *Handler) removeAuthsForPath(ctx context.Context, path string, fallbackID string) { + if h == nil || h.authManager == nil { + return + } + removed := false + for _, auth := range h.authManager.List() { + if auth == nil { + continue + } + if sameAuthFilePath(authAttribute(auth, "path"), path) || sameAuthFilePath(authAttribute(auth, coreauth.AttributeVirtualSource), path) { + h.removeAuth(ctx, auth.ID) + removed = true + } + } + if removed { + return + } + if strings.TrimSpace(fallbackID) != "" { + h.removeAuth(ctx, fallbackID) + return + } + h.removeAuth(ctx, path) +} + +func sameAuthFilePath(left, right string) bool { + left = cleanAuthFilePath(left) + right = cleanAuthFilePath(right) + if left == "" || right == "" { + return false + } + if runtime.GOOS == "windows" { + return strings.EqualFold(left, right) + } + return left == right +} + +func cleanAuthFilePath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + if abs, errAbs := filepath.Abs(path); errAbs == nil && strings.TrimSpace(abs) != "" { + path = abs + } + return filepath.Clean(path) +} + func (h *Handler) deleteTokenRecord(ctx context.Context, path string) error { if strings.TrimSpace(path) == "" { return fmt.Errorf("auth path is empty") @@ -1750,7 +1810,7 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s } savedPath, errSave := store.Save(ctx, record) if errSave != nil { - return "", errSave + return savedPath, errSave } if h.postAuthPersistHook != nil { if errHook := h.postAuthPersistHook(ctx, record); errHook != nil { @@ -1904,264 +1964,6 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } -func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { - ctx := context.Background() - ctx = PopulateAuthContext(ctx, c) - proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) - ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient) - - // Optional project ID from query - projectID := c.Query("project_id") - - fmt.Println("Initializing Google authentication...") - - // OAuth2 configuration using exported constants from internal/auth/gemini - conf := &oauth2.Config{ - ClientID: geminiAuth.ClientID, - ClientSecret: geminiAuth.ClientSecret, - RedirectURL: fmt.Sprintf("http://localhost:%d/oauth2callback", geminiAuth.DefaultCallbackPort), - Scopes: geminiAuth.Scopes, - Endpoint: google.Endpoint, - } - - // Build authorization URL and return it immediately - state := fmt.Sprintf("gem-%d", time.Now().UnixNano()) - authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) - - RegisterOAuthSession(state, "gemini") - - isWebUI := isWebUIRequest(c) - var forwarder *callbackForwarder - if isWebUI { - targetURL, errTarget := h.managementCallbackURL("/google/callback") - if errTarget != nil { - log.WithError(errTarget).Error("failed to compute gemini callback target") - c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) - return - } - var errStart error - if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { - log.WithError(errStart).Error("failed to start gemini callback forwarder") - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) - return - } - } - - go func() { - if isWebUI { - defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder) - } - - // Wait for callback file written by server route - waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state)) - fmt.Println("Waiting for authentication callback...") - deadline := time.Now().Add(5 * time.Minute) - var authCode string - for { - if !IsOAuthSessionPending(state, "gemini") { - return - } - if time.Now().After(deadline) { - log.Error("oauth flow timed out") - SetOAuthSessionError(state, "OAuth flow timed out") - return - } - if data, errR := os.ReadFile(waitFile); errR == nil { - var m map[string]string - _ = json.Unmarshal(data, &m) - _ = os.Remove(waitFile) - if errStr := m["error"]; errStr != "" { - log.Errorf("Authentication failed: %s", errStr) - SetOAuthSessionError(state, "Authentication failed") - return - } - authCode = m["code"] - if authCode == "" { - log.Errorf("Authentication failed: code not found") - SetOAuthSessionError(state, "Authentication failed: code not found") - return - } - break - } - time.Sleep(500 * time.Millisecond) - } - - // Exchange authorization code for token - token, err := conf.Exchange(ctx, authCode) - if err != nil { - log.Errorf("Failed to exchange token: %v", err) - SetOAuthSessionError(state, "Failed to exchange token") - return - } - - requestedProjectID := strings.TrimSpace(projectID) - - // Create token storage (mirrors internal/auth/gemini createTokenStorage) - authHTTPClient := conf.Client(ctx, token) - req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) - if errNewRequest != nil { - log.Errorf("Could not get user info: %v", errNewRequest) - SetOAuthSessionError(state, "Could not get user info") - return - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) - - resp, errDo := authHTTPClient.Do(req) - if errDo != nil { - log.Errorf("Failed to execute request: %v", errDo) - SetOAuthSessionError(state, "Failed to execute request") - return - } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Printf("warn: failed to close response body: %v", errClose) - } - }() - - bodyBytes, _ := io.ReadAll(resp.Body) - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)) - return - } - - email := gjson.GetBytes(bodyBytes, "email").String() - if email != "" { - fmt.Printf("Authenticated user email: %s\n", email) - } else { - fmt.Println("Failed to get user email from token") - } - - // Marshal/unmarshal oauth2.Token to generic map and enrich fields - var ifToken map[string]any - jsonData, _ := json.Marshal(token) - if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil { - log.Errorf("Failed to unmarshal token: %v", errUnmarshal) - SetOAuthSessionError(state, "Failed to unmarshal token") - return - } - - ifToken["token_uri"] = "https://oauth2.googleapis.com/token" - ifToken["client_id"] = geminiAuth.ClientID - ifToken["client_secret"] = geminiAuth.ClientSecret - ifToken["scopes"] = geminiAuth.Scopes - ifToken["universe_domain"] = "googleapis.com" - - ts := geminiAuth.GeminiTokenStorage{ - Token: ifToken, - ProjectID: requestedProjectID, - Email: email, - Auto: requestedProjectID == "", - } - - // Initialize authenticated HTTP client via GeminiAuth to honor proxy settings - gemAuth := geminiAuth.NewGeminiAuth() - gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{ - NoBrowser: true, - }) - if errGetClient != nil { - log.Errorf("failed to get authenticated client: %v", errGetClient) - SetOAuthSessionError(state, "Failed to get authenticated client") - return - } - fmt.Println("Authentication successful.") - - if strings.EqualFold(requestedProjectID, "ALL") { - ts.Auto = false - projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts) - if errAll != nil { - log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll) - SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errAll)) - return - } - if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil { - log.Errorf("Failed to verify Cloud AI API status: %v", errVerify) - SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errVerify)) - return - } - ts.ProjectID = strings.Join(projects, ",") - ts.Checked = true - } else if strings.EqualFold(requestedProjectID, "GOOGLE_ONE") { - ts.Auto = false - if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil { - log.Errorf("Google One auto-discovery failed: %v", errSetup) - SetOAuthSessionError(state, fmt.Sprintf("Google One auto-discovery failed: %v", errSetup)) - return - } - if strings.TrimSpace(ts.ProjectID) == "" { - log.Error("Google One auto-discovery returned empty project ID") - SetOAuthSessionError(state, "Google One auto-discovery returned empty project ID") - return - } - isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) - if errCheck != nil { - log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) - SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck)) - return - } - ts.Checked = isChecked - if !isChecked { - log.Error("Cloud AI API is not enabled for the auto-discovered project") - SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID)) - return - } - } else { - if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil { - log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure) - SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errEnsure)) - return - } - - if strings.TrimSpace(ts.ProjectID) == "" { - log.Error("Onboarding did not return a project ID") - SetOAuthSessionError(state, "Failed to resolve project ID") - return - } - - isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) - if errCheck != nil { - log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) - SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck)) - return - } - ts.Checked = isChecked - if !isChecked { - log.Error("Cloud AI API is not enabled for the selected project") - SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID)) - return - } - } - - recordMetadata := map[string]any{ - "email": ts.Email, - "project_id": ts.ProjectID, - "auto": ts.Auto, - "checked": ts.Checked, - } - - fileName := geminiAuth.CredentialFileName(ts.Email, ts.ProjectID, true) - record := &coreauth.Auth{ - ID: fileName, - Provider: "gemini", - FileName: fileName, - Storage: &ts, - Metadata: recordMetadata, - } - savedPath, errSave := h.saveTokenRecord(ctx, record) - if errSave != nil { - log.Errorf("Failed to save token to file: %v", errSave) - SetOAuthSessionError(state, "Failed to save token to file") - return - } - - CompleteOAuthSession(state) - fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) - }() - - c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) -} - func (h *Handler) RequestCodexToken(c *gin.Context) { ctx := context.Background() ctx = PopulateAuthContext(ctx, c) @@ -2725,383 +2527,6 @@ func (h *Handler) RequestKimiToken(c *gin.Context) { c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } -type projectSelectionRequiredError struct{} - -func (e *projectSelectionRequiredError) Error() string { - return "gemini cli: project selection required" -} - -func ensureGeminiProjectAndOnboard(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage, requestedProject string) error { - if storage == nil { - return fmt.Errorf("gemini storage is nil") - } - - trimmedRequest := strings.TrimSpace(requestedProject) - if trimmedRequest == "" { - projects, errProjects := fetchGCPProjects(ctx, httpClient) - if errProjects != nil { - return fmt.Errorf("fetch project list: %w", errProjects) - } - if len(projects) == 0 { - return fmt.Errorf("no Google Cloud projects available for this account") - } - trimmedRequest = strings.TrimSpace(projects[0].ProjectID) - if trimmedRequest == "" { - return fmt.Errorf("resolved project id is empty") - } - storage.Auto = true - } else { - storage.Auto = false - } - - if err := performGeminiCLISetup(ctx, httpClient, storage, trimmedRequest); err != nil { - return err - } - - if strings.TrimSpace(storage.ProjectID) == "" { - storage.ProjectID = trimmedRequest - } - - return nil -} - -func onboardAllGeminiProjects(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage) ([]string, error) { - projects, errProjects := fetchGCPProjects(ctx, httpClient) - if errProjects != nil { - return nil, fmt.Errorf("fetch project list: %w", errProjects) - } - if len(projects) == 0 { - return nil, fmt.Errorf("no Google Cloud projects available for this account") - } - activated := make([]string, 0, len(projects)) - seen := make(map[string]struct{}, len(projects)) - for _, project := range projects { - candidate := strings.TrimSpace(project.ProjectID) - if candidate == "" { - continue - } - if _, dup := seen[candidate]; dup { - continue - } - if err := performGeminiCLISetup(ctx, httpClient, storage, candidate); err != nil { - return nil, fmt.Errorf("onboard project %s: %w", candidate, err) - } - finalID := strings.TrimSpace(storage.ProjectID) - if finalID == "" { - finalID = candidate - } - activated = append(activated, finalID) - seen[candidate] = struct{}{} - } - if len(activated) == 0 { - return nil, fmt.Errorf("no Google Cloud projects available for this account") - } - return activated, nil -} - -func ensureGeminiProjectsEnabled(ctx context.Context, httpClient *http.Client, projectIDs []string) error { - for _, pid := range projectIDs { - trimmed := strings.TrimSpace(pid) - if trimmed == "" { - continue - } - isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, trimmed) - if errCheck != nil { - return fmt.Errorf("project %s: %w", trimmed, errCheck) - } - if !isChecked { - return fmt.Errorf("project %s: Cloud AI API not enabled", trimmed) - } - } - return nil -} - -func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage, requestedProject string) error { - metadata := map[string]string{ - "ideType": "IDE_UNSPECIFIED", - "platform": "PLATFORM_UNSPECIFIED", - "pluginType": "GEMINI", - } - - trimmedRequest := strings.TrimSpace(requestedProject) - explicitProject := trimmedRequest != "" - - loadReqBody := map[string]any{ - "metadata": metadata, - } - if explicitProject { - loadReqBody["cloudaicompanionProject"] = trimmedRequest - } - - var loadResp map[string]any - if errLoad := callGeminiCLI(ctx, httpClient, "loadCodeAssist", loadReqBody, &loadResp); errLoad != nil { - return fmt.Errorf("load code assist: %w", errLoad) - } - - tierID := "legacy-tier" - if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers { - for _, rawTier := range tiers { - tier, okTier := rawTier.(map[string]any) - if !okTier { - continue - } - if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault { - if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" { - tierID = strings.TrimSpace(id) - break - } - } - } - } - - projectID := trimmedRequest - if projectID == "" { - if id, okProject := loadResp["cloudaicompanionProject"].(string); okProject { - projectID = strings.TrimSpace(id) - } - if projectID == "" { - if projectMap, okProject := loadResp["cloudaicompanionProject"].(map[string]any); okProject { - if id, okID := projectMap["id"].(string); okID { - projectID = strings.TrimSpace(id) - } - } - } - } - if projectID == "" { - // Auto-discovery: try onboardUser without specifying a project - // to let Google auto-provision one (matches Gemini CLI headless behavior - // and Antigravity's FetchProjectID pattern). - autoOnboardReq := map[string]any{ - "tierId": tierID, - "metadata": metadata, - } - - autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second) - defer autoCancel() - for attempt := 1; ; attempt++ { - var onboardResp map[string]any - if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil { - return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard) - } - - if done, okDone := onboardResp["done"].(bool); okDone && done { - if resp, okResp := onboardResp["response"].(map[string]any); okResp { - switch v := resp["cloudaicompanionProject"].(type) { - case string: - projectID = strings.TrimSpace(v) - case map[string]any: - if id, okID := v["id"].(string); okID { - projectID = strings.TrimSpace(id) - } - } - } - break - } - - log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt) - select { - case <-autoCtx.Done(): - return &projectSelectionRequiredError{} - case <-time.After(2 * time.Second): - } - } - - if projectID == "" { - return &projectSelectionRequiredError{} - } - log.Infof("Auto-discovered project ID via onboarding: %s", projectID) - } - - onboardReqBody := map[string]any{ - "tierId": tierID, - "metadata": metadata, - "cloudaicompanionProject": projectID, - } - - storage.ProjectID = projectID - - for { - var onboardResp map[string]any - if errOnboard := callGeminiCLI(ctx, httpClient, "onboardUser", onboardReqBody, &onboardResp); errOnboard != nil { - return fmt.Errorf("onboard user: %w", errOnboard) - } - - if done, okDone := onboardResp["done"].(bool); okDone && done { - responseProjectID := "" - if resp, okResp := onboardResp["response"].(map[string]any); okResp { - switch projectValue := resp["cloudaicompanionProject"].(type) { - case map[string]any: - if id, okID := projectValue["id"].(string); okID { - responseProjectID = strings.TrimSpace(id) - } - case string: - responseProjectID = strings.TrimSpace(projectValue) - } - } - - finalProjectID := projectID - if responseProjectID != "" { - if explicitProject && !strings.EqualFold(responseProjectID, projectID) { - log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID) - log.Infof("Using backend project ID: %s", responseProjectID) - } - finalProjectID = responseProjectID - } - - storage.ProjectID = strings.TrimSpace(finalProjectID) - if storage.ProjectID == "" { - storage.ProjectID = strings.TrimSpace(projectID) - } - if storage.ProjectID == "" { - return fmt.Errorf("onboard user completed without project id") - } - log.Infof("Onboarding complete. Using Project ID: %s", storage.ProjectID) - return nil - } - - log.Println("Onboarding in progress, waiting 5 seconds...") - time.Sleep(5 * time.Second) - } -} - -func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string, body any, result any) error { - endPointURL := fmt.Sprintf("%s/%s:%s", geminiCLIEndpoint, geminiCLIVersion, endpoint) - if strings.HasPrefix(endpoint, "operations/") { - endPointURL = fmt.Sprintf("%s/%s", geminiCLIEndpoint, endpoint) - } - - var reader io.Reader - if body != nil { - rawBody, errMarshal := json.Marshal(body) - if errMarshal != nil { - return fmt.Errorf("marshal request body: %w", errMarshal) - } - reader = bytes.NewReader(rawBody) - } - - req, errRequest := http.NewRequestWithContext(ctx, http.MethodPost, endPointURL, reader) - if errRequest != nil { - return fmt.Errorf("create request: %w", errRequest) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", misc.GeminiCLIUserAgent("")) - - resp, errDo := httpClient.Do(req) - if errDo != nil { - return fmt.Errorf("execute request: %w", errDo) - } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("response body close error: %v", errClose) - } - }() - - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - bodyBytes, _ := io.ReadAll(resp.Body) - return fmt.Errorf("api request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes))) - } - - if result == nil { - _, _ = io.Copy(io.Discard, resp.Body) - return nil - } - - if errDecode := json.NewDecoder(resp.Body).Decode(result); errDecode != nil { - return fmt.Errorf("decode response body: %w", errDecode) - } - - return nil -} - -func fetchGCPProjects(ctx context.Context, httpClient *http.Client) ([]interfaces.GCPProjectProjects, error) { - req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil) - if errRequest != nil { - return nil, fmt.Errorf("could not create project list request: %w", errRequest) - } - - resp, errDo := httpClient.Do(req) - if errDo != nil { - return nil, fmt.Errorf("failed to execute project list request: %w", errDo) - } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("response body close error: %v", errClose) - } - }() - - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes))) - } - - var projects interfaces.GCPProject - if errDecode := json.NewDecoder(resp.Body).Decode(&projects); errDecode != nil { - return nil, fmt.Errorf("failed to unmarshal project list: %w", errDecode) - } - - return projects.Projects, nil -} - -func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projectID string) (bool, error) { - serviceUsageURL := "https://serviceusage.googleapis.com" - requiredServices := []string{ - "cloudaicompanion.googleapis.com", - } - for _, service := range requiredServices { - checkURL := fmt.Sprintf("%s/v1/projects/%s/services/%s", serviceUsageURL, projectID, service) - req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, checkURL, nil) - if errRequest != nil { - return false, fmt.Errorf("failed to create request: %w", errRequest) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", misc.GeminiCLIUserAgent("")) - resp, errDo := httpClient.Do(req) - if errDo != nil { - return false, fmt.Errorf("failed to execute request: %w", errDo) - } - - if resp.StatusCode == http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - if gjson.GetBytes(bodyBytes, "state").String() == "ENABLED" { - _ = resp.Body.Close() - continue - } - } - _ = resp.Body.Close() - - enableURL := fmt.Sprintf("%s/v1/projects/%s/services/%s:enable", serviceUsageURL, projectID, service) - req, errRequest = http.NewRequestWithContext(ctx, http.MethodPost, enableURL, strings.NewReader("{}")) - if errRequest != nil { - return false, fmt.Errorf("failed to create request: %w", errRequest) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", misc.GeminiCLIUserAgent("")) - resp, errDo = httpClient.Do(req) - if errDo != nil { - return false, fmt.Errorf("failed to execute request: %w", errDo) - } - - bodyBytes, _ := io.ReadAll(resp.Body) - errMessage := string(bodyBytes) - errMessageResult := gjson.GetBytes(bodyBytes, "error.message") - if errMessageResult.Exists() { - errMessage = errMessageResult.String() - } - if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated { - _ = resp.Body.Close() - continue - } else if resp.StatusCode == http.StatusBadRequest { - _ = resp.Body.Close() - if strings.Contains(strings.ToLower(errMessage), "already enabled") { - continue - } - } - _ = resp.Body.Close() - return false, fmt.Errorf("project activation required: %s", errMessage) - } - return true, nil -} - func (h *Handler) GetAuthStatus(c *gin.Context) { state := strings.TrimSpace(c.Query("state")) if state == "" { @@ -3151,13 +2576,13 @@ func (h *Handler) GetAuthStatus(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "error", "error": message}) return case pluginapi.AuthLoginStatusSuccess: - record := host.AuthDataToCoreAuth(resp.Auth, "", "") - if record == nil { + records := pluginLoginPollAuths(host, resp) + if len(records) == 0 { SetOAuthSessionError(state, "Authentication failed") c.JSON(http.StatusOK, gin.H{"status": "error", "error": "Authentication failed"}) return } - if _, errSave := h.saveTokenRecord(ctx, record); errSave != nil { + if errSave := h.savePluginLoginRecords(ctx, records); errSave != nil { log.WithError(errSave).WithField("provider", provider).Error("failed to save plugin auth tokens") SetOAuthSessionError(state, "Failed to save authentication tokens") c.JSON(http.StatusOK, gin.H{"status": "error", "error": "Failed to save authentication tokens"}) @@ -3175,6 +2600,53 @@ func (h *Handler) GetAuthStatus(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "wait"}) } +func pluginLoginPollAuths(host *pluginhost.Host, resp pluginapi.AuthLoginPollResponse) []*coreauth.Auth { + if host == nil { + return nil + } + authDatas := resp.Auths + if len(authDatas) == 0 { + authDatas = []pluginapi.AuthData{resp.Auth} + } + records := make([]*coreauth.Auth, 0, len(authDatas)) + for _, authData := range authDatas { + record := host.AuthDataToCoreAuth(authData, "", "") + if record == nil { + return nil + } + records = append(records, record) + } + return records +} + +func (h *Handler) savePluginLoginRecords(ctx context.Context, records []*coreauth.Auth) error { + savedPaths := make([]string, 0, len(records)) + for _, record := range records { + savedPath, errSave := h.saveTokenRecord(ctx, record) + if strings.TrimSpace(savedPath) != "" { + savedPaths = append(savedPaths, savedPath) + } + if errSave != nil { + h.rollbackSavedTokenRecords(ctx, savedPaths) + return errSave + } + } + return nil +} + +func (h *Handler) rollbackSavedTokenRecords(ctx context.Context, savedPaths []string) { + for i := len(savedPaths) - 1; i >= 0; i-- { + path := strings.TrimSpace(savedPaths[i]) + if path == "" { + continue + } + if errDelete := h.deleteTokenRecord(ctx, path); errDelete != nil { + log.WithError(errDelete).WithField("path", path).Warn("failed to roll back plugin auth token") + } + h.removeAuthsForPath(ctx, path, path) + } +} + // PopulateAuthContext extracts request info and adds it to the context func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context { info := &coreauth.RequestInfo{ diff --git a/internal/api/handlers/management/auth_files_plugin_oauth_test.go b/internal/api/handlers/management/auth_files_plugin_oauth_test.go new file mode 100644 index 0000000000..29acc76266 --- /dev/null +++ b/internal/api/handlers/management/auth_files_plugin_oauth_test.go @@ -0,0 +1,213 @@ +package management + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost" + coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi" +) + +func TestPluginLoginPollAuthsExpandsMultipleAuths(t *testing.T) { + host := pluginhost.New() + resp := pluginapi.AuthLoginPollResponse{ + Status: pluginapi.AuthLoginStatusSuccess, + Auths: []pluginapi.AuthData{ + { + Provider: "gemini-cli", + ID: "geminicli.json", + FileName: "geminicli.json", + StorageJSON: []byte(`{"type":"gemini-cli"}`), + }, + { + Provider: "gemini-cli", + ID: "geminicli-project-a.json", + FileName: "geminicli-project-a.json", + StorageJSON: []byte(`{"type":"gemini-cli","project_id":"project-a"}`), + Metadata: map[string]any{"project_id": "project-a"}, + }, + }, + } + + records := pluginLoginPollAuths(host, resp) + if len(records) != 2 { + t.Fatalf("pluginLoginPollAuths() len = %d, want two records", len(records)) + } + if records[0].ID != "geminicli.json" || records[1].ID != "geminicli-project-a.json" { + t.Fatalf("records = %#v, want both plugin auths", records) + } + if gotProject := records[1].Metadata["project_id"]; gotProject != "project-a" { + t.Fatalf("project_id = %#v, want project-a", gotProject) + } +} + +func TestSavePluginLoginRecordsRollsBackSavedAuthsOnFailure(t *testing.T) { + store := &pluginLoginRollbackStore{failAt: 2} + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, nil) + h.tokenStore = store + + records := []*coreauth.Auth{ + { + ID: "geminicli.json", + FileName: "geminicli.json", + Provider: "gemini-cli", + Metadata: map[string]any{"type": "gemini-cli"}, + }, + { + ID: "geminicli-project-a.json", + FileName: "geminicli-project-a.json", + Provider: "gemini-cli", + Metadata: map[string]any{"type": "gemini-cli", "project_id": "project-a"}, + }, + } + + errSave := h.savePluginLoginRecords(context.Background(), records) + if errSave == nil { + t.Fatal("savePluginLoginRecords() error = nil, want rollback-triggering error") + } + if len(store.saved) != 2 { + t.Fatalf("saved len = %d, want two attempted saves", len(store.saved)) + } + if !store.deleted["geminicli.json"] || !store.deleted["geminicli-project-a.json"] { + t.Fatalf("deleted = %#v, want both saved auths rolled back", store.deleted) + } +} + +func TestPatchPluginVirtualAuthStatusReturnsConflict(t *testing.T) { + manager := coreauth.NewManager(nil, nil, nil) + auth := pluginVirtualAuthForTest(t.TempDir(), "source.json", "auth-1") + if _, errRegister := manager.Register(context.Background(), auth); errRegister != nil { + t.Fatalf("register virtual auth: %v", errRegister) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager) + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/status", strings.NewReader(`{"name":"auth-1","disabled":true}`)) + req.Header.Set("Content-Type", "application/json") + ctx.Request = req + + h.PatchAuthFileStatus(ctx) + + if rec.Code != http.StatusConflict { + t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusConflict, rec.Body.String()) + } +} + +func TestPatchPluginVirtualAuthFieldsReturnsConflict(t *testing.T) { + manager := coreauth.NewManager(nil, nil, nil) + auth := pluginVirtualAuthForTest(t.TempDir(), "source.json", "auth-1") + if _, errRegister := manager.Register(context.Background(), auth); errRegister != nil { + t.Fatalf("register virtual auth: %v", errRegister) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager) + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodPatch, "/v0/management/auth-files/fields", strings.NewReader(`{"name":"auth-1","note":"hello"}`)) + req.Header.Set("Content-Type", "application/json") + ctx.Request = req + + h.PatchAuthFileFields(ctx) + + if rec.Code != http.StatusConflict { + t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusConflict, rec.Body.String()) + } +} + +func TestDeletePluginVirtualSourceRemovesExpandedRuntimeAuths(t *testing.T) { + authDir := t.TempDir() + fileName := "source.json" + filePath := filepath.Join(authDir, fileName) + if errWrite := os.WriteFile(filePath, []byte(`{"type":"gemini-cli"}`), 0o600); errWrite != nil { + t.Fatalf("write source auth file: %v", errWrite) + } + + manager := coreauth.NewManager(nil, nil, nil) + for _, id := range []string{"auth-1", "auth-2"} { + auth := pluginVirtualAuthForTest(authDir, fileName, id) + if _, errRegister := manager.Register(context.Background(), auth); errRegister != nil { + t.Fatalf("register virtual auth %s: %v", id, errRegister) + } + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager) + h.tokenStore = &memoryAuthStore{} + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil) + ctx.Request = req + + h.DeleteAuthFile(ctx) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if _, errStat := os.Stat(filePath); !os.IsNotExist(errStat) { + t.Fatalf("expected source auth file to be removed, stat err: %v", errStat) + } + for _, id := range []string{"auth-1", "auth-2"} { + if _, ok := manager.GetByID(id); ok { + t.Fatalf("expected virtual auth %s to be removed", id) + } + } +} + +func pluginVirtualAuthForTest(authDir, fileName, id string) *coreauth.Auth { + filePath := filepath.Join(authDir, fileName) + auth := &coreauth.Auth{ + ID: id, + FileName: fileName, + Provider: "gemini-cli", + Attributes: map[string]string{ + "path": filePath, + }, + Metadata: map[string]any{ + "type": "gemini-cli", + }, + } + coreauth.MarkPluginVirtualAuth(auth, filePath, 0) + return auth +} + +type pluginLoginRollbackStore struct { + failAt int + saved []string + deleted map[string]bool +} + +func (s *pluginLoginRollbackStore) List(context.Context) ([]*coreauth.Auth, error) { + return nil, nil +} + +func (s *pluginLoginRollbackStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) { + path := strings.TrimSpace(auth.FileName) + if path == "" { + path = strings.TrimSpace(auth.ID) + } + s.saved = append(s.saved, path) + if len(s.saved) == s.failAt { + return path, errors.New("save failed after write") + } + return path, nil +} + +func (s *pluginLoginRollbackStore) Delete(_ context.Context, id string) error { + if s.deleted == nil { + s.deleted = make(map[string]bool) + } + s.deleted[id] = true + return nil +} + +func (s *pluginLoginRollbackStore) SetBaseDir(string) {} diff --git a/internal/api/handlers/management/auth_files_project_id_test.go b/internal/api/handlers/management/auth_files_project_id_test.go index 3bacc9a4c9..870b61cbed 100644 --- a/internal/api/handlers/management/auth_files_project_id_test.go +++ b/internal/api/handlers/management/auth_files_project_id_test.go @@ -18,9 +18,9 @@ func TestListAuthFiles_IncludesProjectIDFromManager(t *testing.T) { t.Setenv("MANAGEMENT_PASSWORD", "") authDir := t.TempDir() - fileName := "gemini-user@example.com-project-a.json" + fileName := "antigravity-user@example.com-project-a.json" filePath := filepath.Join(authDir, fileName) - if errWrite := os.WriteFile(filePath, []byte(`{"type":"gemini","email":"user@example.com","project_id":"project-a"}`), 0o600); errWrite != nil { + if errWrite := os.WriteFile(filePath, []byte(`{"type":"antigravity","email":"user@example.com","project_id":"project-a"}`), 0o600); errWrite != nil { t.Fatalf("failed to write auth file: %v", errWrite) } @@ -28,13 +28,13 @@ func TestListAuthFiles_IncludesProjectIDFromManager(t *testing.T) { record := &coreauth.Auth{ ID: fileName, FileName: fileName, - Provider: "gemini-cli", + Provider: "antigravity", Status: coreauth.StatusActive, Attributes: map[string]string{ "path": filePath, }, Metadata: map[string]any{ - "type": "gemini", + "type": "antigravity", "email": "user@example.com", "project_id": "project-a", }, @@ -56,8 +56,8 @@ func TestListAuthFilesFromDisk_IncludesProjectID(t *testing.T) { t.Setenv("MANAGEMENT_PASSWORD", "") authDir := t.TempDir() - filePath := filepath.Join(authDir, "gemini-user@example.com-project-a.json") - if errWrite := os.WriteFile(filePath, []byte(`{"type":"gemini","email":"user@example.com","project_id":"project-a"}`), 0o600); errWrite != nil { + filePath := filepath.Join(authDir, "antigravity-user@example.com-project-a.json") + if errWrite := os.WriteFile(filePath, []byte(`{"type":"antigravity","email":"user@example.com","project_id":"project-a"}`), 0o600); errWrite != nil { t.Fatalf("failed to write auth file: %v", errWrite) } diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index d9050b4c83..fb4c67d213 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -307,13 +307,14 @@ func (h *Handler) PutClaudeKeys(c *gin.Context) { } func (h *Handler) PatchClaudeKey(c *gin.Context) { type claudeKeyPatch struct { - APIKey *string `json:"api-key"` - Prefix *string `json:"prefix"` - BaseURL *string `json:"base-url"` - ProxyURL *string `json:"proxy-url"` - Models *[]config.ClaudeModel `json:"models"` - Headers *map[string]string `json:"headers"` - ExcludedModels *[]string `json:"excluded-models"` + APIKey *string `json:"api-key"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + ProxyURL *string `json:"proxy-url"` + Models *[]config.ClaudeModel `json:"models"` + Headers *map[string]string `json:"headers"` + ExcludedModels *[]string `json:"excluded-models"` + RebuildMidSystemMessage *bool `json:"rebuild-mid-system-message"` } var body struct { Index *int `json:"index"` @@ -367,6 +368,9 @@ func (h *Handler) PatchClaudeKey(c *gin.Context) { if body.Value.ExcludedModels != nil { entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) } + if body.Value.RebuildMidSystemMessage != nil { + entry.RebuildMidSystemMessage = *body.Value.RebuildMidSystemMessage + } normalizeClaudeKey(&entry) h.cfg.ClaudeKey[targetIndex] = entry h.cfg.SanitizeClaudeKeys() diff --git a/internal/api/handlers/management/oauth_callback.go b/internal/api/handlers/management/oauth_callback.go index 251f999e07..832462db93 100644 --- a/internal/api/handlers/management/oauth_callback.go +++ b/internal/api/handlers/management/oauth_callback.go @@ -25,14 +25,26 @@ func (h *Handler) PostOAuthCallback(c *gin.Context) { } var req oauthCallbackRequest - if err := c.ShouldBindJSON(&req); err != nil { + if errBindJSON := c.ShouldBindJSON(&req); errBindJSON != nil { c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"}) return } + h.handleOAuthCallback(c, req) +} - canonicalProvider, err := NormalizeOAuthProvider(req.Provider) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"}) +func (h *Handler) GetOAuthCallback(c *gin.Context) { + req := oauthCallbackRequest{ + Provider: strings.TrimSpace(c.Query("provider")), + Code: strings.TrimSpace(c.Query("code")), + State: strings.TrimSpace(c.Query("state")), + Error: firstNonEmpty(c.Query("error"), c.Query("error_description")), + } + h.handleOAuthCallback(c, req) +} + +func (h *Handler) handleOAuthCallback(c *gin.Context, req oauthCallbackRequest) { + if h == nil || h.cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"}) return } @@ -74,11 +86,26 @@ func (h *Handler) PostOAuthCallback(c *gin.Context) { return } - sessionProvider, sessionStatus, ok := GetOAuthSession(state) + sessionProvider, sessionStatus, isPlugin, _, ok := GetOAuthSessionDetails(state) if !ok { c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"}) return } + provider := strings.TrimSpace(req.Provider) + if provider == "" { + provider = sessionProvider + } + var canonicalProvider string + var errNormalize error + if isPlugin { + canonicalProvider, errNormalize = NormalizePluginOAuthCallbackProvider(provider) + } else { + canonicalProvider, errNormalize = NormalizeOAuthCallbackProvider(provider) + } + if errNormalize != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"}) + return + } if sessionStatus != "" { c.JSON(http.StatusConflict, gin.H{"status": "error", "error": sessionStatus}) return @@ -105,3 +132,13 @@ func (h *Handler) PostOAuthCallback(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) } + +func firstNonEmpty(values ...string) string { + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed != "" { + return trimmed + } + } + return "" +} diff --git a/internal/api/handlers/management/oauth_callback_test.go b/internal/api/handlers/management/oauth_callback_test.go index 065f89f0c7..832423bb20 100644 --- a/internal/api/handlers/management/oauth_callback_test.go +++ b/internal/api/handlers/management/oauth_callback_test.go @@ -50,6 +50,72 @@ func TestPostOAuthCallbackCreatesMissingAuthDir(t *testing.T) { } } +func TestGetOAuthCallbackWritesPluginProviderCallback(t *testing.T) { + authDir := filepath.Join(t.TempDir(), "missing-auth") + state := "test-geminicli-state" + if errRegister := RegisterPluginOAuthSession(state, "gemini-cli", nil); errRegister != nil { + t.Fatalf("register plugin oauth session: %v", errRegister) + } + defer CompleteOAuthSession(state) + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil) + router := gin.New() + router.GET("/v0/management/oauth-callback", h.GetOAuthCallback) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/oauth-callback?state="+state+"&code=test-code", nil) + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, w.Code, w.Body.String()) + } + + callbackPath := filepath.Join(authDir, ".oauth-gemini-cli-"+state+".oauth") + data, errRead := os.ReadFile(callbackPath) + if errRead != nil { + t.Fatalf("expected callback file to be written: %v", errRead) + } + + var payload oauthCallbackFilePayload + if errUnmarshal := json.Unmarshal(data, &payload); errUnmarshal != nil { + t.Fatalf("failed to decode callback payload: %v", errUnmarshal) + } + if payload.State != state || payload.Code != "test-code" || payload.Error != "" { + t.Fatalf("unexpected callback payload: %+v", payload) + } +} + +func TestGetOAuthCallbackDoesNotAliasPluginProvider(t *testing.T) { + authDir := filepath.Join(t.TempDir(), "missing-auth") + state := "test-openai-plugin-state" + if errRegister := RegisterPluginOAuthSession(state, "openai", nil); errRegister != nil { + t.Fatalf("register plugin oauth session: %v", errRegister) + } + defer CompleteOAuthSession(state) + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil) + router := gin.New() + router.GET("/v0/management/oauth-callback", h.GetOAuthCallback) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/oauth-callback?state="+state+"&code=test-code", nil) + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, w.Code, w.Body.String()) + } + + callbackPath := filepath.Join(authDir, ".oauth-openai-"+state+".oauth") + if _, errRead := os.ReadFile(callbackPath); errRead != nil { + t.Fatalf("expected plugin callback provider to stay openai: %v", errRead) + } + if _, errRead := os.ReadFile(filepath.Join(authDir, ".oauth-codex-"+state+".oauth")); errRead == nil { + t.Fatal("unexpected codex callback file for openai plugin provider") + } +} + func TestWriteOAuthCallbackFileForPendingSessionCreatesMissingAuthDirForCallbackProviders(t *testing.T) { providers := []string{"anthropic", "codex", "gemini", "antigravity", "xai"} for _, provider := range providers { diff --git a/internal/api/handlers/management/oauth_sessions.go b/internal/api/handlers/management/oauth_sessions.go index 6c51ff4531..078c51c67f 100644 --- a/internal/api/handlers/management/oauth_sessions.go +++ b/internal/api/handlers/management/oauth_sessions.go @@ -200,9 +200,6 @@ func (s *oauthSessionStore) IsPending(state, provider string) bool { if session.Status != "" { return false } - if session.Source == oauthSessionSourcePlugin { - return false - } if provider == "" { return true } @@ -308,8 +305,6 @@ func NormalizeOAuthProvider(provider string) (string, error) { return "anthropic", nil case "codex", "openai": return "codex", nil - case "gemini", "google": - return "gemini", nil case "antigravity", "anti-gravity": return "antigravity", nil case "xai", "x-ai", "x.ai", "grok": @@ -319,6 +314,38 @@ func NormalizeOAuthProvider(provider string) (string, error) { } } +func NormalizeOAuthCallbackProvider(provider string) (string, error) { + if normalized, errNormalize := NormalizeOAuthProvider(provider); errNormalize == nil { + return normalized, nil + } + return NormalizePluginOAuthCallbackProvider(provider) +} + +func NormalizePluginOAuthCallbackProvider(provider string) (string, error) { + trimmed := strings.ToLower(strings.TrimSpace(provider)) + if trimmed == "" { + return "", errUnsupportedOAuthFlow + } + for _, r := range trimmed { + switch { + case r >= 'a' && r <= 'z': + case r >= '0' && r <= '9': + case r == '-': + default: + return "", errUnsupportedOAuthFlow + } + } + return trimmed, nil +} + +func normalizeOAuthCallbackProviderForPendingSession(provider, state string) (string, error) { + session, ok := oauthSessions.Get(state) + if ok && session.Source == oauthSessionSourcePlugin { + return NormalizePluginOAuthCallbackProvider(provider) + } + return NormalizeOAuthCallbackProvider(provider) +} + type oauthCallbackFilePayload struct { Code string `json:"code"` State string `json:"state"` @@ -326,12 +353,20 @@ type oauthCallbackFilePayload struct { } func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) { + canonicalProvider, err := NormalizeOAuthCallbackProvider(provider) + if err != nil { + return "", err + } + return writeOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage) +} + +func writeOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage string) (string, error) { if strings.TrimSpace(authDir) == "" { return "", fmt.Errorf("auth dir is empty") } - canonicalProvider, err := NormalizeOAuthProvider(provider) - if err != nil { - return "", err + canonicalProvider = strings.TrimSpace(canonicalProvider) + if canonicalProvider == "" { + return "", errUnsupportedOAuthFlow } if err := ValidateOAuthState(state); err != nil { return "", err @@ -358,12 +393,12 @@ func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) } func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) { - canonicalProvider, err := NormalizeOAuthProvider(provider) + canonicalProvider, err := normalizeOAuthCallbackProviderForPendingSession(provider, state) if err != nil { return "", err } if !IsOAuthSessionPending(state, canonicalProvider) { return "", errOAuthSessionNotPending } - return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage) + return writeOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage) } diff --git a/internal/api/handlers/management/quota.go b/internal/api/handlers/management/quota.go index c7efd217bd..a87a05ef52 100644 --- a/internal/api/handlers/management/quota.go +++ b/internal/api/handlers/management/quota.go @@ -1,6 +1,12 @@ package management -import "github.com/gin-gonic/gin" +import ( + "fmt" + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) // Quota exceeded toggles func (h *Handler) GetSwitchProject(c *gin.Context) { @@ -16,3 +22,48 @@ func (h *Handler) GetSwitchPreviewModel(c *gin.Context) { func (h *Handler) PutSwitchPreviewModel(c *gin.Context) { h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchPreviewModel = v }) } + +// ResetQuota clears quota/cooldown routing state for one auth index. +func (h *Handler) ResetQuota(c *gin.Context) { + if h.authManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) + return + } + + var req struct { + AuthIndex string `json:"auth_index"` + } + if errBindJSON := c.ShouldBindJSON(&req); errBindJSON != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) + return + } + + authIndex := strings.TrimSpace(req.AuthIndex) + if authIndex == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "auth_index is required"}) + return + } + + auth := h.authByIndex(authIndex) + if auth == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "auth not found"}) + return + } + + updated, models, errReset := h.authManager.ResetQuota(c.Request.Context(), auth.ID) + if errReset != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to reset quota: %v", errReset)}) + return + } + if updated == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "auth not found"}) + return + } + updated.EnsureIndex() + + c.JSON(http.StatusOK, gin.H{ + "status": "ok", + "auth_index": updated.Index, + "models": models, + }) +} diff --git a/internal/api/handlers/management/quota_test.go b/internal/api/handlers/management/quota_test.go new file mode 100644 index 0000000000..aee9b1d8c2 --- /dev/null +++ b/internal/api/handlers/management/quota_test.go @@ -0,0 +1,134 @@ +package management + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" +) + +func TestResetQuota_UsesAuthIndex(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + + manager := coreauth.NewManager(nil, nil, nil) + next := time.Now().Add(time.Hour) + auth := &coreauth.Auth{ + ID: "reset-auth-id", + FileName: "reset-auth-file.json", + Provider: "claude", + Status: coreauth.StatusError, + StatusMessage: "quota exhausted", + Unavailable: true, + NextRetryAfter: next, + Quota: coreauth.QuotaState{Exceeded: true, Reason: "quota", NextRecoverAt: next, BackoffLevel: 2}, + ModelStates: map[string]*coreauth.ModelState{ + "claude-reset-model": { + Status: coreauth.StatusError, + StatusMessage: "quota exhausted", + Unavailable: true, + NextRetryAfter: next, + Quota: coreauth.QuotaState{Exceeded: true, Reason: "quota", NextRecoverAt: next, BackoffLevel: 2}, + }, + }, + } + authIndex := auth.EnsureIndex() + if _, errRegister := manager.Register(context.Background(), auth); errRegister != nil { + t.Fatalf("failed to register auth record: %v", errRegister) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager) + + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodPost, "/v0/management/reset-quota", strings.NewReader(`{"auth_index":"`+authIndex+`"}`)) + req.Header.Set("Content-Type", "application/json") + ctx.Request = req + h.ResetQuota(ctx) + + if rec.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + var payload map[string]any + if errUnmarshal := json.Unmarshal(rec.Body.Bytes(), &payload); errUnmarshal != nil { + t.Fatalf("failed to decode response: %v", errUnmarshal) + } + if payload["auth_index"] != authIndex { + t.Fatalf("auth_index = %#v, want %q", payload["auth_index"], authIndex) + } + + updated, ok := manager.GetByID("reset-auth-id") + if !ok || updated == nil { + t.Fatalf("expected auth record to exist after reset") + } + if updated.Status != coreauth.StatusActive || updated.StatusMessage != "" || updated.Unavailable || !updated.NextRetryAfter.IsZero() { + t.Fatalf("updated auth state = status %q message %q unavailable %v next %v", updated.Status, updated.StatusMessage, updated.Unavailable, updated.NextRetryAfter) + } + if updated.Quota.Exceeded || updated.Quota.Reason != "" || !updated.Quota.NextRecoverAt.IsZero() || updated.Quota.BackoffLevel != 0 { + t.Fatalf("updated auth quota = %+v, want cleared", updated.Quota) + } + state := updated.ModelStates["claude-reset-model"] + if state == nil { + t.Fatalf("expected model state to remain") + } + if state.Status != coreauth.StatusActive || state.StatusMessage != "" || state.Unavailable || !state.NextRetryAfter.IsZero() { + t.Fatalf("updated model state = status %q message %q unavailable %v next %v", state.Status, state.StatusMessage, state.Unavailable, state.NextRetryAfter) + } + if state.Quota.Exceeded || state.Quota.Reason != "" || !state.Quota.NextRecoverAt.IsZero() || state.Quota.BackoffLevel != 0 { + t.Fatalf("updated model quota = %+v, want cleared", state.Quota) + } +} + +func TestResetQuota_DoesNotAcceptAuthIDOrFileName(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + + manager := coreauth.NewManager(nil, nil, nil) + auth := &coreauth.Auth{ + ID: "reset-auth-id-only", + FileName: "reset-auth-file-only.json", + Provider: "claude", + Status: coreauth.StatusError, + } + authIndex := auth.EnsureIndex() + if authIndex == auth.ID || authIndex == auth.FileName { + t.Fatalf("test auth_index unexpectedly matches id or file name: %q", authIndex) + } + if _, errRegister := manager.Register(context.Background(), auth); errRegister != nil { + t.Fatalf("failed to register auth record: %v", errRegister) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager) + + tests := []struct { + name string + body string + wantCode int + }{ + {name: "auth_id field ignored", body: `{"auth_id":"reset-auth-id-only"}`, wantCode: http.StatusBadRequest}, + {name: "id field ignored", body: `{"id":"reset-auth-id-only"}`, wantCode: http.StatusBadRequest}, + {name: "file name is not an index", body: `{"auth_index":"reset-auth-file-only.json"}`, wantCode: http.StatusNotFound}, + {name: "auth id is not an index", body: `{"auth_index":"reset-auth-id-only"}`, wantCode: http.StatusNotFound}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodPost, "/v0/management/reset-quota", strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + ctx.Request = req + h.ResetQuota(ctx) + + if rec.Code != tt.wantCode { + t.Fatalf("status = %d, want %d with body %s", rec.Code, tt.wantCode, rec.Body.String()) + } + }) + } +} diff --git a/internal/api/server.go b/internal/api/server.go index beacc18adf..01875fd69a 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -16,6 +16,7 @@ import ( "os" "path/filepath" "sort" + "strconv" "strings" "sync" "sync/atomic" @@ -32,6 +33,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v7/internal/managementasset" "github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost" "github.com/router-for-me/CLIProxyAPI/v7/internal/redisqueue" + "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" "github.com/router-for-me/CLIProxyAPI/v7/internal/util" sdkaccess "github.com/router-for-me/CLIProxyAPI/v7/sdk/access" "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers" @@ -422,7 +424,6 @@ func (s *Server) setupRoutes() { s.engine.GET("/management.html", s.serveManagementControlPanel) openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers) geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers) - geminiCLIHandlers := gemini.NewGeminiCLIAPIHandler(s.handlers) claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(s.handlers) openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(s.handlers) @@ -484,7 +485,6 @@ func (s *Server) setupRoutes() { }, }) }) - s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler) // OAuth callback endpoints (reuse main server port) // These endpoints receive provider redirects and persist @@ -517,20 +517,6 @@ func (s *Server) setupRoutes() { c.String(http.StatusOK, oauthCallbackSuccessHTML) }) - s.engine.GET("/google/callback", func(c *gin.Context) { - code := c.Query("code") - state := c.Query("state") - errStr := c.Query("error") - if errStr == "" { - errStr = c.Query("error_description") - } - if state != "" { - _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr) - } - c.Header("Content-Type", "text/html; charset=utf-8") - c.String(http.StatusOK, oauthCallbackSuccessHTML) - }) - s.engine.GET("/antigravity/callback", func(c *gin.Context) { code := c.Query("code") state := c.Query("state") @@ -609,6 +595,9 @@ func (s *Server) registerManagementRoutes() { log.Info("management routes registered after secret key configuration") + s.engine.POST("/v0/management/oauth-callback", s.managementAvailabilityMiddleware(), s.mgmt.PostOAuthCallback) + s.engine.GET("/v0/management/oauth-callback", s.managementAvailabilityMiddleware(), s.mgmt.GetOAuthCallback) + mgmt := s.engine.Group("/v0/management") mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware()) { @@ -659,6 +648,7 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/quota-exceeded/switch-preview-model", s.mgmt.GetSwitchPreviewModel) mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel) mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel) + mgmt.POST("/reset-quota", s.mgmt.ResetQuota) mgmt.GET("/api-keys", s.mgmt.GetAPIKeys) mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys) @@ -741,11 +731,9 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken) mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken) - mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken) mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken) mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken) mgmt.GET("/xai-auth-url", s.mgmt.RequestXAIToken) - mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) } } @@ -964,10 +952,20 @@ func (s *Server) watchKeepAlive() { } } +// isAnthropicModelsRequest reports whether a /v1/models request should be served in +// Anthropic format. Anthropic API clients send the Anthropic-Version header; Claude +// Code additionally uses a claude-cli User-Agent. +func isAnthropicModelsRequest(c *gin.Context) bool { + if c.GetHeader("Anthropic-Version") != "" { + return true + } + return strings.HasPrefix(c.GetHeader("User-Agent"), "claude-cli") +} + // unifiedModelsHandler creates a unified handler for the /v1/models endpoint -// that routes to different handlers based on the User-Agent header. -// If User-Agent starts with "claude-cli", it routes to Claude handler, -// otherwise it routes to OpenAI handler. +// that routes to different handlers based on the request. +// Anthropic API requests (Anthropic-Version header, or a claude-cli User-Agent) +// route to the Claude handler, otherwise they route to the OpenAI handler. func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, claudeHandler *claude.ClaudeCodeAPIHandler) gin.HandlerFunc { return func(c *gin.Context) { if _, ok := c.Request.URL.Query()["client_version"]; ok { @@ -984,14 +982,10 @@ func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, cl return } - userAgent := c.GetHeader("User-Agent") - - // Route to Claude handler if User-Agent starts with "claude-cli" - if strings.HasPrefix(userAgent, "claude-cli") { - // log.Debugf("Routing /v1/models to Claude handler for User-Agent: %s", userAgent) + // Route to Claude handler for Anthropic API requests. + if isAnthropicModelsRequest(c) { claudeHandler.ClaudeModels(c) } else { - // log.Debugf("Routing /v1/models to OpenAI handler for User-Agent: %s", userAgent) openaiHandler.OpenAIModels(c) } } @@ -1048,10 +1042,12 @@ func (s *Server) geminiGetHandler(geminiHandler *gemini.GeminiAPIHandler) gin.Ha } type homeModelEntry struct { - id string - created int64 - ownedBy string - displayName string + id string + created int64 + ownedBy string + displayName string + contextLength int + maxCompletionTokens int } func (s *Server) handleHomeModels(c *gin.Context) { @@ -1060,25 +1056,10 @@ func (s *Server) handleHomeModels(c *gin.Context) { return } - userAgent := c.GetHeader("User-Agent") - isClaude := strings.HasPrefix(userAgent, "claude-cli") + isClaude := isAnthropicModelsRequest(c) if isClaude { - out := make([]map[string]any, 0, len(entries)) - for _, entry := range entries { - model := map[string]any{ - "id": entry.id, - "object": "model", - "owned_by": entry.ownedBy, - } - if entry.created > 0 { - model["created_at"] = entry.created - } - if entry.displayName != "" { - model["display_name"] = entry.displayName - } - out = append(out, model) - } + out := formatHomeClaudeModels(entries) firstID := "" lastID := "" if len(out) > 0 { @@ -1118,6 +1099,42 @@ func (s *Server) handleHomeModels(c *gin.Context) { }) } +func formatHomeClaudeModels(entries []homeModelEntry) []map[string]any { + out := make([]map[string]any, 0, len(entries)) + for _, entry := range entries { + out = append(out, formatHomeClaudeModel(entry)) + } + return out +} + +func formatHomeClaudeModel(entry homeModelEntry) map[string]any { + displayName := entry.displayName + if displayName == "" { + displayName = entry.id + } + maxInput := entry.contextLength + if maxInput <= 0 { + maxInput = registry.DefaultClaudeMaxInputTokens + } + maxOutput := entry.maxCompletionTokens + if maxOutput <= 0 { + maxOutput = registry.DefaultClaudeMaxOutputTokens + } + model := map[string]any{ + "id": entry.id, + "object": "model", + "owned_by": entry.ownedBy, + "type": "model", + "display_name": displayName, + "max_input_tokens": maxInput, + "max_tokens": maxOutput, + } + if entry.created > 0 { + model["created_at"] = time.Unix(entry.created, 0).UTC().Format(time.RFC3339) + } + return model +} + func (s *Server) handleHomeGeminiModels(c *gin.Context) { entries, ok := s.loadHomeModelEntries(c) if !ok { @@ -1333,20 +1350,6 @@ func decodeHomeModels(raw []byte) ([]homeModelEntry, error) { } seen[id] = struct{}{} - created := int64(0) - switch v := model["created"].(type) { - case float64: - created = int64(v) - case int64: - created = v - case int: - created = int64(v) - case json.Number: - if n, err := v.Int64(); err == nil { - created = n - } - } - ownedBy, _ := model["owned_by"].(string) ownedBy = strings.TrimSpace(ownedBy) displayName, _ := model["display_name"].(string) @@ -1357,10 +1360,12 @@ func decodeHomeModels(raw []byte) ([]homeModelEntry, error) { } out = append(out, homeModelEntry{ - id: id, - created: created, - ownedBy: ownedBy, - displayName: displayName, + id: id, + created: homeModelInt64Value(model, "created"), + ownedBy: ownedBy, + displayName: displayName, + contextLength: int(homeModelInt64Value(model, "context_length", "contextLength", "inputTokenLimit", "max_input_tokens")), + maxCompletionTokens: int(homeModelInt64Value(model, "max_completion_tokens", "maxCompletionTokens", "outputTokenLimit", "max_tokens")), }) } } @@ -1372,6 +1377,28 @@ func decodeHomeModels(raw []byte) ([]homeModelEntry, error) { return out, nil } +func homeModelInt64Value(model map[string]any, keys ...string) int64 { + for _, key := range keys { + switch value := model[key].(type) { + case float64: + return int64(value) + case int64: + return value + case int: + return int64(value) + case json.Number: + if n, errInt := value.Int64(); errInt == nil { + return n + } + case string: + if n, errParse := strconv.ParseInt(strings.TrimSpace(value), 10, 64); errParse == nil { + return n + } + } + } + return 0 +} + // Start begins listening for and serving HTTP or HTTPS requests. // It's a blocking call and will only return on an unrecoverable error. // diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 0f42cac19e..011c1f1e9b 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -11,6 +11,7 @@ import ( "time" gin "github.com/gin-gonic/gin" + managementHandlers "github.com/router-for-me/CLIProxyAPI/v7/internal/api/handlers/management" proxyconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config" internallogging "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" "github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost" @@ -122,6 +123,30 @@ func TestManagementResponseExposesPluginSupportHeaderForCORS(t *testing.T) { } } +func TestOAuthCallbackRouteSkipsManagementKeyMiddleware(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "test-management-key") + + server := newTestServer(t) + state := "server-plugin-oauth-state" + if errRegister := managementHandlers.RegisterPluginOAuthSession(state, "gemini-cli", nil); errRegister != nil { + t.Fatalf("register plugin oauth session: %v", errRegister) + } + defer managementHandlers.CompleteOAuthSession(state) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/oauth-callback?state="+state+"&code=test-code", nil) + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want %d body=%s", rr.Code, http.StatusOK, rr.Body.String()) + } + + callbackPath := filepath.Join(server.cfg.AuthDir, ".oauth-gemini-cli-"+state+".oauth") + if _, errRead := os.ReadFile(callbackPath); errRead != nil { + t.Fatalf("expected callback file to be written without management key: %v", errRead) + } +} + func TestNewServerWithPluginHostInjectsHandlerInterceptors(t *testing.T) { host := pluginhost.New() server := newTestServerWithOptions(t, WithPluginHost(host)) @@ -332,6 +357,100 @@ func TestHomeEnabledHidesManagementEndpointsAndControlPanel(t *testing.T) { }) } +func TestModelsDispatchByAnthropicVersionHeader(t *testing.T) { + modelRegistry := registry.GetGlobalRegistry() + clientID := "test-anthropic-version-dispatch" + modelRegistry.RegisterClient(clientID, "claude", []*registry.ModelInfo{ + { + ID: "claude-sonnet-4-6", + Object: "model", + OwnedBy: "anthropic", + Type: "claude", + DisplayName: "Claude 4.6 Sonnet", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + }) + t.Cleanup(func() { + modelRegistry.UnregisterClient(clientID) + }) + + server := newTestServer(t) + + // Anthropic API request (Anthropic-Version header, non-claude-cli User-Agent) -> Claude format. + t.Run("anthropic version header routes to claude format", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + req.Header.Set("Authorization", "Bearer test-key") + req.Header.Set("User-Agent", "Zed/1.0") + req.Header.Set("Anthropic-Version", "2023-06-01") + + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want %d body=%s", rr.Code, http.StatusOK, rr.Body.String()) + } + + var resp struct { + Object string `json:"object"` + HasMore *bool `json:"has_more"` + Data []map[string]any `json:"data"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response JSON: %v; body=%s", err, rr.Body.String()) + } + if resp.Object == "list" { + t.Fatalf("expected Claude format (no object=list), got OpenAI format: %s", rr.Body.String()) + } + if resp.HasMore == nil { + t.Fatalf("expected Claude envelope with has_more, got %s", rr.Body.String()) + } + + var claudeModel map[string]any + for _, m := range resp.Data { + if id, _ := m["id"].(string); id == "claude-sonnet-4-6" { + claudeModel = m + } + } + if claudeModel == nil { + t.Fatalf("expected claude-sonnet-4-6 in response, got %s", rr.Body.String()) + } + for _, field := range []string{"max_input_tokens", "max_tokens", "display_name"} { + if _, ok := claudeModel[field]; !ok { + t.Fatalf("expected Claude model to include %q, got %v", field, claudeModel) + } + } + }) + + // Plain request (no Anthropic-Version, non-claude-cli User-Agent) -> OpenAI format, unaffected. + t.Run("plain request stays on openai format", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + req.Header.Set("Authorization", "Bearer test-key") + req.Header.Set("User-Agent", "Mozilla/5.0") + + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want %d body=%s", rr.Code, http.StatusOK, rr.Body.String()) + } + + var resp struct { + Object string `json:"object"` + Data []map[string]any `json:"data"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response JSON: %v; body=%s", err, rr.Body.String()) + } + if resp.Object != "list" { + t.Fatalf("expected OpenAI format (object=list), got %s", rr.Body.String()) + } + for _, m := range resp.Data { + if _, ok := m["max_input_tokens"]; ok { + t.Fatalf("did not expect max_input_tokens in OpenAI format, got %v", m) + } + } + }) +} + func TestModelsWithClientVersionReturnsCodexCatalog(t *testing.T) { modelRegistry := registry.GetGlobalRegistry() clientID := "test-client-version-catalog" @@ -421,6 +540,9 @@ func TestModelsWithClientVersionReturnsCodexCatalog(t *testing.T) { if got, _ := custom["display_name"].(string); got != "Custom Codex Model" { t.Fatalf("custom display_name = %q, want Custom Codex Model", got) } + if got := int(codexClientTestPriority(custom["priority"])); got != 129 { + t.Fatalf("custom priority = %v, want 129", custom["priority"]) + } if got, _ := custom["description"].(string); got != "Custom model from registry" { t.Fatalf("custom description = %q, want Custom model from registry", got) } @@ -437,6 +559,10 @@ func TestModelsWithClientVersionReturnsCodexCatalog(t *testing.T) { if got, _ := custom["prefer_websockets"].(bool); got { t.Fatalf("custom prefer_websockets = %v, want false", custom["prefer_websockets"]) } + customServiceTiers, ok := custom["service_tiers"].([]any) + if !ok || len(customServiceTiers) != 0 { + t.Fatalf("expected custom model service_tiers = [], got %#v", custom["service_tiers"]) + } if _, ok := custom["apply_patch_tool_type"]; ok { t.Fatal("expected custom model to omit apply_patch_tool_type") } @@ -471,6 +597,17 @@ func TestModelsWithClientVersionReturnsCodexCatalog(t *testing.T) { } } +func codexClientTestPriority(raw any) int { + switch value := raw.(type) { + case int: + return value + case float64: + return int(value) + default: + return -1 + } +} + func assertCodexSupportedReasoningLevels(t *testing.T, model map[string]any, want []string) { t.Helper() @@ -591,6 +728,89 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) { } } +func TestFormatHomeClaudeModelIncludesAnthropicSchemaFields(t *testing.T) { + withMetadata := formatHomeClaudeModel(homeModelEntry{ + id: "claude-sonnet-4-6", + created: 1771372800, + ownedBy: "anthropic", + displayName: "Claude 4.6 Sonnet", + contextLength: 200000, + maxCompletionTokens: 64000, + }) + if got := withMetadata["created_at"]; got != "2026-02-18T00:00:00Z" { + t.Fatalf("created_at = %v, want RFC3339 timestamp", got) + } + if got := withMetadata["type"]; got != "model" { + t.Fatalf("type = %v, want model", got) + } + if got := withMetadata["display_name"]; got != "Claude 4.6 Sonnet" { + t.Fatalf("display_name = %v, want Claude 4.6 Sonnet", got) + } + if got := withMetadata["max_input_tokens"]; got != 200000 { + t.Fatalf("max_input_tokens = %v, want 200000", got) + } + if got := withMetadata["max_tokens"]; got != 64000 { + t.Fatalf("max_tokens = %v, want 64000", got) + } + + withDefaults := formatHomeClaudeModel(homeModelEntry{id: "claude-no-limits"}) + if got := withDefaults["display_name"]; got != "claude-no-limits" { + t.Fatalf("display_name fallback = %v, want claude-no-limits", got) + } + if got := withDefaults["max_input_tokens"]; got != registry.DefaultClaudeMaxInputTokens { + t.Fatalf("max_input_tokens fallback = %v, want %d", got, registry.DefaultClaudeMaxInputTokens) + } + if got := withDefaults["max_tokens"]; got != registry.DefaultClaudeMaxOutputTokens { + t.Fatalf("max_tokens fallback = %v, want %d", got, registry.DefaultClaudeMaxOutputTokens) + } + if _, ok := withDefaults["created_at"]; ok { + t.Fatalf("created_at should be omitted when source created is missing, got %v", withDefaults) + } +} + +func TestDecodeHomeModelsKeepsTokenMetadata(t *testing.T) { + entries, errDecode := decodeHomeModels([]byte(`{ + "claude": [ + { + "id": "claude-sonnet-4-6", + "created": 1771372800, + "owned_by": "anthropic", + "context_length": 200000, + "max_completion_tokens": 64000 + } + ], + "gemini": [ + { + "name": "models/gemini-3-pro", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536 + } + ] + }`)) + if errDecode != nil { + t.Fatalf("decodeHomeModels returned error: %v", errDecode) + } + + byID := make(map[string]homeModelEntry, len(entries)) + for _, entry := range entries { + byID[entry.id] = entry + } + claudeEntry, ok := byID["claude-sonnet-4-6"] + if !ok { + t.Fatalf("expected claude-sonnet-4-6 entry, got %v", byID) + } + if claudeEntry.contextLength != 200000 || claudeEntry.maxCompletionTokens != 64000 { + t.Fatalf("claude token metadata = %d/%d, want 200000/64000", claudeEntry.contextLength, claudeEntry.maxCompletionTokens) + } + geminiEntry, ok := byID["gemini-3-pro"] + if !ok { + t.Fatalf("expected gemini-3-pro entry, got %v", byID) + } + if geminiEntry.contextLength != 1048576 || geminiEntry.maxCompletionTokens != 65536 { + t.Fatalf("gemini token metadata = %d/%d, want 1048576/65536", geminiEntry.contextLength, geminiEntry.maxCompletionTokens) + } +} + func TestHomeModelsAuthStatus(t *testing.T) { cases := []struct { name string diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go deleted file mode 100644 index 5b9ee82d26..0000000000 --- a/internal/auth/gemini/gemini_auth.go +++ /dev/null @@ -1,372 +0,0 @@ -// Package gemini provides authentication and token management functionality -// for Google's Gemini AI services. It handles OAuth2 authentication flows, -// including obtaining tokens via web-based authorization, storing tokens, -// and refreshing them when they expire. -package gemini - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "time" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex" - "github.com/router-for-me/CLIProxyAPI/v7/internal/browser" - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" - "github.com/router-for-me/CLIProxyAPI/v7/internal/util" - "github.com/router-for-me/CLIProxyAPI/v7/sdk/proxyutil" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" -) - -// OAuth configuration constants for Gemini -const ( - ClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - ClientSecret = "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" - DefaultCallbackPort = 8085 -) - -// OAuth scopes for Gemini authentication -var Scopes = []string{ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", -} - -// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow. -// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens -// for Google's Gemini AI services. -type GeminiAuth struct { -} - -// WebLoginOptions customizes the interactive OAuth flow. -type WebLoginOptions struct { - NoBrowser bool - CallbackPort int - Prompt func(string) (string, error) -} - -// NewGeminiAuth creates a new instance of GeminiAuth. -func NewGeminiAuth() *GeminiAuth { - return &GeminiAuth{} -} - -// GetAuthenticatedClient configures and returns an HTTP client ready for making authenticated API calls. -// It manages the entire OAuth2 flow, including handling proxies, loading existing tokens, -// initiating a new web-based OAuth flow if necessary, and refreshing tokens. -// -// Parameters: -// - ctx: The context for the HTTP client -// - ts: The Gemini token storage containing authentication tokens -// - cfg: The configuration containing proxy settings -// - opts: Optional parameters to customize browser and prompt behavior -// -// Returns: -// - *http.Client: An HTTP client configured with authentication -// - error: An error if the client configuration fails, nil otherwise -func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) { - callbackPort := DefaultCallbackPort - if opts != nil && opts.CallbackPort > 0 { - callbackPort = opts.CallbackPort - } - callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort) - - transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL) - if errBuild != nil { - log.Errorf("%v", errBuild) - } else if transport != nil { - proxyClient := &http.Client{Transport: transport} - ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient) - } - - var err error - - // Configure the OAuth2 client. - conf := &oauth2.Config{ - ClientID: ClientID, - ClientSecret: ClientSecret, - RedirectURL: callbackURL, // This will be used by the local server. - Scopes: Scopes, - Endpoint: google.Endpoint, - } - - var token *oauth2.Token - - // If no token is found in storage, initiate the web-based OAuth flow. - if ts.Token == nil { - fmt.Printf("Could not load token from file, starting OAuth flow.\n") - token, err = g.getTokenFromWeb(ctx, conf, opts) - if err != nil { - return nil, fmt.Errorf("failed to get token from web: %w", err) - } - // After getting a new token, create a new token storage object with user info. - newTs, errCreateTokenStorage := g.createTokenStorage(ctx, conf, token, ts.ProjectID) - if errCreateTokenStorage != nil { - log.Errorf("Warning: failed to create token storage: %v", errCreateTokenStorage) - return nil, errCreateTokenStorage - } - *ts = *newTs - } - - // Unmarshal the stored token into an oauth2.Token object. - tsToken, _ := json.Marshal(ts.Token) - if err = json.Unmarshal(tsToken, &token); err != nil { - return nil, fmt.Errorf("failed to unmarshal token: %w", err) - } - - // Return an HTTP client that automatically handles token refreshing. - return conf.Client(ctx, token), nil -} - -// createTokenStorage creates a new GeminiTokenStorage object. It fetches the user's email -// using the provided token and populates the storage structure. -// -// Parameters: -// - ctx: The context for the HTTP request -// - config: The OAuth2 configuration -// - token: The OAuth2 token to use for authentication -// - projectID: The Google Cloud Project ID to associate with this token -// -// Returns: -// - *GeminiTokenStorage: A new token storage object with user information -// - error: An error if the token storage creation fails, nil otherwise -func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*GeminiTokenStorage, error) { - httpClient := config.Client(ctx, token) - req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) - if err != nil { - return nil, fmt.Errorf("could not get user info: %v", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) - - resp, err := httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to execute request: %w", err) - } - defer func() { - if err = resp.Body.Close(); err != nil { - log.Printf("warn: failed to close response body: %v", err) - } - }() - - bodyBytes, _ := io.ReadAll(resp.Body) - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - } - - emailResult := gjson.GetBytes(bodyBytes, "email") - if emailResult.Exists() && emailResult.Type == gjson.String { - fmt.Printf("Authenticated user email: %s\n", emailResult.String()) - } else { - fmt.Println("Failed to get user email from token") - } - - var ifToken map[string]any - jsonData, _ := json.Marshal(token) - err = json.Unmarshal(jsonData, &ifToken) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal token: %w", err) - } - - ifToken["token_uri"] = "https://oauth2.googleapis.com/token" - ifToken["client_id"] = ClientID - ifToken["client_secret"] = ClientSecret - ifToken["scopes"] = Scopes - ifToken["universe_domain"] = "googleapis.com" - - ts := GeminiTokenStorage{ - Token: ifToken, - ProjectID: projectID, - Email: emailResult.String(), - } - - return &ts, nil -} - -// getTokenFromWeb initiates the web-based OAuth2 authorization flow. -// It starts a local HTTP server to listen for the callback from Google's auth server, -// opens the user's browser to the authorization URL, and exchanges the received -// authorization code for an access token. -// -// Parameters: -// - ctx: The context for the HTTP client -// - config: The OAuth2 configuration -// - opts: Optional parameters to customize browser and prompt behavior -// -// Returns: -// - *oauth2.Token: The OAuth2 token obtained from the authorization flow -// - error: An error if the token acquisition fails, nil otherwise -func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { - callbackPort := DefaultCallbackPort - if opts != nil && opts.CallbackPort > 0 { - callbackPort = opts.CallbackPort - } - callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort) - - // Use a channel to pass the authorization code from the HTTP handler to the main function. - codeChan := make(chan string, 1) - errChan := make(chan error, 1) - - // Create a new HTTP server with its own multiplexer. - mux := http.NewServeMux() - server := &http.Server{Addr: fmt.Sprintf(":%d", callbackPort), Handler: mux} - config.RedirectURL = callbackURL - - mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) { - if err := r.URL.Query().Get("error"); err != "" { - _, _ = fmt.Fprintf(w, "Authentication failed: %s", err) - select { - case errChan <- fmt.Errorf("authentication failed via callback: %s", err): - default: - } - return - } - code := r.URL.Query().Get("code") - if code == "" { - _, _ = fmt.Fprint(w, "Authentication failed: code not found.") - select { - case errChan <- fmt.Errorf("code not found in callback"): - default: - } - return - } - _, _ = fmt.Fprint(w, "

Authentication successful!

You can close this window.

") - select { - case codeChan <- code: - default: - } - }) - - // Start the server in a goroutine. - go func() { - if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { - log.Errorf("ListenAndServe(): %v", err) - select { - case errChan <- err: - default: - } - } - }() - - // Open the authorization URL in the user's browser. - authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) - - noBrowser := false - if opts != nil { - noBrowser = opts.NoBrowser - } - - if !noBrowser { - fmt.Println("Opening browser for authentication...") - - // Check if browser is available - if !browser.IsAvailable() { - log.Warn("No browser available on this system") - util.PrintSSHTunnelInstructions(callbackPort) - fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL) - } else { - if err := browser.OpenURL(authURL); err != nil { - authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err) - log.Warn(codex.GetUserFriendlyMessage(authErr)) - util.PrintSSHTunnelInstructions(callbackPort) - fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL) - - // Log platform info for debugging - platformInfo := browser.GetPlatformInfo() - log.Debugf("Browser platform info: %+v", platformInfo) - } else { - log.Debug("Browser opened successfully") - } - } - } else { - util.PrintSSHTunnelInstructions(callbackPort) - fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL) - } - - fmt.Println("Waiting for authentication callback...") - - // Wait for the authorization code or an error. - var authCode string - timeoutTimer := time.NewTimer(5 * time.Minute) - defer timeoutTimer.Stop() - - var manualPromptTimer *time.Timer - var manualPromptC <-chan time.Time - if opts != nil && opts.Prompt != nil { - manualPromptTimer = time.NewTimer(15 * time.Second) - manualPromptC = manualPromptTimer.C - defer manualPromptTimer.Stop() - } - - var manualInputCh <-chan string - var manualInputErrCh <-chan error - -waitForCallback: - for { - select { - case code := <-codeChan: - authCode = code - break waitForCallback - case err := <-errChan: - return nil, err - case <-manualPromptC: - manualPromptC = nil - if manualPromptTimer != nil { - manualPromptTimer.Stop() - } - select { - case code := <-codeChan: - authCode = code - break waitForCallback - case err := <-errChan: - return nil, err - default: - } - manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the Gemini callback URL (or press Enter to keep waiting): ") - continue - case input := <-manualInputCh: - manualInputCh = nil - manualInputErrCh = nil - parsed, errParse := misc.ParseOAuthCallback(input) - if errParse != nil { - return nil, errParse - } - if parsed == nil { - continue - } - if parsed.Error != "" { - return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error) - } - if parsed.Code == "" { - return nil, fmt.Errorf("code not found in callback") - } - authCode = parsed.Code - break waitForCallback - case errManual := <-manualInputErrCh: - return nil, errManual - case <-timeoutTimer.C: - return nil, fmt.Errorf("oauth flow timed out") - } - } - - // Shutdown the server. - if err := server.Shutdown(ctx); err != nil { - log.Errorf("Failed to shut down server: %v", err) - } - - // Exchange the authorization code for a token. - token, err := config.Exchange(ctx, authCode) - if err != nil { - return nil, fmt.Errorf("failed to exchange token: %w", err) - } - - fmt.Println("Authentication successful.") - return token, nil -} diff --git a/internal/auth/gemini/gemini_token.go b/internal/auth/gemini/gemini_token.go deleted file mode 100644 index a6ea8c5151..0000000000 --- a/internal/auth/gemini/gemini_token.go +++ /dev/null @@ -1,104 +0,0 @@ -// Package gemini provides authentication and token management functionality -// for Google's Gemini AI services. It handles OAuth2 token storage, serialization, -// and retrieval for maintaining authenticated sessions with the Gemini API. -package gemini - -import ( - "encoding/json" - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" - log "github.com/sirupsen/logrus" -) - -// GeminiTokenStorage stores OAuth2 token information for Google Gemini API authentication. -// It maintains compatibility with the existing auth system while adding Gemini-specific fields -// for managing access tokens, refresh tokens, and user account information. -type GeminiTokenStorage struct { - // Token holds the raw OAuth2 token data, including access and refresh tokens. - Token any `json:"token"` - - // ProjectID is the Google Cloud Project ID associated with this token. - ProjectID string `json:"project_id"` - - // Email is the email address of the authenticated user. - Email string `json:"email"` - - // Auto indicates if the project ID was automatically selected. - Auto bool `json:"auto"` - - // Checked indicates if the associated Cloud AI API has been verified as enabled. - Checked bool `json:"checked"` - - // Type indicates the authentication provider type, always "gemini" for this storage. - Type string `json:"type"` - - // Metadata holds arbitrary key-value pairs injected via hooks. - // It is not exported to JSON directly to allow flattening during serialization. - Metadata map[string]any `json:"-"` -} - -// SetMetadata allows external callers to inject metadata into the storage before saving. -func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) { - ts.Metadata = meta -} - -// SaveTokenToFile serializes the Gemini token storage to a JSON file. -// This method creates the necessary directory structure and writes the token -// data in JSON format to the specified file path for persistent storage. -// It merges any injected metadata into the top-level JSON object. -// -// Parameters: -// - authFilePath: The full path where the token file should be saved -// -// Returns: -// - error: An error if the operation fails, nil otherwise -func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error { - misc.LogSavingCredentials(authFilePath) - ts.Type = "gemini" - // Merge metadata using helper - data, errMerge := misc.MergeMetadata(ts, ts.Metadata) - if errMerge != nil { - return fmt.Errorf("failed to merge metadata: %w", errMerge) - } - if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { - return fmt.Errorf("failed to create directory: %v", err) - } - - f, err := os.Create(authFilePath) - if err != nil { - return fmt.Errorf("failed to create token file: %w", err) - } - defer func() { - if errClose := f.Close(); errClose != nil { - log.Errorf("failed to close file: %v", errClose) - } - }() - - enc := json.NewEncoder(f) - enc.SetIndent("", " ") - if err := enc.Encode(data); err != nil { - return fmt.Errorf("failed to write token to file: %w", err) - } - return nil -} - -// CredentialFileName returns the filename used to persist Gemini CLI credentials. -// When projectID represents multiple projects (comma-separated or literal ALL), -// the suffix is normalized to "all" and a "gemini-" prefix is enforced to keep -// web and CLI generated files consistent. -func CredentialFileName(email, projectID string, includeProviderPrefix bool) string { - email = strings.TrimSpace(email) - project := strings.TrimSpace(projectID) - if strings.EqualFold(project, "all") || strings.Contains(project, ",") { - return fmt.Sprintf("gemini-%s-all.json", email) - } - prefix := "" - if includeProviderPrefix { - prefix = "gemini-" - } - return fmt.Sprintf("%s%s-%s.json", prefix, email, project) -} diff --git a/internal/cache/antigravity_reasoning_replay_cache.go b/internal/cache/antigravity_reasoning_replay_cache.go new file mode 100644 index 0000000000..a9f58c28d3 --- /dev/null +++ b/internal/cache/antigravity_reasoning_replay_cache.go @@ -0,0 +1,347 @@ +package cache + +import ( + "context" + "encoding/json" + "sort" + "strings" + "sync" + "time" + + homekv "github.com/router-for-me/CLIProxyAPI/v7/internal/home" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + // AntigravityReasoningReplayCacheTTL limits how long encrypted reasoning replay + // items stay in process memory. + AntigravityReasoningReplayCacheTTL = 1 * time.Hour + + // AntigravityReasoningReplayCacheMaxEntries bounds process memory for replay + // continuity. Oldest entries are evicted first. + AntigravityReasoningReplayCacheMaxEntries = 10240 + + // AntigravityReasoningReplayCacheEvictBatchSize leaves headroom after the cache + // reaches capacity so high write volume does not rescan the map every turn. + AntigravityReasoningReplayCacheEvictBatchSize = 128 + + minAntigravityThoughtSignatureReplayLen = 16 +) + +type antigravityReasoningReplayEntry struct { + Items [][]byte + Timestamp time.Time +} + +var ( + antigravityReasoningReplayMu sync.Mutex + antigravityReasoningReplayEntries = make(map[string]antigravityReasoningReplayEntry) +) + +type antigravityReasoningReplayKVClient interface { + KVGet(ctx context.Context, key string) ([]byte, bool, error) + KVSet(ctx context.Context, key string, value []byte, opts homekv.KVSetOptions) (bool, error) + KVDel(ctx context.Context, keys ...string) (int64, error) + KVExpire(ctx context.Context, key string, ttl time.Duration) (bool, error) +} + +var currentAntigravityReasoningReplayKVClient = func() (antigravityReasoningReplayKVClient, bool, error) { + return homekv.CurrentKVClient() +} + +// CacheAntigravityReasoningReplayItem stores a final GPT/Codex reasoning item for +// stateless replay. The stored item is normalized to the minimal shape accepted +// by Responses input replay. +func CacheAntigravityReasoningReplayItem(modelName, sessionKey string, item []byte) bool { + return CacheAntigravityReasoningReplayItems(modelName, sessionKey, [][]byte{item}) +} + +// CacheAntigravityReasoningReplayItems stores the final GPT/Codex assistant output +// items needed to replay a stateless next turn. +func CacheAntigravityReasoningReplayItems(modelName, sessionKey string, items [][]byte) bool { + return CacheAntigravityReasoningReplayItemsBestEffort(context.Background(), modelName, sessionKey, items) +} + +// CacheAntigravityReasoningReplayItemsBestEffort stores replay items for completed response paths. +func CacheAntigravityReasoningReplayItemsBestEffort(ctx context.Context, modelName, sessionKey string, items [][]byte) bool { + key := antigravityReasoningReplayCacheKey(modelName, sessionKey) + if key == "" { + return false + } + normalized, ok := normalizeAntigravityReasoningReplayItems(items) + if !ok { + return false + } + if client, homeMode, errClient := currentAntigravityReasoningReplayKVClient(); homeMode { + if errClient != nil { + log.Errorf("home kv best-effort antigravity reasoning replay set failed prefix=cpa:antigravity:*: %v", errClient) + return false + } + raw, errMarshal := json.Marshal(normalized) + if errMarshal != nil { + log.Errorf("home kv best-effort antigravity reasoning replay set failed prefix=cpa:antigravity:*: %v", errMarshal) + return false + } + written, errSet := client.KVSet(ctx, antigravityReasoningReplayKVKey(modelName, sessionKey), raw, homekv.KVSetOptions{EX: AntigravityReasoningReplayCacheTTL}) + if errSet != nil { + log.Errorf("home kv best-effort antigravity reasoning replay set failed prefix=cpa:antigravity:*: %v", errSet) + return false + } + return written + } + + cacheCleanupOnce.Do(startCacheCleanup) + now := time.Now() + antigravityReasoningReplayMu.Lock() + defer antigravityReasoningReplayMu.Unlock() + antigravityReasoningReplayEntries[key] = antigravityReasoningReplayEntry{ + Items: normalized, + Timestamp: now, + } + if len(antigravityReasoningReplayEntries) > AntigravityReasoningReplayCacheMaxEntries { + evictOldestAntigravityReasoningReplayEntries(AntigravityReasoningReplayCacheEvictBatchSize) + } + return true +} + +// GetAntigravityReasoningReplayItem retrieves a normalized reasoning replay item. +func GetAntigravityReasoningReplayItem(modelName, sessionKey string) ([]byte, bool) { + items, ok := GetAntigravityReasoningReplayItems(modelName, sessionKey) + if !ok || len(items) == 0 { + return nil, false + } + return items[0], true +} + +// GetAntigravityReasoningReplayItems retrieves normalized assistant output items. +func GetAntigravityReasoningReplayItems(modelName, sessionKey string) ([][]byte, bool) { + items, ok, err := GetAntigravityReasoningReplayItemsRequired(context.Background(), modelName, sessionKey) + if err == nil { + return items, ok + } + return nil, false +} + +// GetAntigravityReasoningReplayItemsRequired retrieves replay items for request-time paths. +func GetAntigravityReasoningReplayItemsRequired(ctx context.Context, modelName, sessionKey string) ([][]byte, bool, error) { + key := antigravityReasoningReplayCacheKey(modelName, sessionKey) + if key == "" { + return nil, false, nil + } + client, homeMode, errClient := currentAntigravityReasoningReplayKVClient() + if homeMode { + if errClient != nil { + return nil, false, errClient + } + raw, found, errGet := client.KVGet(ctx, antigravityReasoningReplayKVKey(modelName, sessionKey)) + if errGet != nil || !found { + return nil, false, errGet + } + var homeItems [][]byte + if errUnmarshal := json.Unmarshal(raw, &homeItems); errUnmarshal != nil { + return nil, false, errUnmarshal + } + if _, errExpire := client.KVExpire(ctx, antigravityReasoningReplayKVKey(modelName, sessionKey), AntigravityReasoningReplayCacheTTL); errExpire != nil { + return nil, false, errExpire + } + return cloneAntigravityReasoningReplayItems(homeItems), true, nil + } + + cacheCleanupOnce.Do(startCacheCleanup) + now := time.Now() + antigravityReasoningReplayMu.Lock() + defer antigravityReasoningReplayMu.Unlock() + entry, ok := antigravityReasoningReplayEntries[key] + if !ok { + return nil, false, nil + } + if now.Sub(entry.Timestamp) > AntigravityReasoningReplayCacheTTL { + delete(antigravityReasoningReplayEntries, key) + return nil, false, nil + } + entry.Timestamp = now + antigravityReasoningReplayEntries[key] = entry + return cloneAntigravityReasoningReplayItems(entry.Items), true, nil +} + +// DeleteAntigravityReasoningReplayItem removes one replay item after upstream rejects +// it or the caller otherwise knows it is stale. +func DeleteAntigravityReasoningReplayItem(modelName, sessionKey string) { + if errDelete := DeleteAntigravityReasoningReplayItemRequired(context.Background(), modelName, sessionKey); errDelete != nil { + return + } +} + +// DeleteAntigravityReasoningReplayItemRequired removes one replay item for request-time paths. +func DeleteAntigravityReasoningReplayItemRequired(ctx context.Context, modelName, sessionKey string) error { + key := antigravityReasoningReplayCacheKey(modelName, sessionKey) + if key == "" { + return nil + } + client, homeMode, errClient := currentAntigravityReasoningReplayKVClient() + if homeMode { + if errClient != nil { + return errClient + } + _, errDel := client.KVDel(ctx, antigravityReasoningReplayKVKey(modelName, sessionKey)) + return errDel + } + antigravityReasoningReplayMu.Lock() + delete(antigravityReasoningReplayEntries, key) + antigravityReasoningReplayMu.Unlock() + return nil +} + +// ClearAntigravityReasoningReplayCache clears all Antigravity reasoning replay state. +func ClearAntigravityReasoningReplayCache() { + antigravityReasoningReplayMu.Lock() + antigravityReasoningReplayEntries = make(map[string]antigravityReasoningReplayEntry) + antigravityReasoningReplayMu.Unlock() +} + +func antigravityReasoningReplayCacheKey(modelName, sessionKey string) string { + modelName = strings.TrimSpace(modelName) + sessionKey = strings.TrimSpace(sessionKey) + if modelName == "" || sessionKey == "" { + return "" + } + // The session key is the continuity boundary. Keep this independent from + // the selected upstream Codex credential so auth failover can preserve replay. + return strings.Join([]string{"antigravity-reasoning-replay", modelName, sessionKey}, "\x00") +} + +func antigravityReasoningReplayKVKey(modelName, sessionKey string) string { + return "cpa:antigravity:reasoning-replay:" + homekv.HashKeyPart(strings.TrimSpace(modelName)) + ":" + homekv.HashKeyPart(strings.TrimSpace(sessionKey)) +} + +func normalizeAntigravityReasoningReplayItems(items [][]byte) ([][]byte, bool) { + normalized := make([][]byte, 0, len(items)) + for _, item := range items { + normalizedItem, ok := normalizeAntigravityReasoningReplayItem(item) + if ok { + normalized = append(normalized, normalizedItem) + } + } + return normalized, len(normalized) > 0 +} + +func normalizeAntigravityReasoningReplayItem(item []byte) ([]byte, bool) { + itemResult := gjson.ParseBytes(item) + switch strings.TrimSpace(itemResult.Get("type").String()) { + case "thought_signature": + return normalizeAntigravityThoughtSignatureReplayItem(itemResult) + case "function_call_part": + return normalizeAntigravityFunctionCallPartReplayItem(itemResult) + default: + return nil, false + } +} + +func normalizeAntigravityThoughtSignatureReplayItem(itemResult gjson.Result) ([]byte, bool) { + sig := strings.TrimSpace(itemResult.Get("thoughtSignature").String()) + if sig == "" { + sig = strings.TrimSpace(itemResult.Get("thought_signature").String()) + } + if sig == "" || len(sig) < minAntigravityThoughtSignatureReplayLen { + return nil, false + } + normalized := []byte(`{"type":"thought_signature"}`) + normalized, _ = sjson.SetBytes(normalized, "thoughtSignature", sig) + if contentIndex := itemResult.Get("contentIndex"); contentIndex.Type == gjson.Number { + normalized, _ = sjson.SetBytes(normalized, "contentIndex", contentIndex.Int()) + } + if partIndex := itemResult.Get("partIndex"); partIndex.Type == gjson.Number { + normalized, _ = sjson.SetBytes(normalized, "partIndex", partIndex.Int()) + } + return normalized, true +} + +func normalizeAntigravityFunctionCallPartReplayItem(itemResult gjson.Result) ([]byte, bool) { + callID := strings.TrimSpace(itemResult.Get("call_id").String()) + if callID == "" { + callID = strings.TrimSpace(itemResult.Get("id").String()) + } + name := strings.TrimSpace(itemResult.Get("name").String()) + args := itemResult.Get("args") + if name == "" || !args.Exists() { + fc := itemResult.Get("functionCall") + if fc.Exists() { + if callID == "" { + callID = strings.TrimSpace(fc.Get("id").String()) + } + if name == "" { + name = strings.TrimSpace(fc.Get("name").String()) + } + if !args.Exists() { + args = fc.Get("args") + } + } + } + if name == "" || !args.Exists() { + return nil, false + } + normalized := []byte(`{"type":"function_call_part"}`) + if callID != "" { + normalized, _ = sjson.SetBytes(normalized, "call_id", callID) + } + normalized, _ = sjson.SetBytes(normalized, "name", name) + if args.Type == gjson.String { + normalized, _ = sjson.SetBytes(normalized, "args", args.String()) + } else { + normalized, _ = sjson.SetRawBytes(normalized, "args", []byte(args.Raw)) + } + sig := strings.TrimSpace(itemResult.Get("thoughtSignature").String()) + if sig != "" { + normalized, _ = sjson.SetBytes(normalized, "thoughtSignature", sig) + } + if contentIndex := itemResult.Get("contentIndex"); contentIndex.Type == gjson.Number { + normalized, _ = sjson.SetBytes(normalized, "contentIndex", contentIndex.Int()) + } + if partIndex := itemResult.Get("partIndex"); partIndex.Type == gjson.Number { + normalized, _ = sjson.SetBytes(normalized, "partIndex", partIndex.Int()) + } + return normalized, true +} + +func cloneAntigravityReasoningReplayItems(items [][]byte) [][]byte { + cloned := make([][]byte, 0, len(items)) + for _, item := range items { + cloned = append(cloned, append([]byte(nil), item...)) + } + return cloned +} + +func evictOldestAntigravityReasoningReplayEntries(count int) { + if count <= 0 || len(antigravityReasoningReplayEntries) == 0 { + return + } + type candidate struct { + key string + timestamp time.Time + } + candidates := make([]candidate, 0, len(antigravityReasoningReplayEntries)) + for key, entry := range antigravityReasoningReplayEntries { + candidates = append(candidates, candidate{key: key, timestamp: entry.Timestamp}) + } + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].timestamp.Before(candidates[j].timestamp) + }) + if count > len(candidates) { + count = len(candidates) + } + for i := 0; i < count; i++ { + delete(antigravityReasoningReplayEntries, candidates[i].key) + } +} + +func purgeExpiredAntigravityReasoningReplayCache(now time.Time) { + antigravityReasoningReplayMu.Lock() + for key, entry := range antigravityReasoningReplayEntries { + if now.Sub(entry.Timestamp) > AntigravityReasoningReplayCacheTTL { + delete(antigravityReasoningReplayEntries, key) + } + } + antigravityReasoningReplayMu.Unlock() +} diff --git a/internal/cache/signature_cache.go b/internal/cache/signature_cache.go index 1f54458e40..72c3ddebc5 100644 --- a/internal/cache/signature_cache.go +++ b/internal/cache/signature_cache.go @@ -109,6 +109,7 @@ func purgeExpiredCaches() { return true }) purgeExpiredCodexReasoningReplayCache(now) + purgeExpiredAntigravityReasoningReplayCache(now) } // CacheSignature stores a thinking signature for a given model group and text. diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go index a5882e654c..8d19be1cef 100644 --- a/internal/cmd/auth_manager.go +++ b/internal/cmd/auth_manager.go @@ -6,14 +6,13 @@ import ( // newAuthManager creates a new authentication manager instance with all supported // authenticators and a file-based token store. It initializes authenticators for -// Gemini, Codex, Claude, Antigravity, Kimi, and xAI providers. +// Codex, Claude, Antigravity, Kimi, and xAI providers. // // Returns: // - *sdkAuth.Manager: A configured authentication manager instance func newAuthManager() *sdkAuth.Manager { store := sdkAuth.GetTokenStore() manager := sdkAuth.NewManager(store, - sdkAuth.NewGeminiAuthenticator(), sdkAuth.NewCodexAuthenticator(), sdkAuth.NewClaudeAuthenticator(), sdkAuth.NewAntigravityAuthenticator(), diff --git a/internal/cmd/login.go b/internal/cmd/login.go deleted file mode 100644 index a71bb28263..0000000000 --- a/internal/cmd/login.go +++ /dev/null @@ -1,663 +0,0 @@ -// Package cmd provides command-line interface functionality for the CLI Proxy API server. -// It includes authentication flows for various AI service providers, service startup, -// and other command-line operations. -package cmd - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "os" - "strconv" - "strings" - "time" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini" - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" - "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" - sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" - cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" -) - -const ( - geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com" - geminiCLIVersion = "v1internal" -) - -type projectSelectionRequiredError struct{} - -func (e *projectSelectionRequiredError) Error() string { - return "gemini cli: project selection required" -} - -// DoLogin handles Google Gemini authentication using the shared authentication manager. -// It initiates the OAuth flow for Google Gemini services, performs the legacy CLI user setup, -// and saves the authentication tokens to the configured auth directory. -// -// Parameters: -// - cfg: The application configuration -// - projectID: Optional Google Cloud project ID for Gemini services -// - options: Login options including browser behavior and prompts -func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { - if options == nil { - options = &LoginOptions{} - } - - ctx := context.Background() - - promptFn := options.Prompt - if promptFn == nil { - promptFn = defaultProjectPrompt() - } - - trimmedProjectID := strings.TrimSpace(projectID) - callbackPrompt := promptFn - if trimmedProjectID == "" { - callbackPrompt = nil - } - - loginOpts := &sdkAuth.LoginOptions{ - NoBrowser: options.NoBrowser, - ProjectID: trimmedProjectID, - CallbackPort: options.CallbackPort, - Metadata: map[string]string{}, - Prompt: callbackPrompt, - } - - authenticator := sdkAuth.NewGeminiAuthenticator() - record, errLogin := authenticator.Login(ctx, cfg, loginOpts) - if errLogin != nil { - log.Errorf("Gemini authentication failed: %v", errLogin) - return - } - - storage, okStorage := record.Storage.(*gemini.GeminiTokenStorage) - if !okStorage || storage == nil { - log.Error("Gemini authentication failed: unsupported token storage") - return - } - - geminiAuth := gemini.NewGeminiAuth() - httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{ - NoBrowser: options.NoBrowser, - CallbackPort: options.CallbackPort, - Prompt: callbackPrompt, - }) - if errClient != nil { - log.Errorf("Gemini authentication failed: %v", errClient) - return - } - - log.Info("Authentication successful.") - - var activatedProjects []string - - useGoogleOne := false - if trimmedProjectID == "" && promptFn != nil { - fmt.Println("\nSelect login mode:") - fmt.Println(" 1. Code Assist (GCP project, manual selection)") - fmt.Println(" 2. Google One (personal account, auto-discover project)") - choice, errPrompt := promptFn("Enter choice [1/2] (default: 1): ") - if errPrompt == nil && strings.TrimSpace(choice) == "2" { - useGoogleOne = true - } - } - - if useGoogleOne { - log.Info("Google One mode: auto-discovering project...") - if errSetup := performGeminiCLISetup(ctx, httpClient, storage, ""); errSetup != nil { - log.Errorf("Google One auto-discovery failed: %v", errSetup) - return - } - autoProject := strings.TrimSpace(storage.ProjectID) - if autoProject == "" { - log.Error("Google One auto-discovery returned empty project ID") - return - } - log.Infof("Auto-discovered project: %s", autoProject) - activatedProjects = []string{autoProject} - } else { - projects, errProjects := fetchGCPProjects(ctx, httpClient) - if errProjects != nil { - log.Errorf("Failed to get project list: %v", errProjects) - return - } - - selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn) - projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects) - if errSelection != nil { - log.Errorf("Invalid project selection: %v", errSelection) - return - } - if len(projectSelections) == 0 { - log.Error("No project selected; aborting login.") - return - } - - seenProjects := make(map[string]bool) - for _, candidateID := range projectSelections { - log.Infof("Activating project %s", candidateID) - if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil { - if _, ok := errors.AsType[*projectSelectionRequiredError](errSetup); ok { - log.Error("Failed to start user onboarding: A project ID is required.") - showProjectSelectionHelp(storage.Email, projects) - return - } - log.Errorf("Failed to complete user setup: %v", errSetup) - return - } - finalID := strings.TrimSpace(storage.ProjectID) - if finalID == "" { - finalID = candidateID - } - - if seenProjects[finalID] { - log.Infof("Project %s already activated, skipping", finalID) - continue - } - seenProjects[finalID] = true - activatedProjects = append(activatedProjects, finalID) - } - } - - storage.Auto = false - storage.ProjectID = strings.Join(activatedProjects, ",") - - if !storage.Auto && !storage.Checked { - for _, pid := range activatedProjects { - isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, pid) - if errCheck != nil { - log.Errorf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck) - return - } - if !isChecked { - log.Errorf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid) - return - } - } - storage.Checked = true - } - - updateAuthRecord(record, storage) - - store := sdkAuth.GetTokenStore() - if setter, okSetter := store.(interface{ SetBaseDir(string) }); okSetter && cfg != nil { - setter.SetBaseDir(cfg.AuthDir) - } - - savedPath, errSave := store.Save(ctx, record) - if errSave != nil { - log.Errorf("Failed to save token to file: %v", errSave) - return - } - - if savedPath != "" { - fmt.Printf("Authentication saved to %s\n", savedPath) - } - - fmt.Println("Gemini authentication successful!") -} - -func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage *gemini.GeminiTokenStorage, requestedProject string) error { - metadata := map[string]string{ - "ideType": "IDE_UNSPECIFIED", - "platform": "PLATFORM_UNSPECIFIED", - "pluginType": "GEMINI", - } - - trimmedRequest := strings.TrimSpace(requestedProject) - explicitProject := trimmedRequest != "" - - loadReqBody := map[string]any{ - "metadata": metadata, - } - if explicitProject { - loadReqBody["cloudaicompanionProject"] = trimmedRequest - } - - var loadResp map[string]any - if errLoad := callGeminiCLI(ctx, httpClient, "loadCodeAssist", loadReqBody, &loadResp); errLoad != nil { - return fmt.Errorf("load code assist: %w", errLoad) - } - - tierID := "legacy-tier" - if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers { - for _, rawTier := range tiers { - tier, okTier := rawTier.(map[string]any) - if !okTier { - continue - } - if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault { - if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" { - tierID = strings.TrimSpace(id) - break - } - } - } - } - - projectID := trimmedRequest - if projectID == "" { - if id, okProject := loadResp["cloudaicompanionProject"].(string); okProject { - projectID = strings.TrimSpace(id) - } - if projectID == "" { - if projectMap, okProject := loadResp["cloudaicompanionProject"].(map[string]any); okProject { - if id, okID := projectMap["id"].(string); okID { - projectID = strings.TrimSpace(id) - } - } - } - } - if projectID == "" { - // Auto-discovery: try onboardUser without specifying a project - // to let Google auto-provision one (matches Gemini CLI headless behavior - // and Antigravity's FetchProjectID pattern). - autoOnboardReq := map[string]any{ - "tierId": tierID, - "metadata": metadata, - } - - autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second) - defer autoCancel() - for attempt := 1; ; attempt++ { - var onboardResp map[string]any - if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil { - return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard) - } - - if done, okDone := onboardResp["done"].(bool); okDone && done { - if resp, okResp := onboardResp["response"].(map[string]any); okResp { - switch v := resp["cloudaicompanionProject"].(type) { - case string: - projectID = strings.TrimSpace(v) - case map[string]any: - if id, okID := v["id"].(string); okID { - projectID = strings.TrimSpace(id) - } - } - } - break - } - - log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt) - select { - case <-autoCtx.Done(): - return &projectSelectionRequiredError{} - case <-time.After(2 * time.Second): - } - } - - if projectID == "" { - return &projectSelectionRequiredError{} - } - log.Infof("Auto-discovered project ID via onboarding: %s", projectID) - } - - onboardReqBody := map[string]any{ - "tierId": tierID, - "metadata": metadata, - "cloudaicompanionProject": projectID, - } - - // Store the requested project as a fallback in case the response omits it. - storage.ProjectID = projectID - - for { - var onboardResp map[string]any - if errOnboard := callGeminiCLI(ctx, httpClient, "onboardUser", onboardReqBody, &onboardResp); errOnboard != nil { - return fmt.Errorf("onboard user: %w", errOnboard) - } - - if done, okDone := onboardResp["done"].(bool); okDone && done { - responseProjectID := "" - if resp, okResp := onboardResp["response"].(map[string]any); okResp { - switch projectValue := resp["cloudaicompanionProject"].(type) { - case map[string]any: - if id, okID := projectValue["id"].(string); okID { - responseProjectID = strings.TrimSpace(id) - } - case string: - responseProjectID = strings.TrimSpace(projectValue) - } - } - - finalProjectID := projectID - if responseProjectID != "" { - if explicitProject && !strings.EqualFold(responseProjectID, projectID) { - log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID) - log.Infof("Using backend project ID: %s", responseProjectID) - } - finalProjectID = responseProjectID - } - - storage.ProjectID = strings.TrimSpace(finalProjectID) - if storage.ProjectID == "" { - storage.ProjectID = strings.TrimSpace(projectID) - } - if storage.ProjectID == "" { - return fmt.Errorf("onboard user completed without project id") - } - log.Infof("Onboarding complete. Using Project ID: %s", storage.ProjectID) - return nil - } - - log.Println("Onboarding in progress, waiting 5 seconds...") - time.Sleep(5 * time.Second) - } -} - -func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string, body any, result any) error { - url := fmt.Sprintf("%s/%s:%s", geminiCLIEndpoint, geminiCLIVersion, endpoint) - if strings.HasPrefix(endpoint, "operations/") { - url = fmt.Sprintf("%s/%s", geminiCLIEndpoint, endpoint) - } - - var reader io.Reader - if body != nil { - rawBody, errMarshal := json.Marshal(body) - if errMarshal != nil { - return fmt.Errorf("marshal request body: %w", errMarshal) - } - reader = bytes.NewReader(rawBody) - } - - req, errRequest := http.NewRequestWithContext(ctx, http.MethodPost, url, reader) - if errRequest != nil { - return fmt.Errorf("create request: %w", errRequest) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", misc.GeminiCLIUserAgent("")) - - resp, errDo := httpClient.Do(req) - if errDo != nil { - return fmt.Errorf("execute request: %w", errDo) - } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("response body close error: %v", errClose) - } - }() - - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - bodyBytes, _ := io.ReadAll(resp.Body) - return fmt.Errorf("api request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes))) - } - - if result == nil { - _, _ = io.Copy(io.Discard, resp.Body) - return nil - } - - if errDecode := json.NewDecoder(resp.Body).Decode(result); errDecode != nil { - return fmt.Errorf("decode response body: %w", errDecode) - } - - return nil -} - -func fetchGCPProjects(ctx context.Context, httpClient *http.Client) ([]interfaces.GCPProjectProjects, error) { - req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil) - if errRequest != nil { - return nil, fmt.Errorf("could not create project list request: %w", errRequest) - } - - resp, errDo := httpClient.Do(req) - if errDo != nil { - return nil, fmt.Errorf("failed to execute project list request: %w", errDo) - } - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("response body close error: %v", errClose) - } - }() - - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { - bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes))) - } - - var projects interfaces.GCPProject - if errDecode := json.NewDecoder(resp.Body).Decode(&projects); errDecode != nil { - return nil, fmt.Errorf("failed to unmarshal project list: %w", errDecode) - } - - return projects.Projects, nil -} - -// promptForProjectSelection prints available projects and returns the chosen project ID. -func promptForProjectSelection(projects []interfaces.GCPProjectProjects, presetID string, promptFn func(string) (string, error)) string { - trimmedPreset := strings.TrimSpace(presetID) - if len(projects) == 0 { - if trimmedPreset != "" { - return trimmedPreset - } - fmt.Println("No Google Cloud projects are available for selection.") - return "" - } - - fmt.Println("Available Google Cloud projects:") - defaultIndex := 0 - for idx, project := range projects { - fmt.Printf("[%d] %s (%s)\n", idx+1, project.ProjectID, project.Name) - if trimmedPreset != "" && project.ProjectID == trimmedPreset { - defaultIndex = idx - } - } - fmt.Println("Type 'ALL' to onboard every listed project.") - - defaultID := projects[defaultIndex].ProjectID - - if trimmedPreset != "" { - if strings.EqualFold(trimmedPreset, "ALL") { - return "ALL" - } - for _, project := range projects { - if project.ProjectID == trimmedPreset { - return trimmedPreset - } - } - log.Warnf("Provided project ID %s not found in available projects; please choose from the list.", trimmedPreset) - } - - for { - promptMsg := fmt.Sprintf("Enter project ID [%s] or ALL: ", defaultID) - answer, errPrompt := promptFn(promptMsg) - if errPrompt != nil { - log.Errorf("Project selection prompt failed: %v", errPrompt) - return defaultID - } - answer = strings.TrimSpace(answer) - if strings.EqualFold(answer, "ALL") { - return "ALL" - } - if answer == "" { - return defaultID - } - - for _, project := range projects { - if project.ProjectID == answer { - return project.ProjectID - } - } - - if idx, errAtoi := strconv.Atoi(answer); errAtoi == nil { - if idx >= 1 && idx <= len(projects) { - return projects[idx-1].ProjectID - } - } - - fmt.Println("Invalid selection, enter a project ID or a number from the list.") - } -} - -func resolveProjectSelections(selection string, projects []interfaces.GCPProjectProjects) ([]string, error) { - trimmed := strings.TrimSpace(selection) - if trimmed == "" { - return nil, nil - } - available := make(map[string]struct{}, len(projects)) - ordered := make([]string, 0, len(projects)) - for _, project := range projects { - id := strings.TrimSpace(project.ProjectID) - if id == "" { - continue - } - if _, exists := available[id]; exists { - continue - } - available[id] = struct{}{} - ordered = append(ordered, id) - } - if strings.EqualFold(trimmed, "ALL") { - if len(ordered) == 0 { - return nil, fmt.Errorf("no projects available for ALL selection") - } - return append([]string(nil), ordered...), nil - } - parts := strings.Split(trimmed, ",") - selections := make([]string, 0, len(parts)) - seen := make(map[string]struct{}, len(parts)) - for _, part := range parts { - id := strings.TrimSpace(part) - if id == "" { - continue - } - if _, dup := seen[id]; dup { - continue - } - if len(available) > 0 { - if _, ok := available[id]; !ok { - return nil, fmt.Errorf("project %s not found in available projects", id) - } - } - seen[id] = struct{}{} - selections = append(selections, id) - } - return selections, nil -} - -func defaultProjectPrompt() func(string) (string, error) { - reader := bufio.NewReader(os.Stdin) - return func(prompt string) (string, error) { - fmt.Print(prompt) - line, errRead := reader.ReadString('\n') - if errRead != nil { - if errors.Is(errRead, io.EOF) { - return strings.TrimSpace(line), nil - } - return "", errRead - } - return strings.TrimSpace(line), nil - } -} - -func showProjectSelectionHelp(email string, projects []interfaces.GCPProjectProjects) { - if email != "" { - log.Infof("Your account %s needs to specify a project ID.", email) - } else { - log.Info("You need to specify a project ID.") - } - - if len(projects) > 0 { - fmt.Println("========================================================================") - for _, p := range projects { - fmt.Printf("Project ID: %s\n", p.ProjectID) - fmt.Printf("Project Name: %s\n", p.Name) - fmt.Println("------------------------------------------------------------------------") - } - } else { - fmt.Println("No active projects were returned for this account.") - } - - fmt.Printf("Please run this command to login again with a specific project:\n\n%s --login --project_id \n", os.Args[0]) -} - -func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projectID string) (bool, error) { - serviceUsageURL := "https://serviceusage.googleapis.com" - requiredServices := []string{ - // "geminicloudassist.googleapis.com", // Gemini Cloud Assist API - "cloudaicompanion.googleapis.com", // Gemini for Google Cloud API - } - for _, service := range requiredServices { - checkUrl := fmt.Sprintf("%s/v1/projects/%s/services/%s", serviceUsageURL, projectID, service) - req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, checkUrl, nil) - if errRequest != nil { - return false, fmt.Errorf("failed to create request: %w", errRequest) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", misc.GeminiCLIUserAgent("")) - resp, errDo := httpClient.Do(req) - if errDo != nil { - return false, fmt.Errorf("failed to execute request: %w", errDo) - } - - if resp.StatusCode == http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - if gjson.GetBytes(bodyBytes, "state").String() == "ENABLED" { - _ = resp.Body.Close() - continue - } - } - _ = resp.Body.Close() - - enableUrl := fmt.Sprintf("%s/v1/projects/%s/services/%s:enable", serviceUsageURL, projectID, service) - req, errRequest = http.NewRequestWithContext(ctx, http.MethodPost, enableUrl, strings.NewReader("{}")) - if errRequest != nil { - return false, fmt.Errorf("failed to create request: %w", errRequest) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", misc.GeminiCLIUserAgent("")) - resp, errDo = httpClient.Do(req) - if errDo != nil { - return false, fmt.Errorf("failed to execute request: %w", errDo) - } - - bodyBytes, _ := io.ReadAll(resp.Body) - errMessage := string(bodyBytes) - errMessageResult := gjson.GetBytes(bodyBytes, "error.message") - if errMessageResult.Exists() { - errMessage = errMessageResult.String() - } - if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated { - _ = resp.Body.Close() - continue - } else if resp.StatusCode == http.StatusBadRequest { - _ = resp.Body.Close() - if strings.Contains(strings.ToLower(errMessage), "already enabled") { - continue - } - } - _ = resp.Body.Close() - return false, fmt.Errorf("project activation required: %s", errMessage) - } - return true, nil -} - -func updateAuthRecord(record *cliproxyauth.Auth, storage *gemini.GeminiTokenStorage) { - if record == nil || storage == nil { - return - } - - finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, true) - - if record.Metadata == nil { - record.Metadata = make(map[string]any) - } - record.Metadata["email"] = storage.Email - record.Metadata["project_id"] = storage.ProjectID - record.Metadata["auto"] = storage.Auto - record.Metadata["checked"] = storage.Checked - - record.ID = finalName - record.FileName = finalName - record.Storage = storage -} diff --git a/internal/cmd/login_prompt.go b/internal/cmd/login_prompt.go new file mode 100644 index 0000000000..156c836faf --- /dev/null +++ b/internal/cmd/login_prompt.go @@ -0,0 +1,24 @@ +package cmd + +import ( + "bufio" + "fmt" + "io" + "os" + "strings" +) + +func defaultProjectPrompt() func(string) (string, error) { + reader := bufio.NewReader(os.Stdin) + return func(prompt string) (string, error) { + fmt.Print(prompt) + line, errRead := reader.ReadString('\n') + if errRead != nil { + if errRead == io.EOF { + return strings.TrimSpace(line), nil + } + return "", errRead + } + return strings.TrimSpace(line), nil + } +} diff --git a/internal/config/clone_test.go b/internal/config/clone_test.go index 152a852b05..1ee33035f5 100644 --- a/internal/config/clone_test.go +++ b/internal/config/clone_test.go @@ -129,7 +129,7 @@ func sampleCloneRuntimeConfig() *Config { AntigravitySignatureBypassStrict: &bypassStrict, GeminiKey: []GeminiKey{{ APIKey: "gemini-key", - Models: []GeminiModel{{Name: "gemini-upstream", Alias: "gemini-client"}}, + Models: []GeminiModel{{Name: "gemini-upstream", Alias: "gemini-upstream-alias"}}, Headers: map[string]string{"X-Gemini": "one"}, ExcludedModels: []string{"gemini-hidden"}, }}, diff --git a/internal/config/config.go b/internal/config/config.go index 885cbdf2b0..ffb67e4275 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -155,7 +155,7 @@ type Config struct { // OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels. // These aliases affect both model listing and model routing for supported channels: - // gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi, xai. + // vertex, aistudio, antigravity, claude, codex, kimi, xai. // // NOTE: This does not apply to existing per-credential model alias features under: // gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, and vertex-api-key. @@ -456,6 +456,9 @@ type ClaudeKey struct { // ExcludedModels lists model IDs that should be excluded for this provider. ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` + // RebuildMidSystemMessage moves Claude messages with role "system" into the top-level system field. + RebuildMidSystemMessage bool `yaml:"rebuild-mid-system-message,omitempty" json:"rebuild-mid-system-message,omitempty"` + // DisableCooling disables auth/model cooldown scheduling for this credential when true. DisableCooling bool `yaml:"disable-cooling,omitempty" json:"disable-cooling,omitempty"` diff --git a/internal/config/sdk_config.go b/internal/config/sdk_config.go index 5c0bc4ab9c..995fd585c8 100644 --- a/internal/config/sdk_config.go +++ b/internal/config/sdk_config.go @@ -34,10 +34,6 @@ type SDKConfig struct { // Empty or invalid values use the default 3h. VideoResultAuthCacheTTL string `yaml:"video-result-auth-cache-ttl,omitempty" json:"video-result-auth-cache-ttl,omitempty"` - // EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled. - // Default is false for safety; when false, /v1internal:* requests are rejected. - EnableGeminiCLIEndpoint bool `yaml:"enable-gemini-cli-endpoint" json:"enable-gemini-cli-endpoint"` - // ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview") // to target prefixed credentials. When false, unprefixed model requests may use prefixed // credentials as well. diff --git a/internal/constant/constant.go b/internal/constant/constant.go index 58b388a138..6a977077e1 100644 --- a/internal/constant/constant.go +++ b/internal/constant/constant.go @@ -7,9 +7,6 @@ const ( // Gemini represents the Google Gemini provider identifier. Gemini = "gemini" - // GeminiCLI represents the Google Gemini CLI provider identifier. - GeminiCLI = "gemini-cli" - // Codex represents the OpenAI Codex provider identifier. Codex = "codex" diff --git a/internal/interfaces/client_models.go b/internal/interfaces/client_models.go index c6e4ff7802..e2d6da82a1 100644 --- a/internal/interfaces/client_models.go +++ b/internal/interfaces/client_models.go @@ -3,46 +3,6 @@ // such as AI service clients, API handlers, and data models. package interfaces -import ( - "time" -) - -// GCPProject represents the response structure for a Google Cloud project list request. -// This structure is used when fetching available projects for a Google Cloud account. -type GCPProject struct { - // Projects is a list of Google Cloud projects accessible by the user. - Projects []GCPProjectProjects `json:"projects"` -} - -// GCPProjectLabels defines the labels associated with a GCP project. -// These labels can contain metadata about the project's purpose or configuration. -type GCPProjectLabels struct { - // GenerativeLanguage indicates if the project has generative language APIs enabled. - GenerativeLanguage string `json:"generative-language"` -} - -// GCPProjectProjects contains details about a single Google Cloud project. -// This includes identifying information, metadata, and configuration details. -type GCPProjectProjects struct { - // ProjectNumber is the unique numeric identifier for the project. - ProjectNumber string `json:"projectNumber"` - - // ProjectID is the unique string identifier for the project. - ProjectID string `json:"projectId"` - - // LifecycleState indicates the current state of the project (e.g., "ACTIVE"). - LifecycleState string `json:"lifecycleState"` - - // Name is the human-readable name of the project. - Name string `json:"name"` - - // Labels contains metadata labels associated with the project. - Labels GCPProjectLabels `json:"labels"` - - // CreateTime is the timestamp when the project was created. - CreateTime time.Time `json:"createTime"` -} - // Content represents a single message in a conversation, with a role and parts. // This structure models a message exchange between a user and an AI model. type Content struct { diff --git a/internal/misc/header_utils.go b/internal/misc/header_utils.go index ac022a9627..0c3abbf4b3 100644 --- a/internal/misc/header_utils.go +++ b/internal/misc/header_utils.go @@ -4,51 +4,10 @@ package misc import ( - "fmt" "net/http" - "runtime" "strings" ) -const ( - // GeminiCLIVersion is the version string reported in the User-Agent for upstream requests. - GeminiCLIVersion = "0.34.0" - - // GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream. - GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0" -) - -// geminiCLIOS maps Go runtime OS names to the Node.js-style platform strings used by Gemini CLI. -func geminiCLIOS() string { - switch runtime.GOOS { - case "windows": - return "win32" - default: - return runtime.GOOS - } -} - -// geminiCLIArch maps Go runtime architecture names to the Node.js-style arch strings used by Gemini CLI. -func geminiCLIArch() string { - switch runtime.GOARCH { - case "amd64": - return "x64" - case "386": - return "x86" - default: - return runtime.GOARCH - } -} - -// GeminiCLIUserAgent returns a User-Agent string that matches the Gemini CLI format. -// The model parameter is included in the UA; pass "" or "unknown" when the model is not applicable. -func GeminiCLIUserAgent(model string) string { - if model == "" { - model = "unknown" - } - return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s; terminal)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch()) -} - // ScrubProxyAndFingerprintHeaders removes all headers that could reveal // proxy infrastructure, client identity, or browser fingerprints from an // outgoing request. This ensures requests to upstream services look like they diff --git a/internal/pluginhost/auth_callbacks.go b/internal/pluginhost/auth_callbacks.go index f05329402b..3573999af5 100644 --- a/internal/pluginhost/auth_callbacks.go +++ b/internal/pluginhost/auth_callbacks.go @@ -556,9 +556,6 @@ func authProjectID(auth *coreauth.Auth) string { if projectID := strings.TrimSpace(auth.Attributes["project_id"]); projectID != "" { return projectID } - if projectID := strings.TrimSpace(auth.Attributes["gemini_virtual_project"]); projectID != "" { - return projectID - } } return "" } diff --git a/internal/pluginhost/auth_callbacks_test.go b/internal/pluginhost/auth_callbacks_test.go index c9c079449e..2a1b325eb6 100644 --- a/internal/pluginhost/auth_callbacks_test.go +++ b/internal/pluginhost/auth_callbacks_test.go @@ -34,15 +34,15 @@ func (s *memoryAuthStorage) SaveTokenToFile(authFilePath string) error { func TestHostAuthListCallbackUsesAuthManager(t *testing.T) { authDir := t.TempDir() - path := filepath.Join(authDir, "gemini-a.json") - if errWrite := os.WriteFile(path, []byte(`{"type":"gemini","email":"a@example.com","api_key":"k1"}`), 0o600); errWrite != nil { + path := filepath.Join(authDir, "demo-a.json") + if errWrite := os.WriteFile(path, []byte(`{"type":"demo","email":"a@example.com","api_key":"k1"}`), 0o600); errWrite != nil { t.Fatalf("write auth file: %v", errWrite) } auth := &coreauth.Auth{ - ID: "gemini-a.json", - Provider: "gemini", - FileName: "gemini-a.json", + ID: "demo-a.json", + Provider: "demo", + FileName: "demo-a.json", Label: "a@example.com", Status: coreauth.StatusActive, Attributes: map[string]string{ @@ -50,11 +50,11 @@ func TestHostAuthListCallbackUsesAuthManager(t *testing.T) { "source": path, }, Metadata: map[string]any{ - "type": "gemini", + "type": "demo", "email": "a@example.com", "api_key": "k1", }, - Storage: &memoryAuthStorage{payload: []byte(`{"type":"gemini","email":"a@example.com","api_key":"k1"}`)}, + Storage: &memoryAuthStorage{payload: []byte(`{"type":"demo","email":"a@example.com","api_key":"k1"}`)}, } auth.EnsureIndex() @@ -77,22 +77,22 @@ func TestHostAuthListCallbackUsesAuthManager(t *testing.T) { t.Fatalf("files = %#v, want one entry", resp.Files) } entry := resp.Files[0] - if entry.AuthIndex != auth.Index || entry.Name != "gemini-a.json" || entry.Email != "a@example.com" { + if entry.AuthIndex != auth.Index || entry.Name != "demo-a.json" || entry.Email != "a@example.com" { t.Fatalf("entry = %#v, want auth index and file metadata", entry) } } func TestHostAuthGetCallbackReturnsPhysicalJSONByAuthIndex(t *testing.T) { authDir := t.TempDir() - path := filepath.Join(authDir, "gemini-b.json") - if errWrite := os.WriteFile(path, []byte(`{"type":"gemini","email":"b@example.com","api_key":"k2"}`), 0o600); errWrite != nil { + path := filepath.Join(authDir, "demo-b.json") + if errWrite := os.WriteFile(path, []byte(`{"type":"demo","email":"b@example.com","api_key":"k2"}`), 0o600); errWrite != nil { t.Fatalf("write auth file: %v", errWrite) } auth := &coreauth.Auth{ - ID: "gemini-b.json", - Provider: "gemini", - FileName: "gemini-b.json", + ID: "demo-b.json", + Provider: "demo", + FileName: "demo-b.json", Label: "b@example.com", Status: coreauth.StatusActive, Attributes: map[string]string{ @@ -100,11 +100,11 @@ func TestHostAuthGetCallbackReturnsPhysicalJSONByAuthIndex(t *testing.T) { "source": path, }, Metadata: map[string]any{ - "type": "gemini", + "type": "demo", "email": "b@example.com", "api_key": "k2", }, - Storage: &memoryAuthStorage{payload: []byte(`{"type":"gemini","email":"b@example.com","api_key":"changed"}`)}, + Storage: &memoryAuthStorage{payload: []byte(`{"type":"demo","email":"b@example.com","api_key":"changed"}`)}, } auth.EnsureIndex() @@ -126,7 +126,7 @@ func TestHostAuthGetCallbackReturnsPhysicalJSONByAuthIndex(t *testing.T) { if errDecode != nil { t.Fatalf("decode response: %v", errDecode) } - if resp.AuthIndex != auth.Index || resp.Name != "gemini-b.json" { + if resp.AuthIndex != auth.Index || resp.Name != "demo-b.json" { t.Fatalf("response = %#v, want auth index and name", resp) } var decoded map[string]any @@ -171,20 +171,20 @@ func TestHostAuthListCallbackFallsBackToDisk(t *testing.T) { func TestHostAuthGetRuntimeCallbackReturnsRuntimeInfo(t *testing.T) { auth := &coreauth.Auth{ - ID: "gemini-runtime.json", - Provider: "gemini", - FileName: "gemini-runtime.json", + ID: "demo-runtime.json", + Provider: "demo", + FileName: "demo-runtime.json", Label: "runtime@example.com", Status: coreauth.StatusActive, Attributes: map[string]string{ "runtime_only": "true", }, Metadata: map[string]any{ - "type": "gemini", + "type": "demo", "email": "runtime@example.com", "api_key": "runtime-key", }, - Storage: &memoryAuthStorage{payload: []byte(`{"type":"gemini","email":"runtime@example.com","api_key":"runtime-key"}`)}, + Storage: &memoryAuthStorage{payload: []byte(`{"type":"demo","email":"runtime@example.com","api_key":"runtime-key"}`)}, } auth.EnsureIndex() @@ -219,7 +219,7 @@ func TestHostAuthSaveCallbackWritesPhysicalFile(t *testing.T) { req, errMarshal := json.Marshal(pluginapi.HostAuthSaveRequest{ Name: "saved.json", - JSON: json.RawMessage(`{"type":"gemini","email":"saved@example.com","api_key":"saved-key"}`), + JSON: json.RawMessage(`{"type":"demo","email":"saved@example.com","api_key":"saved-key"}`), }) if errMarshal != nil { t.Fatalf("marshal request: %v", errMarshal) @@ -239,7 +239,7 @@ func TestHostAuthSaveCallbackWritesPhysicalFile(t *testing.T) { if errRead != nil { t.Fatalf("read saved file: %v", errRead) } - if string(data) != `{"type":"gemini","email":"saved@example.com","api_key":"saved-key"}` { + if string(data) != `{"type":"demo","email":"saved@example.com","api_key":"saved-key"}` { t.Fatalf("saved file = %q, want credential json", string(data)) } auths := host.currentAuthManager().List() diff --git a/internal/pluginhost/auth_provider.go b/internal/pluginhost/auth_provider.go index 6439f690f4..32cf6e24ce 100644 --- a/internal/pluginhost/auth_provider.go +++ b/internal/pluginhost/auth_provider.go @@ -161,6 +161,14 @@ func (h *Host) callAuthProviderIdentifier(pluginID string, provider pluginapi.Au } func (h *Host) ParseAuth(ctx context.Context, req pluginapi.AuthParseRequest) (*coreauth.Auth, bool, error) { + auths, handled, errParseAuths := h.ParseAuths(ctx, req) + if errParseAuths != nil || !handled || len(auths) == 0 { + return nil, handled, errParseAuths + } + return auths[0], true, nil +} + +func (h *Host) ParseAuths(ctx context.Context, req pluginapi.AuthParseRequest) ([]*coreauth.Auth, bool, error) { if h == nil { return nil, false, nil } @@ -169,21 +177,29 @@ func (h *Host) ParseAuth(ctx context.Context, req pluginapi.AuthParseRequest) (* if record == nil { return nil, false, nil } - return h.callParseAuth(ctx, *record, req) + return h.callParseAuths(ctx, *record, req) } for _, record := range h.Snapshot().records { if record.plugin.Capabilities.AuthProvider == nil || h.isPluginFused(record.id) { continue } - auth, handled, errParse := h.callParseAuth(ctx, record, req) + auths, handled, errParse := h.callParseAuths(ctx, record, req) if errParse != nil || handled { - return auth, handled, errParse + return auths, handled, errParse } } return nil, false, nil } func (h *Host) callParseAuth(ctx context.Context, record capabilityRecord, req pluginapi.AuthParseRequest) (auth *coreauth.Auth, handled bool, err error) { + auths, handled, errParseAuths := h.callParseAuths(ctx, record, req) + if errParseAuths != nil || !handled || len(auths) == 0 { + return nil, handled, errParseAuths + } + return auths[0], true, nil +} + +func (h *Host) callParseAuths(ctx context.Context, record capabilityRecord, req pluginapi.AuthParseRequest) (auths []*coreauth.Auth, handled bool, err error) { provider := record.plugin.Capabilities.AuthProvider if h == nil || provider == nil || h.isPluginFused(record.id) { return nil, false, nil @@ -191,7 +207,7 @@ func (h *Host) callParseAuth(ctx context.Context, record capabilityRecord, req p defer func() { if recovered := recover(); recovered != nil { h.fusePlugin(record.id, "AuthProvider.ParseAuth", recovered) - auth = nil + auths = nil handled = false err = fmt.Errorf("auth provider panic: %v", recovered) } @@ -211,21 +227,32 @@ func (h *Host) callParseAuth(ctx context.Context, record capabilityRecord, req p if !resp.Handled { return nil, false, nil } - data := resp.Auth - if strings.TrimSpace(data.Provider) == "" { - data.Provider = req.Provider - } - if strings.TrimSpace(data.Provider) == "" { - data.Provider = normalizeProviderID(provider.Identifier()) - } - if normalizeProviderID(data.Provider) == "" { - return nil, true, fmt.Errorf("auth provider %s returned auth without provider", record.id) + datas := pluginAuthParseResponseAuths(resp) + auths = make([]*coreauth.Auth, 0, len(datas)) + for _, data := range datas { + if strings.TrimSpace(data.Provider) == "" { + data.Provider = req.Provider + } + if strings.TrimSpace(data.Provider) == "" { + data.Provider = normalizeProviderID(provider.Identifier()) + } + if normalizeProviderID(data.Provider) == "" { + return nil, true, fmt.Errorf("auth provider %s returned auth without provider", record.id) + } + parsed := h.AuthDataToCoreAuth(data, req.Path, req.FileName) + if parsed == nil { + return nil, true, fmt.Errorf("auth provider %s returned invalid auth data", record.id) + } + auths = append(auths, parsed) } - parsed := h.AuthDataToCoreAuth(data, req.Path, req.FileName) - if parsed == nil { - return nil, true, fmt.Errorf("auth provider %s returned invalid auth data", record.id) + return auths, true, nil +} + +func pluginAuthParseResponseAuths(resp pluginapi.AuthParseResponse) []pluginapi.AuthData { + if len(resp.Auths) > 0 { + return append([]pluginapi.AuthData(nil), resp.Auths...) } - return parsed, true, nil + return []pluginapi.AuthData{resp.Auth} } func (h *Host) StartLogin(ctx context.Context, provider string, baseURL string) (pluginapi.AuthLoginStartResponse, bool, error) { diff --git a/internal/pluginhost/auth_provider_test.go b/internal/pluginhost/auth_provider_test.go index 717d340b68..f6d01e9b2c 100644 --- a/internal/pluginhost/auth_provider_test.go +++ b/internal/pluginhost/auth_provider_test.go @@ -117,6 +117,51 @@ func TestParseAuthDefaultsProviderFromAuthProviderIdentifier(t *testing.T) { } } +func TestParseAuthsExpandsMultiplePluginAuths(t *testing.T) { + host := newHostWithRecords(capabilityRecord{ + id: "geminicli", + plugin: pluginapi.Plugin{ + Capabilities: pluginapi.Capabilities{ + AuthProvider: fakeAuthProvider{ + identifier: "gemini-cli", + parseAuth: func(ctx context.Context, req pluginapi.AuthParseRequest) (pluginapi.AuthParseResponse, error) { + return pluginapi.AuthParseResponse{ + Handled: true, + Auths: []pluginapi.AuthData{ + { + Provider: "gemini-cli", + ID: "user.json", + FileName: "user.json", + StorageJSON: []byte(`{"type":"gemini-cli"}`), + }, + { + Provider: "gemini-cli", + ID: "user-project-a.json", + FileName: "user-project-a.json", + StorageJSON: []byte(`{"type":"gemini-cli","project_id":"project-a"}`), + Metadata: map[string]any{"project_id": "project-a"}, + }, + }, + }, nil + }, + }, + }, + }, + }) + host.runtimeConfig = &config.Config{AuthDir: t.TempDir()} + + auths, handled, errParse := host.ParseAuths(context.Background(), pluginapi.AuthParseRequest{Provider: "gemini-cli"}) + if errParse != nil { + t.Fatalf("ParseAuths() error = %v", errParse) + } + if !handled || len(auths) != 2 { + t.Fatalf("ParseAuths() handled=%t len=%d, want two auths", handled, len(auths)) + } + if auths[1].Provider != "gemini-cli" || auths[1].Metadata["project_id"] != "project-a" { + t.Fatalf("second auth = %#v, want project-a virtual auth", auths[1]) + } +} + func TestStartLoginPassesProviderBaseURLHostAndHTTPClient(t *testing.T) { authDir := t.TempDir() expiresAt := time.Now().Add(time.Minute).UTC() diff --git a/internal/pluginhost/loader_unix.go b/internal/pluginhost/loader_unix.go index 32261752e3..9cfb08c755 100644 --- a/internal/pluginhost/loader_unix.go +++ b/internal/pluginhost/loader_unix.go @@ -191,6 +191,9 @@ func (c *dynamicLibraryClient) Call(ctx context.Context, method string, request C.cliproxy_free_plugin_buffer(c.api.free_buffer, response.ptr, response.len) } if rc != 0 { + if isPluginErrorEnvelope(out) { + return out, nil + } return nil, fmt.Errorf("plugin call %s returned %d: %s", method, int(rc), string(out)) } return out, nil diff --git a/internal/pluginhost/loader_windows.go b/internal/pluginhost/loader_windows.go index 317860e793..7bdc12dd62 100644 --- a/internal/pluginhost/loader_windows.go +++ b/internal/pluginhost/loader_windows.go @@ -128,6 +128,9 @@ func (c *dynamicLibraryClient) Call(ctx context.Context, method string, request _, _, _ = syscall.SyscallN(c.api.freeBuffer, response.ptr, response.len) } if rc != 0 { + if isPluginErrorEnvelope(out) { + return out, nil + } return nil, fmt.Errorf("plugin call %s returned %d: %s", method, rc, string(out)) } return out, nil diff --git a/internal/pluginhost/rpc_client.go b/internal/pluginhost/rpc_client.go index c4b29d0287..10f767a5a8 100644 --- a/internal/pluginhost/rpc_client.go +++ b/internal/pluginhost/rpc_client.go @@ -35,6 +35,19 @@ type rpcThinkingApplier struct { *rpcPluginAdapter } +type rpcPluginError struct { + message string + statusCode int +} + +func (e rpcPluginError) Error() string { + return e.message +} + +func (e rpcPluginError) StatusCode() int { + return e.statusCode +} + type rpcResponseNormalizer struct { *rpcPluginAdapter method string @@ -140,6 +153,9 @@ func callPlugin[T any](ctx context.Context, client pluginClient, method string, } out, errDecode := decodeEnvelopeResult[T](envelope) if errDecode != nil { + if !envelope.OK { + return zero, errDecode + } return zero, fmt.Errorf("decode plugin result %s: %w", method, errDecode) } return out, nil @@ -260,11 +276,26 @@ func decodeRPCEnvelope[T any](raw []byte) (T, error) { return decodeEnvelopeResult[T](envelope) } +func isPluginErrorEnvelope(raw []byte) bool { + var envelope pluginabi.Envelope + if errUnmarshal := json.Unmarshal(raw, &envelope); errUnmarshal != nil { + return false + } + return !envelope.OK && envelope.Error != nil +} + func decodeEnvelopeResult[T any](envelope pluginabi.Envelope) (T, error) { var zero T if !envelope.OK { if envelope.Error != nil { - return zero, fmt.Errorf("%s", envelope.Error.Message) + message := strings.TrimSpace(envelope.Error.Message) + if message == "" { + message = "plugin call failed" + } + if envelope.Error.HTTPStatus > 0 { + return zero, rpcPluginError{message: message, statusCode: envelope.Error.HTTPStatus} + } + return zero, fmt.Errorf("%s", message) } return zero, fmt.Errorf("plugin call failed") } diff --git a/internal/pluginhost/rpc_client_error_test.go b/internal/pluginhost/rpc_client_error_test.go new file mode 100644 index 0000000000..a74e6bb7a0 --- /dev/null +++ b/internal/pluginhost/rpc_client_error_test.go @@ -0,0 +1,82 @@ +package pluginhost + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginabi" +) + +type staticEnvelopePluginClient struct { + raw []byte +} + +func (c staticEnvelopePluginClient) Call(context.Context, string, []byte) ([]byte, error) { + return c.raw, nil +} + +func (c staticEnvelopePluginClient) Shutdown() {} + +func TestDecodeEnvelopeResultPreservesPluginHTTPStatus(t *testing.T) { + _, errDecode := decodeEnvelopeResult[rpcEmptyResponse](pluginabi.Envelope{ + OK: false, + Error: &pluginabi.Error{ + Code: "plugin_error", + Message: "license required", + HTTPStatus: http.StatusForbidden, + }, + }) + if errDecode == nil { + t.Fatal("decodeEnvelopeResult returned nil error") + } + if got := errDecode.Error(); got != "license required" { + t.Fatalf("error = %q, want license required", got) + } + statusProvider, ok := errDecode.(interface{ StatusCode() int }) + if !ok { + t.Fatalf("error %T does not expose StatusCode", errDecode) + } + if got := statusProvider.StatusCode(); got != http.StatusForbidden { + t.Fatalf("status = %d, want %d", got, http.StatusForbidden) + } +} + +func TestCallPluginReturnsPluginErrorWithoutMethodWrapper(t *testing.T) { + raw, errMarshal := json.Marshal(pluginabi.Envelope{ + OK: false, + Error: &pluginabi.Error{ + Code: "plugin_error", + Message: "license required", + HTTPStatus: http.StatusForbidden, + }, + }) + if errMarshal != nil { + t.Fatalf("marshal envelope: %v", errMarshal) + } + _, errCall := callPlugin[rpcEmptyResponse](context.Background(), staticEnvelopePluginClient{raw: raw}, pluginabi.MethodExecutorExecuteStream, rpcEmptyResponse{}) + if errCall == nil { + t.Fatal("callPlugin returned nil error") + } + if got := errCall.Error(); got != "license required" { + t.Fatalf("error = %q, want license required", got) + } + statusProvider, ok := errCall.(interface{ StatusCode() int }) + if !ok { + t.Fatalf("error %T does not expose StatusCode", errCall) + } + if got := statusProvider.StatusCode(); got != http.StatusForbidden { + t.Fatalf("status = %d, want %d", got, http.StatusForbidden) + } +} + +func TestIsPluginErrorEnvelopeAcceptsNonzeroReturnEnvelope(t *testing.T) { + raw := marshalRPCError("plugin_error", "upstream failed") + if !isPluginErrorEnvelope(raw) { + t.Fatalf("isPluginErrorEnvelope(%s) = false, want true", raw) + } + if isPluginErrorEnvelope([]byte(`not json`)) { + t.Fatal("isPluginErrorEnvelope accepted invalid JSON") + } +} diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 3f7263e442..ef1ce79524 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -20,7 +20,6 @@ type staticModelsJSON struct { Claude []*ModelInfo `json:"claude"` Gemini []*ModelInfo `json:"gemini"` Vertex []*ModelInfo `json:"vertex"` - GeminiCLI []*ModelInfo `json:"gemini-cli"` AIStudio []*ModelInfo `json:"aistudio"` CodexFree []*ModelInfo `json:"codex-free"` CodexTeam []*ModelInfo `json:"codex-team"` @@ -46,11 +45,6 @@ func GetGeminiVertexModels() []*ModelInfo { return cloneModelInfos(getModels().Vertex) } -// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI. -func GetGeminiCLIModels() []*ModelInfo { - return cloneModelInfos(getModels().GeminiCLI) -} - // GetAIStudioModels returns model definitions for AI Studio. func GetAIStudioModels() []*ModelInfo { return cloneModelInfos(getModels().AIStudio) @@ -278,7 +272,6 @@ func cloneModelInfos(models []*ModelInfo) []*ModelInfo { // - claude // - gemini // - vertex -// - gemini-cli // - aistudio // - codex // - kimi @@ -293,8 +286,6 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { return GetGeminiModels() case "vertex": return GetGeminiVertexModels() - case "gemini-cli": - return GetGeminiCLIModels() case "aistudio": return GetAIStudioModels() case "codex": @@ -322,7 +313,6 @@ func LookupStaticModelInfo(modelID string) *ModelInfo { data.Claude, data.Gemini, data.Vertex, - data.GeminiCLI, data.AIStudio, data.CodexPro, data.Kimi, diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index 3fab95e38f..0b8e8415d8 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -18,6 +18,11 @@ import ( // OpenAIImageModelType marks models that are callable through OpenAI-compatible image endpoints. const OpenAIImageModelType = "openai-image" +const ( + DefaultClaudeMaxInputTokens = 200000 + DefaultClaudeMaxOutputTokens = 64000 +) + // ModelInfo represents information about an available model type ModelInfo struct { // ID is the unique identifier for the model @@ -1156,14 +1161,24 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) "owned_by": model.OwnedBy, } if model.Created > 0 { - result["created_at"] = model.Created - } - if model.Type != "" { - result["type"] = "model" + result["created_at"] = time.Unix(model.Created, 0).UTC().Format(time.RFC3339) } + result["type"] = "model" if model.DisplayName != "" { result["display_name"] = model.DisplayName + } else { + result["display_name"] = model.ID + } + maxInput := model.ContextLength + if maxInput <= 0 { + maxInput = DefaultClaudeMaxInputTokens + } + maxOutput := model.MaxCompletionTokens + if maxOutput <= 0 { + maxOutput = DefaultClaudeMaxOutputTokens } + result["max_input_tokens"] = maxInput + result["max_tokens"] = maxOutput return result case "gemini": diff --git a/internal/registry/model_registry_cache_test.go b/internal/registry/model_registry_cache_test.go index 4653167bee..fb49e1f4ac 100644 --- a/internal/registry/model_registry_cache_test.go +++ b/internal/registry/model_registry_cache_test.go @@ -22,6 +22,52 @@ func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) { } } +func TestGetAvailableModelsClaudeIncludesTokenLimits(t *testing.T) { + r := newTestModelRegistry() + r.RegisterClient("client-1", "Claude", []*ModelInfo{ + {ID: "claude-sonnet-4-6", OwnedBy: "anthropic", Type: "claude", Created: 1771372800, ContextLength: 200000, MaxCompletionTokens: 64000}, + {ID: "claude-no-limits", OwnedBy: "anthropic", Type: "claude"}, + }) + + models := r.GetAvailableModels("claude") + byID := make(map[string]map[string]any, len(models)) + for _, m := range models { + id, _ := m["id"].(string) + byID[id] = m + } + + withLimits, ok := byID["claude-sonnet-4-6"] + if !ok { + t.Fatalf("expected claude-sonnet-4-6 in available models, got %v", byID) + } + if got := withLimits["max_input_tokens"]; got != 200000 { + t.Fatalf("expected max_input_tokens 200000, got %v", got) + } + if got := withLimits["max_tokens"]; got != 64000 { + t.Fatalf("expected max_tokens 64000, got %v", got) + } + if got := withLimits["created_at"]; got != "2026-02-18T00:00:00Z" { + t.Fatalf("expected created_at as RFC 3339 string, got %v", got) + } + + withDefaults, ok := byID["claude-no-limits"] + if !ok { + t.Fatalf("expected claude-no-limits in available models, got %v", byID) + } + if got := withDefaults["max_input_tokens"]; got != DefaultClaudeMaxInputTokens { + t.Fatalf("expected fallback max_input_tokens %d, got %v", DefaultClaudeMaxInputTokens, got) + } + if got := withDefaults["max_tokens"]; got != DefaultClaudeMaxOutputTokens { + t.Fatalf("expected fallback max_tokens %d, got %v", DefaultClaudeMaxOutputTokens, got) + } + if got := withDefaults["display_name"]; got != "claude-no-limits" { + t.Fatalf("expected display_name to fall back to id, got %v", got) + } + if got := withDefaults["type"]; got != "model" { + t.Fatalf("expected type to default to model, got %v", got) + } +} + func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) { r := newTestModelRegistry() r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}}) diff --git a/internal/registry/model_updater.go b/internal/registry/model_updater.go index 40033801d0..4c398fb149 100644 --- a/internal/registry/model_updater.go +++ b/internal/registry/model_updater.go @@ -207,7 +207,6 @@ func detectChangedProviders(oldData, newData *staticModelsJSON) []string { {"claude", oldData.Claude, newData.Claude}, {"gemini", oldData.Gemini, newData.Gemini}, {"vertex", oldData.Vertex, newData.Vertex}, - {"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI}, {"aistudio", oldData.AIStudio, newData.AIStudio}, {"codex", oldData.CodexFree, newData.CodexFree}, {"codex", oldData.CodexTeam, newData.CodexTeam}, @@ -328,7 +327,6 @@ func validateModelsCatalog(data *staticModelsJSON) error { {name: "claude", models: data.Claude}, {name: "gemini", models: data.Gemini}, {name: "vertex", models: data.Vertex}, - {name: "gemini-cli", models: data.GeminiCLI}, {name: "aistudio", models: data.AIStudio}, {name: "codex-free", models: data.CodexFree}, {name: "codex-team", models: data.CodexTeam}, diff --git a/internal/registry/models/models.json b/internal/registry/models/models.json index 944eff80f5..c2b80fd877 100644 --- a/internal/registry/models/models.json +++ b/internal/registry/models/models.json @@ -46,7 +46,8 @@ "levels": [ "low", "medium", - "high" + "high", + "max" ] } }, @@ -870,197 +871,6 @@ } } ], - "gemini-cli": [ - { - "id": "gemini-2.5-pro", - "object": "model", - "created": 1750118400, - "owned_by": "google", - "type": "gemini", - "display_name": "Gemini 2.5 Pro", - "name": "models/gemini-2.5-pro", - "version": "2.5", - "description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro", - "inputTokenLimit": 1048576, - "outputTokenLimit": 65536, - "supportedGenerationMethods": [ - "generateContent", - "countTokens", - "createCachedContent", - "batchGenerateContent" - ], - "thinking": { - "min": 128, - "max": 32768, - "dynamic_allowed": true - } - }, - { - "id": "gemini-2.5-flash", - "object": "model", - "created": 1750118400, - "owned_by": "google", - "type": "gemini", - "display_name": "Gemini 2.5 Flash", - "name": "models/gemini-2.5-flash", - "version": "001", - "description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", - "inputTokenLimit": 1048576, - "outputTokenLimit": 65536, - "supportedGenerationMethods": [ - "generateContent", - "countTokens", - "createCachedContent", - "batchGenerateContent" - ], - "thinking": { - "max": 24576, - "zero_allowed": true, - "dynamic_allowed": true - } - }, - { - "id": "gemini-2.5-flash-lite", - "object": "model", - "created": 1753142400, - "owned_by": "google", - "type": "gemini", - "display_name": "Gemini 2.5 Flash Lite", - "name": "models/gemini-2.5-flash-lite", - "version": "2.5", - "description": "Our smallest and most cost effective model, built for at scale usage.", - "inputTokenLimit": 1048576, - "outputTokenLimit": 65536, - "supportedGenerationMethods": [ - "generateContent", - "countTokens", - "createCachedContent", - "batchGenerateContent" - ], - "thinking": { - "max": 24576, - "zero_allowed": true, - "dynamic_allowed": true - } - }, - { - "id": "gemini-3-pro-preview", - "object": "model", - "created": 1737158400, - "owned_by": "google", - "type": "gemini", - "display_name": "Gemini 3 Pro Preview", - "name": "models/gemini-3-pro-preview", - "version": "3.0", - "description": "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities", - "inputTokenLimit": 1048576, - "outputTokenLimit": 65536, - "supportedGenerationMethods": [ - "generateContent", - "countTokens", - "createCachedContent", - "batchGenerateContent" - ], - "thinking": { - "min": 128, - "max": 32768, - "dynamic_allowed": true, - "levels": [ - "low", - "high" - ] - } - }, - { - "id": "gemini-3.1-pro-preview", - "object": "model", - "created": 1771459200, - "owned_by": "google", - "type": "gemini", - "display_name": "Gemini 3.1 Pro Preview", - "name": "models/gemini-3.1-pro-preview", - "version": "3.1", - "description": "Gemini 3.1 Pro Preview", - "inputTokenLimit": 1048576, - "outputTokenLimit": 65536, - "supportedGenerationMethods": [ - "generateContent", - "countTokens", - "createCachedContent", - "batchGenerateContent" - ], - "thinking": { - "min": 128, - "max": 32768, - "dynamic_allowed": true, - "levels": [ - "low", - "medium", - "high" - ] - } - }, - { - "id": "gemini-3-flash-preview", - "object": "model", - "created": 1765929600, - "owned_by": "google", - "type": "gemini", - "display_name": "Gemini 3 Flash Preview", - "name": "models/gemini-3-flash-preview", - "version": "3.0", - "description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.", - "inputTokenLimit": 1048576, - "outputTokenLimit": 65536, - "supportedGenerationMethods": [ - "generateContent", - "countTokens", - "createCachedContent", - "batchGenerateContent" - ], - "thinking": { - "min": 128, - "max": 32768, - "dynamic_allowed": true, - "levels": [ - "minimal", - "low", - "medium", - "high" - ] - } - }, - { - "id": "gemini-3.1-flash-lite-preview", - "object": "model", - "created": 1776288000, - "owned_by": "google", - "type": "gemini", - "display_name": "Gemini 3.1 Flash Lite Preview", - "name": "models/gemini-3.1-flash-lite-preview", - "version": "3.1", - "description": "Our smallest and most cost effective model, built for at scale usage.", - "inputTokenLimit": 1048576, - "outputTokenLimit": 65536, - "supportedGenerationMethods": [ - "generateContent", - "countTokens", - "createCachedContent", - "batchGenerateContent" - ], - "thinking": { - "min": 128, - "max": 32768, - "dynamic_allowed": true, - "levels": [ - "minimal", - "low", - "medium", - "high" - ] - } - } - ], "aistudio": [ { "id": "gemini-2.5-pro", @@ -1918,46 +1728,6 @@ ] } }, - { - "id": "gemini-3-pro-high", - "object": "model", - "owned_by": "antigravity", - "type": "antigravity", - "display_name": "Gemini 3 Pro (High)", - "name": "gemini-3-pro-high", - "description": "Gemini 3 Pro (High)", - "context_length": 1048576, - "max_completion_tokens": 65535, - "thinking": { - "min": 128, - "max": 32768, - "dynamic_allowed": true, - "levels": [ - "low", - "high" - ] - } - }, - { - "id": "gemini-3-pro-low", - "object": "model", - "owned_by": "antigravity", - "type": "antigravity", - "display_name": "Gemini 3 Pro (Low)", - "name": "gemini-3-pro-low", - "description": "Gemini 3 Pro (Low)", - "context_length": 1048576, - "max_completion_tokens": 65535, - "thinking": { - "min": 128, - "max": 32768, - "dynamic_allowed": true, - "levels": [ - "low", - "high" - ] - } - }, { "id": "gemini-3.1-flash-image", "object": "model", @@ -2073,7 +1843,6 @@ ] } } - ], "xai": [ { @@ -2086,16 +1855,7 @@ "name": "grok-build-0.1", "description": "Grok Build 0.1 is xAI’s fast coding model trained specifically for agentic software engineering workflows.", "context_length": 256000, - "max_completion_tokens": 256000, - "thinking": { - "zero_allowed": true, - "levels": [ - "none", - "low", - "medium", - "high" - ] - } + "max_completion_tokens": 256000 }, { "id": "grok-4.3", @@ -2209,14 +1969,7 @@ "name": "grok-composer-2.5-fast", "description": "xAI Composer 2.5 Fast model for the Responses API.", "context_length": 200000, - "max_completion_tokens": 32768, - "thinking": { - "levels": [ - "low", - "medium", - "high" - ] - } + "max_completion_tokens": 32768 } ] -} +} \ No newline at end of file diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 6fd1146d29..a6973783ee 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -306,9 +306,6 @@ func validateAntigravityRequestSignatures(ctx context.Context, modelName string, rawJSON = antigravityclaude.StripEmptySignatureThinkingBlocks(rawJSON) logAntigravitySignatureStrip(before, countClaudeThinkingBlocks(rawJSON), "prefix_cleanup", "empty_or_non_claude_signature") if cache.SignatureCacheEnabled() { - if errRequire := antigravityclaude.RequireCachedThinkingSignatures(ctx, modelName, rawJSON); errRequire != nil { - return nil, homeKVUnavailableStatusErr(errRequire) - } return rawJSON, nil } if !cache.SignatureBypassStrictMode() { @@ -490,7 +487,7 @@ func decideAntigravity429(body []byte) antigravity429Decision { return decision } - if retryAfter, parseErr := parseRetryDelay(body); parseErr == nil && retryAfter != nil { + if retryAfter, parseErr := helps.ParseRetryDelay(body); parseErr == nil && retryAfter != nil { decision.retryAfter = retryAfter } @@ -607,7 +604,7 @@ func antigravityHasExplicitCreditsBalanceExhaustedReason(body []byte) bool { func newAntigravityStatusErr(statusCode int, body []byte) statusErr { err := statusErr{code: statusCode, msg: string(body)} if statusCode == http.StatusTooManyRequests { - if retryAfter, parseErr := parseRetryDelay(body); parseErr == nil && retryAfter != nil { + if retryAfter, parseErr := helps.ParseRetryDelay(body); parseErr == nil && retryAfter != nil { err.retryAfter = retryAfter } } @@ -691,6 +688,15 @@ attemptLoop: helps.MarkCreditsUsed(ctx) } } + replayScope := antigravityReasoningReplayScope{} + if antigravityUsesReasoningReplayCache(baseModel) { + var errReplay error + requestPayload, replayScope, errReplay = prepareAntigravityGeminiReasoningReplayPayload(ctx, baseModel, req, opts, requestPayload) + if errReplay != nil { + err = errReplay + return resp, err + } + } httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, false, opts.Alt, baseURL) if errReq != nil { @@ -798,6 +804,10 @@ attemptLoop: continue attemptLoop } } + if errClear := clearAntigravityReasoningReplayOnInvalidSignature(ctx, replayScope, httpResp.StatusCode, bodyBytes); errClear != nil { + err = errClear + return resp, err + } err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes) return resp, err } @@ -806,6 +816,7 @@ attemptLoop: if useCredits { clearAntigravityCreditsFailureState(auth) } + cacheAntigravityReasoningReplayFromResponse(ctx, replayScope, requestPayload, bodyBytes) bodyBytes = e.resolveWebSearchGroundingURLs(ctx, auth, from, originalPayload, translated, bodyBytes) reporter.Publish(ctx, helps.ParseAntigravityUsage(bodyBytes)) var param any @@ -1369,6 +1380,15 @@ attemptLoop: helps.MarkCreditsUsed(ctx) } } + replayScope := antigravityReasoningReplayScope{} + if antigravityUsesReasoningReplayCache(baseModel) { + var errReplay error + requestPayload, replayScope, errReplay = prepareAntigravityGeminiReasoningReplayPayload(ctx, baseModel, req, opts, requestPayload) + if errReplay != nil { + err = errReplay + return nil, err + } + } httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL) if errReq != nil { err = errReq @@ -1487,6 +1507,10 @@ attemptLoop: continue attemptLoop } } + if errClear := clearAntigravityReasoningReplayOnInvalidSignature(ctx, replayScope, httpResp.StatusCode, bodyBytes); errClear != nil { + err = errClear + return nil, err + } err = newAntigravityStatusErr(httpResp.StatusCode, bodyBytes) return nil, err } @@ -1495,12 +1519,16 @@ attemptLoop: if useCredits { clearAntigravityCreditsFailureState(auth) } + replayAccumulator := newAntigravityReasoningReplayAccumulator(replayScope, requestPayload) out := make(chan cliproxyexecutor.StreamChunk) go func(resp *http.Response) { defer close(out) defer func() { + if replayAccumulator != nil { + replayAccumulator.Flush(ctx) + } if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("antigravity executor: close response body error: %v", errClose) + log.Errorf("antigravity executor: close response line error: %v", errClose) } }() scanner := bufio.NewScanner(resp.Body) @@ -1509,6 +1537,9 @@ attemptLoop: for scanner.Scan() { line := scanner.Bytes() helps.AppendAPIResponseChunk(ctx, e.cfg, line) + if replayAccumulator != nil { + replayAccumulator.ObserveSSELine(line) + } // Filter usage metadata for all models // Only retain usage statistics in the terminal chunk @@ -1655,9 +1686,9 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut return cliproxyexecutor.Response{}, err } - payload = deleteJSONField(payload, "project") - payload = deleteJSONField(payload, "model") - payload = deleteJSONField(payload, "request.safetySettings") + payload = helps.DeleteJSONField(payload, "project") + payload = helps.DeleteJSONField(payload, "model") + payload = helps.DeleteJSONField(payload, "request.safetySettings") baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) @@ -1758,7 +1789,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut } sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)} if httpResp.StatusCode == http.StatusTooManyRequests { - if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil { + if retryAfter, parseErr := helps.ParseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil { sErr.retryAfter = retryAfter } } @@ -1769,7 +1800,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut case lastStatus != 0: sErr := statusErr{code: lastStatus, msg: string(lastBody)} if lastStatus == http.StatusTooManyRequests { - if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil { + if retryAfter, parseErr := helps.ParseRetryDelay(lastBody); parseErr == nil && retryAfter != nil { sErr.retryAfter = retryAfter } } @@ -1979,7 +2010,7 @@ func (e *AntigravityExecutor) refreshTokenSingleFlight(ctx context.Context, auth if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)} if httpResp.StatusCode == http.StatusTooManyRequests { - if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil { + if retryAfter, parseErr := helps.ParseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil { sErr.retryAfter = retryAfter } } @@ -2229,9 +2260,10 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens") } - bodyReader = strings.NewReader(payloadStr) + payloadStrBytes := applyAntigravityNativeSignatureReplayIfNeeded(modelName, []byte(payloadStr)) + bodyReader = bytes.NewReader(payloadStrBytes) if e.cfg != nil && e.cfg.RequestLog { - payloadLog = []byte(payloadStr) + payloadLog = append([]byte(nil), payloadStrBytes...) } } else { if strings.Contains(modelName, "claude") { @@ -2240,6 +2272,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens") } + payload = applyAntigravityNativeSignatureReplayIfNeeded(modelName, payload) bodyReader = bytes.NewReader(payload) if e.cfg != nil && e.cfg.RequestLog { payloadLog = append([]byte(nil), payload...) diff --git a/internal/runtime/executor/antigravity_executor_buildrequest_test.go b/internal/runtime/executor/antigravity_executor_buildrequest_test.go index ff4f69f1aa..b5329d7894 100644 --- a/internal/runtime/executor/antigravity_executor_buildrequest_test.go +++ b/internal/runtime/executor/antigravity_executor_buildrequest_test.go @@ -300,17 +300,20 @@ func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any "parametersJsonSchema": { "$schema": "http://json-schema.org/draft-07/schema#", "$id": "root-schema", + "$comment": "root comment should be removed", "type": "object", "properties": { "$id": {"type": "string"}, "arg": { "type": "object", + "$comment": "nested comment should be removed", "prefill": "hello", "properties": { "mode": { "type": "string", "deprecated": true, "enum": ["a", "b"], + "enumDescriptions": ["Alpha", "Beta"], "enumTitles": ["A", "B"] } } @@ -389,6 +392,9 @@ func assertSchemaSanitizedAndPropertyPreserved(t *testing.T, params map[string]a if _, ok := params["$id"]; ok { t.Fatalf("root $id should be removed from schema") } + if _, ok := params["$comment"]; ok { + t.Fatalf("root $comment should be removed from schema") + } if _, ok := params["patternProperties"]; ok { t.Fatalf("patternProperties should be removed from schema") } @@ -408,6 +414,9 @@ func assertSchemaSanitizedAndPropertyPreserved(t *testing.T, params map[string]a if _, ok := arg["prefill"]; ok { t.Fatalf("prefill should be removed from nested schema") } + if _, ok := arg["$comment"]; ok { + t.Fatalf("nested $comment should be removed from schema") + } argProps, ok := arg["properties"].(map[string]any) if !ok { @@ -420,6 +429,9 @@ func assertSchemaSanitizedAndPropertyPreserved(t *testing.T, params map[string]a if _, ok := mode["enumTitles"]; ok { t.Fatalf("enumTitles should be removed from nested schema") } + if _, ok := mode["enumDescriptions"]; ok { + t.Fatalf("enumDescriptions should be removed from nested schema") + } if _, ok := mode["deprecated"]; ok { t.Fatalf("deprecated should be removed from nested schema") } diff --git a/internal/runtime/executor/antigravity_executor_credits_test.go b/internal/runtime/executor/antigravity_executor_credits_test.go index 507a57b356..74f84a5855 100644 --- a/internal/runtime/executor/antigravity_executor_credits_test.go +++ b/internal/runtime/executor/antigravity_executor_credits_test.go @@ -14,6 +14,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v7/internal/config" homekv "github.com/router-for-me/CLIProxyAPI/v7/internal/home" + "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" @@ -247,16 +248,16 @@ func TestInjectEnabledCreditTypes(t *testing.T) { func TestParseRetryDelay_HumanReadableDuration(t *testing.T) { body := []byte(`{"error":{"message":"You have exhausted your capacity on this model. Your quota will reset after 1h43m56s."}}`) - retryAfter, err := parseRetryDelay(body) + retryAfter, err := helps.ParseRetryDelay(body) if err != nil { - t.Fatalf("parseRetryDelay() error = %v", err) + t.Fatalf("helps.ParseRetryDelay() error = %v", err) } if retryAfter == nil { - t.Fatal("parseRetryDelay() returned nil") + t.Fatal("helps.ParseRetryDelay() returned nil") } want := time.Hour + 43*time.Minute + 56*time.Second if *retryAfter != want { - t.Fatalf("parseRetryDelay() = %v, want %v", *retryAfter, want) + t.Fatalf("helps.ParseRetryDelay() = %v, want %v", *retryAfter, want) } } diff --git a/internal/runtime/executor/antigravity_reasoning_replay.go b/internal/runtime/executor/antigravity_reasoning_replay.go new file mode 100644 index 0000000000..79d2ccc1be --- /dev/null +++ b/internal/runtime/executor/antigravity_reasoning_replay.go @@ -0,0 +1,607 @@ +package executor + +import ( + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "net/http" + "strings" + + internalcache "github.com/router-for-me/CLIProxyAPI/v7/internal/cache" + "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +type antigravityReasoningReplayScope struct { + modelName string + sessionKey string +} + +func (s antigravityReasoningReplayScope) valid() bool { + return strings.TrimSpace(s.modelName) != "" && strings.TrimSpace(s.sessionKey) != "" +} + +func antigravityReasoningReplayScopeFromPayload(modelName string, payload []byte) antigravityReasoningReplayScope { + sessionID := antigravityReplaySessionIDFromPayload(payload) + if sessionID == "" { + if stable := strings.TrimSpace(generateStableSessionID(payload)); stable != "" { + sessionID = strings.TrimPrefix(stable, "-") + if sessionID == "" { + sessionID = stable + } + } + } + if sessionID == "" { + return antigravityReasoningReplayScope{} + } + return antigravityReasoningReplayScope{ + modelName: strings.TrimSpace(modelName), + sessionKey: "session:" + sessionID, + } +} + +func antigravityReasoningReplayScopeFromRequest(ctx context.Context, modelName string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, payload []byte) antigravityReasoningReplayScope { + if scope := antigravityReasoningReplayScopeFromPayload(modelName, payload); scope.valid() { + return scope + } + if scope := antigravityReasoningReplayScopeFromPayload(modelName, req.Payload); scope.valid() { + return scope + } + if value := metadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" { + return antigravityReasoningReplayScope{modelName: modelName, sessionKey: "execution:" + value} + } + if value := metadataString(req.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" { + return antigravityReasoningReplayScope{modelName: modelName, sessionKey: "execution:" + value} + } + _ = ctx + return antigravityReasoningReplayScope{} +} + +func antigravityReplaySessionIDFromPayload(payload []byte) string { + if len(payload) == 0 { + return "" + } + for _, path := range []string{"sessionId", "session_id", "request.sessionId", "request.session_id"} { + if id := strings.TrimSpace(gjson.GetBytes(payload, path).String()); id != "" { + return id + } + } + return "" +} + +func antigravityReasoningReplayPendingModelContentIndex(payload []byte) (contentIndex int, basePartIndex int) { + contents := gjson.GetBytes(payload, "request.contents") + if !contents.IsArray() { + return 0, 0 + } + arr := contents.Array() + if len(arr) == 0 { + return 0, 0 + } + last := arr[len(arr)-1] + if strings.EqualFold(strings.TrimSpace(last.Get("role").String()), "model") { + ci := len(arr) - 1 + parts := last.Get("parts") + base := 0 + if parts.IsArray() { + base = len(parts.Array()) + } + return ci, base + } + return len(arr), 0 +} + +func antigravityReasoningReplayResolveContentIndex(payload []byte, cached int) int { + contents := gjson.GetBytes(payload, "request.contents") + if !contents.IsArray() { + return cached + } + arr := contents.Array() + if cached >= 0 && cached < len(arr) { + return cached + } + for i := len(arr) - 1; i >= 0; i-- { + if strings.EqualFold(strings.TrimSpace(arr[i].Get("role").String()), "model") { + return i + } + } + if len(arr) == 0 { + return 0 + } + return len(arr) - 1 +} + +func prepareAntigravityGeminiReasoningReplayPayload(ctx context.Context, modelName string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, payload []byte) ([]byte, antigravityReasoningReplayScope, error) { + if !antigravityUsesReasoningReplayCache(modelName) { + return payload, antigravityReasoningReplayScope{}, nil + } + return applyAntigravityReasoningReplayCache(ctx, modelName, req, opts, payload) +} + +func clearAntigravityReasoningReplayOnInvalidSignature(ctx context.Context, scope antigravityReasoningReplayScope, statusCode int, body []byte) error { + if !scope.valid() { + return nil + } + if statusCode != http.StatusBadRequest { + return nil + } + bodyText := strings.ToLower(string(body)) + if !strings.Contains(bodyText, "thoughtsignature") && !strings.Contains(bodyText, "thought_signature") && !strings.Contains(bodyText, "signature") { + return nil + } + return internalcache.DeleteAntigravityReasoningReplayItemRequired(ctx, scope.modelName, scope.sessionKey) +} + +func applyAntigravityReasoningReplayCache(ctx context.Context, modelName string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, payload []byte) ([]byte, antigravityReasoningReplayScope, error) { + scope := antigravityReasoningReplayScopeFromRequest(ctx, modelName, req, opts, payload) + if !scope.valid() { + return payload, scope, nil + } + items, ok, err := internalcache.GetAntigravityReasoningReplayItemsRequired(ctx, scope.modelName, scope.sessionKey) + if err != nil || !ok || len(items) == 0 { + return payload, scope, err + } + items = filterAntigravityReasoningReplayItemsForRequest(payload, items) + if len(items) == 0 { + return payload, scope, nil + } + updated, okApply := insertAntigravityReasoningReplayItems(payload, items) + if !okApply { + return payload, scope, nil + } + return updated, scope, nil +} + +func filterAntigravityReasoningReplayItemsForRequest(payload []byte, items [][]byte) [][]byte { + existing := antigravityExistingToolCallKeys(payload) + filtered := make([][]byte, 0, len(items)) + for _, item := range items { + itemResult := gjson.ParseBytes(item) + switch strings.TrimSpace(itemResult.Get("type").String()) { + case "function_call_part": + keys := antigravityReplayToolCallKeys(itemResult) + if len(keys) == 0 { + continue + } + if antigravityAnyKeyExists(existing, keys) { + if !antigravityNeedsSignatureReplayForExistingFunctionCall(payload, itemResult) { + continue + } + } + if !antigravityRequestHasMatchingFunctionResponse(payload, itemResult) { + continue + } + case "thought_signature": + if antigravityRequestHasThoughtSignatureAt(payload, itemResult) { + continue + } + default: + continue + } + filtered = append(filtered, item) + } + return filtered +} + +func antigravityExistingToolCallKeys(payload []byte) map[string]bool { + existing := make(map[string]bool) + contents := gjson.GetBytes(payload, "request.contents") + if !contents.IsArray() { + return existing + } + for _, content := range contents.Array() { + parts := content.Get("parts") + if !parts.IsArray() { + continue + } + for _, part := range parts.Array() { + if fc := part.Get("functionCall"); fc.Exists() { + for _, key := range antigravityReplayToolCallKeysFromPart(fc) { + existing[key] = true + } + } + } + } + return existing +} + +func antigravityReplayToolCallKeys(itemResult gjson.Result) []string { + callID := strings.TrimSpace(itemResult.Get("call_id").String()) + if callID == "" { + callID = strings.TrimSpace(itemResult.Get("id").String()) + } + name := strings.TrimSpace(itemResult.Get("name").String()) + if name == "" { + return nil + } + args := itemResult.Get("args").Raw + key := antigravityFunctionCallKey(name, args, callID) + if key == "" { + return nil + } + return []string{key} +} + +func antigravityReplayToolCallKeysFromPart(fc gjson.Result) []string { + return antigravityReplayToolCallKeys(gjson.Parse(fc.Raw)) +} + +func antigravityFunctionCallKey(name, argsRaw, callID string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + h := sha256.Sum256([]byte(strings.Join([]string{name, argsRaw, callID}, "\x00"))) + return fmt.Sprintf("fc:%x", h[:8]) +} + +func antigravityAnyKeyExists(existing map[string]bool, keys []string) bool { + for _, key := range keys { + if existing[key] { + return true + } + } + return false +} + +func antigravityNeedsSignatureReplayForExistingFunctionCall(payload []byte, itemResult gjson.Result) bool { + callID := strings.TrimSpace(itemResult.Get("call_id").String()) + if callID == "" { + callID = strings.TrimSpace(itemResult.Get("id").String()) + } + sig := strings.TrimSpace(itemResult.Get("thoughtSignature").String()) + if callID == "" || sig == "" { + return false + } + ci, pi, ok := antigravityFunctionCallPartLocation(payload, callID) + if !ok { + return false + } + pathSig := fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", ci, pi) + return strings.TrimSpace(gjson.GetBytes(payload, pathSig).String()) == "" +} + +func antigravityRequestHasMatchingFunctionResponse(payload []byte, itemResult gjson.Result) bool { + callID := strings.TrimSpace(itemResult.Get("call_id").String()) + if callID == "" { + return true + } + _, ok := antigravityFunctionResponseContentIndex(payload, callID) + return ok +} + +func antigravityFunctionResponseContentIndex(payload []byte, callID string) (int, bool) { + callID = strings.TrimSpace(callID) + if callID == "" { + return -1, false + } + contents := gjson.GetBytes(payload, "request.contents") + if !contents.IsArray() { + return -1, false + } + for i, content := range contents.Array() { + parts := content.Get("parts") + if !parts.IsArray() { + continue + } + for _, part := range parts.Array() { + fr := part.Get("functionResponse") + if fr.Exists() && strings.TrimSpace(fr.Get("id").String()) == callID { + return i, true + } + } + } + return -1, false +} + +func antigravityPayloadHasFunctionCallID(payload []byte, callID string) bool { + _, _, ok := antigravityFunctionCallPartLocation(payload, callID) + return ok +} + +func antigravityFunctionCallPartLocation(payload []byte, callID string) (contentIndex int, partIndex int, ok bool) { + callID = strings.TrimSpace(callID) + if callID == "" { + return -1, -1, false + } + contents := gjson.GetBytes(payload, "request.contents") + if !contents.IsArray() { + return -1, -1, false + } + for ci, content := range contents.Array() { + parts := content.Get("parts") + if !parts.IsArray() { + continue + } + for pi, part := range parts.Array() { + fc := part.Get("functionCall") + if fc.Exists() && strings.TrimSpace(fc.Get("id").String()) == callID { + return ci, pi, true + } + } + } + return -1, -1, false +} + +func insertAntigravityModelFunctionCallBeforeContent(payload []byte, beforeIndex int, name, callID, thoughtSig string, args gjson.Result) ([]byte, bool) { + contents := gjson.GetBytes(payload, "request.contents") + if !contents.IsArray() { + return payload, false + } + arr := contents.Array() + if beforeIndex < 0 || beforeIndex > len(arr) { + return payload, false + } + fc := map[string]any{"name": name} + if callID != "" { + fc["id"] = callID + } + if args.Exists() { + fc["args"] = args.Value() + } + part := map[string]any{"functionCall": fc} + if thoughtSig != "" { + part["thoughtSignature"] = thoughtSig + } + newContent := map[string]any{ + "role": "model", + "parts": []any{part}, + } + newArr := make([]any, 0, len(arr)+1) + for i := 0; i < beforeIndex; i++ { + newArr = append(newArr, arr[i].Value()) + } + newArr = append(newArr, newContent) + for i := beforeIndex; i < len(arr); i++ { + newArr = append(newArr, arr[i].Value()) + } + updated, err := sjson.SetBytes(payload, "request.contents", newArr) + if err != nil { + return payload, false + } + return updated, true +} + +func antigravityRequestHasThoughtSignatureAt(payload []byte, itemResult gjson.Result) bool { + ci := int(itemResult.Get("contentIndex").Int()) + pi := int(itemResult.Get("partIndex").Int()) + path := fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", ci, pi) + return strings.TrimSpace(gjson.GetBytes(payload, path).String()) != "" +} + +func insertAntigravityReasoningReplayItems(payload []byte, items [][]byte) ([]byte, bool) { + out := payload + changed := false + for _, item := range items { + itemResult := gjson.ParseBytes(item) + switch strings.TrimSpace(itemResult.Get("type").String()) { + case "thought_signature": + ci := antigravityReasoningReplayResolveContentIndex(out, int(itemResult.Get("contentIndex").Int())) + pi := int(itemResult.Get("partIndex").Int()) + sig := strings.TrimSpace(itemResult.Get("thoughtSignature").String()) + if sig == "" { + continue + } + path := fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", ci, pi) + if strings.TrimSpace(gjson.GetBytes(out, path).String()) != "" { + continue + } + updated, err := sjson.SetBytes(out, path, sig) + if err != nil { + continue + } + out = updated + changed = true + case "function_call_part": + updated, ok := mergeAntigravityFunctionCallPartReplay(out, itemResult) + if ok { + out = updated + changed = true + } + } + } + return out, changed +} + +func mergeAntigravityFunctionCallPartReplay(payload []byte, itemResult gjson.Result) ([]byte, bool) { + name := strings.TrimSpace(itemResult.Get("name").String()) + args := itemResult.Get("args") + callID := strings.TrimSpace(itemResult.Get("call_id").String()) + sig := strings.TrimSpace(itemResult.Get("thoughtSignature").String()) + if name == "" || !args.Exists() { + return payload, false + } + if callID != "" { + if ci, pi, exists := antigravityFunctionCallPartLocation(payload, callID); exists { + if sig != "" { + pathSig := fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", ci, pi) + if strings.TrimSpace(gjson.GetBytes(payload, pathSig).String()) == "" { + if updated, err := sjson.SetBytes(payload, pathSig, sig); err == nil { + return updated, true + } + } + } + return payload, false + } + if frIndex, ok := antigravityFunctionResponseContentIndex(payload, callID); ok { + return insertAntigravityModelFunctionCallBeforeContent(payload, frIndex, name, callID, sig, args) + } + } + + ci := antigravityReasoningReplayResolveContentIndex(payload, int(itemResult.Get("contentIndex").Int())) + pi := int(itemResult.Get("partIndex").Int()) + pathSig := fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", ci, pi) + out := payload + changed := false + if sig != "" && strings.TrimSpace(gjson.GetBytes(out, pathSig).String()) == "" { + if updated, err := sjson.SetBytes(out, pathSig, sig); err == nil { + out = updated + changed = true + } + } + pathFC := fmt.Sprintf("request.contents.%d.parts.%d.functionCall", ci, pi) + if !gjson.GetBytes(out, pathFC).Exists() { + fc := map[string]any{"name": name} + if callID != "" { + fc["id"] = callID + } + if args.Type == gjson.String { + fc["args"] = args.String() + } else { + var parsed any + if json.Unmarshal([]byte(args.Raw), &parsed) == nil { + fc["args"] = parsed + } + } + if updated, err := sjson.SetBytes(out, pathFC, fc); err == nil { + out = updated + changed = true + } + } + return out, changed +} + +type antigravityReasoningReplayAccumulator struct { + scope antigravityReasoningReplayScope + requestPayload []byte + items [][]byte + seenFC map[string]bool + contentIndex int + nextPartIndex int +} + +func newAntigravityReasoningReplayAccumulator(scope antigravityReasoningReplayScope, requestPayload []byte) *antigravityReasoningReplayAccumulator { + if !scope.valid() { + return nil + } + contentIndex, basePartIndex := antigravityReasoningReplayPendingModelContentIndex(requestPayload) + return &antigravityReasoningReplayAccumulator{ + scope: scope, + requestPayload: append([]byte(nil), requestPayload...), + seenFC: make(map[string]bool), + contentIndex: contentIndex, + nextPartIndex: basePartIndex, + } +} + +func (a *antigravityReasoningReplayAccumulator) ObserveSSELine(line []byte) { + if a == nil { + return + } + payload := helps.JSONPayload(line) + if payload == nil { + return + } + a.observeResponsePayload(payload) +} + +func (a *antigravityReasoningReplayAccumulator) observeResponsePayload(payload []byte) { + parts := gjson.GetBytes(payload, "response.candidates.0.content.parts") + if !parts.IsArray() { + return + } + parts.ForEach(func(_, part gjson.Result) bool { + pi := a.nextPartIndex + a.nextPartIndex++ + sig := antigravityNativePartThoughtSignature(part) + if fc := part.Get("functionCall"); fc.Exists() { + keys := antigravityReplayToolCallKeysFromPart(fc) + for _, k := range keys { + if a.seenFC[k] { + return true + } + } + for _, k := range keys { + a.seenFC[k] = true + } + item := buildAntigravityFunctionCallPartItem(a.contentIndex, pi, fc, sig) + if len(item) > 0 { + a.items = append(a.items, item) + } + return true + } + if sig != "" { + item := buildAntigravityThoughtSignatureItem(a.contentIndex, pi, sig) + a.items = append(a.items, item) + } + return true + }) +} + +func buildAntigravityThoughtSignatureItem(contentIndex, partIndex int, signature string) []byte { + return []byte(fmt.Sprintf(`{"type":"thought_signature","thoughtSignature":%q,"contentIndex":%d,"partIndex":%d}`, + signature, contentIndex, partIndex)) +} + +func buildAntigravityFunctionCallPartItem(contentIndex, partIndex int, fc gjson.Result, signature string) []byte { + item := map[string]any{ + "type": "function_call_part", + "contentIndex": contentIndex, + "partIndex": partIndex, + "name": fc.Get("name").String(), + } + if id := strings.TrimSpace(fc.Get("id").String()); id != "" { + item["call_id"] = id + } + if args := fc.Get("args"); args.Exists() { + if args.Type == gjson.String { + item["args"] = args.String() + } else { + item["args"] = json.RawMessage(args.Raw) + } + } + if signature != "" { + item["thoughtSignature"] = signature + } + raw, err := json.Marshal(item) + if err != nil { + return nil + } + return raw +} + +func (a *antigravityReasoningReplayAccumulator) Flush(ctx context.Context) { + if a == nil || !a.scope.valid() || len(a.items) == 0 { + return + } + if !internalcache.CacheAntigravityReasoningReplayItemsBestEffort(ctx, a.scope.modelName, a.scope.sessionKey, a.items) { + _ = internalcache.DeleteAntigravityReasoningReplayItemRequired(ctx, a.scope.modelName, a.scope.sessionKey) + } +} + +func cacheAntigravityReasoningReplayFromResponse(ctx context.Context, scope antigravityReasoningReplayScope, requestPayload, body []byte) { + if !scope.valid() || len(body) == 0 { + return + } + acc := newAntigravityReasoningReplayAccumulator(scope, requestPayload) + acc.observeResponsePayload(body) + acc.Flush(ctx) +} + +func applyAntigravityNativeSignatureReplayIfNeeded(modelName string, payload []byte) []byte { + if antigravityUsesReasoningReplayCache(modelName) { + return payload + } + // Native per-part signature replay is not on upstream/dev; Gemini uses HOME replay only. + return payload +} + +func antigravityUsesReasoningReplayCache(modelName string) bool { + modelName = strings.ToLower(modelName) + if strings.Contains(modelName, "claude") { + return false + } + return strings.Contains(modelName, "gemini") || strings.Contains(modelName, "flash") || strings.Contains(modelName, "agent") +} + +func antigravityNativePartThoughtSignature(part gjson.Result) string { + for _, path := range []string{"thoughtSignature", "thought_signature", "extra_content.google.thought_signature"} { + if signature := strings.TrimSpace(part.Get(path).String()); signature != "" { + return signature + } + } + return "" +} diff --git a/internal/runtime/executor/antigravity_reasoning_replay_clear_test.go b/internal/runtime/executor/antigravity_reasoning_replay_clear_test.go new file mode 100644 index 0000000000..a15f15ece9 --- /dev/null +++ b/internal/runtime/executor/antigravity_reasoning_replay_clear_test.go @@ -0,0 +1,66 @@ +package executor + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + internalcache "github.com/router-for-me/CLIProxyAPI/v7/internal/cache" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" +) + +func TestAntigravityReasoningReplayClearsOnInvalidSignature400(t *testing.T) { + internalcache.ClearAntigravityReasoningReplayCache() + t.Cleanup(internalcache.ClearAntigravityReasoningReplayCache) + + model := "gemini-3-flash-agent" + sessionKey := "session:pr3900-invalid-sig" + bad := []byte(`{"type":"thought_signature","thoughtSignature":"INVALID_REPLAY_SIGNATURE_PR3900_XXXXXXXXX","contentIndex":1,"partIndex":0}`) + if !internalcache.CacheAntigravityReasoningReplayItems(model, sessionKey, [][]byte{bad}) { + t.Fatal("failed to seed replay cache") + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"Invalid thoughtSignature in model content","code":400}}`)) + })) + defer server.Close() + + exec := NewAntigravityExecutor(&config.Config{RequestRetry: 1}) + auth := &cliproxyauth.Auth{ + ID: "auth-pr3900-invalid-sig", + Attributes: map[string]string{ + "base_url": server.URL, + }, + Metadata: map[string]any{ + "access_token": "token", + "project_id": "project-1", + "expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339), + }, + } + + payload := []byte(`{"sessionId":"pr3900-invalid-sig","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]},{"role":"user","parts":[{"functionResponse":{"id":"id1","name":"Bash","response":{"result":"ok"}}}]}]}}`) + _, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: model, + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FormatAntigravity, + Stream: false, + }) + if err == nil { + t.Fatal("expected upstream 400 error") + } + if _, ok, errGet := internalcache.GetAntigravityReasoningReplayItemsRequired(context.Background(), model, sessionKey); errGet != nil { + t.Fatalf("get after clear: %v", errGet) + } else if ok { + t.Fatal("invalid signature 400 should clear cached replay item") + } +} diff --git a/internal/runtime/executor/antigravity_reasoning_replay_test.go b/internal/runtime/executor/antigravity_reasoning_replay_test.go new file mode 100644 index 0000000000..cc53da2790 --- /dev/null +++ b/internal/runtime/executor/antigravity_reasoning_replay_test.go @@ -0,0 +1,146 @@ +package executor + +import ( + "context" + "strings" + "testing" + + internalcache "github.com/router-for-me/CLIProxyAPI/v7/internal/cache" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + "github.com/tidwall/gjson" +) + +func TestAntigravityReasoningReplayAccumulatorMultiToolSSEChunks(t *testing.T) { + internalcache.ClearAntigravityReasoningReplayCache() + t.Cleanup(internalcache.ClearAntigravityReasoningReplayCache) + + requestPayload := []byte(`{"sessionId":"sess-1","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`) + scope := antigravityReasoningReplayScope{modelName: "gemini-3-flash-agent", sessionKey: "session:sess-1"} + acc := newAntigravityReasoningReplayAccumulator(scope, requestPayload) + if acc == nil { + t.Fatal("accumulator is nil") + } + if acc.contentIndex != 1 || acc.nextPartIndex != 0 { + t.Fatalf("pending model slot = %d/%d, want 1/0", acc.contentIndex, acc.nextPartIndex) + } + + line1 := []byte(`data: {"response":{"candidates":[{"content":{"parts":[{"thoughtSignature":"sig-first","functionCall":{"name":"Read","args":{"file_path":"/a"},"id":"id1"}}]}}]}}`) + line2 := []byte(`data: {"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"Read","args":{"file_path":"/b"},"id":"id2"}}]}}]}}`) + acc.ObserveSSELine(line1) + acc.ObserveSSELine(line2) + acc.Flush(context.Background()) + + items, ok := internalcache.GetAntigravityReasoningReplayItems("gemini-3-flash-agent", "session:sess-1") + if !ok || len(items) != 2 { + t.Fatalf("cached items = %v ok=%v, want 2 items", len(items), ok) + } + pi0 := int(gjson.GetBytes(items[0], "partIndex").Int()) + pi1 := int(gjson.GetBytes(items[1], "partIndex").Int()) + if pi0 != 0 || pi1 != 1 { + t.Fatalf("partIndex = %d,%d, want 0,1", pi0, pi1) + } + if got := gjson.GetBytes(items[0], "thoughtSignature").String(); got != "sig-first" { + t.Fatalf("first sig = %q", got) + } +} + +func TestPrepareAntigravityGeminiReasoningReplayPayloadInjectsCachedToolPart(t *testing.T) { + internalcache.ClearAntigravityReasoningReplayCache() + t.Cleanup(internalcache.ClearAntigravityReasoningReplayCache) + + item := []byte(`{"type":"function_call_part","contentIndex":1,"partIndex":0,"name":"Read","call_id":"id1","args":{"file_path":"/a"},"thoughtSignature":"sig-first"}`) + if !internalcache.CacheAntigravityReasoningReplayItems("gemini-3-flash-agent", "session:sess-2", [][]byte{item}) { + t.Fatal("cache write failed") + } + + req := cliproxyexecutor.Request{} + opts := cliproxyexecutor.Options{} + payload := []byte(`{"sessionId":"sess-2","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]},{"role":"user","parts":[{"functionResponse":{"id":"id1","name":"Read","response":{"result":"ok"}}}]}]}}`) + out, scope, err := prepareAntigravityGeminiReasoningReplayPayload(context.Background(), "gemini-3-flash-agent", req, opts, payload) + if err != nil { + t.Fatalf("prepare error: %v", err) + } + if !scope.valid() { + t.Fatal("scope invalid") + } + if gjson.GetBytes(out, "request.contents.1.role").String() != "model" { + t.Fatalf("functionCall replay must be model role at [1], got %s", string(out)) + } + if got := gjson.GetBytes(out, "request.contents.1.parts.0.thoughtSignature").String(); got != "sig-first" { + t.Fatalf("thoughtSignature = %q, want sig-first", got) + } + if !gjson.GetBytes(out, "request.contents.1.parts.0.functionCall").Exists() { + t.Fatalf("functionCall not injected: %s", string(out)) + } + if !gjson.GetBytes(out, "request.contents.2.parts.0.functionResponse").Exists() { + t.Fatalf("functionResponse should follow model functionCall at [2]: %s", string(out)) + } +} + +func TestPrepareAntigravityGeminiReasoningReplayInsertsBeforeModelFunctionResponse(t *testing.T) { + internalcache.ClearAntigravityReasoningReplayCache() + t.Cleanup(internalcache.ClearAntigravityReasoningReplayCache) + + item := []byte(`{"type":"function_call_part","contentIndex":1,"partIndex":0,"name":"Read","call_id":"id1","args":{"file_path":"/a"},"thoughtSignature":"sig-first"}`) + internalcache.CacheAntigravityReasoningReplayItems("gemini-3-flash-agent", "session:sess-3", [][]byte{item}) + + payload := []byte(`{"sessionId":"sess-3","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]},{"role":"model","parts":[{"functionResponse":{"id":"id1","name":"Read","response":{"result":"ok"}}}]}]}}`) + out, _, err := prepareAntigravityGeminiReasoningReplayPayload(context.Background(), "gemini-3-flash-agent", cliproxyexecutor.Request{}, cliproxyexecutor.Options{}, payload) + if err != nil { + t.Fatal(err) + } + if !gjson.GetBytes(out, "request.contents.1.parts.0.functionCall").Exists() || gjson.GetBytes(out, "request.contents.1.role").String() != "model" { + t.Fatalf("want model functionCall at [1]: %s", string(out)) + } + if !gjson.GetBytes(out, "request.contents.2.parts.0.functionResponse").Exists() { + t.Fatalf("functionResponse should be at [2]: %s", string(out)) + } +} + +func TestMergeAntigravityFunctionCallPartReplayMergesSignatureIntoExistingFunctionCall(t *testing.T) { + internalcache.ClearAntigravityReasoningReplayCache() + t.Cleanup(internalcache.ClearAntigravityReasoningReplayCache) + + item := []byte(`{"type":"function_call_part","contentIndex":1,"partIndex":0,"name":"Read","call_id":"id1","args":{"file_path":"/a"},"thoughtSignature":"sig-first"}`) + internalcache.CacheAntigravityReasoningReplayItems("gemini-3-flash-agent", "session:sess-merge", [][]byte{item}) + + payload := []byte(`{"sessionId":"sess-merge","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]},{"role":"model","parts":[{"functionCall":{"id":"id1","name":"Read","args":{"file_path":"/a"}}}]},{"role":"user","parts":[{"functionResponse":{"id":"id1","name":"Read","response":{"result":"ok"}}}]}]}}`) + out, _, err := prepareAntigravityGeminiReasoningReplayPayload(context.Background(), "gemini-3-flash-agent", cliproxyexecutor.Request{}, cliproxyexecutor.Options{}, payload) + if err != nil { + t.Fatal(err) + } + if got := gjson.GetBytes(out, "request.contents.1.parts.0.thoughtSignature").String(); got != "sig-first" { + t.Fatalf("thoughtSignature = %q, want sig-first; body=%s", got, out) + } +} + +func TestAntigravityReasoningReplayScopeUsesStableSessionWithoutSessionId(t *testing.T) { + payload := []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"stable-user-text"}]}]}}`) + scope := antigravityReasoningReplayScopeFromPayload("gemini-3-flash-agent", payload) + if !scope.valid() { + t.Fatal("scope should be valid from stable session hash") + } + if !strings.HasPrefix(scope.sessionKey, "session:") { + t.Fatalf("sessionKey = %q", scope.sessionKey) + } +} + +func TestAntigravityReplayToolCallKeysUsesNativeFunctionCallID(t *testing.T) { + fc := gjson.Parse(`{"name":"Read","args":{"file_path":"/a"},"id":"id-native"}`) + keys := antigravityReplayToolCallKeysFromPart(fc) + if len(keys) != 1 { + t.Fatalf("keys = %v", keys) + } + fc2 := gjson.Parse(`{"name":"Read","args":{"file_path":"/a"},"id":"id-native-2"}`) + keys2 := antigravityReplayToolCallKeysFromPart(fc2) + if keys[0] == keys2[0] { + t.Fatalf("parallel tool calls should not share replay key: %v vs %v", keys, keys2) + } +} + +func TestAntigravityRequestHasMatchingFunctionResponseWhitespaceCallID(t *testing.T) { + item := gjson.Parse(`{"call_id":" "}`) + if !antigravityRequestHasMatchingFunctionResponse(nil, item) { + t.Fatal("whitespace-only call_id should be treated as empty => true") + } +} diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index fec288b894..c0ece15288 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -45,11 +45,18 @@ type ClaudeExecutor struct { // Previously "proxy_" was used but this is a detectable fingerprint difference. const claudeToolPrefix = "" +func shouldSanitizeClaudeMessagesForUpstream(baseModel string) bool { + return sigcompat.SignatureProviderFromModelName(baseModel) == sigcompat.SignatureProviderClaude +} + func sanitizeClaudeMessagesForClaudeUpstreamWithDebug(ctx context.Context, body []byte, baseModel string) []byte { - sanitized, report := sigcompat.SanitizeClaudeMessagesForClaudeUpstream(body, baseModel) - logClaudeSignatureSanitizeReport(ctx, baseModel, report) - sanitized = sanitizeClaudeWebSearchDomains(sanitized) - return sanitized + sanitized := body + if shouldSanitizeClaudeMessagesForUpstream(baseModel) { + var report sigcompat.SignatureSanitizeReport + sanitized, report = sigcompat.SanitizeClaudeMessagesForClaudeUpstream(body, baseModel) + logClaudeSignatureSanitizeReport(ctx, baseModel, report) + } + return sanitizeClaudeWebSearchDomains(sanitized) } // sanitizeClaudeWebSearchDomains removes empty allowed_domains/blocked_domains @@ -219,6 +226,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r if err != nil { return resp, err } + if rebuildMidSystemMessageEnabled(e.cfg, auth) { + body = rebuildMidSystemMessagesToTopLevel(body) + } // Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation) // based on client type and configuration. @@ -406,6 +416,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A if err != nil { return nil, err } + if rebuildMidSystemMessageEnabled(e.cfg, auth) { + body = rebuildMidSystemMessagesToTopLevel(body) + } // Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation) // based on client type and configuration. @@ -674,6 +687,9 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut stream := from != to body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream) body, _ = sjson.SetBytes(body, "model", baseModel) + if rebuildMidSystemMessageEnabled(e.cfg, auth) { + body = rebuildMidSystemMessagesToTopLevel(body) + } if !strings.HasPrefix(baseModel, "claude-3-5-haiku") { body = checkSystemInstructions(body) @@ -1141,6 +1157,91 @@ func checkSystemInstructions(payload []byte) []byte { return checkSystemInstructionsWithSigningMode(payload, false, false, false, "2.1.63", "", "") } +func rebuildMidSystemMessagesToTopLevel(payload []byte) []byte { + messages := gjson.GetBytes(payload, "messages") + if !messages.IsArray() { + return payload + } + + var movedSystemParts []string + keptMessages := make([]string, 0, int(messages.Get("#").Int())) + messages.ForEach(func(_, message gjson.Result) bool { + if strings.EqualFold(strings.TrimSpace(message.Get("role").String()), "system") { + movedSystemParts = append(movedSystemParts, claudeSystemTextParts(message.Get("content"))...) + return true + } + keptMessages = append(keptMessages, message.Raw) + return true + }) + if len(movedSystemParts) == 0 { + return payload + } + + systemParts := claudeSystemTextParts(gjson.GetBytes(payload, "system")) + systemParts = append(systemParts, movedSystemParts...) + if len(systemParts) > 0 { + if updated, errSetSystem := sjson.SetRawBytes(payload, "system", rawJSONArray(systemParts)); errSetSystem == nil { + payload = updated + } + } + if updated, errSetMessages := sjson.SetRawBytes(payload, "messages", rawJSONArray(keptMessages)); errSetMessages == nil { + payload = updated + } + return payload +} + +func claudeSystemTextParts(content gjson.Result) []string { + if !content.Exists() { + return nil + } + if content.Type == gjson.String { + text := content.String() + if strings.TrimSpace(text) == "" { + return nil + } + block := []byte(`{"type":"text","text":""}`) + block, _ = sjson.SetBytes(block, "text", text) + return []string{string(block)} + } + if !content.IsArray() { + return nil + } + + var parts []string + content.ForEach(func(_, item gjson.Result) bool { + if item.Type == gjson.String { + text := item.String() + if strings.TrimSpace(text) != "" { + block := []byte(`{"type":"text","text":""}`) + block, _ = sjson.SetBytes(block, "text", text) + parts = append(parts, string(block)) + } + return true + } + if item.IsObject() && item.Get("type").String() == "text" && strings.TrimSpace(item.Get("text").String()) != "" { + parts = append(parts, item.Raw) + } + return true + }) + return parts +} + +func rawJSONArray(items []string) []byte { + if len(items) == 0 { + return []byte("[]") + } + var builder strings.Builder + builder.WriteByte('[') + for i, item := range items { + if i > 0 { + builder.WriteByte(',') + } + builder.WriteString(item) + } + builder.WriteByte(']') + return []byte(builder.String()) +} + func isClaudeOAuthToken(apiKey string) bool { return strings.Contains(apiKey, "sk-ant-oat") } diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index be4a97190c..0d960bcfef 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/gzip" "context" + "encoding/base64" "fmt" "io" "net/http" @@ -31,6 +32,10 @@ func resetClaudeDeviceProfileCache() { helps.ResetClaudeDeviceProfileCache() } +func malformedClaudeTreeSignatureForClaudeExecutorTest() string { + return base64.StdEncoding.EncodeToString([]byte{0x12, 0xFF, 0xFE, 0xFD}) +} + func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request { t.Helper() @@ -857,6 +862,420 @@ func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) { } } +func TestClaudeExecutor_ExecuteStripsOpenAIEncryptedThinkingBeforeUpstream(t *testing.T) { + var seenBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + seenBody = bytes.Clone(body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{ + "messages": [ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"codex reasoning","signature":"gAAAAABopenai-encrypted-content"}, + {"type":"text","text":"Answer"} + ]}, + {"role":"user","content":[{"type":"text","text":"next"}]} + ] + }`) + + _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if len(seenBody) == 0 { + t.Fatal("expected request body to be captured") + } + if strings.Contains(string(seenBody), "gAAAAABopenai-encrypted-content") || strings.Contains(string(seenBody), "codex reasoning") { + t.Fatalf("invalid thinking block was forwarded: %s", string(seenBody)) + } + content := gjson.GetBytes(seenBody, "messages.0.content").Array() + if len(content) != 1 { + t.Fatalf("messages.0.content length = %d, want 1: %s", len(content), string(seenBody)) + } + if got := content[0].Get("text").String(); got != "Answer" { + t.Fatalf("remaining content text = %q, want Answer", got) + } +} + +func TestClaudeExecutor_ExecuteStripsForeignToolUseSignaturesBeforeUpstream(t *testing.T) { + var seenBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + seenBody = bytes.Clone(body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{ + "messages": [ + {"role":"assistant","content":[ + { + "type":"tool_use", + "id":"toolu_1", + "name":"lookup", + "input":{"q":"x"}, + "signature":"skip_thought_signature_validator", + "thought_signature":"skip_thought_signature_validator", + "extra_content":{"google":{"thought_signature":"skip_thought_signature_validator"}} + } + ]}, + {"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_1","content":"ok"}]} + ] + }`) + + _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if len(seenBody) == 0 { + t.Fatal("expected request body to be captured") + } + toolUse := gjson.GetBytes(seenBody, "messages.0.content.0") + if !toolUse.Get("type").Exists() || toolUse.Get("type").String() != "tool_use" { + t.Fatalf("tool_use block was not preserved: %s", string(seenBody)) + } + for _, path := range []string{"signature", "thought_signature", "extra_content"} { + if toolUse.Get(path).Exists() { + t.Fatalf("foreign tool_use signature field %s was forwarded: %s", path, string(seenBody)) + } + } +} + +func TestShouldSanitizeClaudeMessagesForUpstream_OnlyClaudeFamily(t *testing.T) { + cases := []struct { + model string + want bool + }{ + {model: "claude-sonnet-4-5", want: true}, + {model: "claude-3-5-sonnet-20241022", want: true}, + {model: "kimi-k2.5", want: false}, + {model: "mimo-v2", want: false}, + {model: "gemini-3.5-flash", want: false}, + } + for _, tc := range cases { + t.Run(tc.model, func(t *testing.T) { + got := shouldSanitizeClaudeMessagesForUpstream(tc.model) + if got != tc.want { + t.Errorf("shouldSanitizeClaudeMessagesForUpstream(%q) = %v, want %v", tc.model, got, tc.want) + } + }) + } +} + +func TestSanitizeClaudeMessagesForClaudeUpstream_BypassesUnknownModelSignatureMatrix(t *testing.T) { + rawSignature := "skip_thought_signature_validator" + body := []byte(`{ + "model": "kimi-k2.5", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "keep", "signature": "` + rawSignature + `"}, + {"type": "text", "text": "hello"}, + {"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {}, "signature": "` + rawSignature + `"} + ] + } + ] + }`) + + output := sanitizeClaudeMessagesForClaudeUpstreamWithDebug(context.Background(), body, "kimi-k2.5") + parts := gjson.GetBytes(output, "messages.0.content").Array() + if len(parts) != 3 { + t.Fatalf("content length = %d, want 3 when sanitizer is bypassed: %s", len(parts), output) + } + if got := parts[0].Get("signature").String(); got != rawSignature { + t.Fatalf("thinking signature = %q, want preserved %q", got, rawSignature) + } + if got := parts[2].Get("signature").String(); got != rawSignature { + t.Fatalf("tool_use signature = %q, want preserved %q", got, rawSignature) + } +} + +func TestClaudeExecutor_ExecuteBypassesSignatureSanitizerForUnknownModel(t *testing.T) { + var seenBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + seenBody = bytes.Clone(body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"mimo-v2","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{ + "messages": [ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"keep reasoning","signature":""}, + {"type":"text","text":"Answer"} + ]}, + {"role":"user","content":[{"type":"text","text":"next"}]} + ] + }`) + + _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "mimo-v2", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if len(seenBody) == 0 { + t.Fatal("expected request body to be captured") + } + if !strings.Contains(string(seenBody), "keep reasoning") { + t.Fatalf("unknown-model thinking block should bypass Claude sanitizer: %s", string(seenBody)) + } +} + +func TestClaudeExecutor_ExecuteStripsMalformedEPrefixThinkingBeforeUpstream(t *testing.T) { + var seenBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + seenBody = bytes.Clone(body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + malformedSignature := malformedClaudeTreeSignatureForClaudeExecutorTest() + payload := []byte(`{ + "messages": [ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"bad reasoning","signature":"` + malformedSignature + `"}, + {"type":"text","text":"Answer"} + ]}, + {"role":"user","content":[{"type":"text","text":"next"}]} + ] + }`) + + _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if len(seenBody) == 0 { + t.Fatal("expected request body to be captured") + } + if strings.Contains(string(seenBody), malformedSignature) || strings.Contains(string(seenBody), "bad reasoning") { + t.Fatalf("malformed E-prefix thinking block was forwarded: %s", string(seenBody)) + } + content := gjson.GetBytes(seenBody, "messages.0.content").Array() + if len(content) != 1 { + t.Fatalf("messages.0.content length = %d, want 1: %s", len(content), string(seenBody)) + } + if got := content[0].Get("text").String(); got != "Answer" { + t.Fatalf("remaining content text = %q, want Answer", got) + } +} + +func TestClaudeExecutor_ExecuteStripsInvalidBase64ThinkingBeforeUpstream(t *testing.T) { + var seenBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + seenBody = bytes.Clone(body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{ + "messages": [ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"bad reasoning","signature":"E!!!invalid!!!"}, + {"type":"text","text":"Answer"} + ]}, + {"role":"user","content":[{"type":"text","text":"next"}]} + ] + }`) + + _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if len(seenBody) == 0 { + t.Fatal("expected request body to be captured") + } + if strings.Contains(string(seenBody), "E!!!invalid!!!") || strings.Contains(string(seenBody), "bad reasoning") { + t.Fatalf("invalid-base64 thinking block was forwarded: %s", string(seenBody)) + } + content := gjson.GetBytes(seenBody, "messages.0.content").Array() + if len(content) != 1 { + t.Fatalf("messages.0.content length = %d, want 1: %s", len(content), string(seenBody)) + } +} + +func TestClaudeExecutor_ExecuteStripsEmptySignatureEmptyTextThinking(t *testing.T) { + var seenBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + seenBody = bytes.Clone(body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{ + "messages": [ + {"role":"assistant","content":[ + {"type":"thinking","text":"","signature":""}, + {"type":"text","text":"Answer"} + ]}, + {"role":"user","content":[{"type":"text","text":"next"}]} + ] + }`) + + _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if len(seenBody) == 0 { + t.Fatal("expected request body to be captured") + } + content := gjson.GetBytes(seenBody, "messages.0.content").Array() + if len(content) != 1 { + t.Fatalf("messages.0.content length = %d, want 1: %s", len(content), string(seenBody)) + } + if got := content[0].Get("type").String(); got != "text" { + t.Fatalf("remaining content type = %q, want text: %s", got, string(seenBody)) + } + if got := content[0].Get("text").String(); got != "Answer" { + t.Fatalf("remaining content text = %q, want Answer: %s", got, string(seenBody)) + } +} + +func TestClaudeExecutor_ExecuteStreamStripsOpenAIEncryptedThinkingBeforeUpstream(t *testing.T) { + var seenBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + seenBody = bytes.Clone(body) + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n")) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{ + "messages": [ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"codex reasoning","signature":"gAAAAABopenai-encrypted-content"}, + {"type":"text","text":"Answer"} + ]}, + {"role":"user","content":[{"type":"text","text":"next"}]} + ] + }`) + + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + if err != nil { + t.Fatalf("ExecuteStream() error = %v", err) + } + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected chunk error: %v", chunk.Err) + } + } + if len(seenBody) == 0 { + t.Fatal("expected request body to be captured") + } + if strings.Contains(string(seenBody), "gAAAAABopenai-encrypted-content") || strings.Contains(string(seenBody), "codex reasoning") { + t.Fatalf("invalid thinking block was forwarded: %s", string(seenBody)) + } +} + +func TestClaudeExecutor_CountTokensStripsOpenAIEncryptedThinkingBeforeUpstream(t *testing.T) { + var seenBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + seenBody = bytes.Clone(body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"input_tokens":42}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{ + "messages": [ + {"role":"assistant","content":[ + {"type":"thinking","thinking":"codex reasoning","signature":"gAAAAABopenai-encrypted-content"}, + {"type":"text","text":"Answer"} + ]}, + {"role":"user","content":[{"type":"text","text":"next"}]} + ] + }`) + + _, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + if err != nil { + t.Fatalf("CountTokens() error = %v", err) + } + if len(seenBody) == 0 { + t.Fatal("expected request body to be captured") + } + if strings.Contains(string(seenBody), "gAAAAABopenai-encrypted-content") || strings.Contains(string(seenBody), "codex reasoning") { + t.Fatalf("invalid thinking block was forwarded: %s", string(seenBody)) + } +} + func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) { var userIDs []string var requestModels []string @@ -2113,6 +2532,103 @@ func TestClaudeExecutor_ExperimentalCCHSigningOptInSignsFinalBody(t *testing.T) } } +func TestClaudeExecutor_RebuildMidSystemMessageDisabledByDefault(t *testing.T) { + var seenBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + seenBody = bytes.Clone(body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{ + ClaudeKey: []config.ClaudeKey{{ + APIKey: "key-123", + BaseURL: server.URL, + }}, + }) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"system":[{"type":"text","text":"Top rule","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]},{"role":"system","content":"Mid rule"},{"role":"user","content":[{"type":"text","text":"continue"}]}]}`) + ctx := contextWithGinHeaders(map[string]string{"User-Agent": "claude-cli/2.1.153 (external, cli)"}) + + _, errExecute := executor.Execute(ctx, auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + if errExecute != nil { + t.Fatalf("Execute() error = %v", errExecute) + } + if len(seenBody) == 0 { + t.Fatal("expected request body to be captured") + } + if got := gjson.GetBytes(seenBody, "system.0.text").String(); got != "Top rule" { + t.Fatalf("system.0.text = %q, want top-level system preserved", got) + } + if got := gjson.GetBytes(seenBody, `messages.#(role=="system").content`).String(); got != "Mid rule" { + t.Fatalf("mid system message = %q, want original message preserved", got) + } +} + +func TestClaudeExecutor_RebuildMidSystemMessageOptInMovesSystemMessages(t *testing.T) { + var seenBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + seenBody = bytes.Clone(body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{ + ClaudeKey: []config.ClaudeKey{{ + APIKey: "key-123", + BaseURL: server.URL, + RebuildMidSystemMessage: true, + }}, + }) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"system":"Top rule","messages":[{"role":"user","content":[{"type":"text","text":"hi"}]},{"role":"system","content":"Mid string rule"},{"role":"assistant","content":[{"type":"text","text":"ok"}]},{"role":"system","content":[{"type":"text","text":"Mid array rule","cache_control":{"type":"ephemeral"}}]},{"role":"user","content":[{"type":"text","text":"continue"}]}]}`) + ctx := contextWithGinHeaders(map[string]string{"User-Agent": "claude-cli/2.1.153 (external, cli)"}) + + _, errExecute := executor.Execute(ctx, auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + if errExecute != nil { + t.Fatalf("Execute() error = %v", errExecute) + } + if len(seenBody) == 0 { + t.Fatal("expected request body to be captured") + } + + system := gjson.GetBytes(seenBody, "system").Array() + if len(system) != 3 { + t.Fatalf("system has %d items, want 3: %s", len(system), gjson.GetBytes(seenBody, "system").Raw) + } + wantTexts := []string{"Top rule", "Mid string rule", "Mid array rule"} + for i, want := range wantTexts { + if got := system[i].Get("text").String(); got != want { + t.Fatalf("system[%d].text = %q, want %q", i, got, want) + } + } + if got := gjson.GetBytes(seenBody, "system.2.cache_control.type").String(); got != "ephemeral" { + t.Fatalf("system.2.cache_control.type = %q, want ephemeral", got) + } + if gjson.GetBytes(seenBody, `messages.#(role=="system")`).Exists() { + t.Fatalf("messages should not contain system role after rebuild: %s", gjson.GetBytes(seenBody, "messages").Raw) + } + if got := gjson.GetBytes(seenBody, "messages.#").Int(); got != 3 { + t.Fatalf("messages count = %d, want 3", got) + } +} + func TestApplyCloaking_PreservesConfiguredStrictModeAndSensitiveWordsWhenModeOmitted(t *testing.T) { cfg := &config.Config{ ClaudeKey: []config.ClaudeKey{{ diff --git a/internal/runtime/executor/claude_signing.go b/internal/runtime/executor/claude_signing.go index 060e86e846..8afd57a675 100644 --- a/internal/runtime/executor/claude_signing.go +++ b/internal/runtime/executor/claude_signing.go @@ -79,3 +79,11 @@ func experimentalCCHSigningEnabled(cfg *config.Config, auth *cliproxyauth.Auth) entry := resolveClaudeKeyConfig(cfg, auth) return entry != nil && entry.ExperimentalCCHSigning } + +func rebuildMidSystemMessageEnabled(cfg *config.Config, auth *cliproxyauth.Auth) bool { + if auth != nil && auth.Attributes != nil && strings.EqualFold(strings.TrimSpace(auth.Attributes["rebuild_mid_system_message"]), "true") { + return true + } + entry := resolveClaudeKeyConfig(cfg, auth) + return entry != nil && entry.RebuildMidSystemMessage +} diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 24a520cc4b..7b69f67d79 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -829,6 +829,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re body = ensureImageGenerationTool(body, baseModel, auth) } body = sanitizeOpenAIResponsesReasoningEncryptedContent(ctx, "codex executor", body) + body = normalizeCodexParallelToolCallsForTools(body) body, replayScope, errReplay := applyCodexReasoningReplayCacheRequired(ctx, from, req, opts, body) if errReplay != nil { return resp, errReplay @@ -1004,6 +1005,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A body = ensureImageGenerationTool(body, baseModel, auth) } body = sanitizeOpenAIResponsesReasoningEncryptedContent(ctx, "codex executor", body) + body = normalizeCodexParallelToolCallsForTools(body) reporter.SetTranslatedReasoningEffort(body, to.String()) url := strings.TrimSuffix(baseURL, "/") + "/responses/compact" @@ -1113,6 +1115,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au body = ensureImageGenerationTool(body, baseModel, auth) } body = sanitizeOpenAIResponsesReasoningEncryptedContent(ctx, "codex executor", body) + body = normalizeCodexParallelToolCallsForTools(body) body, replayScope, errReplay := applyCodexReasoningReplayCacheRequired(ctx, from, req, opts, body) if errReplay != nil { return nil, errReplay @@ -1775,6 +1778,21 @@ func ensureImageGenerationTool(body []byte, baseModel string, auth *cliproxyauth return body } +func normalizeCodexParallelToolCallsForTools(body []byte) []byte { + if !gjson.GetBytes(body, "parallel_tool_calls").Exists() { + return body + } + + tools := gjson.GetBytes(body, "tools") + hasTools := tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 + if hasTools { + return body + } + + body, _ = sjson.DeleteBytes(body, "parallel_tool_calls") + return body +} + func publishCodexImageToolUsage(ctx context.Context, reporter *helps.UsageReporter, body []byte, completedData []byte) { detail, ok := helps.ParseCodexImageToolUsage(completedData) if !ok { diff --git a/internal/runtime/executor/codex_executor_parallel_tool_calls_test.go b/internal/runtime/executor/codex_executor_parallel_tool_calls_test.go new file mode 100644 index 0000000000..d1f4f8e174 --- /dev/null +++ b/internal/runtime/executor/codex_executor_parallel_tool_calls_test.go @@ -0,0 +1,40 @@ +package executor + +import ( + "testing" + + "github.com/tidwall/gjson" +) + +func TestNormalizeCodexParallelToolCallsForTools_DropsWhenToolsMissing(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","parallel_tool_calls":true,"input":"hi"}`) + + out := normalizeCodexParallelToolCallsForTools(body) + + if gjson.GetBytes(out, "parallel_tool_calls").Exists() { + t.Fatalf("parallel_tool_calls should be removed when tools are missing: %s", string(out)) + } +} + +func TestNormalizeCodexParallelToolCallsForTools_DropsWhenToolsEmpty(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","tools":[],"parallel_tool_calls":false,"input":"hi"}`) + + out := normalizeCodexParallelToolCallsForTools(body) + + if gjson.GetBytes(out, "parallel_tool_calls").Exists() { + t.Fatalf("parallel_tool_calls should be removed when tools are empty: %s", string(out)) + } + if !gjson.GetBytes(out, "tools").Exists() { + t.Fatalf("tools should be preserved: %s", string(out)) + } +} + +func TestNormalizeCodexParallelToolCallsForTools_PreservesWhenToolsPresent(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","tools":[{"type":"function","name":"lookup"}],"parallel_tool_calls":true,"input":"hi"}`) + + out := normalizeCodexParallelToolCallsForTools(body) + + if !gjson.GetBytes(out, "parallel_tool_calls").Bool() { + t.Fatalf("parallel_tool_calls should be preserved when tools are present: %s", string(out)) + } +} diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go deleted file mode 100644 index 7055f8ad01..0000000000 --- a/internal/runtime/executor/gemini_cli_executor.go +++ /dev/null @@ -1,1041 +0,0 @@ -// Package executor provides runtime execution capabilities for various AI service providers. -// This file implements the Gemini CLI executor that talks to Cloud Code Assist endpoints -// using OAuth credentials from auth metadata. -package executor - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "regexp" - "strconv" - "strings" - "time" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" - "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps" - "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/geminicli" - "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" - "github.com/router-for-me/CLIProxyAPI/v7/internal/util" - cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" - cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" - sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" -) - -const ( - codeAssistEndpoint = "https://cloudcode-pa.googleapis.com" - codeAssistVersion = "v1internal" - geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - geminiOAuthClientSecret = "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" -) - -var geminiOAuthScopes = []string{ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", -} - -// GeminiCLIExecutor talks to the Cloud Code Assist endpoint using OAuth credentials from auth metadata. -type GeminiCLIExecutor struct { - cfg *config.Config -} - -// NewGeminiCLIExecutor creates a new Gemini CLI executor instance. -// -// Parameters: -// - cfg: The application configuration -// -// Returns: -// - *GeminiCLIExecutor: A new Gemini CLI executor instance -func NewGeminiCLIExecutor(cfg *config.Config) *GeminiCLIExecutor { - return &GeminiCLIExecutor{cfg: cfg} -} - -// Identifier returns the executor identifier. -func (e *GeminiCLIExecutor) Identifier() string { return "gemini-cli" } - -// PrepareRequest injects Gemini CLI credentials into the outgoing HTTP request. -func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { - if req == nil { - return nil - } - tokenSource, _, errSource := prepareGeminiCLITokenSource(req.Context(), e.cfg, auth) - if errSource != nil { - return errSource - } - tok, errTok := tokenSource.Token() - if errTok != nil { - return errTok - } - if strings.TrimSpace(tok.AccessToken) == "" { - return statusErr{code: http.StatusUnauthorized, msg: "missing access token"} - } - req.Header.Set("Authorization", "Bearer "+tok.AccessToken) - applyGeminiCLIHeaders(req, "unknown") - var attrs map[string]string - if auth != nil { - attrs = auth.Attributes - } - util.ApplyCustomHeadersFromAttrs(req, attrs) - return nil -} - -// HttpRequest injects Gemini CLI credentials into the request and executes it. -func (e *GeminiCLIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { - if req == nil { - return nil, fmt.Errorf("gemini-cli executor: request is nil") - } - if ctx == nil { - ctx = req.Context() - } - httpReq := req.WithContext(ctx) - if err := e.PrepareRequest(httpReq, auth); err != nil { - return nil, err - } - httpClient := newHTTPClient(ctx, e.cfg, auth, 0) - return httpClient.Do(httpReq) -} - -// Execute performs a non-streaming request to the Gemini CLI API. -func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { - if opts.Alt == "responses/compact" { - return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} - } - baseModel := thinking.ParseSuffix(req.Model).ModelName - - tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth) - if err != nil { - return resp, err - } - - reporter := helps.NewExecutorUsageReporter(ctx, e, baseModel, auth) - defer reporter.TrackFailure(ctx, &err) - - from := opts.SourceFormat - responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) - to := sdktranslator.FromString("gemini-cli") - - originalPayloadSource := req.Payload - if len(opts.OriginalRequest) > 0 { - originalPayloadSource = opts.OriginalRequest - } - originalPayload := originalPayloadSource - originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) - basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) - - basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier()) - if err != nil { - return resp, err - } - - basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload) - requestedModel := helps.PayloadRequestedModel(opts, req.Model) - requestPath := helps.PayloadRequestPath(opts) - basePayload = helps.ApplyPayloadConfigWithRequest(e.cfg, baseModel, "gemini", from.String(), "request", basePayload, originalTranslated, requestedModel, requestPath, opts.Headers) - basePayload = cleanGeminiCLIRequestSchemas(basePayload) - reporter.SetTranslatedReasoningEffort(basePayload, to.String()) - - action := "generateContent" - if req.Metadata != nil { - if a, _ := req.Metadata["action"].(string); a == "countTokens" { - action = "countTokens" - } - } - - projectID := resolveGeminiProjectID(auth) - models := cliPreviewFallbackOrder(baseModel) - if len(models) == 0 || models[0] != baseModel { - models = append([]string{baseModel}, models...) - } - - httpClient := newHTTPClient(ctx, e.cfg, auth, 0) - httpClient = reporter.TrackHTTPClient(httpClient) - respCtx := context.WithValue(ctx, "alt", opts.Alt) - - var authID, authLabel, authType, authValue string - authID = auth.ID - authLabel = auth.Label - authType, authValue = auth.AccountInfo() - - var lastStatus int - var lastBody []byte - - for idx, attemptModel := range models { - payload := append([]byte(nil), basePayload...) - if action == "countTokens" { - payload = deleteJSONField(payload, "project") - payload = deleteJSONField(payload, "model") - } else { - payload = setJSONField(payload, "project", projectID) - payload = setJSONField(payload, "model", attemptModel) - } - - tok, errTok := tokenSource.Token() - if errTok != nil { - err = errTok - return resp, err - } - updateGeminiCLITokenMetadata(auth, baseTokenData, tok) - - url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, action) - if opts.Alt != "" && action != "countTokens" { - url = url + fmt.Sprintf("?$alt=%s", opts.Alt) - } - - reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) - if errReq != nil { - err = errReq - return resp, err - } - reqHTTP.Header.Set("Content-Type", "application/json") - reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) - applyGeminiCLIHeaders(reqHTTP, attemptModel) - reqHTTP.Header.Set("Accept", "application/json") - util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes) - helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: reqHTTP.Header.Clone(), - Body: payload, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpResp, errDo := httpClient.Do(reqHTTP) - if errDo != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errDo) - err = errDo - return resp, err - } - - data, errRead := io.ReadAll(httpResp.Body) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("gemini cli executor: close response body error: %v", errClose) - } - helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if errRead != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errRead) - err = errRead - return resp, err - } - helps.AppendAPIResponseChunk(ctx, e.cfg, data) - if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 { - reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data)) - var param any - out := sdktranslator.TranslateNonStream(respCtx, to, responseFormat, attemptModel, opts.OriginalRequest, payload, data, ¶m) - resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} - return resp, nil - } - - lastStatus = httpResp.StatusCode - lastBody = append([]byte(nil), data...) - helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - if httpResp.StatusCode == 429 { - if idx+1 < len(models) { - log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) - } else { - log.Debug("gemini cli executor: rate limited, no additional fallback model") - } - continue - } - - err = newGeminiStatusErr(httpResp.StatusCode, data) - return resp, err - } - - if len(lastBody) > 0 { - helps.AppendAPIResponseChunk(ctx, e.cfg, lastBody) - } - if lastStatus == 0 { - lastStatus = 429 - } - err = newGeminiStatusErr(lastStatus, lastBody) - return resp, err -} - -// ExecuteStream performs a streaming request to the Gemini CLI API. -func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { - if opts.Alt == "responses/compact" { - return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} - } - baseModel := thinking.ParseSuffix(req.Model).ModelName - - tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth) - if err != nil { - return nil, err - } - - reporter := helps.NewExecutorUsageReporter(ctx, e, baseModel, auth) - defer reporter.TrackFailure(ctx, &err) - - from := opts.SourceFormat - responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) - to := sdktranslator.FromString("gemini-cli") - - originalPayloadSource := req.Payload - if len(opts.OriginalRequest) > 0 { - originalPayloadSource = opts.OriginalRequest - } - originalPayload := originalPayloadSource - originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) - basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) - - basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier()) - if err != nil { - return nil, err - } - - basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload) - requestedModel := helps.PayloadRequestedModel(opts, req.Model) - requestPath := helps.PayloadRequestPath(opts) - basePayload = helps.ApplyPayloadConfigWithRequest(e.cfg, baseModel, "gemini", from.String(), "request", basePayload, originalTranslated, requestedModel, requestPath, opts.Headers) - basePayload = cleanGeminiCLIRequestSchemas(basePayload) - reporter.SetTranslatedReasoningEffort(basePayload, to.String()) - - projectID := resolveGeminiProjectID(auth) - - models := cliPreviewFallbackOrder(baseModel) - if len(models) == 0 || models[0] != baseModel { - models = append([]string{baseModel}, models...) - } - - httpClient := newHTTPClient(ctx, e.cfg, auth, 0) - httpClient = reporter.TrackHTTPClient(httpClient) - respCtx := context.WithValue(ctx, "alt", opts.Alt) - - var authID, authLabel, authType, authValue string - authID = auth.ID - authLabel = auth.Label - authType, authValue = auth.AccountInfo() - - var lastStatus int - var lastBody []byte - - for idx, attemptModel := range models { - payload := append([]byte(nil), basePayload...) - payload = setJSONField(payload, "project", projectID) - payload = setJSONField(payload, "model", attemptModel) - - tok, errTok := tokenSource.Token() - if errTok != nil { - err = errTok - return nil, err - } - updateGeminiCLITokenMetadata(auth, baseTokenData, tok) - - url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, "streamGenerateContent") - if opts.Alt == "" { - url = url + "?alt=sse" - } else { - url = url + fmt.Sprintf("?$alt=%s", opts.Alt) - } - - reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) - if errReq != nil { - err = errReq - return nil, err - } - reqHTTP.Header.Set("Content-Type", "application/json") - reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) - applyGeminiCLIHeaders(reqHTTP, attemptModel) - reqHTTP.Header.Set("Accept", "text/event-stream") - util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes) - helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: reqHTTP.Header.Clone(), - Body: payload, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - httpResp, errDo := httpClient.Do(reqHTTP) - if errDo != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errDo) - err = errDo - return nil, err - } - helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - data, errRead := io.ReadAll(httpResp.Body) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("gemini cli executor: close response body error: %v", errClose) - } - if errRead != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errRead) - err = errRead - return nil, err - } - helps.AppendAPIResponseChunk(ctx, e.cfg, data) - lastStatus = httpResp.StatusCode - lastBody = append([]byte(nil), data...) - helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - if httpResp.StatusCode == 429 { - if idx+1 < len(models) { - log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) - } else { - log.Debug("gemini cli executor: rate limited, no additional fallback model") - } - continue - } - err = newGeminiStatusErr(httpResp.StatusCode, data) - return nil, err - } - - out := make(chan cliproxyexecutor.StreamChunk) - go func(resp *http.Response, reqBody []byte, attemptModel string) { - defer close(out) - defer func() { - if errClose := resp.Body.Close(); errClose != nil { - log.Errorf("gemini cli executor: close response body error: %v", errClose) - } - }() - if opts.Alt == "" { - scanner := bufio.NewScanner(resp.Body) - scanner.Buffer(nil, streamScannerBuffer) - var param any - for scanner.Scan() { - line := scanner.Bytes() - helps.AppendAPIResponseChunk(ctx, e.cfg, line) - if detail, ok := helps.ParseGeminiCLIStreamUsage(line); ok { - reporter.Publish(ctx, detail) - } - if bytes.HasPrefix(line, dataTag) { - segments := sdktranslator.TranslateStream(respCtx, to, responseFormat, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m) - for i := range segments { - select { - case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}: - case <-ctx.Done(): - return - } - } - } - } - - segments := sdktranslator.TranslateStream(respCtx, to, responseFormat, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m) - for i := range segments { - select { - case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}: - case <-ctx.Done(): - return - } - } - if errScan := scanner.Err(); errScan != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errScan) - reporter.PublishFailure(ctx, errScan) - select { - case out <- cliproxyexecutor.StreamChunk{Err: errScan}: - case <-ctx.Done(): - } - return - } - reporter.EnsurePublished(ctx) - return - } - - data, errRead := io.ReadAll(resp.Body) - if errRead != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errRead) - reporter.PublishFailure(ctx, errRead) - select { - case out <- cliproxyexecutor.StreamChunk{Err: errRead}: - case <-ctx.Done(): - } - return - } - helps.AppendAPIResponseChunk(ctx, e.cfg, data) - reporter.Publish(ctx, helps.ParseGeminiCLIUsage(data)) - var param any - segments := sdktranslator.TranslateStream(respCtx, to, responseFormat, attemptModel, opts.OriginalRequest, reqBody, data, ¶m) - for i := range segments { - select { - case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}: - case <-ctx.Done(): - return - } - } - - segments = sdktranslator.TranslateStream(respCtx, to, responseFormat, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m) - for i := range segments { - select { - case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}: - case <-ctx.Done(): - return - } - } - }(httpResp, append([]byte(nil), payload...), attemptModel) - - return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil - } - - if len(lastBody) > 0 { - helps.AppendAPIResponseChunk(ctx, e.cfg, lastBody) - } - if lastStatus == 0 { - lastStatus = 429 - } - err = newGeminiStatusErr(lastStatus, lastBody) - return nil, err -} - -// CountTokens counts tokens for the given request using the Gemini CLI API. -func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - baseModel := thinking.ParseSuffix(req.Model).ModelName - - tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth) - if err != nil { - return cliproxyexecutor.Response{}, err - } - - from := opts.SourceFormat - responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) - to := sdktranslator.FromString("gemini-cli") - - models := cliPreviewFallbackOrder(baseModel) - if len(models) == 0 || models[0] != baseModel { - models = append([]string{baseModel}, models...) - } - - httpClient := newHTTPClient(ctx, e.cfg, auth, 0) - respCtx := context.WithValue(ctx, "alt", opts.Alt) - - var authID, authLabel, authType, authValue string - if auth != nil { - authID = auth.ID - authLabel = auth.Label - authType, authValue = auth.AccountInfo() - } - - var lastStatus int - var lastBody []byte - - // The loop variable attemptModel is only used as the concrete model id sent to the upstream - // Gemini CLI endpoint when iterating fallback variants. - for range models { - payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) - - payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier()) - if err != nil { - return cliproxyexecutor.Response{}, err - } - - payload = deleteJSONField(payload, "project") - payload = deleteJSONField(payload, "model") - payload = deleteJSONField(payload, "request.safetySettings") - payload = fixGeminiCLIImageAspectRatio(baseModel, payload) - payload = cleanGeminiCLIRequestSchemas(payload) - - tok, errTok := tokenSource.Token() - if errTok != nil { - return cliproxyexecutor.Response{}, errTok - } - updateGeminiCLITokenMetadata(auth, baseTokenData, tok) - - url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, "countTokens") - if opts.Alt != "" { - url = url + fmt.Sprintf("?$alt=%s", opts.Alt) - } - - reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) - if errReq != nil { - return cliproxyexecutor.Response{}, errReq - } - reqHTTP.Header.Set("Content-Type", "application/json") - reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) - applyGeminiCLIHeaders(reqHTTP, baseModel) - reqHTTP.Header.Set("Accept", "application/json") - util.ApplyCustomHeadersFromAttrs(reqHTTP, auth.Attributes) - helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ - URL: url, - Method: http.MethodPost, - Headers: reqHTTP.Header.Clone(), - Body: payload, - Provider: e.Identifier(), - AuthID: authID, - AuthLabel: authLabel, - AuthType: authType, - AuthValue: authValue, - }) - - resp, errDo := httpClient.Do(reqHTTP) - if errDo != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errDo) - return cliproxyexecutor.Response{}, errDo - } - data, errRead := io.ReadAll(resp.Body) - if errClose := resp.Body.Close(); errClose != nil { - helps.LogWithRequestID(ctx).Errorf("response body close error: %v", errClose) - } - helps.RecordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone()) - if errRead != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errRead) - return cliproxyexecutor.Response{}, errRead - } - helps.AppendAPIResponseChunk(ctx, e.cfg, data) - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - count := gjson.GetBytes(data, "totalTokens").Int() - translated := sdktranslator.TranslateTokenCount(respCtx, to, responseFormat, count, data) - return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil - } - lastStatus = resp.StatusCode - lastBody = append([]byte(nil), data...) - if resp.StatusCode == 429 { - log.Debugf("gemini cli executor: rate limited, retrying with next model") - continue - } - break - } - - if lastStatus == 0 { - lastStatus = 429 - } - return cliproxyexecutor.Response{}, newGeminiStatusErr(lastStatus, lastBody) -} - -// Refresh refreshes the authentication credentials (no-op for Gemini CLI). -func (e *GeminiCLIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - if refreshed, handled, err := helps.RefreshAuthViaHome(ctx, e.cfg, auth); handled { - return refreshed, err - } - return auth, nil -} - -func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth) (oauth2.TokenSource, map[string]any, error) { - metadata := geminiOAuthMetadata(auth) - if auth == nil || metadata == nil { - return nil, nil, fmt.Errorf("gemini-cli auth metadata missing") - } - - buildToken := func(meta map[string]any) (map[string]any, oauth2.Token) { - var base map[string]any - if tokenRaw, ok := meta["token"].(map[string]any); ok && tokenRaw != nil { - base = cloneMap(tokenRaw) - } else { - base = make(map[string]any) - } - - var token oauth2.Token - if len(base) > 0 { - if raw, err := json.Marshal(base); err == nil { - _ = json.Unmarshal(raw, &token) - } - } - - if token.AccessToken == "" { - token.AccessToken = stringValue(meta, "access_token") - } - if token.RefreshToken == "" { - token.RefreshToken = stringValue(meta, "refresh_token") - } - if token.TokenType == "" { - token.TokenType = stringValue(meta, "token_type") - } - if token.Expiry.IsZero() { - if expiry := stringValue(meta, "expiry"); expiry != "" { - if ts, err := time.Parse(time.RFC3339, expiry); err == nil { - token.Expiry = ts - } - } - } - - return base, token - } - - base, token := buildToken(metadata) - - conf := &oauth2.Config{ - ClientID: geminiOAuthClientID, - ClientSecret: geminiOAuthClientSecret, - Scopes: geminiOAuthScopes, - Endpoint: google.Endpoint, - } - - ctxToken := ctx - if httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil { - ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient) - } - - if cfg != nil && cfg.Home.Enabled { - now := time.Now() - if token.AccessToken == "" || (!token.Expiry.IsZero() && token.Expiry.Before(now.Add(30*time.Second))) { - refreshed, handled, errRefresh := helps.RefreshAuthViaHome(ctx, cfg, auth) - if handled { - if errRefresh != nil { - return nil, nil, errRefresh - } - auth = refreshed - metadata = geminiOAuthMetadata(auth) - if metadata == nil { - return nil, nil, fmt.Errorf("gemini-cli auth metadata missing") - } - base, token = buildToken(metadata) - } - } - if token.AccessToken == "" { - return nil, nil, fmt.Errorf("gemini-cli access token missing") - } - updateGeminiCLITokenMetadata(auth, base, &token) - return oauth2.StaticTokenSource(&token), base, nil - } - - src := conf.TokenSource(ctxToken, &token) - currentToken, err := src.Token() - if err != nil { - return nil, nil, err - } - updateGeminiCLITokenMetadata(auth, base, currentToken) - return oauth2.ReuseTokenSource(currentToken, src), base, nil -} - -func updateGeminiCLITokenMetadata(auth *cliproxyauth.Auth, base map[string]any, tok *oauth2.Token) { - if auth == nil || tok == nil { - return - } - merged := buildGeminiTokenMap(base, tok) - fields := buildGeminiTokenFields(tok, merged) - shared := geminicli.ResolveSharedCredential(auth.Runtime) - if shared != nil { - snapshot := shared.MergeMetadata(fields) - if !geminicli.IsVirtual(auth.Runtime) { - auth.Metadata = snapshot - } - return - } - if auth.Metadata == nil { - auth.Metadata = make(map[string]any) - } - for k, v := range fields { - auth.Metadata[k] = v - } -} - -func buildGeminiTokenMap(base map[string]any, tok *oauth2.Token) map[string]any { - merged := cloneMap(base) - if merged == nil { - merged = make(map[string]any) - } - if raw, err := json.Marshal(tok); err == nil { - var tokenMap map[string]any - if err = json.Unmarshal(raw, &tokenMap); err == nil { - for k, v := range tokenMap { - merged[k] = v - } - } - } - return merged -} - -func buildGeminiTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any { - fields := make(map[string]any, 5) - if tok.AccessToken != "" { - fields["access_token"] = tok.AccessToken - } - if tok.TokenType != "" { - fields["token_type"] = tok.TokenType - } - if tok.RefreshToken != "" { - fields["refresh_token"] = tok.RefreshToken - } - if !tok.Expiry.IsZero() { - fields["expiry"] = tok.Expiry.Format(time.RFC3339) - } - if len(merged) > 0 { - fields["token"] = cloneMap(merged) - } - return fields -} - -func resolveGeminiProjectID(auth *cliproxyauth.Auth) string { - if auth == nil { - return "" - } - if runtime := auth.Runtime; runtime != nil { - if virtual, ok := runtime.(*geminicli.VirtualCredential); ok && virtual != nil { - return strings.TrimSpace(virtual.ProjectID) - } - } - return strings.TrimSpace(stringValue(auth.Metadata, "project_id")) -} - -func geminiOAuthMetadata(auth *cliproxyauth.Auth) map[string]any { - if auth == nil { - return nil - } - if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil { - if snapshot := shared.MetadataSnapshot(); len(snapshot) > 0 { - return snapshot - } - } - return auth.Metadata -} - -func newHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client { - return helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout) -} - -func cloneMap(in map[string]any) map[string]any { - if in == nil { - return nil - } - out := make(map[string]any, len(in)) - for k, v := range in { - out[k] = v - } - return out -} - -func stringValue(m map[string]any, key string) string { - if m == nil { - return "" - } - if v, ok := m[key]; ok { - switch typed := v.(type) { - case string: - return typed - case fmt.Stringer: - return typed.String() - } - } - return "" -} - -// applyGeminiCLIHeaders sets required headers for the Gemini CLI upstream. -// User-Agent is always forced to the GeminiCLI format regardless of the client's value, -// so that upstream identifies the request as a native GeminiCLI client. -func applyGeminiCLIHeaders(r *http.Request, model string) { - r.Header.Set("User-Agent", misc.GeminiCLIUserAgent(model)) - r.Header.Set("X-Goog-Api-Client", misc.GeminiCLIApiClientHeader) -} - -// cliPreviewFallbackOrder returns preview model candidates for a base model. -func cliPreviewFallbackOrder(model string) []string { - switch model { - case "gemini-2.5-pro": - return []string{ - // "gemini-2.5-pro-preview-05-06", - // "gemini-2.5-pro-preview-06-05", - } - case "gemini-2.5-flash": - return []string{ - // "gemini-2.5-flash-preview-04-17", - // "gemini-2.5-flash-preview-05-20", - } - case "gemini-2.5-flash-lite": - return []string{ - // "gemini-2.5-flash-lite-preview-06-17", - } - default: - return nil - } -} - -// setJSONField sets a top-level JSON field on a byte slice payload via sjson. -func setJSONField(body []byte, key, value string) []byte { - if key == "" { - return body - } - updated, err := sjson.SetBytes(body, key, value) - if err != nil { - return body - } - return updated -} - -// deleteJSONField removes a top-level key if present (best-effort) via sjson. -func deleteJSONField(body []byte, key string) []byte { - if key == "" || len(body) == 0 { - return body - } - updated, err := sjson.DeleteBytes(body, key) - if err != nil { - return body - } - return updated -} - -func cleanGeminiCLIRequestSchemas(body []byte) []byte { - if len(body) == 0 { - return body - } - hasTools := gjson.GetBytes(body, "request.tools.0").Exists() - hasResponseSchema := gjson.GetBytes(body, "request.generationConfig.responseSchema").Exists() - hasResponseJSONSchema := gjson.GetBytes(body, "request.generationConfig.responseJsonSchema").Exists() - if !hasTools && !hasResponseSchema && !hasResponseJSONSchema { - return body - } - - tools := gjson.GetBytes(body, "request.tools") - if tools.IsArray() { - for i, tool := range tools.Array() { - for _, declarationsKey := range []string{"function_declarations", "functionDeclarations"} { - funcDecls := tool.Get(declarationsKey) - if !funcDecls.IsArray() { - continue - } - for j, decl := range funcDecls.Array() { - for _, schemaKey := range []string{"parameters", "parametersJsonSchema"} { - params := decl.Get(schemaKey) - if !params.Exists() || !params.IsObject() { - continue - } - cleaned := util.CleanJSONSchemaForGemini(params.Raw) - path := fmt.Sprintf("request.tools.%d.%s.%d.%s", i, declarationsKey, j, schemaKey) - updated, errSet := sjson.SetRawBytes(body, path, []byte(cleaned)) - if errSet != nil { - log.Errorf("gemini cli executor: failed to set cleaned schema at %s: %v", path, errSet) - continue - } - body = updated - } - } - } - } - } - - for _, schemaPath := range []string{ - "request.generationConfig.responseSchema", - "request.generationConfig.responseJsonSchema", - } { - responseSchema := gjson.GetBytes(body, schemaPath) - if !responseSchema.IsObject() { - continue - } - cleaned := util.CleanJSONSchemaForGemini(responseSchema.Raw) - updated, errSet := sjson.SetRawBytes(body, schemaPath, []byte(cleaned)) - if errSet != nil { - log.Errorf("gemini cli executor: failed to set cleaned response schema at %s: %v", schemaPath, errSet) - continue - } - body = updated - } - - return body -} - -func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte { - if modelName == "gemini-2.5-flash-image-preview" { - aspectRatioResult := gjson.GetBytes(rawJSON, "request.generationConfig.imageConfig.aspectRatio") - if aspectRatioResult.Exists() { - contents := gjson.GetBytes(rawJSON, "request.contents") - contentArray := contents.Array() - if len(contentArray) > 0 { - hasInlineData := false - loopContent: - for i := 0; i < len(contentArray); i++ { - parts := contentArray[i].Get("parts").Array() - for j := 0; j < len(parts); j++ { - if parts[j].Get("inlineData").Exists() { - hasInlineData = true - break loopContent - } - } - } - - if !hasInlineData { - emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String()) - emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`) - emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed) - newPartsJson := []byte(`[]`) - newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)) - newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart) - - parts := contentArray[0].Get("parts").Array() - for j := 0; j < len(parts); j++ { - newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw)) - } - - rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", newPartsJson) - rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`)) - } - } - rawJSON, _ = sjson.DeleteBytes(rawJSON, "request.generationConfig.imageConfig") - } - } - return rawJSON -} - -func newGeminiStatusErr(statusCode int, body []byte) statusErr { - err := statusErr{code: statusCode, msg: string(body)} - if statusCode == http.StatusTooManyRequests { - if retryAfter, parseErr := parseRetryDelay(body); parseErr == nil && retryAfter != nil { - err.retryAfter = retryAfter - } - } - return err -} - -// parseRetryDelay extracts the retry delay from a Google API 429 error response. -// The error response contains a RetryInfo.retryDelay field in the format "0.847655010s". -// Returns the parsed duration or an error if it cannot be determined. -func parseRetryDelay(errorBody []byte) (*time.Duration, error) { - // Try to parse the retryDelay from the error response - // Format: error.details[].retryDelay where @type == "type.googleapis.com/google.rpc.RetryInfo" - details := gjson.GetBytes(errorBody, "error.details") - if details.Exists() && details.IsArray() { - for _, detail := range details.Array() { - typeVal := detail.Get("@type").String() - if typeVal == "type.googleapis.com/google.rpc.RetryInfo" { - retryDelay := detail.Get("retryDelay").String() - if retryDelay != "" { - // Parse duration string like "0.847655010s" - duration, err := time.ParseDuration(retryDelay) - if err != nil { - return nil, fmt.Errorf("failed to parse duration") - } - return &duration, nil - } - } - } - - // Fallback: try ErrorInfo.metadata.quotaResetDelay (e.g., "373.801628ms") - for _, detail := range details.Array() { - typeVal := detail.Get("@type").String() - if typeVal == "type.googleapis.com/google.rpc.ErrorInfo" { - quotaResetDelay := detail.Get("metadata.quotaResetDelay").String() - if quotaResetDelay != "" { - duration, err := time.ParseDuration(quotaResetDelay) - if err == nil { - return &duration, nil - } - } - } - } - } - - // Fallback: parse from error.message "Your quota will reset after Xs." - message := gjson.GetBytes(errorBody, "error.message").String() - if message != "" { - re := regexp.MustCompile(`after\s+(\d+)s\.?`) - if matches := re.FindStringSubmatch(message); len(matches) > 1 { - seconds, err := strconv.Atoi(matches[1]) - if err == nil { - duration := time.Duration(seconds) * time.Second - return &duration, nil - } - } - reHuman := regexp.MustCompile(`after\s+((?:\d+h)?(?:\d+m)?(?:\d+s)?)\.?`) - if matches := reHuman.FindStringSubmatch(strings.ToLower(message)); len(matches) > 1 { - if duration, err := time.ParseDuration(matches[1]); err == nil && duration > 0 { - return &duration, nil - } - } - } - - return nil, fmt.Errorf("no RetryInfo found") -} diff --git a/internal/runtime/executor/gemini_cli_executor_test.go b/internal/runtime/executor/gemini_cli_executor_test.go deleted file mode 100644 index b77134ed8c..0000000000 --- a/internal/runtime/executor/gemini_cli_executor_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package executor - -import ( - "strings" - "testing" - - "github.com/tidwall/gjson" -) - -func TestCleanGeminiCLIRequestSchemasFlattensFunctionDeclarationTypeArray(t *testing.T) { - input := []byte(`{ - "request": { - "tools": [ - { - "function_declarations": [ - { - "name": "wecom_mcp", - "parameters": { - "type": "object", - "properties": { - "args": { - "description": "call args", - "type": ["string", "object"] - } - } - } - } - ] - }, - { - "functionDeclarations": [ - { - "name": "camel_tool", - "parametersJsonSchema": { - "type": "object", - "properties": { - "value": { - "type": ["integer", "string"] - } - } - } - } - ] - } - ], - "nonSchema": { - "type": ["string", "object"] - } - } - }`) - - out := cleanGeminiCLIRequestSchemas(input) - - argsType := gjson.GetBytes(out, "request.tools.0.function_declarations.0.parameters.properties.args.type") - if argsType.String() != "string" { - t.Fatalf("args.type = %s, want string; body=%s", argsType.Raw, string(out)) - } - argsDesc := gjson.GetBytes(out, "request.tools.0.function_declarations.0.parameters.properties.args.description").String() - if !strings.Contains(argsDesc, "Accepts: string | object") { - t.Fatalf("args.description = %q, want accepted type hint", argsDesc) - } - - valueType := gjson.GetBytes(out, "request.tools.1.functionDeclarations.0.parametersJsonSchema.properties.value.type") - if valueType.String() != "integer" { - t.Fatalf("value.type = %s, want integer; body=%s", valueType.Raw, string(out)) - } - valueDesc := gjson.GetBytes(out, "request.tools.1.functionDeclarations.0.parametersJsonSchema.properties.value.description").String() - if !strings.Contains(valueDesc, "Accepts: integer | string") { - t.Fatalf("value.description = %q, want accepted type hint", valueDesc) - } - - if nonSchema := gjson.GetBytes(out, "request.nonSchema.type"); !nonSchema.IsArray() { - t.Fatalf("request.nonSchema.type should be preserved outside schema paths, got %s", nonSchema.Raw) - } -} diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 6f502a737b..f68a7073a9 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -37,8 +37,7 @@ const ( ) // GeminiExecutor is a stateless executor for the official Gemini API using API keys. -// It handles both API key and OAuth bearer token authentication, supporting both -// regular and streaming requests to the Google Generative Language API. +// It supports regular and streaming requests to the Google Generative Language API. type GeminiExecutor struct { // cfg holds the application configuration. cfg *config.Config @@ -63,13 +62,10 @@ func (e *GeminiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Au if req == nil { return nil } - apiKey, bearer := geminiCreds(auth) + apiKey := geminiAPIKey(auth) if apiKey != "" { req.Header.Set("x-goog-api-key", apiKey) req.Header.Del("Authorization") - } else if bearer != "" { - req.Header.Set("Authorization", "Bearer "+bearer) - req.Header.Del("x-goog-api-key") } applyGeminiHeaders(req, auth) return nil @@ -110,12 +106,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } baseModel := thinking.ParseSuffix(req.Model).ModelName - apiKey, bearer := geminiCreds(auth) + apiKey := geminiAPIKey(auth) reporter := helps.NewExecutorUsageReporter(ctx, e, baseModel, auth) defer reporter.TrackFailure(ctx, &err) - // Official Gemini API via API key or OAuth bearer + // Official Gemini API via API key. from := opts.SourceFormat responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) to := sdktranslator.FromString("gemini") @@ -161,8 +157,6 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r httpReq.Header.Set("Content-Type", "application/json") if apiKey != "" { httpReq.Header.Set("x-goog-api-key", apiKey) - } else if bearer != "" { - httpReq.Header.Set("Authorization", "Bearer "+bearer) } applyGeminiHeaders(httpReq, auth) var authID, authLabel, authType, authValue string @@ -223,7 +217,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A } baseModel := thinking.ParseSuffix(req.Model).ModelName - apiKey, bearer := geminiCreds(auth) + apiKey := geminiAPIKey(auth) reporter := helps.NewExecutorUsageReporter(ctx, e, baseModel, auth) defer reporter.TrackFailure(ctx, &err) @@ -269,8 +263,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A httpReq.Header.Set("Content-Type", "application/json") if apiKey != "" { httpReq.Header.Set("x-goog-api-key", apiKey) - } else { - httpReq.Header.Set("Authorization", "Bearer "+bearer) } applyGeminiHeaders(httpReq, auth) var authID, authLabel, authType, authValue string @@ -364,7 +356,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { baseModel := thinking.ParseSuffix(req.Model).ModelName - apiKey, bearer := geminiCreds(auth) + apiKey := geminiAPIKey(auth) from := opts.SourceFormat responseFormat := cliproxyexecutor.ResponseFormatOrSource(opts) @@ -395,8 +387,6 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut httpReq.Header.Set("Content-Type", "application/json") if apiKey != "" { httpReq.Header.Set("x-goog-api-key", apiKey) - } else { - httpReq.Header.Set("Authorization", "Bearer "+bearer) } applyGeminiHeaders(httpReq, auth) var authID, authLabel, authType, authValue string @@ -454,27 +444,16 @@ func (e *GeminiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) ( return auth, nil } -func geminiCreds(a *cliproxyauth.Auth) (apiKey, bearer string) { +func geminiAPIKey(a *cliproxyauth.Auth) string { if a == nil { - return "", "" + return "" } if a.Attributes != nil { if v := a.Attributes["api_key"]; v != "" { - apiKey = v - } - } - if a.Metadata != nil { - // GeminiTokenStorage.Token is a map that may contain access_token - if v, ok := a.Metadata["access_token"].(string); ok && v != "" { - bearer = v - } - if token, ok := a.Metadata["token"].(map[string]any); ok && token != nil { - if v, ok2 := token["access_token"].(string); ok2 && v != "" { - bearer = v - } + return v } } - return + return "" } func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string { diff --git a/internal/runtime/executor/helps/json_retry_helpers.go b/internal/runtime/executor/helps/json_retry_helpers.go new file mode 100644 index 0000000000..e2b1412301 --- /dev/null +++ b/internal/runtime/executor/helps/json_retry_helpers.go @@ -0,0 +1,80 @@ +package helps + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// DeleteJSONField removes a top-level or nested JSON field from a payload. +func DeleteJSONField(body []byte, key string) []byte { + if key == "" || len(body) == 0 { + return body + } + updated, err := sjson.DeleteBytes(body, key) + if err != nil { + return body + } + return updated +} + +// ParseRetryDelay extracts the retry delay from a Google API 429 error response. +func ParseRetryDelay(errorBody []byte) (*time.Duration, error) { + details := gjson.GetBytes(errorBody, "error.details") + if details.Exists() && details.IsArray() { + for _, detail := range details.Array() { + if detail.Get("@type").String() != "type.googleapis.com/google.rpc.RetryInfo" { + continue + } + retryDelay := detail.Get("retryDelay").String() + if retryDelay == "" { + continue + } + duration, err := time.ParseDuration(retryDelay) + if err != nil { + return nil, fmt.Errorf("failed to parse duration") + } + return &duration, nil + } + + for _, detail := range details.Array() { + if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" { + continue + } + quotaResetDelay := detail.Get("metadata.quotaResetDelay").String() + if quotaResetDelay == "" { + continue + } + duration, err := time.ParseDuration(quotaResetDelay) + if err == nil { + return &duration, nil + } + } + } + + message := gjson.GetBytes(errorBody, "error.message").String() + if message != "" { + re := regexp.MustCompile(`after\s+(\d+)s\.?`) + if matches := re.FindStringSubmatch(message); len(matches) > 1 { + seconds, err := strconv.Atoi(matches[1]) + if err == nil { + duration := time.Duration(seconds) * time.Second + return &duration, nil + } + } + reHuman := regexp.MustCompile(`after\s+((?:\d+h)?(?:\d+m)?(?:\d+s)?)\.?`) + if matches := reHuman.FindStringSubmatch(strings.ToLower(message)); len(matches) > 1 { + duration, err := time.ParseDuration(matches[1]) + if err == nil && duration > 0 { + return &duration, nil + } + } + } + + return nil, fmt.Errorf("no RetryInfo found") +} diff --git a/internal/runtime/executor/helps/payload_helpers.go b/internal/runtime/executor/helps/payload_helpers.go index 8f8434c82c..2035898309 100644 --- a/internal/runtime/executor/helps/payload_helpers.go +++ b/internal/runtime/executor/helps/payload_helpers.go @@ -15,8 +15,8 @@ import ( ) // ApplyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter -// paths as relative to the provided root path (for example, "request" for Gemini CLI) -// and restricts matches to the given protocol when supplied. Defaults are checked +// paths as relative to the provided root path and restricts matches to the given +// protocol when supplied. Defaults are checked // against the original payload when provided. requestedModel carries the client-visible // model name before alias resolution so payload rules can target aliases precisely. // requestPath is the inbound HTTP request path (when available) used for endpoint-scoped gates. @@ -398,8 +398,6 @@ func normalizePayloadFromProtocol(protocol string) string { switch protocol { case "openai-response", "openai-responses", "response": return "responses" - case "gemini-cli": - return "gemini" default: return protocol } diff --git a/internal/runtime/executor/helps/payload_helpers_disable_image_generation_test.go b/internal/runtime/executor/helps/payload_helpers_disable_image_generation_test.go index fe6de37f64..d2649703ba 100644 --- a/internal/runtime/executor/helps/payload_helpers_disable_image_generation_test.go +++ b/internal/runtime/executor/helps/payload_helpers_disable_image_generation_test.go @@ -35,7 +35,7 @@ func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntryWith } payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}]}}`) - out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "", "") + out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "antigravity", "request", payload, nil, "", "") tools := gjson.GetBytes(out, "request.tools") if !tools.Exists() || !tools.IsArray() { @@ -69,7 +69,7 @@ func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolChoiceByNa } payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}],"tool_choice":{"type":"tool","name":"image_generation"}}}`) - out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "", "") + out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "antigravity", "request", payload, nil, "", "") if gjson.GetBytes(out, "request.tool_choice").Exists() { t.Fatalf("expected request.tool_choice to be removed") diff --git a/internal/runtime/executor/helps/thinking_providers.go b/internal/runtime/executor/helps/thinking_providers.go index 013f93e34f..e879ff1308 100644 --- a/internal/runtime/executor/helps/thinking_providers.go +++ b/internal/runtime/executor/helps/thinking_providers.go @@ -5,7 +5,6 @@ import ( _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/claude" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/codex" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/gemini" - _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/geminicli" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/kimi" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/openai" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/xai" diff --git a/internal/runtime/executor/helps/usage_helpers.go b/internal/runtime/executor/helps/usage_helpers.go index 551bd02ad3..33004d8c86 100644 --- a/internal/runtime/executor/helps/usage_helpers.go +++ b/internal/runtime/executor/helps/usage_helpers.go @@ -404,11 +404,6 @@ func APIKeyFromContext(ctx context.Context) string { func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string { if auth != nil { provider := strings.TrimSpace(auth.Provider) - if strings.EqualFold(provider, "gemini-cli") { - if id := strings.TrimSpace(auth.ID); id != "" { - return id - } - } if strings.EqualFold(provider, "vertex") { if auth.Metadata != nil { if projectID, ok := auth.Metadata["project_id"].(string); ok { @@ -590,28 +585,6 @@ func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail { return detail } -func hasGeminiFamilyUsageTokenFields(node gjson.Result) bool { - return node.Get("promptTokenCount").Exists() || - node.Get("candidatesTokenCount").Exists() || - node.Get("thoughtsTokenCount").Exists() || - node.Get("totalTokenCount").Exists() || - node.Get("cachedContentTokenCount").Exists() -} - -func ParseGeminiCLIUsage(data []byte) usage.Detail { - usageNode := gjson.ParseBytes(data) - node := firstExistingUsageNode(usageNode, - "response.usageMetadata", - "response.usage_metadata", - "usageMetadata", - "usage_metadata", - ) - if !node.Exists() { - return usage.Detail{} - } - return parseGeminiFamilyUsageDetail(node) -} - func ParseGeminiUsage(data []byte) usage.Detail { usageNode := gjson.ParseBytes(data) node := usageNode.Get("usageMetadata") @@ -639,27 +612,6 @@ func ParseGeminiStreamUsage(line []byte) (usage.Detail, bool) { return parseGeminiFamilyUsageDetail(node), true } -func ParseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) { - payload := jsonPayload(line) - if len(payload) == 0 || !gjson.ValidBytes(payload) { - return usage.Detail{}, false - } - root := gjson.ParseBytes(payload) - node := firstExistingUsageNode(root, - "response.usageMetadata", - "response.usage_metadata", - "usageMetadata", - "usage_metadata", - ) - if !node.Exists() { - return usage.Detail{}, false - } - if !hasGeminiFamilyUsageTokenFields(node) { - return usage.Detail{}, false - } - return parseGeminiFamilyUsageDetail(node), true -} - func firstExistingUsageNode(root gjson.Result, paths ...string) gjson.Result { for _, path := range paths { node := root.Get(path) diff --git a/internal/runtime/executor/helps/usage_helpers_test.go b/internal/runtime/executor/helps/usage_helpers_test.go index 5cca50acac..a7557aea57 100644 --- a/internal/runtime/executor/helps/usage_helpers_test.go +++ b/internal/runtime/executor/helps/usage_helpers_test.go @@ -123,50 +123,6 @@ func TestParseClaudeUsageFallsBackCachedTokensToCacheCreation(t *testing.T) { } } -func TestParseGeminiCLIUsage_TopLevelUsageMetadata(t *testing.T) { - data := []byte(`{"usageMetadata":{"promptTokenCount":11,"candidatesTokenCount":7,"thoughtsTokenCount":3,"totalTokenCount":21,"cachedContentTokenCount":5}}`) - detail := ParseGeminiCLIUsage(data) - if detail.InputTokens != 11 { - t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 11) - } - if detail.OutputTokens != 7 { - t.Fatalf("output tokens = %d, want %d", detail.OutputTokens, 7) - } - if detail.ReasoningTokens != 3 { - t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 3) - } - if detail.TotalTokens != 21 { - t.Fatalf("total tokens = %d, want %d", detail.TotalTokens, 21) - } - if detail.CachedTokens != 5 { - t.Fatalf("cached tokens = %d, want %d", detail.CachedTokens, 5) - } -} - -func TestParseGeminiCLIStreamUsage_ResponseSnakeCaseUsageMetadata(t *testing.T) { - line := []byte(`data: {"response":{"usage_metadata":{"promptTokenCount":13,"candidatesTokenCount":2,"totalTokenCount":15}}}`) - detail, ok := ParseGeminiCLIStreamUsage(line) - if !ok { - t.Fatal("ParseGeminiCLIStreamUsage() ok = false, want true") - } - if detail.InputTokens != 13 { - t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 13) - } - if detail.OutputTokens != 2 { - t.Fatalf("output tokens = %d, want %d", detail.OutputTokens, 2) - } - if detail.TotalTokens != 15 { - t.Fatalf("total tokens = %d, want %d", detail.TotalTokens, 15) - } -} - -func TestParseGeminiCLIStreamUsage_IgnoresTrafficTypeOnlyUsageMetadata(t *testing.T) { - line := []byte(`data: {"response":{"usageMetadata":{"trafficType":"ON_DEMAND"}}}`) - if detail, ok := ParseGeminiCLIStreamUsage(line); ok { - t.Fatalf("ParseGeminiCLIStreamUsage() = (%+v, true), want false for traffic-only usage metadata", detail) - } -} - func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) { reporter := &UsageReporter{ provider: "openai", diff --git a/internal/runtime/executor/xai_executor.go b/internal/runtime/executor/xai_executor.go index ff9acd08b6..c6795ef98c 100644 --- a/internal/runtime/executor/xai_executor.go +++ b/internal/runtime/executor/xai_executor.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/google/uuid" xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps" @@ -49,6 +50,7 @@ const ( xaiVideosExtensionsPath = "/videos/extensions" xaiVideosPath = "/videos" xaiIdempotencyKeyMetaKey = "idempotency_key" + xaiComposerModelPrefix = "grok-composer-" ) // XAIExecutor is a stateless executor for xAI Grok's Responses API. @@ -837,6 +839,9 @@ func (e *XAIExecutor) prepareResponsesRequestTo(ctx context.Context, req cliprox body = sanitizeXAIResponsesBody(body, baseModel) sessionID := xaiExecutionSessionID(req, opts) + if sessionID == "" && xaiRequiresIsolatedConversation(baseModel) { + sessionID = uuid.NewString() + } if sessionID != "" { body, _ = sjson.SetBytes(body, "prompt_cache_key", sessionID) } @@ -925,6 +930,10 @@ func xaiExecutionSessionID(req cliproxyexecutor.Request, opts cliproxyexecutor.O return "" } +func xaiRequiresIsolatedConversation(model string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), xaiComposerModelPrefix) +} + func xaiImageEndpointPath(opts cliproxyexecutor.Options) string { if opts.SourceFormat.String() != xaiImageHandlerType { return "" diff --git a/internal/runtime/executor/xai_executor_test.go b/internal/runtime/executor/xai_executor_test.go index 8ed24fe9c2..5e7b371a22 100644 --- a/internal/runtime/executor/xai_executor_test.go +++ b/internal/runtime/executor/xai_executor_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + "github.com/google/uuid" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" @@ -159,6 +160,100 @@ func TestXAIExecutorExecuteShapesResponsesRequest(t *testing.T) { } } +func TestXAIExecutorComposerSessionIsolation(t *testing.T) { + exec := NewXAIExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + Provider: "xai", + Metadata: map[string]any{"access_token": "xai-token"}, + } + + tests := []struct { + name string + model string + payload []byte + wantGenerated bool + wantSession string + }{ + { + name: "composer_generates_fresh_session", + model: "grok-composer-2.5-fast", + payload: []byte(`{"model":"grok-composer-2.5-fast","input":"hello"}`), + wantGenerated: true, + }, + { + name: "grok_build_stays_stateless_without_session", + model: "grok-build-0.1", + payload: []byte(`{"model":"grok-build-0.1","input":"hello"}`), + }, + { + name: "explicit_prompt_cache_key_is_preserved", + model: "grok-composer-2.5-fast", + payload: []byte(`{"model":"grok-composer-2.5-fast","prompt_cache_key":"client-session","input":"hello"}`), + wantSession: "client-session", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prepared, err := exec.prepareResponsesRequest(context.Background(), cliproxyexecutor.Request{ + Model: tt.model, + Payload: tt.payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FormatOpenAIResponse, + Stream: true, + }, true) + if err != nil { + t.Fatalf("prepareResponsesRequest() error = %v", err) + } + + gotSession := prepared.sessionID + gotPromptCacheKey := gjson.GetBytes(prepared.body, "prompt_cache_key").String() + httpReq, errRequest := http.NewRequest(http.MethodPost, "https://example.test/responses", bytes.NewReader(prepared.body)) + if errRequest != nil { + t.Fatalf("NewRequest() error = %v", errRequest) + } + applyXAIHeaders(httpReq, auth, "xai-token", true, gotSession) + gotGrokConvID := httpReq.Header.Get("x-grok-conv-id") + + if tt.wantGenerated { + if _, errParse := uuid.Parse(gotSession); errParse != nil { + t.Fatalf("generated sessionID = %q, want UUID; body=%s", gotSession, string(prepared.body)) + } + if gotPromptCacheKey != gotSession { + t.Fatalf("prompt_cache_key = %q, want sessionID %q; body=%s", gotPromptCacheKey, gotSession, string(prepared.body)) + } + if gotGrokConvID != gotSession { + t.Fatalf("x-grok-conv-id = %q, want sessionID %q", gotGrokConvID, gotSession) + } + return + } + + if tt.wantSession != "" { + if gotSession != tt.wantSession { + t.Fatalf("sessionID = %q, want %q", gotSession, tt.wantSession) + } + if gotPromptCacheKey != tt.wantSession { + t.Fatalf("prompt_cache_key = %q, want %q; body=%s", gotPromptCacheKey, tt.wantSession, string(prepared.body)) + } + if gotGrokConvID != tt.wantSession { + t.Fatalf("x-grok-conv-id = %q, want %q", gotGrokConvID, tt.wantSession) + } + return + } + + if gotSession != "" { + t.Fatalf("sessionID = %q, want empty", gotSession) + } + if gotPromptCacheKey != "" { + t.Fatalf("prompt_cache_key = %q, want empty; body=%s", gotPromptCacheKey, string(prepared.body)) + } + if gotGrokConvID != "" { + t.Fatalf("x-grok-conv-id = %q, want empty", gotGrokConvID) + } + }) + } +} + func TestXAIExecutorCompactUsesCompactEndpoint(t *testing.T) { var gotPath string var gotAuth string diff --git a/internal/runtime/geminicli/state.go b/internal/runtime/geminicli/state.go deleted file mode 100644 index e323b44bf2..0000000000 --- a/internal/runtime/geminicli/state.go +++ /dev/null @@ -1,144 +0,0 @@ -package geminicli - -import ( - "strings" - "sync" -) - -// SharedCredential keeps canonical OAuth metadata for a multi-project Gemini CLI login. -type SharedCredential struct { - primaryID string - email string - metadata map[string]any - projectIDs []string - mu sync.RWMutex -} - -// NewSharedCredential builds a shared credential container for the given primary entry. -func NewSharedCredential(primaryID, email string, metadata map[string]any, projectIDs []string) *SharedCredential { - return &SharedCredential{ - primaryID: strings.TrimSpace(primaryID), - email: strings.TrimSpace(email), - metadata: cloneMap(metadata), - projectIDs: cloneStrings(projectIDs), - } -} - -// PrimaryID returns the owning credential identifier. -func (s *SharedCredential) PrimaryID() string { - if s == nil { - return "" - } - return s.primaryID -} - -// Email returns the associated account email. -func (s *SharedCredential) Email() string { - if s == nil { - return "" - } - return s.email -} - -// ProjectIDs returns a snapshot of the configured project identifiers. -func (s *SharedCredential) ProjectIDs() []string { - if s == nil { - return nil - } - return cloneStrings(s.projectIDs) -} - -// MetadataSnapshot returns a deep copy of the stored OAuth metadata. -func (s *SharedCredential) MetadataSnapshot() map[string]any { - if s == nil { - return nil - } - s.mu.RLock() - defer s.mu.RUnlock() - return cloneMap(s.metadata) -} - -// MergeMetadata merges the provided fields into the shared metadata and returns an updated copy. -func (s *SharedCredential) MergeMetadata(values map[string]any) map[string]any { - if s == nil { - return nil - } - if len(values) == 0 { - return s.MetadataSnapshot() - } - s.mu.Lock() - defer s.mu.Unlock() - if s.metadata == nil { - s.metadata = make(map[string]any, len(values)) - } - for k, v := range values { - if v == nil { - delete(s.metadata, k) - continue - } - s.metadata[k] = v - } - return cloneMap(s.metadata) -} - -// SetProjectIDs updates the stored project identifiers. -func (s *SharedCredential) SetProjectIDs(ids []string) { - if s == nil { - return - } - s.mu.Lock() - s.projectIDs = cloneStrings(ids) - s.mu.Unlock() -} - -// VirtualCredential tracks a per-project virtual auth entry that reuses a primary credential. -type VirtualCredential struct { - ProjectID string - Parent *SharedCredential -} - -// NewVirtualCredential creates a virtual credential descriptor bound to the shared parent. -func NewVirtualCredential(projectID string, parent *SharedCredential) *VirtualCredential { - return &VirtualCredential{ProjectID: strings.TrimSpace(projectID), Parent: parent} -} - -// ResolveSharedCredential returns the shared credential backing the provided runtime payload. -func ResolveSharedCredential(runtime any) *SharedCredential { - switch typed := runtime.(type) { - case *SharedCredential: - return typed - case *VirtualCredential: - return typed.Parent - default: - return nil - } -} - -// IsVirtual reports whether the runtime payload represents a virtual credential. -func IsVirtual(runtime any) bool { - if runtime == nil { - return false - } - _, ok := runtime.(*VirtualCredential) - return ok -} - -func cloneMap(in map[string]any) map[string]any { - if len(in) == 0 { - return nil - } - out := make(map[string]any, len(in)) - for k, v := range in { - out[k] = v - } - return out -} - -func cloneStrings(in []string) []string { - if len(in) == 0 { - return nil - } - out := make([]string, len(in)) - copy(out, in) - return out -} diff --git a/internal/thinking/apply.go b/internal/thinking/apply.go index de2e604ee6..389196b0e0 100644 --- a/internal/thinking/apply.go +++ b/internal/thinking/apply.go @@ -21,7 +21,6 @@ var providerAppliersMu sync.RWMutex // nativeProviderAppliers maps built-in provider names to their implementations. var nativeProviderAppliers = map[string]ProviderApplier{ "gemini": nil, - "gemini-cli": nil, "claude": nil, "openai": nil, "codex": nil, @@ -140,7 +139,7 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool { // - body: Original request body JSON // - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)") // - fromFormat: Source request format (e.g., openai, codex, gemini) -// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, kimi, xai) +// - toFormat: Target provider format for the request body (gemini, antigravity, claude, openai, codex, kimi, xai) // - providerKey: Provider identifier used for registry model lookups (may differ from toFormat, e.g., openrouter -> openai) // // Returns: @@ -413,7 +412,7 @@ func extractThinkingConfig(body []byte, provider string) ThinkingConfig { switch provider { case "claude": return extractClaudeConfig(body) - case "gemini", "gemini-cli", "antigravity": + case "gemini", "antigravity": return extractGeminiConfig(body, provider) case "openai": return extractOpenAIConfig(body) @@ -560,13 +559,13 @@ func extractClaudeConfig(body []byte) ThinkingConfig { // - generationConfig.thinkingConfig.thinkingLevel: "none", "auto", or level name (Gemini 3) // - generationConfig.thinkingConfig.thinkingBudget: integer (Gemini 2.5) // -// For gemini-cli and antigravity providers, the path is prefixed with "request.". +// For antigravity providers, the path is prefixed with "request.". // // Priority: thinkingLevel is checked first (Gemini 3 format), then thinkingBudget (Gemini 2.5 format). // This allows newer Gemini 3 level-based configs to take precedence. func extractGeminiConfig(body []byte, provider string) ThinkingConfig { prefix := "generationConfig.thinkingConfig" - if provider == "gemini-cli" || provider == "antigravity" { + if provider == "antigravity" { prefix = "request.generationConfig.thinkingConfig" } diff --git a/internal/thinking/provider/antigravity/apply.go b/internal/thinking/provider/antigravity/apply.go index 4a2c76c30e..cb0659f123 100644 --- a/internal/thinking/provider/antigravity/apply.go +++ b/internal/thinking/provider/antigravity/apply.go @@ -1,6 +1,6 @@ // Package antigravity implements thinking configuration for Antigravity API format. // -// Antigravity uses request.generationConfig.thinkingConfig.* path (same as gemini-cli) +// Antigravity uses request.generationConfig.thinkingConfig.* path. // but requires additional normalization for Claude models: // - Ensure thinking budget < max_tokens // - Remove thinkingConfig if budget < minimum allowed diff --git a/internal/thinking/provider/geminicli/apply.go b/internal/thinking/provider/geminicli/apply.go deleted file mode 100644 index 1bc9315f0e..0000000000 --- a/internal/thinking/provider/geminicli/apply.go +++ /dev/null @@ -1,165 +0,0 @@ -// Package geminicli implements thinking configuration for Gemini CLI API format. -// -// Gemini CLI uses request.generationConfig.thinkingConfig.* path instead of -// generationConfig.thinkingConfig.* used by standard Gemini API. -package geminicli - -import ( - "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" - "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// Applier applies thinking configuration for Gemini CLI API format. -type Applier struct{} - -var _ thinking.ProviderApplier = (*Applier)(nil) - -// NewApplier creates a new Gemini CLI thinking applier. -func NewApplier() *Applier { - return &Applier{} -} - -func init() { - thinking.RegisterProvider("gemini-cli", NewApplier()) -} - -// Apply applies thinking configuration to Gemini CLI request body. -func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) { - if thinking.IsUserDefinedModel(modelInfo) { - return a.applyCompatible(body, config) - } - if modelInfo.Thinking == nil { - return body, nil - } - - if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto { - return body, nil - } - - if len(body) == 0 || !gjson.ValidBytes(body) { - body = []byte(`{}`) - } - - // ModeAuto: Always use Budget format with thinkingBudget=-1 - if config.Mode == thinking.ModeAuto { - return a.applyBudgetFormat(body, config) - } - if config.Mode == thinking.ModeBudget { - return a.applyBudgetFormat(body, config) - } - - // For non-auto modes, choose format based on model capabilities - support := modelInfo.Thinking - if len(support.Levels) > 0 { - return a.applyLevelFormat(body, config) - } - return a.applyBudgetFormat(body, config) -} - -func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig) ([]byte, error) { - if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto { - return body, nil - } - - if len(body) == 0 || !gjson.ValidBytes(body) { - body = []byte(`{}`) - } - - if config.Mode == thinking.ModeAuto { - return a.applyBudgetFormat(body, config) - } - - if config.Mode == thinking.ModeLevel || (config.Mode == thinking.ModeNone && config.Level != "") { - return a.applyLevelFormat(body, config) - } - - return a.applyBudgetFormat(body, config) -} - -func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) { - // Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output - result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget") - result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget") - result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level") - // Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing. - result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts") - - if config.Mode == thinking.ModeNone { - if config.Budget == 0 && config.Level == "" { - result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig") - return result, nil - } - result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false) - if config.Level != "" { - result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", string(config.Level)) - } - return result, nil - } - - // Only handle ModeLevel - budget conversion should be done by upper layer - if config.Mode != thinking.ModeLevel { - return body, nil - } - - level := string(config.Level) - result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level) - - // Respect user's explicit includeThoughts setting from original body; default to true if not set - // Support both camelCase and snake_case variants - includeThoughts := true - if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() { - includeThoughts = inc.Bool() - } else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() { - includeThoughts = inc.Bool() - } - result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts) - return result, nil -} - -func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) { - // Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output - result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel") - result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level") - result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget") - // Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing. - result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts") - - budget := config.Budget - - // For ModeNone, always set includeThoughts to false regardless of user setting. - // This ensures that when user requests budget=0 (disable thinking output), - // the includeThoughts is correctly set to false even if budget is clamped to min. - if config.Mode == thinking.ModeNone { - result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget) - result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false) - return result, nil - } - - // Determine includeThoughts: respect user's explicit setting from original body if provided - // Support both camelCase and snake_case variants - var includeThoughts bool - var userSetIncludeThoughts bool - if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() { - includeThoughts = inc.Bool() - userSetIncludeThoughts = true - } else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() { - includeThoughts = inc.Bool() - userSetIncludeThoughts = true - } - - if !userSetIncludeThoughts { - // No explicit setting, use default logic based on mode - switch config.Mode { - case thinking.ModeAuto: - includeThoughts = true - default: - includeThoughts = budget > 0 - } - } - - result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget) - result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts) - return result, nil -} diff --git a/internal/thinking/strip.go b/internal/thinking/strip.go index 75755b31ff..9fac8ae9ed 100644 --- a/internal/thinking/strip.go +++ b/internal/thinking/strip.go @@ -33,7 +33,7 @@ func StripThinkingConfig(body []byte, provider string) []byte { paths = []string{"thinking", "output_config.effort"} case "gemini": paths = []string{"generationConfig.thinkingConfig"} - case "gemini-cli", "antigravity": + case "antigravity": paths = []string{"request.generationConfig.thinkingConfig"} case "openai": paths = []string{"reasoning_effort"} diff --git a/internal/thinking/validate.go b/internal/thinking/validate.go index 2baa93f1da..7a7a8fa664 100644 --- a/internal/thinking/validate.go +++ b/internal/thinking/validate.go @@ -339,7 +339,7 @@ func normalizeLevels(levels []string) []string { // These providers may also support level-based thinking (hybrid models). func isBudgetCapableProvider(provider string) bool { switch provider { - case "gemini", "gemini-cli", "antigravity", "claude": + case "gemini", "antigravity", "claude": return true default: return false @@ -348,7 +348,7 @@ func isBudgetCapableProvider(provider string) bool { func isGeminiFamily(provider string) bool { switch provider { - case "gemini", "gemini-cli", "antigravity": + case "gemini", "antigravity": return true default: return false diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index d196de7cba..94d600fb0f 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -1,8 +1,8 @@ // Package claude provides request translation functionality for Claude Code API compatibility. -// This package handles the conversion of Claude Code API requests into Gemini CLI-compatible +// This package handles the conversion of Claude Code API requests into Antigravity-compatible // JSON format, transforming message contents, system instructions, and tool declarations -// into the format expected by Gemini CLI API clients. It performs JSON data transformation -// to ensure compatibility between Claude Code API format and Gemini CLI API's expected format. +// into the format expected by Antigravity API clients. It performs JSON data transformation +// to ensure compatibility between Claude Code API format and Antigravity API's expected format. package claude import ( @@ -288,12 +288,12 @@ func logDroppedAntigravityToolUseSignature(modelName string, messageIndex, conte }).Debug("antigravity claude translator: dropped tool_use signature field") } -// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format. +// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Antigravity API format. // It extracts the model name, system instruction, message contents, and tool declarations -// from the raw JSON request and returns them in the format expected by the Gemini CLI API. +// from the raw JSON request and returns them in the format expected by the Antigravity API. // The function performs the following transformations: // 1. Extracts the model information from the request -// 2. Restructures the JSON to match Gemini CLI API format +// 2. Restructures the JSON to match Antigravity API format // 3. Converts system instructions to the expected format // 4. Maps message contents with proper role transformations // 5. Handles tool declarations and tool choices @@ -305,7 +305,7 @@ func logDroppedAntigravityToolUseSignature(modelName string, messageIndex, conte // - stream: A boolean indicating if the request is for a streaming response (unused in current implementation) // // Returns: -// - []byte: The transformed request data in Gemini CLI API format +// - []byte: The transformed request data in Antigravity API format func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte { enableThoughtTranslate := true rawJSON := inputRawJSON @@ -681,7 +681,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ } } - // Build output Gemini CLI request JSON + // Build output Antigravity request JSON out := []byte(`{"model":"","request":{"contents":[]}}`) out, _ = sjson.SetBytes(out, "model", modelName) diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index da5098df98..6dd061f58c 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -95,7 +95,7 @@ var toolUseIDCounter uint64 // Parameters: // - ctx: The context for the request, used for cancellation and timeout handling // - modelName: The name of the model being used for the response (unused in current implementation) -// - rawJSON: The raw JSON response from the Gemini CLI API +// - rawJSON: The raw JSON response from the Antigravity API // - param: A pointer to a parameter object for maintaining state between calls // // Returns: @@ -158,7 +158,7 @@ func ConvertAntigravityResponseToClaude(ctx context.Context, _ string, originalR messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int()) } - // Override default values with actual response metadata if available from the Gemini CLI response + // Override default values with actual response metadata if available from the Antigravity response if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.model", modelVersionResult.String()) } @@ -454,12 +454,12 @@ func resolveStopReason(params *Params) string { return "end_turn" } -// ConvertAntigravityResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response. +// ConvertAntigravityResponseToClaudeNonStream converts a non-streaming Antigravity response to a non-streaming Claude response. // // Parameters: // - ctx: The context for the request. // - modelName: The name of the model. -// - rawJSON: The raw JSON response from the Gemini CLI API. +// - rawJSON: The raw JSON response from the Antigravity API. // - param: A pointer to a parameter object for the conversion. // // Returns: diff --git a/internal/translator/antigravity/gemini/antigravity_gemini_request.go b/internal/translator/antigravity/gemini/antigravity_gemini_request.go index 1beaecff4c..2d373890a5 100644 --- a/internal/translator/antigravity/gemini/antigravity_gemini_request.go +++ b/internal/translator/antigravity/gemini/antigravity_gemini_request.go @@ -1,8 +1,8 @@ -// Package gemini provides request translation functionality for Gemini CLI to Gemini API compatibility. -// It handles parsing and transforming Gemini CLI API requests into Gemini API format, +// Package gemini provides request translation functionality for Antigravity to Gemini API compatibility. +// It handles parsing and transforming Antigravity API requests into Gemini API format, // extracting model information, system instructions, message contents, and tool declarations. // The package performs JSON data transformation to ensure compatibility -// between Gemini CLI API format and Gemini API's expected format. +// between Antigravity API format and Gemini API's expected format. package gemini import ( @@ -18,7 +18,7 @@ import ( "github.com/tidwall/sjson" ) -// ConvertGeminiRequestToAntigravity parses and transforms a Gemini CLI API request into Gemini API format. +// ConvertGeminiRequestToAntigravity parses and transforms a Antigravity API request into Gemini API format. // It extracts the model name, system instruction, message contents, and tool declarations // from the raw JSON request and returns them in the format expected by the Gemini API. // The function performs the following transformations: @@ -29,7 +29,7 @@ import ( // // Parameters: // - modelName: The name of the model to use for the request (unused in current implementation) -// - rawJSON: The raw JSON request data from the Gemini CLI API +// - rawJSON: The raw JSON request data from the Antigravity API // - stream: A boolean indicating if the request is for a streaming response (unused in current implementation) // // Returns: diff --git a/internal/translator/antigravity/gemini/antigravity_gemini_response.go b/internal/translator/antigravity/gemini/antigravity_gemini_response.go index b0deb7320a..b6a0cc8b76 100644 --- a/internal/translator/antigravity/gemini/antigravity_gemini_response.go +++ b/internal/translator/antigravity/gemini/antigravity_gemini_response.go @@ -1,8 +1,8 @@ -// Package gemini provides request translation functionality for Gemini to Gemini CLI API compatibility. -// It handles parsing and transforming Gemini API requests into Gemini CLI API format, +// Package gemini provides request translation functionality for Gemini to Antigravity API compatibility. +// It handles parsing and transforming Gemini API requests into Antigravity API format, // extracting model information, system instructions, message contents, and tool declarations. // The package performs JSON data transformation to ensure compatibility -// between Gemini API format and Gemini CLI API's expected format. +// between Gemini API format and Antigravity API's expected format. package gemini import ( @@ -14,7 +14,7 @@ import ( "github.com/tidwall/sjson" ) -// ConvertAntigravityResponseToGemini parses and transforms a Gemini CLI API request into Gemini API format. +// ConvertAntigravityResponseToGemini parses and transforms a Antigravity API request into Gemini API format. // It extracts the model name, system instruction, message contents, and tool declarations // from the raw JSON request and returns them in the format expected by the Gemini API. // The function performs the following transformations: @@ -25,7 +25,7 @@ import ( // Parameters: // - ctx: The context for the request, used for cancellation and timeout handling // - modelName: The name of the model to use for the request (unused in current implementation) -// - rawJSON: The raw JSON request data from the Gemini CLI API +// - rawJSON: The raw JSON request data from the Antigravity API // - param: A pointer to a parameter object for the conversion (unused in current implementation) // // Returns: @@ -62,14 +62,14 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR return [][]byte{} } -// ConvertAntigravityResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response. -// This function processes the complete Gemini CLI request and transforms it into a single Gemini-compatible +// ConvertAntigravityResponseToGeminiNonStream converts a non-streaming Antigravity request to a non-streaming Gemini response. +// This function processes the complete Antigravity request and transforms it into a single Gemini-compatible // JSON response. It extracts the response data from the request and returns it in the expected format. // // Parameters: // - ctx: The context for the request, used for cancellation and timeout handling // - modelName: The name of the model being used for the response (unused in current implementation) -// - rawJSON: The raw JSON request data from the Gemini CLI API +// - rawJSON: The raw JSON request data from the Antigravity API // - param: A pointer to a parameter object for the conversion (unused in current implementation) // // Returns: diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index 0d9ee6fe0a..65c9790c9a 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -1,5 +1,5 @@ -// Package openai provides request translation functionality for OpenAI to Gemini CLI API compatibility. -// It converts OpenAI Chat Completions requests into Gemini CLI compatible JSON using gjson/sjson only. +// Package openai provides request translation functionality for OpenAI to Antigravity API compatibility. +// It converts OpenAI Chat Completions requests into Antigravity compatible JSON using gjson/sjson only. package chat_completions import ( @@ -14,10 +14,10 @@ import ( "github.com/tidwall/sjson" ) -const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator" +const antigravityFunctionThoughtSignature = "skip_thought_signature_validator" // ConvertOpenAIRequestToAntigravity converts an OpenAI Chat Completions request (raw JSON) -// into a complete Gemini CLI request JSON. All JSON construction uses sjson and lookups use gjson. +// into a complete Antigravity request JSON. All JSON construction uses sjson and lookups use gjson. // // Parameters: // - modelName: The name of the model to use for the request @@ -25,7 +25,7 @@ const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator" // - stream: A boolean indicating if the request is for a streaming response (unused in current implementation) // // Returns: -// - []byte: The transformed request data in Gemini CLI API format +// - []byte: The transformed request data in Antigravity API format func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte { rawJSON := inputRawJSON // Base envelope (no default thinkingConfig) @@ -39,7 +39,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.SetRawBytes(out, "request.generationConfig", []byte(genConfig.Raw)) } - // Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig. + // Apply thinking configuration: convert OpenAI reasoning_effort to Antigravity thinkingConfig. // Inline translation-only mapping; capability checks happen later in ApplyThinking. re := gjson.GetBytes(rawJSON, "reasoning_effort") if re.Exists() { @@ -77,7 +77,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ } } - // Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities + // Map OpenAI modalities -> Antigravity request.generationConfig.responseModalities // e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"] if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() { var responseMods []string @@ -194,7 +194,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ data := pieces[1][7:] node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data) - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", antigravityFunctionThoughtSignature) p++ } } @@ -269,7 +269,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ data := pieces[1][7:] node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data) - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", antigravityFunctionThoughtSignature) p++ } } @@ -295,7 +295,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ } else { node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.args.params", []byte(fargs)) } - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", antigravityFunctionThoughtSignature) p++ if fid != "" { fIDs = append(fIDs, fid) diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go index 2be24102ff..8890255f89 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go @@ -1,5 +1,5 @@ -// Package openai provides response translation functionality for Gemini CLI to OpenAI API compatibility. -// This package handles the conversion of Gemini CLI API responses into OpenAI Chat Completions-compatible +// Package openai provides response translation functionality for Antigravity to OpenAI API compatibility. +// This package handles the conversion of Antigravity API responses into OpenAI Chat Completions-compatible // JSON format, transforming streaming events and non-streaming responses into the format // expected by OpenAI API clients. It supports both streaming and non-streaming modes, // handling text content, tool calls, reasoning content, and usage metadata appropriately. @@ -34,15 +34,15 @@ type convertCliResponseToOpenAIChatParams struct { var functionCallIDCounter uint64 // ConvertAntigravityResponseToOpenAI translates a single chunk of a streaming response from the -// Gemini CLI API format to the OpenAI Chat Completions streaming format. -// It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses. +// Antigravity API format to the OpenAI Chat Completions streaming format. +// It processes various Antigravity event types and transforms them into OpenAI-compatible JSON responses. // The function handles text content, tool calls, reasoning content, and usage metadata, outputting // responses that match the OpenAI API format. It supports incremental updates for streaming responses. // // Parameters: // - ctx: The context for the request, used for cancellation and timeout handling // - modelName: The name of the model being used for the response (unused in current implementation) -// - rawJSON: The raw JSON response from the Gemini CLI API +// - rawJSON: The raw JSON response from the Antigravity API // - param: A pointer to a parameter object for maintaining state between calls // // Returns: @@ -225,15 +225,15 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq return [][]byte{template} } -// ConvertAntigravityResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response. -// This function processes the complete Gemini CLI response and transforms it into a single OpenAI-compatible +// ConvertAntigravityResponseToOpenAINonStream converts a non-streaming Antigravity response to a non-streaming OpenAI response. +// This function processes the complete Antigravity response and transforms it into a single OpenAI-compatible // JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all // the information into a single response that matches the OpenAI API format. // // Parameters: // - ctx: The context for the request, used for cancellation and timeout handling // - modelName: The name of the model being used for the response -// - rawJSON: The raw JSON response from the Gemini CLI API +// - rawJSON: The raw JSON response from the Antigravity API // - param: A pointer to a parameter object for the conversion // // Returns: diff --git a/internal/translator/claude/gemini-cli/claude_gemini-cli_request.go b/internal/translator/claude/gemini-cli/claude_gemini-cli_request.go deleted file mode 100644 index fd68a957f5..0000000000 --- a/internal/translator/claude/gemini-cli/claude_gemini-cli_request.go +++ /dev/null @@ -1,45 +0,0 @@ -// Package geminiCLI provides request translation functionality for Gemini CLI to Claude Code API compatibility. -// It handles parsing and transforming Gemini CLI API requests into Claude Code API format, -// extracting model information, system instructions, message contents, and tool declarations. -// The package performs JSON data transformation to ensure compatibility -// between Gemini CLI API format and Claude Code API's expected format. -package geminiCLI - -import ( - . "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/claude/gemini" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// ConvertGeminiCLIRequestToClaude parses and transforms a Gemini CLI API request into Claude Code API format. -// It extracts the model name, system instruction, message contents, and tool declarations -// from the raw JSON request and returns them in the format expected by the Claude Code API. -// The function performs the following transformations: -// 1. Extracts the model information from the request -// 2. Restructures the JSON to match Claude Code API format -// 3. Converts system instructions to the expected format -// 4. Delegates to the Gemini-to-Claude conversion function for further processing -// -// Parameters: -// - modelName: The name of the model to use for the request -// - rawJSON: The raw JSON request data from the Gemini CLI API -// - stream: A boolean indicating if the request is for a streaming response -// -// Returns: -// - []byte: The transformed request data in Claude Code API format -func ConvertGeminiCLIRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte { - rawJSON := inputRawJSON - - modelResult := gjson.GetBytes(rawJSON, "model") - // Extract the inner request object and promote it to the top level - rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw) - // Restore the model information at the top level - rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String()) - // Convert systemInstruction field to system_instruction for Claude Code compatibility - if gjson.GetBytes(rawJSON, "systemInstruction").Exists() { - rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw)) - rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction") - } - // Delegate to the Gemini-to-Claude conversion function for further processing - return ConvertGeminiRequestToClaude(modelName, rawJSON, stream) -} diff --git a/internal/translator/claude/gemini-cli/claude_gemini-cli_response.go b/internal/translator/claude/gemini-cli/claude_gemini-cli_response.go deleted file mode 100644 index 858886c272..0000000000 --- a/internal/translator/claude/gemini-cli/claude_gemini-cli_response.go +++ /dev/null @@ -1,57 +0,0 @@ -// Package geminiCLI provides response translation functionality for Claude Code to Gemini CLI API compatibility. -// This package handles the conversion of Claude Code API responses into Gemini CLI-compatible -// JSON format, transforming streaming events and non-streaming responses into the format -// expected by Gemini CLI API clients. -package geminiCLI - -import ( - "context" - - . "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/claude/gemini" - translatorcommon "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/common" -) - -// ConvertClaudeResponseToGeminiCLI converts Claude Code streaming response format to Gemini CLI format. -// This function processes various Claude Code event types and transforms them into Gemini-compatible JSON responses. -// It handles text content, tool calls, and usage metadata, outputting responses that match the Gemini CLI API format. -// The function wraps each converted response in a "response" object to match the Gemini CLI API structure. -// -// Parameters: -// - ctx: The context for the request, used for cancellation and timeout handling -// - modelName: The name of the model being used for the response -// - rawJSON: The raw JSON response from the Claude Code API -// - param: A pointer to a parameter object for maintaining state between calls -// -// Returns: -// - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object -func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { - outputs := ConvertClaudeResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) - // Wrap each converted response in a "response" object to match Gemini CLI API structure - newOutputs := make([][]byte, 0, len(outputs)) - for i := 0; i < len(outputs); i++ { - newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i])) - } - return newOutputs -} - -// ConvertClaudeResponseToGeminiCLINonStream converts a non-streaming Claude Code response to a non-streaming Gemini CLI response. -// This function processes the complete Claude Code response and transforms it into a single Gemini-compatible -// JSON response. It wraps the converted response in a "response" object to match the Gemini CLI API structure. -// -// Parameters: -// - ctx: The context for the request, used for cancellation and timeout handling -// - modelName: The name of the model being used for the response -// - rawJSON: The raw JSON response from the Claude Code API -// - param: A pointer to a parameter object for the conversion -// -// Returns: -// - []byte: A Gemini-compatible JSON response wrapped in a response object -func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte { - out := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) - // Wrap the converted response in a "response" object to match Gemini CLI API structure - return translatorcommon.WrapGeminiCLIResponse(out) -} - -func GeminiCLITokenCount(ctx context.Context, count int64) []byte { - return GeminiTokenCount(ctx, count) -} diff --git a/internal/translator/claude/gemini-cli/init.go b/internal/translator/claude/gemini-cli/init.go deleted file mode 100644 index 33a1332daf..0000000000 --- a/internal/translator/claude/gemini-cli/init.go +++ /dev/null @@ -1,20 +0,0 @@ -package geminiCLI - -import ( - . "github.com/router-for-me/CLIProxyAPI/v7/internal/constant" - "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" - "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/translator" -) - -func init() { - translator.Register( - GeminiCLI, - Claude, - ConvertGeminiCLIRequestToClaude, - interfaces.TranslateResponse{ - Stream: ConvertClaudeResponseToGeminiCLI, - NonStream: ConvertClaudeResponseToGeminiCLINonStream, - TokenCount: GeminiCLITokenCount, - }, - ) -} diff --git a/internal/translator/claude/gemini/claude_gemini_request.go b/internal/translator/claude/gemini/claude_gemini_request.go index d716d28f35..1f5bf8ed90 100644 --- a/internal/translator/claude/gemini/claude_gemini_request.go +++ b/internal/translator/claude/gemini/claude_gemini_request.go @@ -80,6 +80,25 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream return "toolu_" + b.String() } + getGeminiToolID := func(value gjson.Result) string { + if toolID := strings.TrimSpace(value.Get("id").String()); toolID != "" { + return toolID + } + return strings.TrimSpace(value.Get("call_id").String()) + } + + removePendingToolID := func(ids []string, toolID string) []string { + if toolID == "" { + return ids + } + for idx, pendingID := range ids { + if pendingID == toolID { + return append(ids[:idx], ids[idx+1:]...) + } + } + return ids + } + // FIFO queue to store tool call IDs for matching with tool results // Gemini uses sequential pairing across possibly multiple in-flight // functionCalls, so we keep a FIFO queue of generated tool IDs and @@ -262,9 +281,11 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream if fc := part.Get("functionCall"); fc.Exists() && role == "assistant" { toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`) - // Generate a unique tool ID and enqueue it for later matching - // with the corresponding functionResponse - toolID := genToolCallID() + // Reuse gateway-provided IDs when present, otherwise generate one for pairing. + toolID := getGeminiToolID(fc) + if toolID == "" { + toolID = genToolCallID() + } pendingToolIDs = append(pendingToolIDs, toolID) toolUse, _ = sjson.SetBytes(toolUse, "id", toolID) @@ -285,7 +306,10 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream // Attach the oldest queued tool_id to pair the response // with its call. If the queue is empty, generate a new id. var toolID string - if len(pendingToolIDs) > 0 { + if customID := getGeminiToolID(fr); customID != "" { + toolID = customID + pendingToolIDs = removePendingToolID(pendingToolIDs, toolID) + } else if len(pendingToolIDs) > 0 { toolID = pendingToolIDs[0] // Pop the first element from the queue pendingToolIDs = pendingToolIDs[1:] diff --git a/internal/translator/claude/gemini/claude_gemini_request_test.go b/internal/translator/claude/gemini/claude_gemini_request_test.go new file mode 100644 index 0000000000..06224d5a3f --- /dev/null +++ b/internal/translator/claude/gemini/claude_gemini_request_test.go @@ -0,0 +1,63 @@ +package gemini + +import ( + "fmt" + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertGeminiRequestToClaude_PreservesCustomToolIDs(t *testing.T) { + tests := []struct { + name string + callField string + responseField string + want string + }{ + { + name: "id", + callField: `"id":"call_gateway_id"`, + responseField: `"id":"call_gateway_id"`, + want: "call_gateway_id", + }, + { + name: "call_id", + callField: `"call_id":"call_gateway_call_id"`, + responseField: `"call_id":"call_gateway_call_id"`, + want: "call_gateway_call_id", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + raw := []byte(fmt.Sprintf(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "lookup", %s, "args": {"query": "status"}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"name": "lookup", %s, "response": {"result": "ok"}}} + ] + } + ] + }`, tt.callField, tt.responseField)) + + out := ConvertGeminiRequestToClaude("claude-sonnet-4", raw, false) + + gotCallID := gjson.GetBytes(out, "messages.0.content.0.id").String() + if gotCallID != tt.want { + t.Fatalf("expected tool_use id %q, got %q; output=%s", tt.want, gotCallID, string(out)) + } + + gotResultID := gjson.GetBytes(out, "messages.1.content.0.tool_use_id").String() + if gotResultID != tt.want { + t.Fatalf("expected tool_result tool_use_id %q, got %q; output=%s", tt.want, gotResultID, string(out)) + } + }) + } +} diff --git a/internal/translator/claude/gemini/claude_gemini_response.go b/internal/translator/claude/gemini/claude_gemini_response.go index 3f127e3205..74865ead30 100644 --- a/internal/translator/claude/gemini/claude_gemini_response.go +++ b/internal/translator/claude/gemini/claude_gemini_response.go @@ -37,6 +37,7 @@ type ConvertAnthropicResponseToGeminiParams struct { // Keyed by content_block index from Claude SSE events ToolUseNames map[int]string // function/tool name per block index ToolUseArgs map[int]*strings.Builder // accumulates partial_json across deltas + ToolUseIDs map[int]string // tool use ID per block index } // ConvertClaudeResponseToGemini converts Claude Code streaming response format to Gemini format. @@ -110,6 +111,12 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original if name := cb.Get("name"); name.Exists() { (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames[idx] = name.String() } + if toolID := cb.Get("id").String(); toolID != "" { + if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseIDs == nil { + (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseIDs = map[int]string{} + } + (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseIDs[idx] = toolID + } } } return [][]byte{} @@ -169,6 +176,10 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original argsTrim = strings.TrimSpace(b.String()) } } + toolID := "" + if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseIDs != nil { + toolID = (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseIDs[idx] + } if name != "" || argsTrim != "" { functionCall := []byte(`{"functionCall":{"name":"","args":{}}}`) if name != "" { @@ -177,6 +188,9 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original if argsTrim != "" { functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsTrim)) } + if toolID != "" { + functionCall, _ = sjson.SetBytes(functionCall, "functionCall.id", toolID) + } template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall) template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP") (*param).(*ConvertAnthropicResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...) @@ -187,6 +201,9 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames != nil { delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames, idx) } + if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseIDs != nil { + delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseIDs, idx) + } return [][]byte{template} } return [][]byte{} @@ -308,6 +325,7 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, IsStreaming: false, ToolUseNames: nil, ToolUseArgs: nil, + ToolUseIDs: nil, } // Process each streaming event and collect parts @@ -348,6 +366,12 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, if name := cb.Get("name"); name.Exists() { newParam.ToolUseNames[idx] = name.String() } + if toolID := cb.Get("id").String(); toolID != "" { + if newParam.ToolUseIDs == nil { + newParam.ToolUseIDs = map[int]string{} + } + newParam.ToolUseIDs[idx] = toolID + } } } continue @@ -401,6 +425,10 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, argsTrim = strings.TrimSpace(b.String()) } } + toolID := "" + if newParam.ToolUseIDs != nil { + toolID = newParam.ToolUseIDs[idx] + } if name != "" || argsTrim != "" { functionCallJSON := []byte(`{"functionCall":{"name":"","args":{}}}`) if name != "" { @@ -409,6 +437,9 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, if argsTrim != "" { functionCallJSON, _ = sjson.SetRawBytes(functionCallJSON, "functionCall.args", []byte(argsTrim)) } + if toolID != "" { + functionCallJSON, _ = sjson.SetBytes(functionCallJSON, "functionCall.id", toolID) + } allParts = append(allParts, functionCallJSON) // cleanup used state for this index if newParam.ToolUseArgs != nil { @@ -417,6 +448,9 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, if newParam.ToolUseNames != nil { delete(newParam.ToolUseNames, idx) } + if newParam.ToolUseIDs != nil { + delete(newParam.ToolUseIDs, idx) + } } case "message_delta": diff --git a/internal/translator/claude/gemini/claude_gemini_response_test.go b/internal/translator/claude/gemini/claude_gemini_response_test.go new file mode 100644 index 0000000000..8fb6744c73 --- /dev/null +++ b/internal/translator/claude/gemini/claude_gemini_response_test.go @@ -0,0 +1,53 @@ +package gemini + +import ( + "context" + "strings" + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertClaudeResponseToGemini_StreamPreservesToolUseID(t *testing.T) { + ctx := context.Background() + var param any + + start := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_gateway","name":"lookup"}}`) + out := ConvertClaudeResponseToGemini(ctx, "gemini-2.5-pro", nil, nil, start, ¶m) + if len(out) != 0 { + t.Fatalf("expected content_block_start to be buffered, got %d chunks", len(out)) + } + + delta := []byte(`data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"query\":\"status\"}"}}`) + out = ConvertClaudeResponseToGemini(ctx, "gemini-2.5-pro", nil, nil, delta, ¶m) + if len(out) != 0 { + t.Fatalf("expected input_json_delta to be buffered, got %d chunks", len(out)) + } + + stop := []byte(`data: {"type":"content_block_stop","index":0}`) + out = ConvertClaudeResponseToGemini(ctx, "gemini-2.5-pro", nil, nil, stop, ¶m) + if len(out) != 1 { + t.Fatalf("expected content_block_stop to emit 1 chunk, got %d", len(out)) + } + + got := gjson.GetBytes(out[0], "candidates.0.content.parts.0.functionCall.id").String() + if got != "toolu_gateway" { + t.Fatalf("expected functionCall.id %q, got %q; chunk=%s", "toolu_gateway", got, string(out[0])) + } +} + +func TestConvertClaudeResponseToGeminiNonStreamPreservesToolUseID(t *testing.T) { + ctx := context.Background() + raw := []byte(strings.Join([]string{ + `data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_gateway","name":"lookup"}}`, + `data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"query\":\"status\"}"}}`, + `data: {"type":"content_block_stop","index":0}`, + }, "\n")) + + out := ConvertClaudeResponseToGeminiNonStream(ctx, "gemini-2.5-pro", nil, nil, raw, nil) + + got := gjson.GetBytes(out, "candidates.0.content.parts.0.functionCall.id").String() + if got != "toolu_gateway" { + t.Fatalf("expected functionCall.id %q, got %q; chunk=%s", "toolu_gateway", got, string(out)) + } +} diff --git a/internal/translator/codex/claude/codex_claude_response.go b/internal/translator/codex/claude/codex_claude_response.go index 3a8dab5e6d..ace43013ab 100644 --- a/internal/translator/codex/claude/codex_claude_response.go +++ b/internal/translator/codex/claude/codex_claude_response.go @@ -23,18 +23,28 @@ var ( // ConvertCodexResponseToClaudeParams holds parameters for response conversion. type ConvertCodexResponseToClaudeParams struct { - HasToolCall bool + HasToolCall bool + BlockIndex int + HasReceivedArgumentsDelta bool + HasTextDelta bool + TextBlockOpen bool + ThinkingBlockOpen bool + ThinkingStopPending bool + ThinkingSignature string + ThinkingSummarySeen bool + WebSearchToolUseIDs map[string]struct{} + WebSearchToolResultIDs map[string]struct{} + LastWebSearchToolUseID string + PendingFunctionCalls map[string]*pendingCodexFunctionCall + LastPendingFunctionCallKey string +} + +type pendingCodexFunctionCall struct { + CallID string + Arguments string BlockIndex int HasReceivedArgumentsDelta bool - HasTextDelta bool - TextBlockOpen bool - ThinkingBlockOpen bool - ThinkingStopPending bool - ThinkingSignature string - ThinkingSummarySeen bool - WebSearchToolUseIDs map[string]struct{} - WebSearchToolResultIDs map[string]struct{} - LastWebSearchToolUseID string + StartEmitted bool } // ConvertCodexResponseToClaude performs sophisticated streaming response format conversion. @@ -145,24 +155,26 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa output = append(output, stopCodexTextBlock(params)...) params.HasToolCall = true params.HasReceivedArgumentsDelta = false - template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`) - template, _ = sjson.SetBytes(template, "index", params.BlockIndex) - template, _ = sjson.SetBytes(template, "content_block.id", shortenCodexCallIDIfNeeded(util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))) - { - name := itemResult.Get("name").String() - rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON) - if orig, ok := rev[name]; ok { - name = orig + + callID := itemResult.Get("call_id").String() + name := itemResult.Get("name").String() + key := codexFunctionCallKey(rootResult, itemResult) + if name == "" { + if params.PendingFunctionCalls == nil { + params.PendingFunctionCalls = map[string]*pendingCodexFunctionCall{} + } + params.PendingFunctionCalls[key] = &pendingCodexFunctionCall{ + CallID: callID, + BlockIndex: params.BlockIndex, } - template, _ = sjson.SetBytes(template, "content_block.name", name) + params.LastPendingFunctionCallKey = key + params.BlockIndex++ + break } - output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2) - - template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`) - template, _ = sjson.SetBytes(template, "index", params.BlockIndex) - - output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) + delete(params.PendingFunctionCalls, key) + output = appendCodexFunctionCallStart(output, originalRequestRawJSON, callID, name, params.BlockIndex) + output = appendCodexFunctionCallArgumentDelta(output, "", params.BlockIndex) case "reasoning": params.ThinkingSummarySeen = false params.ThinkingSignature = itemResult.Get("encrypted_content").String() @@ -207,11 +219,36 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa output = append(output, stopCodexTextBlock(params)...) params.HasTextDelta = true case "function_call": - template = []byte(`{"type":"content_block_stop","index":0}`) - template, _ = sjson.SetBytes(template, "index", params.BlockIndex) - params.BlockIndex++ + key := codexFunctionCallKey(rootResult, itemResult) + if pending, pendingKey := pendingCodexFunctionCallForKey(params, key); pending != nil && !pending.StartEmitted { + name := itemResult.Get("name").String() + if name == "" { + return [][]byte{output} + } + callID := pending.CallID + if callID == "" { + callID = itemResult.Get("call_id").String() + } + output = appendCodexFunctionCallStart(output, originalRequestRawJSON, callID, name, pending.BlockIndex) + pending.StartEmitted = true - output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2) + args := pending.Arguments + if args == "" { + args = itemResult.Get("arguments").String() + } + if args != "" { + output = appendCodexFunctionCallArgumentDelta(output, args, pending.BlockIndex) + } + output = appendCodexFunctionCallStop(output, pending.BlockIndex) + + delete(params.PendingFunctionCalls, pendingKey) + if params.LastPendingFunctionCallKey == pendingKey { + params.LastPendingFunctionCallKey = "" + } + } else { + output = appendCodexFunctionCallStop(output, params.BlockIndex) + params.BlockIndex++ + } case "reasoning": if signature := itemResult.Get("encrypted_content").String(); signature != "" { params.ThinkingSignature = signature @@ -227,20 +264,28 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa output = appendCodexWebSearchToolResult(output, params, rootResult, itemResult) } case "response.function_call_arguments.delta": - params.HasReceivedArgumentsDelta = true - template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`) - template, _ = sjson.SetBytes(template, "index", params.BlockIndex) - template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String()) + delta := rootResult.Get("delta").String() + key := codexArgumentsFunctionCallKey(params, rootResult) + if pending, _ := pendingCodexFunctionCallForKey(params, key); pending != nil && !pending.StartEmitted { + pending.HasReceivedArgumentsDelta = true + pending.Arguments += delta + break + } - output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) + params.HasReceivedArgumentsDelta = true + output = appendCodexFunctionCallArgumentDelta(output, delta, params.BlockIndex) case "response.function_call_arguments.done": + key := codexArgumentsFunctionCallKey(params, rootResult) + if pending, _ := pendingCodexFunctionCallForKey(params, key); pending != nil && !pending.StartEmitted { + if !pending.HasReceivedArgumentsDelta { + pending.Arguments = rootResult.Get("arguments").String() + } + break + } + if !params.HasReceivedArgumentsDelta { if args := rootResult.Get("arguments").String(); args != "" { - template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`) - template, _ = sjson.SetBytes(template, "index", params.BlockIndex) - template, _ = sjson.SetBytes(template, "delta.partial_json", args) - - output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) + output = appendCodexFunctionCallArgumentDelta(output, args, params.BlockIndex) } } } @@ -458,6 +503,69 @@ func setClaudeStopSequence(out []byte, path string, responseData gjson.Result) [ return out } +func codexFunctionCallKey(rootResult, itemResult gjson.Result) string { + if outputIndex := rootResult.Get("output_index"); outputIndex.Exists() { + return "output:" + outputIndex.Raw + } + if callID := itemResult.Get("call_id").String(); callID != "" { + return "call:" + callID + } + return "last" +} + +func codexArgumentsFunctionCallKey(params *ConvertCodexResponseToClaudeParams, rootResult gjson.Result) string { + if outputIndex := rootResult.Get("output_index"); outputIndex.Exists() { + return "output:" + outputIndex.Raw + } + return params.LastPendingFunctionCallKey +} + +func pendingCodexFunctionCallForKey(params *ConvertCodexResponseToClaudeParams, key string) (*pendingCodexFunctionCall, string) { + if params == nil || params.PendingFunctionCalls == nil { + return nil, "" + } + if key != "" { + if pending, ok := params.PendingFunctionCalls[key]; ok { + return pending, key + } + } + if params.LastPendingFunctionCallKey != "" { + if pending, ok := params.PendingFunctionCalls[params.LastPendingFunctionCallKey]; ok { + return pending, params.LastPendingFunctionCallKey + } + } + return nil, "" +} + +func appendCodexFunctionCallStart(output []byte, originalRequestRawJSON []byte, callID, name string, blockIndex int) []byte { + template := []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`) + template, _ = sjson.SetBytes(template, "index", blockIndex) + template, _ = sjson.SetBytes(template, "content_block.id", shortenCodexCallIDIfNeeded(util.SanitizeClaudeToolID(callID))) + template, _ = sjson.SetBytes(template, "content_block.name", resolveCodexClaudeToolUseName(originalRequestRawJSON, name)) + return translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2) +} + +func appendCodexFunctionCallArgumentDelta(output []byte, partialJSON string, blockIndex int) []byte { + template := []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`) + template, _ = sjson.SetBytes(template, "index", blockIndex) + template, _ = sjson.SetBytes(template, "delta.partial_json", partialJSON) + return translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) +} + +func appendCodexFunctionCallStop(output []byte, blockIndex int) []byte { + template := []byte(`{"type":"content_block_stop","index":0}`) + template, _ = sjson.SetBytes(template, "index", blockIndex) + return translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2) +} + +func resolveCodexClaudeToolUseName(originalRequestRawJSON []byte, name string) string { + rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON) + if orig, ok := rev[name]; ok { + return orig + } + return name +} + func extractResponsesUsage(usage gjson.Result) (int64, int64, int64) { if !usage.Exists() || usage.Type == gjson.Null { return 0, 0, 0 diff --git a/internal/translator/codex/claude/codex_claude_response_test.go b/internal/translator/codex/claude/codex_claude_response_test.go index e707fa6fb8..c4c828623c 100644 --- a/internal/translator/codex/claude/codex_claude_response_test.go +++ b/internal/translator/codex/claude/codex_claude_response_test.go @@ -531,6 +531,62 @@ func TestConvertCodexResponseToClaude_StreamTextBeforeToolCallsDoesNotEmitGhostS } } +func TestConvertCodexResponseToClaude_StreamFunctionCallDefersStartUntilDoneName(t *testing.T) { + ctx := context.Background() + originalRequest := []byte(`{"tools":[{"name":"web_search","description":"search"}]}`) + var param any + + _ = ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, []byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5"}}`), ¶m) + addedOutputs := ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_1"},"output_index":1}`), ¶m) + argumentsOutputs := ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, []byte(`data: {"type":"response.function_call_arguments.done","arguments":"{\"query\":\"example\"}","output_index":1}`), ¶m) + doneOutputs := ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, []byte(`data: {"type":"response.output_item.done","item":{"type":"function_call","call_id":"call_1","name":"web_search","arguments":"{\"query\":\"example\"}"},"output_index":1}`), ¶m) + + if bytes.Contains(bytes.Join(addedOutputs, nil), []byte(`"content_block_start"`)) { + t.Fatalf("function_call without name must not emit content_block_start: %q", addedOutputs) + } + if bytes.Contains(bytes.Join(argumentsOutputs, nil), []byte(`"input_json_delta"`)) { + t.Fatalf("arguments must be buffered until the tool name is available: %q", argumentsOutputs) + } + + var toolStartCount int + var toolStopCount int + var argumentDeltas []string + for _, out := range doneOutputs { + for _, line := range strings.Split(string(out), "\n") { + if !strings.HasPrefix(line, "data: ") { + continue + } + data := gjson.Parse(strings.TrimPrefix(line, "data: ")) + switch data.Get("type").String() { + case "content_block_start": + if data.Get("content_block.type").String() != "tool_use" { + continue + } + toolStartCount++ + if got := data.Get("content_block.name").String(); got != "web_search" { + t.Fatalf("unexpected tool_use name %q in %s", got, data.Raw) + } + case "content_block_delta": + if data.Get("delta.type").String() == "input_json_delta" { + argumentDeltas = append(argumentDeltas, data.Get("delta.partial_json").String()) + } + case "content_block_stop": + toolStopCount++ + } + } + } + + if toolStartCount != 1 { + t.Fatalf("expected one deferred tool_use start, got %d in %q", toolStartCount, doneOutputs) + } + if len(argumentDeltas) != 1 || argumentDeltas[0] != `{"query":"example"}` { + t.Fatalf("unexpected buffered argument deltas: %v", argumentDeltas) + } + if toolStopCount != 1 { + t.Fatalf("expected one deferred tool_use stop, got %d in %q", toolStopCount, doneOutputs) + } +} + func TestConvertCodexResponseToClaude_StreamEmptyOutputUsesOutputItemDoneMessageFallback(t *testing.T) { ctx := context.Background() originalRequest := []byte(`{"tools":[]}`) diff --git a/internal/translator/codex/gemini-cli/codex_gemini-cli_request.go b/internal/translator/codex/gemini-cli/codex_gemini-cli_request.go deleted file mode 100644 index b69bab11ee..0000000000 --- a/internal/translator/codex/gemini-cli/codex_gemini-cli_request.go +++ /dev/null @@ -1,41 +0,0 @@ -// Package geminiCLI provides request translation functionality for Gemini CLI to Codex API compatibility. -// It handles parsing and transforming Gemini CLI API requests into Codex API format, -// extracting model information, system instructions, message contents, and tool declarations. -// The package performs JSON data transformation to ensure compatibility -// between Gemini CLI API format and Codex API's expected format. -package geminiCLI - -import ( - . "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/codex/gemini" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// ConvertGeminiCLIRequestToCodex parses and transforms a Gemini CLI API request into Codex API format. -// It extracts the model name, system instruction, message contents, and tool declarations -// from the raw JSON request and returns them in the format expected by the Codex API. -// The function performs the following transformations: -// 1. Extracts the inner request object and promotes it to the top level -// 2. Restores the model information at the top level -// 3. Converts systemInstruction field to system_instruction for Codex compatibility -// 4. Delegates to the Gemini-to-Codex conversion function for further processing -// -// Parameters: -// - modelName: The name of the model to use for the request -// - rawJSON: The raw JSON request data from the Gemini CLI API -// - stream: A boolean indicating if the request is for a streaming response -// -// Returns: -// - []byte: The transformed request data in Codex API format -func ConvertGeminiCLIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte { - rawJSON := inputRawJSON - - rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw) - rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName) - if gjson.GetBytes(rawJSON, "systemInstruction").Exists() { - rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw)) - rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction") - } - - return ConvertGeminiRequestToCodex(modelName, rawJSON, stream) -} diff --git a/internal/translator/codex/gemini-cli/codex_gemini-cli_request_test.go b/internal/translator/codex/gemini-cli/codex_gemini-cli_request_test.go deleted file mode 100644 index fc41452b10..0000000000 --- a/internal/translator/codex/gemini-cli/codex_gemini-cli_request_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package geminiCLI - -import ( - "testing" - - "github.com/tidwall/gjson" -) - -func TestConvertGeminiCLIRequestToCodex_PreservesSchemaPropertyNamedType(t *testing.T) { - input := []byte(`{ - "request": { - "tools": [ - { - "functionDeclarations": [ - { - "name": "ask_user", - "description": "Ask the user one or more questions.", - "parametersJsonSchema": { - "type": "object", - "properties": { - "questions": { - "type": "array", - "items": { - "type": "object", - "properties": { - "header": { - "type": "string" - }, - "type": { - "default": "choice", - "description": "Question type.", - "enum": [ - "choice", - "text", - "yesno" - ], - "type": "string" - } - }, - "required": [ - "question", - "header", - "type" - ] - } - } - }, - "required": [ - "questions" - ] - } - } - ] - } - ] - } - }`) - - out := ConvertGeminiCLIRequestToCodex("gpt-5.2", input, true) - tool := gjson.GetBytes(out, "tools.0") - if got := tool.Get("type").String(); got != "function" { - t.Fatalf("expected tool type %q, got %q; output=%s", "function", got, string(out)) - } - - typeProperty := tool.Get("parameters.properties.questions.items.properties.type") - if !typeProperty.IsObject() { - t.Fatalf("expected schema property named type to stay an object; output=%s", string(out)) - } - if got := typeProperty.Get("type").String(); got != "string" { - t.Fatalf("expected schema property type %q, got %q; output=%s", "string", got, string(out)) - } - if got := typeProperty.Get("default").String(); got != "choice" { - t.Fatalf("expected default %q, got %q; output=%s", "choice", got, string(out)) - } - if got := typeProperty.Get("enum.2").String(); got != "yesno" { - t.Fatalf("expected enum value %q, got %q; output=%s", "yesno", got, string(out)) - } -} diff --git a/internal/translator/codex/gemini-cli/codex_gemini-cli_response.go b/internal/translator/codex/gemini-cli/codex_gemini-cli_response.go deleted file mode 100644 index 01dbc0f831..0000000000 --- a/internal/translator/codex/gemini-cli/codex_gemini-cli_response.go +++ /dev/null @@ -1,55 +0,0 @@ -// Package geminiCLI provides response translation functionality for Codex to Gemini CLI API compatibility. -// This package handles the conversion of Codex API responses into Gemini CLI-compatible -// JSON format, transforming streaming events and non-streaming responses into the format -// expected by Gemini CLI API clients. -package geminiCLI - -import ( - "context" - - . "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/codex/gemini" - translatorcommon "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/common" -) - -// ConvertCodexResponseToGeminiCLI converts Codex streaming response format to Gemini CLI format. -// This function processes various Codex event types and transforms them into Gemini-compatible JSON responses. -// It handles text content, tool calls, and usage metadata, outputting responses that match the Gemini CLI API format. -// The function wraps each converted response in a "response" object to match the Gemini CLI API structure. -// -// Parameters: -// - ctx: The context for the request, used for cancellation and timeout handling -// - modelName: The name of the model being used for the response -// - rawJSON: The raw JSON response from the Codex API -// - param: A pointer to a parameter object for maintaining state between calls -// -// Returns: -// - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object -func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { - outputs := ConvertCodexResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) - newOutputs := make([][]byte, 0, len(outputs)) - for i := 0; i < len(outputs); i++ { - newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i])) - } - return newOutputs -} - -// ConvertCodexResponseToGeminiCLINonStream converts a non-streaming Codex response to a non-streaming Gemini CLI response. -// This function processes the complete Codex response and transforms it into a single Gemini-compatible -// JSON response. It wraps the converted response in a "response" object to match the Gemini CLI API structure. -// -// Parameters: -// - ctx: The context for the request, used for cancellation and timeout handling -// - modelName: The name of the model being used for the response -// - rawJSON: The raw JSON response from the Codex API -// - param: A pointer to a parameter object for the conversion -// -// Returns: -// - []byte: A Gemini-compatible JSON response wrapped in a response object -func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte { - out := ConvertCodexResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) - return translatorcommon.WrapGeminiCLIResponse(out) -} - -func GeminiCLITokenCount(ctx context.Context, count int64) []byte { - return translatorcommon.GeminiTokenCountJSON(count) -} diff --git a/internal/translator/codex/gemini-cli/init.go b/internal/translator/codex/gemini-cli/init.go deleted file mode 100644 index 2958e0a825..0000000000 --- a/internal/translator/codex/gemini-cli/init.go +++ /dev/null @@ -1,20 +0,0 @@ -package geminiCLI - -import ( - . "github.com/router-for-me/CLIProxyAPI/v7/internal/constant" - "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" - "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/translator" -) - -func init() { - translator.Register( - GeminiCLI, - Codex, - ConvertGeminiCLIRequestToCodex, - interfaces.TranslateResponse{ - Stream: ConvertCodexResponseToGeminiCLI, - NonStream: ConvertCodexResponseToGeminiCLINonStream, - TokenCount: GeminiCLITokenCount, - }, - ) -} diff --git a/internal/translator/codex/gemini/codex_gemini_request.go b/internal/translator/codex/gemini/codex_gemini_request.go index e96d5aaca1..03a862ba08 100644 --- a/internal/translator/codex/gemini/codex_gemini_request.go +++ b/internal/translator/codex/gemini/codex_gemini_request.go @@ -81,6 +81,25 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) return "call_" + b.String() } + getGeminiCallID := func(value gjson.Result) string { + if callID := strings.TrimSpace(value.Get("id").String()); callID != "" { + return callID + } + return strings.TrimSpace(value.Get("call_id").String()) + } + + removePendingCallID := func(ids []string, callID string) []string { + if callID == "" { + return ids + } + for idx, pendingID := range ids { + if pendingID == callID { + return append(ids[:idx], ids[idx+1:]...) + } + } + return ids + } + // Model out, _ = sjson.SetBytes(out, "model", modelName) @@ -155,10 +174,11 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) if args := fc.Get("args"); args.Exists() { fn, _ = sjson.SetBytes(fn, "arguments", args.Raw) } - // generate a paired random call_id and enqueue it so the - // corresponding functionResponse can pop the earliest id - // to preserve ordering when multiple calls are present. - id := genCallID() + // Reuse gateway-provided IDs when present, otherwise generate one for pairing. + id := getGeminiCallID(fc) + if id == "" { + id = genCallID() + } fn, _ = sjson.SetBytes(fn, "call_id", id) pendingCallIDs = append(pendingCallIDs, id) out, _ = sjson.SetRawBytes(out, "input.-1", fn) @@ -178,7 +198,10 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) // attach the oldest queued call_id to pair the response // with its call. If the queue is empty, generate a new id. var id string - if len(pendingCallIDs) > 0 { + if customID := getGeminiCallID(fr); customID != "" { + id = customID + pendingCallIDs = removePendingCallID(pendingCallIDs, id) + } else if len(pendingCallIDs) > 0 { id = pendingCallIDs[0] // pop the first element pendingCallIDs = pendingCallIDs[1:] diff --git a/internal/translator/codex/gemini/codex_gemini_request_test.go b/internal/translator/codex/gemini/codex_gemini_request_test.go new file mode 100644 index 0000000000..a98cdba4dd --- /dev/null +++ b/internal/translator/codex/gemini/codex_gemini_request_test.go @@ -0,0 +1,63 @@ +package gemini + +import ( + "fmt" + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertGeminiRequestToCodex_PreservesCustomCallIDs(t *testing.T) { + tests := []struct { + name string + callField string + responseField string + want string + }{ + { + name: "id", + callField: `"id":"call_gateway_id"`, + responseField: `"id":"call_gateway_id"`, + want: "call_gateway_id", + }, + { + name: "call_id", + callField: `"call_id":"call_gateway_call_id"`, + responseField: `"call_id":"call_gateway_call_id"`, + want: "call_gateway_call_id", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + raw := []byte(fmt.Sprintf(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "lookup", %s, "args": {"query": "status"}}} + ] + }, + { + "role": "user", + "parts": [ + {"functionResponse": {"name": "lookup", %s, "response": {"result": "ok"}}} + ] + } + ] + }`, tt.callField, tt.responseField)) + + out := ConvertGeminiRequestToCodex("gpt-5.1-codex", raw, false) + + gotCallID := gjson.GetBytes(out, "input.0.call_id").String() + if gotCallID != tt.want { + t.Fatalf("expected function_call call_id %q, got %q; output=%s", tt.want, gotCallID, string(out)) + } + + gotOutputID := gjson.GetBytes(out, "input.1.call_id").String() + if gotOutputID != tt.want { + t.Fatalf("expected function_call_output call_id %q, got %q; output=%s", tt.want, gotOutputID, string(out)) + } + }) + } +} diff --git a/internal/translator/codex/gemini/codex_gemini_response.go b/internal/translator/codex/gemini/codex_gemini_response.go index ecf9cf4de8..a5144ea633 100644 --- a/internal/translator/codex/gemini/codex_gemini_response.go +++ b/internal/translator/codex/gemini/codex_gemini_response.go @@ -156,6 +156,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsStr)) } } + functionCall = setGeminiFunctionCallID(functionCall, itemResult) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall) template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP") @@ -361,6 +362,7 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsStr)) } } + functionCall = setGeminiFunctionCallID(functionCall, value) pendingFunctionCalls = append(pendingFunctionCalls, functionCall) } @@ -410,6 +412,17 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string { return rev } +func setGeminiFunctionCallID(functionCall []byte, item gjson.Result) []byte { + if callID := strings.TrimSpace(item.Get("call_id").String()); callID != "" { + functionCall, _ = sjson.SetBytes(functionCall, "functionCall.id", callID) + return functionCall + } + if id := strings.TrimSpace(item.Get("id").String()); id != "" { + functionCall, _ = sjson.SetBytes(functionCall, "functionCall.id", id) + } + return functionCall +} + func GeminiTokenCount(ctx context.Context, count int64) []byte { return translatorcommon.GeminiTokenCountJSON(count) } diff --git a/internal/translator/codex/gemini/codex_gemini_response_test.go b/internal/translator/codex/gemini/codex_gemini_response_test.go index 547ee84715..55b1352908 100644 --- a/internal/translator/codex/gemini/codex_gemini_response_test.go +++ b/internal/translator/codex/gemini/codex_gemini_response_test.go @@ -109,3 +109,43 @@ func TestConvertCodexResponseToGemini_NonStreamImageGenerationCallAddsInlineData t.Fatalf("expected inlineData.mimeType %q, got %q; chunk=%s", "image/png", gotMime, string(out)) } } + +func TestConvertCodexResponseToGemini_StreamPreservesFunctionCallID(t *testing.T) { + ctx := context.Background() + originalRequest := []byte(`{"tools":[]}`) + var param any + + out := ConvertCodexResponseToGemini(ctx, "gemini-2.5-pro", originalRequest, nil, []byte(`data: {"type":"response.output_item.done","item":{"type":"function_call","call_id":"call_gateway","name":"lookup","arguments":"{\"query\":\"status\"}"}}`), ¶m) + if len(out) != 0 { + t.Fatalf("expected function call output to be buffered, got %d chunks", len(out)) + } + + out = ConvertCodexResponseToGemini(ctx, "gemini-2.5-pro", originalRequest, nil, []byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`), ¶m) + if len(out) == 0 { + t.Fatal("expected buffered function call to be emitted on completion") + } + + got := "" + for _, chunk := range out { + if value := gjson.GetBytes(chunk, "candidates.0.content.parts.0.functionCall.id").String(); value != "" { + got = value + break + } + } + if got != "call_gateway" { + t.Fatalf("expected functionCall.id %q, got %q; chunks=%q", "call_gateway", got, out) + } +} + +func TestConvertCodexResponseToGeminiNonStreamPreservesFunctionCallID(t *testing.T) { + ctx := context.Background() + originalRequest := []byte(`{"tools":[]}`) + + raw := []byte(`{"type":"response.completed","response":{"id":"resp_123","created_at":1700000000,"usage":{"input_tokens":1,"output_tokens":1},"output":[{"type":"function_call","call_id":"call_gateway","name":"lookup","arguments":"{\"query\":\"status\"}"}]}}`) + out := ConvertCodexResponseToGeminiNonStream(ctx, "gemini-2.5-pro", originalRequest, nil, raw, nil) + + got := gjson.GetBytes(out, "candidates.0.content.parts.0.functionCall.id").String() + if got != "call_gateway" { + t.Fatalf("expected functionCall.id %q, got %q; chunk=%s", "call_gateway", got, string(out)) + } +} diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_request.go b/internal/translator/codex/openai/chat-completions/codex_openai_request.go index 569e06e316..046216b42f 100644 --- a/internal/translator/codex/openai/chat-completions/codex_openai_request.go +++ b/internal/translator/codex/openai/chat-completions/codex_openai_request.go @@ -193,6 +193,20 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b msg, _ = sjson.SetRawBytes(msg, "content.-1", part) } } + case "input_audio": + if role == "user" { + audioData := it.Get("input_audio.data").String() + audioFormat := it.Get("input_audio.format").String() + if audioData != "" { + part := []byte(`{}`) + part, _ = sjson.SetBytes(part, "type", "input_audio") + part, _ = sjson.SetBytes(part, "data", audioData) + if audioFormat != "" { + part, _ = sjson.SetBytes(part, "format", audioFormat) + } + msg, _ = sjson.SetRawBytes(msg, "content.-1", part) + } + } } } } diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_request_test.go b/internal/translator/codex/openai/chat-completions/codex_openai_request_test.go index e31db6d373..5be9c8b851 100644 --- a/internal/translator/codex/openai/chat-completions/codex_openai_request_test.go +++ b/internal/translator/codex/openai/chat-completions/codex_openai_request_test.go @@ -352,6 +352,39 @@ func TestToolCallOutputWithNonStringJSONContent(t *testing.T) { } } +func TestConvertOpenAIRequestToCodexPreservesInputAudio(t *testing.T) { + input := []byte(`{ + "model": "gpt-5.5", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Transcribe this audio verbatim."}, + {"type": "input_audio", "input_audio": {"data": "SUQzBA==", "format": "mp3"}} + ] + } + ] + }`) + + out := ConvertOpenAIRequestToCodex("gpt-5.5", input, true) + parts := gjson.GetBytes(out, "input.0.content").Array() + if len(parts) != 2 { + t.Fatalf("expected 2 content parts, got %d: %s", len(parts), gjson.GetBytes(out, "input.0.content").Raw) + } + if parts[0].Get("type").String() != "input_text" || parts[0].Get("text").String() != "Transcribe this audio verbatim." { + t.Fatalf("part 0: expected input_text with prompt text, got %s", parts[0].Raw) + } + if parts[1].Get("type").String() != "input_audio" { + t.Fatalf("part 1: expected input_audio, got %s", parts[1].Raw) + } + if parts[1].Get("data").String() != "SUQzBA==" { + t.Fatalf("part 1: expected audio data to be preserved, got %s", parts[1].Get("data").String()) + } + if parts[1].Get("format").String() != "mp3" { + t.Fatalf("part 1: expected audio format mp3, got %s", parts[1].Get("format").String()) + } +} + // Parallel tool calls: assistant invokes 3 tools at once, all call_ids // and outputs must be translated and paired correctly. func TestMultipleToolCalls(t *testing.T) { diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_request.go b/internal/translator/codex/openai/responses/codex_openai-responses_request.go index cc218b12b3..be0383bcc5 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_request.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_request.go @@ -1,6 +1,7 @@ package responses import ( + "encoding/json" "fmt" log "github.com/sirupsen/logrus" @@ -71,18 +72,38 @@ func convertSystemRoleToDeveloper(rawJSON []byte) []byte { return rawJSON } - inputArray := inputResult.Array() - result := rawJSON + inputItems := inputResult.Array() + if len(inputItems) == 0 { + return rawJSON + } - // Directly modify role values for items with "system" role - for i := 0; i < len(inputArray); i++ { - rolePath := fmt.Sprintf("input.%d.role", i) - if gjson.GetBytes(result, rolePath).String() == "system" { - result, _ = sjson.SetBytes(result, rolePath, "developer") + changed := false + rebuiltInput := make([]json.RawMessage, 0, len(inputItems)) + for _, item := range inputItems { + itemRaw := []byte(item.Raw) + if item.IsObject() && item.Get("role").String() == "system" { + updatedItem, errSetItem := sjson.SetRawBytes(itemRaw, "role", []byte(`"developer"`)) + if errSetItem != nil { + return rawJSON + } + itemRaw = updatedItem + changed = true } + rebuiltInput = append(rebuiltInput, json.RawMessage(itemRaw)) + } + if !changed { + return rawJSON } - return result + inputRaw, errMarshalInput := json.Marshal(rebuiltInput) + if errMarshalInput != nil { + return rawJSON + } + updated, errSetInput := sjson.SetRawBytes(rawJSON, "input", inputRaw) + if errSetInput != nil { + return rawJSON + } + return updated } // normalizeCodexBuiltinTools rewrites legacy/preview built-in tool variants to the diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go b/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go index 3b48a76e04..7b0ebadb38 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go @@ -1,11 +1,17 @@ package responses import ( + "fmt" + "strconv" + "strings" "testing" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) +var benchmarkConvertSystemRoleOutput []byte + // TestConvertSystemRoleToDeveloper_BasicConversion tests the basic system -> developer role conversion func TestConvertSystemRoleToDeveloper_BasicConversion(t *testing.T) { inputJSON := []byte(`{ @@ -364,3 +370,101 @@ func TestTruncationRemovedForCodexCompatibility(t *testing.T) { t.Fatalf("truncation should be removed for Codex compatibility") } } + +func BenchmarkConvertSystemRoleToDeveloperLargeInput(b *testing.B) { + cases := []struct { + name string + inputJSON []byte + }{ + { + name: "200_input_1_system", + inputJSON: makeLargeResponsesInputForBenchmark(200, 200), + }, + { + name: "200_input_2_system", + inputJSON: makeLargeResponsesInputForBenchmark(200, 100), + }, + { + name: "2000_input_20_system", + inputJSON: makeLargeResponsesInputForBenchmark(2000, 100), + }, + } + benchmarks := []struct { + name string + fn func([]byte) []byte + }{ + { + name: "previous_root_path_rewrite", + fn: convertSystemRoleToDeveloperPreviousRootPathRewriteForBenchmark, + }, + { + name: "current_rebuilt_input_json_marshal", + fn: convertSystemRoleToDeveloper, + }, + } + + for _, testCase := range cases { + for _, benchmark := range benchmarks { + b.Run(testCase.name+"/"+benchmark.name, func(b *testing.B) { + output := benchmark.fn(testCase.inputJSON) + if got := gjson.GetBytes(output, "input.0.role").String(); got != "developer" { + b.Fatalf("input.0.role = %q, want %q", got, "developer") + } + if got := gjson.GetBytes(output, "input.1.role").String(); got != "user" { + b.Fatalf("input.1.role = %q, want %q", got, "user") + } + + b.ReportAllocs() + b.SetBytes(int64(len(testCase.inputJSON))) + b.ResetTimer() + + var benchmarkOutput []byte + for i := 0; i < b.N; i++ { + benchmarkOutput = benchmark.fn(testCase.inputJSON) + } + benchmarkConvertSystemRoleOutput = benchmarkOutput + }) + } + } +} + +func makeLargeResponsesInputForBenchmark(inputCount int, systemEvery int) []byte { + var builder strings.Builder + builder.Grow(inputCount * 96) + builder.WriteString(`{"model":"gpt-5.2","input":[`) + for i := 0; i < inputCount; i++ { + if i > 0 { + builder.WriteByte(',') + } + role := "user" + if i%systemEvery == 0 { + role = "system" + } + builder.WriteString(`{"type":"message","role":"`) + builder.WriteString(role) + builder.WriteString(`","content":[{"type":"input_text","text":"message `) + builder.WriteString(strconv.Itoa(i)) + builder.WriteString(`"}]}`) + } + builder.WriteString(`]}`) + return []byte(builder.String()) +} + +func convertSystemRoleToDeveloperPreviousRootPathRewriteForBenchmark(rawJSON []byte) []byte { + inputResult := gjson.GetBytes(rawJSON, "input") + if !inputResult.IsArray() { + return rawJSON + } + + inputArray := inputResult.Array() + result := rawJSON + + for i := 0; i < len(inputArray); i++ { + rolePath := fmt.Sprintf("input.%d.role", i) + if gjson.GetBytes(result, rolePath).String() == "system" { + result, _ = sjson.SetBytes(result, rolePath, "developer") + } + } + + return result +} diff --git a/internal/translator/common/bytes.go b/internal/translator/common/bytes.go index ff42d7e9d4..96bec594e2 100644 --- a/internal/translator/common/bytes.go +++ b/internal/translator/common/bytes.go @@ -2,18 +2,8 @@ package common import ( "strconv" - - "github.com/tidwall/sjson" ) -func WrapGeminiCLIResponse(response []byte) []byte { - out, err := sjson.SetRawBytes([]byte(`{"response":{}}`), "response", response) - if err != nil { - return response - } - return out -} - func GeminiTokenCountJSON(count int64) []byte { out := make([]byte, 0, 96) out = append(out, `{"totalTokens":`...) diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go deleted file mode 100644 index 5291df4378..0000000000 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go +++ /dev/null @@ -1,257 +0,0 @@ -// Package claude provides request translation functionality for Claude Code API compatibility. -// This package handles the conversion of Claude Code API requests into Gemini CLI-compatible -// JSON format, transforming message contents, system instructions, and tool declarations -// into the format expected by Gemini CLI API clients. It performs JSON data transformation -// to ensure compatibility between Claude Code API format and Gemini CLI API's expected format. -package claude - -import ( - "strings" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini/common" - "github.com/router-for-me/CLIProxyAPI/v7/internal/util" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator" - -// ConvertClaudeRequestToCLI parses and transforms a Claude Code API request into Gemini CLI API format. -// It extracts the model name, system instruction, message contents, and tool declarations -// from the raw JSON request and returns them in the format expected by the Gemini CLI API. -// The function performs the following transformations: -// 1. Extracts the model information from the request -// 2. Restructures the JSON to match Gemini CLI API format -// 3. Converts system instructions to the expected format -// 4. Maps message contents with proper role transformations -// 5. Handles tool declarations and tool choices -// 6. Maps generation configuration parameters -// -// Parameters: -// - modelName: The name of the model to use for the request -// - rawJSON: The raw JSON request data from the Claude Code API -// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation) -// -// Returns: -// - []byte: The transformed request data in Gemini CLI API format -func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []byte { - rawJSON := inputRawJSON - - // Build output Gemini CLI request JSON - out := []byte(`{"model":"","request":{"contents":[]}}`) - out, _ = sjson.SetBytes(out, "model", modelName) - - // system instruction - if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() { - systemInstruction := []byte(`{"role":"user","parts":[]}`) - hasSystemParts := false - systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool { - if systemPromptResult.Get("type").String() == "text" { - textResult := systemPromptResult.Get("text") - if textResult.Type == gjson.String { - if util.IsClaudeCodeAttributionSystemText(textResult.String()) { - return true - } - part := []byte(`{"text":""}`) - part, _ = sjson.SetBytes(part, "text", textResult.String()) - systemInstruction, _ = sjson.SetRawBytes(systemInstruction, "parts.-1", part) - hasSystemParts = true - } - } - return true - }) - if hasSystemParts { - out, _ = sjson.SetRawBytes(out, "request.systemInstruction", systemInstruction) - } - } else if systemResult.Type == gjson.String && !util.IsClaudeCodeAttributionSystemText(systemResult.String()) { - out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.-1.text", systemResult.String()) - } - - // contents - if messagesResult := gjson.GetBytes(rawJSON, "messages"); messagesResult.IsArray() { - messagesResult.ForEach(func(_, messageResult gjson.Result) bool { - roleResult := messageResult.Get("role") - if roleResult.Type != gjson.String { - return true - } - role := roleResult.String() - if role == "assistant" { - role = "model" - } else if role == "system" { - role = "user" - } - - contentJSON := []byte(`{"role":"","parts":[]}`) - contentJSON, _ = sjson.SetBytes(contentJSON, "role", role) - - contentsResult := messageResult.Get("content") - if contentsResult.IsArray() { - contentsResult.ForEach(func(_, contentResult gjson.Result) bool { - switch contentResult.Get("type").String() { - case "text": - part := []byte(`{"text":""}`) - part, _ = sjson.SetBytes(part, "text", contentResult.Get("text").String()) - contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part) - - case "tool_use": - functionName := util.SanitizeFunctionName(contentResult.Get("name").String()) - functionArgs := contentResult.Get("input").String() - argsResult := gjson.Parse(functionArgs) - if argsResult.IsObject() && gjson.Valid(functionArgs) { - part := []byte(`{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`) - part, _ = sjson.SetBytes(part, "thoughtSignature", geminiCLIClaudeThoughtSignature) - part, _ = sjson.SetBytes(part, "functionCall.name", functionName) - part, _ = sjson.SetRawBytes(part, "functionCall.args", []byte(functionArgs)) - contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part) - } - - case "tool_result": - toolCallID := contentResult.Get("tool_use_id").String() - if toolCallID == "" { - return true - } - funcName := toolCallID - toolCallIDs := strings.Split(toolCallID, "-") - if len(toolCallIDs) > 1 { - funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-") - } - toolResult := util.ConvertClaudeToolResultContent(contentResult.Get("content")) - part := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`) - part, _ = sjson.SetBytes(part, "functionResponse.name", util.SanitizeFunctionName(funcName)) - if toolResult.ResultIsRaw { - part, _ = sjson.SetRawBytes(part, "functionResponse.response.result", []byte(toolResult.Result)) - } else { - part, _ = sjson.SetBytes(part, "functionResponse.response.result", toolResult.Result) - } - contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part) - for _, img := range toolResult.Images { - imagePart := []byte(`{"inlineData":{"mime_type":"","data":""}}`) - imagePart, _ = sjson.SetBytes(imagePart, "inlineData.mime_type", img.MimeType) - imagePart, _ = sjson.SetBytes(imagePart, "inlineData.data", img.Data) - contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", imagePart) - } - - case "image": - source := contentResult.Get("source") - if source.Get("type").String() == "base64" { - mimeType := source.Get("media_type").String() - data := source.Get("data").String() - if mimeType != "" && data != "" { - part := []byte(`{"inlineData":{"mime_type":"","data":""}}`) - part, _ = sjson.SetBytes(part, "inlineData.mime_type", mimeType) - part, _ = sjson.SetBytes(part, "inlineData.data", data) - contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part) - } - } - } - return true - }) - out, _ = sjson.SetRawBytes(out, "request.contents.-1", contentJSON) - } else if contentsResult.Type == gjson.String { - part := []byte(`{"text":""}`) - part, _ = sjson.SetBytes(part, "text", contentsResult.String()) - contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part) - out, _ = sjson.SetRawBytes(out, "request.contents.-1", contentJSON) - } - return true - }) - } - - // tools - if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() { - hasTools := false - toolsResult.ForEach(func(_, toolResult gjson.Result) bool { - inputSchemaResult := toolResult.Get("input_schema") - if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { - inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw) - tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema") - tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema)) - tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String())) - tool, _ = sjson.DeleteBytes(tool, "strict") - tool, _ = sjson.DeleteBytes(tool, "input_examples") - tool, _ = sjson.DeleteBytes(tool, "type") - tool, _ = sjson.DeleteBytes(tool, "cache_control") - tool, _ = sjson.DeleteBytes(tool, "defer_loading") - tool, _ = sjson.DeleteBytes(tool, "eager_input_streaming") - if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() { - if !hasTools { - out, _ = sjson.SetRawBytes(out, "request.tools", []byte(`[{"functionDeclarations":[]}]`)) - hasTools = true - } - out, _ = sjson.SetRawBytes(out, "request.tools.0.functionDeclarations.-1", tool) - } - } - return true - }) - if !hasTools { - out, _ = sjson.DeleteBytes(out, "request.tools") - } - } - - // tool_choice - toolChoiceResult := gjson.GetBytes(rawJSON, "tool_choice") - if toolChoiceResult.Exists() { - toolChoiceType := "" - toolChoiceName := "" - if toolChoiceResult.IsObject() { - toolChoiceType = toolChoiceResult.Get("type").String() - toolChoiceName = toolChoiceResult.Get("name").String() - } else if toolChoiceResult.Type == gjson.String { - toolChoiceType = toolChoiceResult.String() - } - - switch toolChoiceType { - case "auto": - out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "AUTO") - case "none": - out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "NONE") - case "any": - out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY") - case "tool": - out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY") - if toolChoiceName != "" { - out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)}) - } - } - } - - // Map Anthropic thinking -> Gemini CLI thinkingConfig when enabled - // Translator only does format conversion, ApplyThinking handles model capability validation. - if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() { - switch t.Get("type").String() { - case "enabled": - if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { - budget := int(b.Int()) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true) - } - case "adaptive", "auto": - // For adaptive thinking: - // - If output_config.effort is explicitly present, pass through as thinkingLevel. - // - Otherwise, treat it as "enabled with target-model maximum" and emit high. - // ApplyThinking handles clamping to target model's supported levels. - effort := "" - if v := gjson.GetBytes(rawJSON, "output_config.effort"); v.Exists() && v.Type == gjson.String { - effort = strings.ToLower(strings.TrimSpace(v.String())) - } - if effort != "" { - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort) - } else { - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high") - } - out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true) - } - } - if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number { - out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", v.Num) - } - if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number { - out, _ = sjson.SetBytes(out, "request.generationConfig.topP", v.Num) - } - if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number { - out, _ = sjson.SetBytes(out, "request.generationConfig.topK", v.Num) - } - - out = common.AttachDefaultSafetySettings(out, "request.safetySettings") - return out -} diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_request_test.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_request_test.go deleted file mode 100644 index ea634205b1..0000000000 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_request_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package claude - -import ( - "testing" - - "github.com/tidwall/gjson" -) - -func TestConvertClaudeRequestToCLI_ToolChoice_SpecificTool(t *testing.T) { - inputJSON := []byte(`{ - "model": "gemini-3-flash-preview", - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "hi"} - ] - } - ], - "tools": [ - { - "name": "json", - "description": "A JSON tool", - "input_schema": { - "type": "object", - "properties": {} - } - } - ], - "tool_choice": {"type": "tool", "name": "json"} - }`) - - output := ConvertClaudeRequestToCLI("gemini-3-flash-preview", inputJSON, false) - - if got := gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.mode").String(); got != "ANY" { - t.Fatalf("Expected request.toolConfig.functionCallingConfig.mode 'ANY', got '%s'", got) - } - allowed := gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Array() - if len(allowed) != 1 || allowed[0].String() != "json" { - t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Raw) - } -} - -func TestConvertClaudeRequestToCLI_StripsClaudeCodeAttribution(t *testing.T) { - inputJSON := []byte(`{ - "model": "claude-sonnet-4-5", - "system": [ - {"type": "text", "text": "x-anthropic-billing-header: cc_version=2.1.63.abc; cc_entrypoint=cli; cch=12345;"}, - {"type": "text", "text": "User system prompt"} - ], - "messages": [{"role": "user", "content": [{"type": "text", "text": "hi"}]}] - }`) - - output := ConvertClaudeRequestToCLI("gemini-3-flash-preview", inputJSON, false) - - parts := gjson.GetBytes(output, "request.systemInstruction.parts").Array() - if len(parts) != 1 { - t.Fatalf("Expected 1 system part after attribution strip, got %d: %s", len(parts), gjson.GetBytes(output, "request.systemInstruction.parts").Raw) - } - if got := parts[0].Get("text").String(); got != "User system prompt" { - t.Fatalf("Unexpected system part: %q", got) - } -} - -func TestConvertClaudeRequestToCLI_ConvertsMessageSystemRoleToUserContent(t *testing.T) { - inputJSON := []byte(`{ - "model": "gemini-3-flash-preview", - "system": [{"type": "text", "text": "Top-level rules"}], - "messages": [ - {"role": "user", "content": [{"type": "text", "text": "Hello"}]}, - {"role": "system", "content": "String mid-conversation rule"}, - {"role": "system", "content": [{"type": "text", "text": "Array mid-conversation rule"}]} - ] - }`) - - output := ConvertClaudeRequestToCLI("gemini-3-flash-preview", inputJSON, false) - - if systemContent := gjson.GetBytes(output, `request.contents.#(role=="system")`); systemContent.Exists() { - t.Fatalf("system role should not be emitted in request.contents: %s", systemContent.Raw) - } - - contents := gjson.GetBytes(output, "request.contents").Array() - if len(contents) != 3 { - t.Fatalf("Expected the user and message-level system turns in request.contents, got %d: %s", len(contents), gjson.GetBytes(output, "request.contents").Raw) - } - if got := contents[0].Get("role").String(); got != "user" { - t.Fatalf("Expected first content role user, got %q", got) - } - if got := contents[1].Get("role").String(); got != "user" { - t.Fatalf("Expected message-level string system content to be downgraded to user role, got %q", got) - } - if got := contents[1].Get("parts.0.text").String(); got != "String mid-conversation rule" { - t.Fatalf("Unexpected string message-level system content text: %q", got) - } - if got := contents[2].Get("role").String(); got != "user" { - t.Fatalf("Expected message-level array system content to be downgraded to user role, got %q", got) - } - if got := contents[2].Get("parts.0.text").String(); got != "Array mid-conversation rule" { - t.Fatalf("Unexpected array message-level system content text: %q", got) - } - - parts := gjson.GetBytes(output, "request.systemInstruction.parts").Array() - if len(parts) != 1 { - t.Fatalf("Expected only top-level system parts, got %d: %s", len(parts), gjson.GetBytes(output, "request.systemInstruction.parts").Raw) - } - if got := parts[0].Get("text").String(); got != "Top-level rules" { - t.Fatalf("Unexpected first system part: %q", got) - } -} - -func TestConvertClaudeRequestToCLI_StructuredToolResult(t *testing.T) { - inputJSON := []byte(`{ - "model": "gemini-3-flash-preview", - "messages": [ - { - "role": "assistant", - "content": [ - {"type": "tool_use", "id": "json-call-1", "name": "json", "input": {"ok": true}} - ] - }, - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": "json-call-1", - "content": [ - {"type": "text", "text": "alpha"}, - {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": "aGVsbG8="}} - ] - } - ] - } - ] - }`) - - output := ConvertClaudeRequestToCLI("gemini-3-flash-preview", inputJSON, false) - - fr := gjson.GetBytes(output, "request.contents.1.parts.0.functionResponse") - if !fr.Exists() { - t.Fatalf("expected functionResponse part, contents=%s", gjson.GetBytes(output, "request.contents").Raw) - } - // The text block must remain structured JSON, not a double-encoded string blob. - if got := fr.Get("response.result.text").String(); got != "alpha" { - t.Fatalf("expected structured result text 'alpha', got result=%s", fr.Get("response.result").Raw) - } - // The image block must be emitted as a separate inlineData part, not embedded in result. - img := gjson.GetBytes(output, "request.contents.1.parts.1.inlineData") - if got := img.Get("mime_type").String(); got != "image/png" { - t.Fatalf("expected image mime type 'image/png', got '%s'", got) - } - if got := img.Get("data").String(); got != "aGVsbG8=" { - t.Fatalf("expected image data 'aGVsbG8=', got '%s'", got) - } -} - -func TestConvertClaudeRequestToCLI_StringToolResult(t *testing.T) { - inputJSON := []byte(`{ - "model": "gemini-3-flash-preview", - "messages": [ - { - "role": "assistant", - "content": [ - {"type": "tool_use", "id": "json-call-1", "name": "json", "input": {"ok": true}} - ] - }, - { - "role": "user", - "content": [ - {"type": "tool_result", "tool_use_id": "json-call-1", "content": "alpha"} - ] - } - ] - }`) - - output := ConvertClaudeRequestToCLI("gemini-3-flash-preview", inputJSON, false) - - fr := gjson.GetBytes(output, "request.contents.1.parts.0.functionResponse") - if !fr.Exists() { - t.Fatalf("expected functionResponse part, contents=%s", gjson.GetBytes(output, "request.contents").Raw) - } - // String content must not be double-encoded: result should be exactly "alpha". - if got := fr.Get("response.result").String(); got != "alpha" { - t.Fatalf("expected result 'alpha', got '%s' (raw=%s)", got, fr.Get("response.result").Raw) - } -} diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go deleted file mode 100644 index 607d6b9fc0..0000000000 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go +++ /dev/null @@ -1,358 +0,0 @@ -// Package claude provides response translation functionality for Claude Code API compatibility. -// This package handles the conversion of backend client responses into Claude Code-compatible -// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages -// different response types including text content, thinking processes, and function calls. -// The translation ensures proper sequencing of SSE events and maintains state across -// multiple response chunks to provide a seamless streaming experience. -package claude - -import ( - "bytes" - "context" - "fmt" - "strings" - "sync/atomic" - "time" - - translatorcommon "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/common" - "github.com/router-for-me/CLIProxyAPI/v7/internal/util" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// Params holds parameters for response conversion and maintains state across streaming chunks. -// This structure tracks the current state of the response translation process to ensure -// proper sequencing of SSE events and transitions between different content types. -type Params struct { - HasFirstResponse bool // Indicates if the initial message_start event has been sent - ResponseType int // Current response type: 0=none, 1=content, 2=thinking, 3=function - ResponseIndex int // Index counter for content blocks in the streaming response - HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output - - // Reverse map: sanitized Gemini function name → original Claude tool name. - ToolNameMap map[string]string -} - -// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. -var toolUseIDCounter uint64 - -// ConvertGeminiCLIResponseToClaude performs sophisticated streaming response format conversion. -// This function implements a complex state machine that translates backend client responses -// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types -// and handles state transitions between content blocks, thinking processes, and function calls. -// -// Response type states: 0=none, 1=content, 2=thinking, 3=function -// The function maintains state across multiple calls to ensure proper SSE event sequencing. -// -// Parameters: -// - ctx: The context for the request, used for cancellation and timeout handling -// - modelName: The name of the model being used for the response (unused in current implementation) -// - rawJSON: The raw JSON response from the Gemini CLI API -// - param: A pointer to a parameter object for maintaining state between calls -// -// Returns: -// - [][]byte: A slice of bytes, each containing a Claude Code-compatible SSE payload. -func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { - if *param == nil { - *param = &Params{ - HasFirstResponse: false, - ResponseType: 0, - ResponseIndex: 0, - ToolNameMap: util.SanitizedToolNameMap(originalRequestRawJSON), - } - } - - if bytes.Equal(rawJSON, []byte("[DONE]")) { - // Only send message_stop if we have actually output content - if (*param).(*Params).HasContent { - return [][]byte{translatorcommon.AppendSSEEventString(nil, "message_stop", `{"type":"message_stop"}`, 3)} - } - return [][]byte{} - } - - // Track whether tools are being used in this response chunk - usedTool := false - output := make([]byte, 0, 1024) - appendEvent := func(event, payload string) { - output = translatorcommon.AppendSSEEventString(output, event, payload, 3) - } - - // Initialize the streaming session with a message_start event - // This is only sent for the very first response chunk to establish the streaming session - if !(*param).(*Params).HasFirstResponse { - // Create the initial message structure with default values according to Claude Code API specification - // This follows the Claude Code API specification for streaming message initialization - messageStartTemplate := []byte(`{"type":"message_start","message":{"id":"msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY","type":"message","role":"assistant","content":[],"model":"claude-3-5-sonnet-20241022","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`) - - // Override default values with actual response metadata if available from the Gemini CLI response - if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { - messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.model", modelVersionResult.String()) - } - if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { - messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.id", responseIDResult.String()) - } - appendEvent("message_start", string(messageStartTemplate)) - - (*param).(*Params).HasFirstResponse = true - } - - // Process the response parts array from the backend client - // Each part can contain text content, thinking content, or function calls - partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts") - if partsResult.IsArray() { - partResults := partsResult.Array() - for i := 0; i < len(partResults); i++ { - partResult := partResults[i] - - // Extract the different types of content from each part - partTextResult := partResult.Get("text") - functionCallResult := partResult.Get("functionCall") - - // Handle text content (both regular content and thinking) - if partTextResult.Exists() { - // Process thinking content (internal reasoning) - if partResult.Get("thought").Bool() { - // Continue existing thinking block if already in thinking state - if (*param).(*Params).ResponseType == 2 { - data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex)), "delta.thinking", partTextResult.String()) - appendEvent("content_block_delta", string(data)) - (*param).(*Params).HasContent = true - } else { - // Transition from another state to thinking - // First, close any existing content block - if (*param).(*Params).ResponseType != 0 { - if (*param).(*Params).ResponseType == 2 { - // output = output + "event: content_block_delta\n" - // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) - // output = output + "\n\n\n" - } - appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) - (*param).(*Params).ResponseIndex++ - } - - // Start a new thinking content block - appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex)) - data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex)), "delta.thinking", partTextResult.String()) - appendEvent("content_block_delta", string(data)) - (*param).(*Params).ResponseType = 2 // Set state to thinking - (*param).(*Params).HasContent = true - } - } else { - // Process regular text content (user-visible output) - // Continue existing text block if already in content state - if (*param).(*Params).ResponseType == 1 { - data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex)), "delta.text", partTextResult.String()) - appendEvent("content_block_delta", string(data)) - (*param).(*Params).HasContent = true - } else { - // Transition from another state to text content - // First, close any existing content block - if (*param).(*Params).ResponseType != 0 { - if (*param).(*Params).ResponseType == 2 { - // output = output + "event: content_block_delta\n" - // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) - // output = output + "\n\n\n" - } - appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) - (*param).(*Params).ResponseIndex++ - } - - // Start a new text content block - appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex)) - data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex)), "delta.text", partTextResult.String()) - appendEvent("content_block_delta", string(data)) - (*param).(*Params).ResponseType = 1 // Set state to content - (*param).(*Params).HasContent = true - } - } - } else if functionCallResult.Exists() { - // Handle function/tool calls from the AI model - // This processes tool usage requests and formats them for Claude Code API compatibility - usedTool = true - fcName := util.RestoreSanitizedToolName((*param).(*Params).ToolNameMap, functionCallResult.Get("name").String()) - - // Handle state transitions when switching to function calls - // Close any existing function call block first - if (*param).(*Params).ResponseType == 3 { - appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) - (*param).(*Params).ResponseIndex++ - (*param).(*Params).ResponseType = 0 - } - - // Special handling for thinking state transition - if (*param).(*Params).ResponseType == 2 { - // output = output + "event: content_block_delta\n" - // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) - // output = output + "\n\n\n" - } - - // Close any other existing content block - if (*param).(*Params).ResponseType != 0 { - appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) - (*param).(*Params).ResponseIndex++ - } - - // Start a new tool use content block - // This creates the structure for a function call in Claude Code format - // Create the tool use block with unique ID and function details - data := []byte(fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)) - data, _ = sjson.SetBytes(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))) - data, _ = sjson.SetBytes(data, "content_block.name", fcName) - appendEvent("content_block_start", string(data)) - - if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { - data, _ = sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex)), "delta.partial_json", fcArgsResult.Raw) - appendEvent("content_block_delta", string(data)) - } - (*param).(*Params).ResponseType = 3 - (*param).(*Params).HasContent = true - } - } - } - - usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata") - // Process usage metadata and finish reason when present in the response - if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) { - if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { - // Only send final events if we have actually output content - if (*param).(*Params).HasContent { - // Close the final content block - appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)) - - // Create the message delta template with appropriate stop reason - template := []byte(`{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`) - // Set tool_use stop reason if tools were used in this response - if usedTool { - template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`) - } else if finish := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" { - template = []byte(`{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`) - } - - // Include thinking tokens in output token count if present - thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() - template, _ = sjson.SetBytes(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount) - template, _ = sjson.SetBytes(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int()) - - appendEvent("message_delta", string(template)) - } - } - } - - return [][]byte{output} -} - -// ConvertGeminiCLIResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response. -// -// Parameters: -// - ctx: The context for the request. -// - modelName: The name of the model. -// - rawJSON: The raw JSON response from the Gemini CLI API. -// - param: A pointer to a parameter object for the conversion. -// -// Returns: -// - []byte: A Claude-compatible JSON response. -func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte { - toolNameMap := util.SanitizedToolNameMap(originalRequestRawJSON) - _ = requestRawJSON - - root := gjson.ParseBytes(rawJSON) - - out := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`) - out, _ = sjson.SetBytes(out, "id", root.Get("response.responseId").String()) - out, _ = sjson.SetBytes(out, "model", root.Get("response.modelVersion").String()) - - inputTokens := root.Get("response.usageMetadata.promptTokenCount").Int() - outputTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int() - out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens) - out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens) - - parts := root.Get("response.candidates.0.content.parts") - textBuilder := strings.Builder{} - thinkingBuilder := strings.Builder{} - toolIDCounter := 0 - hasToolCall := false - - flushText := func() { - if textBuilder.Len() == 0 { - return - } - block := []byte(`{"type":"text","text":""}`) - block, _ = sjson.SetBytes(block, "text", textBuilder.String()) - out, _ = sjson.SetRawBytes(out, "content.-1", block) - textBuilder.Reset() - } - - flushThinking := func() { - if thinkingBuilder.Len() == 0 { - return - } - block := []byte(`{"type":"thinking","thinking":""}`) - block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String()) - out, _ = sjson.SetRawBytes(out, "content.-1", block) - thinkingBuilder.Reset() - } - - if parts.IsArray() { - for _, part := range parts.Array() { - if text := part.Get("text"); text.Exists() && text.String() != "" { - if part.Get("thought").Bool() { - flushText() - thinkingBuilder.WriteString(text.String()) - continue - } - flushThinking() - textBuilder.WriteString(text.String()) - continue - } - - if functionCall := part.Get("functionCall"); functionCall.Exists() { - flushThinking() - flushText() - hasToolCall = true - - name := util.RestoreSanitizedToolName(toolNameMap, functionCall.Get("name").String()) - toolIDCounter++ - toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`) - toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter)) - toolBlock, _ = sjson.SetBytes(toolBlock, "name", name) - inputRaw := "{}" - if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() { - inputRaw = args.Raw - } - toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(inputRaw)) - out, _ = sjson.SetRawBytes(out, "content.-1", toolBlock) - continue - } - } - } - - flushThinking() - flushText() - - stopReason := "end_turn" - if hasToolCall { - stopReason = "tool_use" - } else { - if finish := root.Get("response.candidates.0.finishReason"); finish.Exists() { - switch finish.String() { - case "MAX_TOKENS": - stopReason = "max_tokens" - case "STOP", "FINISH_REASON_UNSPECIFIED", "UNKNOWN": - stopReason = "end_turn" - default: - stopReason = "end_turn" - } - } - } - out, _ = sjson.SetBytes(out, "stop_reason", stopReason) - - if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("response.usageMetadata").Exists() { - out, _ = sjson.DeleteBytes(out, "usage") - } - - return out -} - -func ClaudeTokenCount(ctx context.Context, count int64) []byte { - return translatorcommon.ClaudeInputTokensJSON(count) -} diff --git a/internal/translator/gemini-cli/claude/init.go b/internal/translator/gemini-cli/claude/init.go deleted file mode 100644 index fa2fabdf77..0000000000 --- a/internal/translator/gemini-cli/claude/init.go +++ /dev/null @@ -1,20 +0,0 @@ -package claude - -import ( - . "github.com/router-for-me/CLIProxyAPI/v7/internal/constant" - "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" - "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/translator" -) - -func init() { - translator.Register( - Claude, - GeminiCLI, - ConvertClaudeRequestToCLI, - interfaces.TranslateResponse{ - Stream: ConvertGeminiCLIResponseToClaude, - NonStream: ConvertGeminiCLIResponseToClaudeNonStream, - TokenCount: ClaudeTokenCount, - }, - ) -} diff --git a/internal/translator/gemini-cli/gemini/gemini-cli_gemini_request.go b/internal/translator/gemini-cli/gemini/gemini-cli_gemini_request.go deleted file mode 100644 index 3627757502..0000000000 --- a/internal/translator/gemini-cli/gemini/gemini-cli_gemini_request.go +++ /dev/null @@ -1,286 +0,0 @@ -// Package gemini provides request translation functionality for Gemini CLI to Gemini API compatibility. -// It handles parsing and transforming Gemini CLI API requests into Gemini API format, -// extracting model information, system instructions, message contents, and tool declarations. -// The package performs JSON data transformation to ensure compatibility -// between Gemini CLI API format and Gemini API's expected format. -package gemini - -import ( - "fmt" - "strings" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/signature" - "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini/common" - "github.com/router-for-me/CLIProxyAPI/v7/internal/util" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// ConvertGeminiRequestToGeminiCLI parses and transforms a Gemini CLI API request into Gemini API format. -// It extracts the model name, system instruction, message contents, and tool declarations -// from the raw JSON request and returns them in the format expected by the Gemini API. -// The function performs the following transformations: -// 1. Extracts the model information from the request -// 2. Restructures the JSON to match Gemini API format -// 3. Converts system instructions to the expected format -// 4. Fixes CLI tool response format and grouping -// -// Parameters: -// - modelName: The name of the model to use for the request (unused in current implementation) -// - rawJSON: The raw JSON request data from the Gemini CLI API -// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation) -// -// Returns: -// - []byte: The transformed request data in Gemini API format -func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []byte { - rawJSON := inputRawJSON - template := []byte(`{"project":"","request":{},"model":""}`) - template, _ = sjson.SetRawBytes(template, "request", rawJSON) - template, _ = sjson.SetBytes(template, "model", gjson.GetBytes(template, "request.model").String()) - template, _ = sjson.DeleteBytes(template, "request.model") - - templateStr, errFixCLIToolResponse := fixCLIToolResponse(string(template)) - if errFixCLIToolResponse != nil { - return []byte{} - } - template = []byte(templateStr) - - systemInstructionResult := gjson.GetBytes(template, "request.system_instruction") - if systemInstructionResult.Exists() { - template, _ = sjson.SetRawBytes(template, "request.systemInstruction", []byte(systemInstructionResult.Raw)) - template, _ = sjson.DeleteBytes(template, "request.system_instruction") - } - rawJSON = template - - // Normalize roles in request.contents: default to valid values if missing/invalid - contents := gjson.GetBytes(rawJSON, "request.contents") - if contents.Exists() { - prevRole := "" - idx := 0 - contents.ForEach(func(_ gjson.Result, value gjson.Result) bool { - role := value.Get("role").String() - valid := role == "user" || role == "model" - if role == "" || !valid { - var newRole string - if prevRole == "" { - newRole = "user" - } else if prevRole == "user" { - newRole = "model" - } else { - newRole = "user" - } - path := fmt.Sprintf("request.contents.%d.role", idx) - rawJSON, _ = sjson.SetBytes(rawJSON, path, newRole) - role = newRole - } - prevRole = role - idx++ - return true - }) - } - - toolsResult := gjson.GetBytes(rawJSON, "request.tools") - if toolsResult.Exists() && toolsResult.IsArray() { - toolResults := toolsResult.Array() - for i := 0; i < len(toolResults); i++ { - functionDeclarationsResult := gjson.GetBytes(rawJSON, fmt.Sprintf("request.tools.%d.function_declarations", i)) - if functionDeclarationsResult.Exists() && functionDeclarationsResult.IsArray() { - functionDeclarationsResults := functionDeclarationsResult.Array() - for j := 0; j < len(functionDeclarationsResults); j++ { - parametersResult := gjson.GetBytes(rawJSON, fmt.Sprintf("request.tools.%d.function_declarations.%d.parameters", i, j)) - if parametersResult.Exists() { - strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("request.tools.%d.function_declarations.%d.parameters", i, j), fmt.Sprintf("request.tools.%d.function_declarations.%d.parametersJsonSchema", i, j)) - rawJSON = []byte(strJson) - } - } - } - } - } - - rawJSON = signature.SanitizeGeminiRequestThoughtSignatures(rawJSON, "request.contents") - - // Filter out contents with empty parts to avoid Gemini API error: - // "required oneof field 'data' must have one initialized field" - filteredContents := []byte(`[]`) - hasFiltered := false - gjson.GetBytes(rawJSON, "request.contents").ForEach(func(_, content gjson.Result) bool { - parts := content.Get("parts") - if !parts.IsArray() || len(parts.Array()) == 0 { - hasFiltered = true - return true - } - filteredContents, _ = sjson.SetRawBytes(filteredContents, "-1", []byte(content.Raw)) - return true - }) - if hasFiltered { - rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents", filteredContents) - } - - return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings") -} - -// FunctionCallGroup represents a group of function calls and their responses -type FunctionCallGroup struct { - ResponsesNeeded int - CallNames []string // ordered function call names for backfilling empty response names -} - -// backfillFunctionResponseName ensures that a functionResponse JSON object has a non-empty name, -// falling back to fallbackName if the original is empty. -func backfillFunctionResponseName(raw string, fallbackName string) string { - name := gjson.Get(raw, "functionResponse.name").String() - if strings.TrimSpace(name) == "" && fallbackName != "" { - rawBytes, _ := sjson.SetBytes([]byte(raw), "functionResponse.name", fallbackName) - raw = string(rawBytes) - } - return raw -} - -// fixCLIToolResponse performs sophisticated tool response format conversion and grouping. -// This function transforms the CLI tool response format by intelligently grouping function calls -// with their corresponding responses, ensuring proper conversation flow and API compatibility. -// It converts from a linear format (1.json) to a grouped format (2.json) where function calls -// and their responses are properly associated and structured. -// -// Parameters: -// - input: The input JSON string to be processed -// -// Returns: -// - string: The processed JSON string with grouped function calls and responses -// - error: An error if the processing fails -func fixCLIToolResponse(input string) (string, error) { - // Parse the input JSON to extract the conversation structure - parsed := gjson.Parse(input) - - // Extract the contents array which contains the conversation messages - contents := parsed.Get("request.contents") - if !contents.Exists() { - // log.Debugf(input) - return input, fmt.Errorf("contents not found in input") - } - - // Initialize data structures for processing and grouping - contentsWrapper := []byte(`{"contents":[]}`) - var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses - var collectedResponses []gjson.Result // Standalone responses to be matched - - // Process each content object in the conversation - // This iterates through messages and groups function calls with their responses - contents.ForEach(func(key, value gjson.Result) bool { - role := value.Get("role").String() - parts := value.Get("parts") - - // Check if this content has function responses - var responsePartsInThisContent []gjson.Result - parts.ForEach(func(_, part gjson.Result) bool { - if part.Get("functionResponse").Exists() { - responsePartsInThisContent = append(responsePartsInThisContent, part) - } - return true - }) - - // If this content has function responses, collect them - if len(responsePartsInThisContent) > 0 { - collectedResponses = append(collectedResponses, responsePartsInThisContent...) - - // Check if pending groups can be satisfied (FIFO: oldest group first) - for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded { - group := pendingGroups[0] - pendingGroups = pendingGroups[1:] - - // Take the needed responses for this group - groupResponses := collectedResponses[:group.ResponsesNeeded] - collectedResponses = collectedResponses[group.ResponsesNeeded:] - - // Create merged function response content - functionResponseContent := []byte(`{"parts":[],"role":"function"}`) - for ri, response := range groupResponses { - if !response.IsObject() { - log.Warnf("failed to parse function response") - continue - } - raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri]) - functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(raw)) - } - - if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 { - contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent) - } - } - - return true // Skip adding this content, responses are merged - } - - // If this is a model with function calls, create a new group - if role == "model" { - var callNames []string - parts.ForEach(func(_, part gjson.Result) bool { - if part.Get("functionCall").Exists() { - callNames = append(callNames, part.Get("functionCall.name").String()) - } - return true - }) - - if len(callNames) > 0 { - // Add the model content - if !value.IsObject() { - log.Warnf("failed to parse model content") - return true - } - contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw)) - - // Create a new group for tracking responses - group := &FunctionCallGroup{ - ResponsesNeeded: len(callNames), - CallNames: callNames, - } - pendingGroups = append(pendingGroups, group) - } else { - // Regular model content without function calls - if !value.IsObject() { - log.Warnf("failed to parse content") - return true - } - contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw)) - } - } else { - // Non-model content (user, etc.) - if !value.IsObject() { - log.Warnf("failed to parse content") - return true - } - contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw)) - } - - return true - }) - - // Handle any remaining pending groups with remaining responses - for _, group := range pendingGroups { - if len(collectedResponses) >= group.ResponsesNeeded { - groupResponses := collectedResponses[:group.ResponsesNeeded] - collectedResponses = collectedResponses[group.ResponsesNeeded:] - - functionResponseContent := []byte(`{"parts":[],"role":"function"}`) - for ri, response := range groupResponses { - if !response.IsObject() { - log.Warnf("failed to parse function response") - continue - } - raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri]) - functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(raw)) - } - - if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 { - contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent) - } - } - } - - // Update the original JSON with the new contents - result := []byte(input) - result, _ = sjson.SetRawBytes(result, "request.contents", []byte(gjson.GetBytes(contentsWrapper, "contents").Raw)) - - return string(result), nil -} diff --git a/internal/translator/gemini-cli/gemini/gemini-cli_gemini_response.go b/internal/translator/gemini-cli/gemini/gemini-cli_gemini_response.go deleted file mode 100644 index 0e100c1489..0000000000 --- a/internal/translator/gemini-cli/gemini/gemini-cli_gemini_response.go +++ /dev/null @@ -1,86 +0,0 @@ -// Package gemini provides request translation functionality for Gemini to Gemini CLI API compatibility. -// It handles parsing and transforming Gemini API requests into Gemini CLI API format, -// extracting model information, system instructions, message contents, and tool declarations. -// The package performs JSON data transformation to ensure compatibility -// between Gemini API format and Gemini CLI API's expected format. -package gemini - -import ( - "bytes" - "context" - - translatorcommon "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/common" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// ConvertGeminiCliResponseToGemini parses and transforms a Gemini CLI API request into Gemini API format. -// It extracts the model name, system instruction, message contents, and tool declarations -// from the raw JSON request and returns them in the format expected by the Gemini API. -// The function performs the following transformations: -// 1. Extracts the response data from the request -// 2. Handles alternative response formats -// 3. Processes array responses by extracting individual response objects -// -// Parameters: -// - ctx: The context for the request, used for cancellation and timeout handling -// - modelName: The name of the model to use for the request (unused in current implementation) -// - rawJSON: The raw JSON request data from the Gemini CLI API -// - param: A pointer to a parameter object for the conversion (unused in current implementation) -// -// Returns: -// - [][]byte: The transformed request data in Gemini API format -func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) [][]byte { - if bytes.HasPrefix(rawJSON, []byte("data:")) { - rawJSON = bytes.TrimSpace(rawJSON[5:]) - } - - if alt, ok := ctx.Value("alt").(string); ok { - var chunk []byte - if alt == "" { - responseResult := gjson.GetBytes(rawJSON, "response") - if responseResult.Exists() { - chunk = []byte(responseResult.Raw) - } - } else { - chunkTemplate := []byte(`[]`) - responseResult := gjson.ParseBytes(chunk) - if responseResult.IsArray() { - responseResultItems := responseResult.Array() - for i := 0; i < len(responseResultItems); i++ { - responseResultItem := responseResultItems[i] - if responseResultItem.Get("response").Exists() { - chunkTemplate, _ = sjson.SetRawBytes(chunkTemplate, "-1", []byte(responseResultItem.Get("response").Raw)) - } - } - } - chunk = chunkTemplate - } - return [][]byte{chunk} - } - return [][]byte{} -} - -// ConvertGeminiCliResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response. -// This function processes the complete Gemini CLI request and transforms it into a single Gemini-compatible -// JSON response. It extracts the response data from the request and returns it in the expected format. -// -// Parameters: -// - ctx: The context for the request, used for cancellation and timeout handling -// - modelName: The name of the model being used for the response (unused in current implementation) -// - rawJSON: The raw JSON request data from the Gemini CLI API -// - param: A pointer to a parameter object for the conversion (unused in current implementation) -// -// Returns: -// - []byte: A Gemini-compatible JSON response containing the response data -func ConvertGeminiCliResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte { - responseResult := gjson.GetBytes(rawJSON, "response") - if responseResult.Exists() { - return []byte(responseResult.Raw) - } - return rawJSON -} - -func GeminiTokenCount(ctx context.Context, count int64) []byte { - return translatorcommon.GeminiTokenCountJSON(count) -} diff --git a/internal/translator/gemini-cli/gemini/init.go b/internal/translator/gemini-cli/gemini/init.go deleted file mode 100644 index 1c2f38f215..0000000000 --- a/internal/translator/gemini-cli/gemini/init.go +++ /dev/null @@ -1,20 +0,0 @@ -package gemini - -import ( - . "github.com/router-for-me/CLIProxyAPI/v7/internal/constant" - "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" - "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/translator" -) - -func init() { - translator.Register( - Gemini, - GeminiCLI, - ConvertGeminiRequestToGeminiCLI, - interfaces.TranslateResponse{ - Stream: ConvertGeminiCliResponseToGemini, - NonStream: ConvertGeminiCliResponseToGeminiNonStream, - TokenCount: GeminiTokenCount, - }, - ) -} diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go deleted file mode 100644 index c0c7a8deb8..0000000000 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go +++ /dev/null @@ -1,416 +0,0 @@ -// Package openai provides request translation functionality for OpenAI to Gemini CLI API compatibility. -// It converts OpenAI Chat Completions requests into Gemini CLI compatible JSON using gjson/sjson only. -package chat_completions - -import ( - "fmt" - "strings" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" - sigcompat "github.com/router-for-me/CLIProxyAPI/v7/internal/signature" - "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini/common" - "github.com/router-for-me/CLIProxyAPI/v7/internal/util" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator" - -// ConvertOpenAIRequestToGeminiCLI converts an OpenAI Chat Completions request (raw JSON) -// into a complete Gemini CLI request JSON. All JSON construction uses sjson and lookups use gjson. -// -// Parameters: -// - modelName: The name of the model to use for the request -// - rawJSON: The raw JSON request data from the OpenAI API -// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation) -// -// Returns: -// - []byte: The transformed request data in Gemini CLI API format -func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bool) []byte { - rawJSON := inputRawJSON - // Base envelope (no default thinkingConfig) - out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`) - - // Model - out, _ = sjson.SetBytes(out, "model", modelName) - - // Let user-provided generationConfig pass through - if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() { - out, _ = sjson.SetRawBytes(out, "request.generationConfig", []byte(genConfig.Raw)) - } - - // Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig. - // Inline translation-only mapping; capability checks happen later in ApplyThinking. - re := gjson.GetBytes(rawJSON, "reasoning_effort") - if re.Exists() { - effort := strings.ToLower(strings.TrimSpace(re.String())) - if effort != "" { - thinkingPath := "request.generationConfig.thinkingConfig" - if effort == "auto" { - out, _ = sjson.SetBytes(out, thinkingPath+".thinkingBudget", -1) - out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", true) - } else { - out, _ = sjson.SetBytes(out, thinkingPath+".thinkingLevel", effort) - out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", effort != "none") - } - } - } - - // Temperature/top_p/top_k - if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number { - out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num) - } - if tpr := gjson.GetBytes(rawJSON, "top_p"); tpr.Exists() && tpr.Type == gjson.Number { - out, _ = sjson.SetBytes(out, "request.generationConfig.topP", tpr.Num) - } - if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number { - out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num) - } - - // Candidate count (OpenAI 'n' parameter) - if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number { - if val := n.Int(); val > 1 { - out, _ = sjson.SetBytes(out, "request.generationConfig.candidateCount", val) - } - } - - // Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities - // e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"] - if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() { - var responseMods []string - for _, m := range mods.Array() { - switch strings.ToLower(m.String()) { - case "text": - responseMods = append(responseMods, "TEXT") - case "image": - responseMods = append(responseMods, "IMAGE") - } - } - if len(responseMods) > 0 { - out, _ = sjson.SetBytes(out, "request.generationConfig.responseModalities", responseMods) - } - } - - // OpenRouter-style image_config support - // If the input uses top-level image_config.aspect_ratio, map it into request.generationConfig.imageConfig.aspectRatio. - if imgCfg := gjson.GetBytes(rawJSON, "image_config"); imgCfg.Exists() && imgCfg.IsObject() { - if ar := imgCfg.Get("aspect_ratio"); ar.Exists() && ar.Type == gjson.String { - out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.aspectRatio", ar.Str) - } - if size := imgCfg.Get("image_size"); size.Exists() && size.Type == gjson.String { - out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.imageSize", size.Str) - } - } - - // messages -> systemInstruction + contents - messages := gjson.GetBytes(rawJSON, "messages") - if messages.IsArray() { - arr := messages.Array() - // First pass: assistant tool_calls id->name map - tcID2Name := map[string]string{} - for i := 0; i < len(arr); i++ { - m := arr[i] - if m.Get("role").String() == "assistant" { - tcs := m.Get("tool_calls") - if tcs.IsArray() { - for _, tc := range tcs.Array() { - if tc.Get("type").String() == "function" { - id := tc.Get("id").String() - name := tc.Get("function.name").String() - if id != "" && name != "" { - tcID2Name[id] = name - } - } - } - } - } - } - - // Second pass build systemInstruction/tool responses cache - toolResponses := map[string]string{} // tool_call_id -> response text - for i := 0; i < len(arr); i++ { - m := arr[i] - role := m.Get("role").String() - if role == "tool" { - toolCallID := m.Get("tool_call_id").String() - if toolCallID != "" { - c := m.Get("content") - toolResponses[toolCallID] = c.Raw - } - } - } - - systemPartIndex := 0 - for i := 0; i < len(arr); i++ { - m := arr[i] - role := m.Get("role").String() - content := m.Get("content") - - if (role == "system" || role == "developer") && len(arr) > 1 { - // system -> request.systemInstruction as a user message style - if content.Type == gjson.String { - out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user") - out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), content.String()) - systemPartIndex++ - } else if content.IsObject() && content.Get("type").String() == "text" { - out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user") - out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String()) - systemPartIndex++ - } else if content.IsArray() { - contents := content.Array() - if len(contents) > 0 { - out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user") - for j := 0; j < len(contents); j++ { - out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String()) - systemPartIndex++ - } - } - } - } else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) { - // Build single user content node to avoid splitting into multiple contents - node := []byte(`{"role":"user","parts":[]}`) - if content.Type == gjson.String { - node, _ = sjson.SetBytes(node, "parts.0.text", content.String()) - } else if content.IsArray() { - items := content.Array() - p := 0 - for _, item := range items { - switch item.Get("type").String() { - case "text": - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String()) - p++ - case "image_url": - imageURL := item.Get("image_url.url").String() - if len(imageURL) > 5 { - pieces := strings.SplitN(imageURL[5:], ";", 2) - if len(pieces) == 2 && len(pieces[1]) > 7 { - mime := pieces[0] - data := pieces[1][7:] - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime) - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data) - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) - p++ - } - } - case "file": - filename := item.Get("file.filename").String() - fileData := item.Get("file.file_data").String() - ext := "" - if sp := strings.Split(filename, "."); len(sp) > 1 { - ext = sp[len(sp)-1] - } - if mimeType, ok := misc.MimeTypes[ext]; ok { - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType) - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData) - p++ - } else { - log.Warnf("Unknown file name extension '%s' in user message, skip", ext) - } - } - } - } - out, _ = sjson.SetRawBytes(out, "request.contents.-1", node) - } else if role == "assistant" { - p := 0 - node := []byte(`{"role":"model","parts":[]}`) - if content.Type == gjson.String { - // Assistant text -> single model content - node, _ = sjson.SetBytes(node, "parts.-1.text", content.String()) - p++ - } else if content.IsArray() { - // Assistant multimodal content (e.g. text + image) -> single model content with parts - for _, item := range content.Array() { - switch item.Get("type").String() { - case "text": - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String()) - p++ - case "image_url": - // If the assistant returned an inline data URL, preserve it for history fidelity. - imageURL := item.Get("image_url.url").String() - if len(imageURL) > 5 { // expect data:... - pieces := strings.SplitN(imageURL[5:], ";", 2) - if len(pieces) == 2 && len(pieces[1]) > 7 { - mime := pieces[0] - data := pieces[1][7:] - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime) - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data) - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) - p++ - } - } - } - } - } - - // Tool calls -> single model content with functionCall parts - tcs := m.Get("tool_calls") - if tcs.IsArray() { - fIDs := make([]string, 0) - for _, tc := range tcs.Array() { - if tc.Get("type").String() != "function" { - continue - } - fid := tc.Get("id").String() - fname := util.SanitizeFunctionName(tc.Get("function.name").String()) - fargs := tc.Get("function.arguments").String() - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname) - node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs)) - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", openAIToolCallGeminiThoughtSignature(tc)) - p++ - if fid != "" { - fIDs = append(fIDs, fid) - } - } - out, _ = sjson.SetRawBytes(out, "request.contents.-1", node) - - // Append a single tool content combining name + response per function - toolNode := []byte(`{"role":"user","parts":[]}`) - pp := 0 - for _, fid := range fIDs { - if name, ok := tcID2Name[fid]; ok { - toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", util.SanitizeFunctionName(name)) - resp := toolResponses[fid] - if resp == "" { - resp = "{}" - } - toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp)) - pp++ - } - } - if pp > 0 { - out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode) - } - } else { - out, _ = sjson.SetRawBytes(out, "request.contents.-1", node) - } - } - } - } - - // tools -> request.tools[].functionDeclarations + request.tools[].googleSearch/codeExecution/urlContext passthrough - tools := gjson.GetBytes(rawJSON, "tools") - if tools.IsArray() && len(tools.Array()) > 0 { - functionToolNode := []byte(`{}`) - hasFunction := false - googleSearchNodes := make([][]byte, 0) - codeExecutionNodes := make([][]byte, 0) - urlContextNodes := make([][]byte, 0) - for _, t := range tools.Array() { - if t.Get("type").String() == "function" { - fn := t.Get("function") - if fn.Exists() && fn.IsObject() { - fnRaw := []byte(fn.Raw) - if fn.Get("parameters").Exists() { - renamed, errRename := util.RenameKey(fn.Raw, "parameters", "parametersJsonSchema") - if errRename != nil { - log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename) - var errSet error - fnRaw, errSet = sjson.SetBytes(fnRaw, "parametersJsonSchema.type", "object") - if errSet != nil { - log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) - continue - } - fnRaw, errSet = sjson.SetRawBytes(fnRaw, "parametersJsonSchema.properties", []byte(`{}`)) - if errSet != nil { - log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) - continue - } - } else { - fnRaw = []byte(renamed) - } - } else { - var errSet error - fnRaw, errSet = sjson.SetBytes(fnRaw, "parametersJsonSchema.type", "object") - if errSet != nil { - log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) - continue - } - fnRaw, errSet = sjson.SetRawBytes(fnRaw, "parametersJsonSchema.properties", []byte(`{}`)) - if errSet != nil { - log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) - continue - } - } - fnRaw, _ = sjson.SetBytes(fnRaw, "name", util.SanitizeFunctionName(fn.Get("name").String())) - fnRaw, _ = sjson.DeleteBytes(fnRaw, "strict") - if !hasFunction { - functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]")) - } - tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", fnRaw) - if errSet != nil { - log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet) - continue - } - functionToolNode = tmp - hasFunction = true - } - } - if gs := t.Get("google_search"); gs.Exists() { - googleToolNode := []byte(`{}`) - var errSet error - googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw)) - if errSet != nil { - log.Warnf("Failed to set googleSearch tool: %v", errSet) - continue - } - googleSearchNodes = append(googleSearchNodes, googleToolNode) - } - if ce := t.Get("code_execution"); ce.Exists() { - codeToolNode := []byte(`{}`) - var errSet error - codeToolNode, errSet = sjson.SetRawBytes(codeToolNode, "codeExecution", []byte(ce.Raw)) - if errSet != nil { - log.Warnf("Failed to set codeExecution tool: %v", errSet) - continue - } - codeExecutionNodes = append(codeExecutionNodes, codeToolNode) - } - if uc := t.Get("url_context"); uc.Exists() { - urlToolNode := []byte(`{}`) - var errSet error - urlToolNode, errSet = sjson.SetRawBytes(urlToolNode, "urlContext", []byte(uc.Raw)) - if errSet != nil { - log.Warnf("Failed to set urlContext tool: %v", errSet) - continue - } - urlContextNodes = append(urlContextNodes, urlToolNode) - } - } - if hasFunction || len(googleSearchNodes) > 0 || len(codeExecutionNodes) > 0 || len(urlContextNodes) > 0 { - toolsNode := []byte("[]") - if hasFunction { - toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode) - } - for _, googleNode := range googleSearchNodes { - toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode) - } - for _, codeNode := range codeExecutionNodes { - toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", codeNode) - } - for _, urlNode := range urlContextNodes { - toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", urlNode) - } - out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode) - } - } - - return common.AttachDefaultSafetySettings(out, "request.safetySettings") -} - -func openAIToolCallGeminiThoughtSignature(toolCall gjson.Result) string { - for _, path := range []string{ - "extra_content.google.thought_signature", - "function.extra_content.google.thought_signature", - "thoughtSignature", - "thought_signature", - } { - if signatureResult := toolCall.Get(path); signatureResult.Exists() { - return sigcompat.GeminiReplaySignatureOrBypass(signatureResult.String(), sigcompat.SignatureBlockKindGeminiFunctionCall) - } - } - return geminiCLIFunctionThoughtSignature -} - -// itoa converts int to string without strconv import for few usages. -func itoa(i int) string { return fmt.Sprintf("%d", i) } diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go deleted file mode 100644 index beba911e5a..0000000000 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go +++ /dev/null @@ -1,246 +0,0 @@ -// Package openai provides response translation functionality for Gemini CLI to OpenAI API compatibility. -// This package handles the conversion of Gemini CLI API responses into OpenAI Chat Completions-compatible -// JSON format, transforming streaming events and non-streaming responses into the format -// expected by OpenAI API clients. It supports both streaming and non-streaming modes, -// handling text content, tool calls, reasoning content, and usage metadata appropriately. -package chat_completions - -import ( - "bytes" - "context" - "fmt" - "strings" - "sync/atomic" - "time" - - . "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini/openai/chat-completions" - "github.com/router-for-me/CLIProxyAPI/v7/internal/util" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// convertCliResponseToOpenAIChatParams holds parameters for response conversion. -type convertCliResponseToOpenAIChatParams struct { - UnixTimestamp int64 - FunctionIndex int - SawToolCall bool - UpstreamFinishReason string - SanitizedNameMap map[string]string -} - -// functionCallIDCounter provides a process-wide unique counter for function call identifiers. -var functionCallIDCounter uint64 - -// ConvertCliResponseToOpenAI translates a single chunk of a streaming response from the -// Gemini CLI API format to the OpenAI Chat Completions streaming format. -// It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses. -// The function handles text content, tool calls, reasoning content, and usage metadata, outputting -// responses that match the OpenAI API format. It supports incremental updates for streaming responses. -// -// Parameters: -// - ctx: The context for the request, used for cancellation and timeout handling -// - modelName: The name of the model being used for the response (unused in current implementation) -// - rawJSON: The raw JSON response from the Gemini CLI API -// - param: A pointer to a parameter object for maintaining state between calls -// -// Returns: -// - [][]byte: A slice of OpenAI-compatible JSON responses -func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { - if *param == nil { - *param = &convertCliResponseToOpenAIChatParams{ - UnixTimestamp: 0, - FunctionIndex: 0, - SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON), - } - } - if (*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap == nil { - (*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON) - } - - if bytes.Equal(rawJSON, []byte("[DONE]")) { - return [][]byte{} - } - - // Initialize the OpenAI SSE template. - template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`) - - // Extract and set the model version. - if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { - template, _ = sjson.SetBytes(template, "model", modelVersionResult.String()) - } - - // Extract and set the creation timestamp. - if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() { - t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) - if err == nil { - (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix() - } - template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp) - } else { - template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp) - } - - // Extract and set the response ID. - if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { - template, _ = sjson.SetBytes(template, "id", responseIDResult.String()) - } - - if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() { - (*param).(*convertCliResponseToOpenAIChatParams).UpstreamFinishReason = strings.ToUpper(finishReasonResult.String()) - } - if stopReasonResult := gjson.GetBytes(rawJSON, "response.stop_reason"); stopReasonResult.Exists() && stopReasonResult.String() != "" { - (*param).(*convertCliResponseToOpenAIChatParams).UpstreamFinishReason = strings.ToUpper(stopReasonResult.String()) - } - - // Extract and set usage metadata (token counts). - if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() { - cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int() - if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { - template, _ = sjson.SetBytes(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) - } - if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { - template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokenCountResult.Int()) - } - promptTokenCount := usageResult.Get("promptTokenCount").Int() - thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() - template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokenCount) - if thoughtsTokenCount > 0 { - template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) - } - // Include cached token count if present (indicates prompt caching is working) - if cachedTokenCount > 0 { - var err error - template, err = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) - if err != nil { - log.Warnf("gemini-cli openai response: failed to set cached_tokens: %v", err) - } - } - } - - // Process the main content part of the response. - partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts") - if partsResult.IsArray() { - partResults := partsResult.Array() - for i := 0; i < len(partResults); i++ { - partResult := partResults[i] - partTextResult := partResult.Get("text") - functionCallResult := partResult.Get("functionCall") - thoughtSignatureResult := partResult.Get("thoughtSignature") - if !thoughtSignatureResult.Exists() { - thoughtSignatureResult = partResult.Get("thought_signature") - } - inlineDataResult := partResult.Get("inlineData") - if !inlineDataResult.Exists() { - inlineDataResult = partResult.Get("inline_data") - } - - hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != "" - hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists() - - // Ignore encrypted thoughtSignature but keep any actual content in the same part. - if hasThoughtSignature && !hasContentPayload { - continue - } - - if partTextResult.Exists() { - textContent := partTextResult.String() - - // Handle text content, distinguishing between regular content and reasoning/thoughts. - if partResult.Get("thought").Bool() { - template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", textContent) - } else { - template, _ = sjson.SetBytes(template, "choices.0.delta.content", textContent) - } - template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant") - } else if functionCallResult.Exists() { - // Handle function call content. - (*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true - toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls") - functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex - (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++ - if toolCallsResult.Exists() && toolCallsResult.IsArray() { - functionCallIndex = len(toolCallsResult.Array()) - } else { - template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`)) - } - - functionCallTemplate := []byte(`{"id":"","index":0,"type":"function","function":{"name":"","arguments":""}}`) - fcName := util.RestoreSanitizedToolName((*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap, functionCallResult.Get("name").String()) - functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))) - functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex) - functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName) - if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { - functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", fcArgsResult.Raw) - } - template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant") - template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallTemplate) - } else if inlineDataResult.Exists() { - data := inlineDataResult.Get("data").String() - if data == "" { - continue - } - mimeType := inlineDataResult.Get("mimeType").String() - if mimeType == "" { - mimeType = inlineDataResult.Get("mime_type").String() - } - if mimeType == "" { - mimeType = "image/png" - } - imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) - imagesResult := gjson.GetBytes(template, "choices.0.delta.images") - if !imagesResult.Exists() || !imagesResult.IsArray() { - template, _ = sjson.SetRawBytes(template, "choices.0.delta.images", []byte(`[]`)) - } - imageIndex := len(gjson.GetBytes(template, "choices.0.delta.images").Array()) - imagePayload := []byte(`{"type":"image_url","image_url":{"url":""}}`) - imagePayload, _ = sjson.SetBytes(imagePayload, "index", imageIndex) - imagePayload, _ = sjson.SetBytes(imagePayload, "image_url.url", imageURL) - template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant") - template, _ = sjson.SetRawBytes(template, "choices.0.delta.images.-1", imagePayload) - } - } - } - - params := (*param).(*convertCliResponseToOpenAIChatParams) - upstreamFinishReason := params.UpstreamFinishReason - sawToolCall := params.SawToolCall - usageExists := gjson.GetBytes(rawJSON, "response.usageMetadata").Exists() - isFinalChunk := upstreamFinishReason != "" && usageExists - - if isFinalChunk { - var finishReason string - if sawToolCall { - finishReason = "tool_calls" - } else if upstreamFinishReason == "MAX_TOKENS" { - finishReason = "max_tokens" - } else { - finishReason = "stop" - } - template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason) - template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason)) - } - - return [][]byte{template} -} - -// ConvertCliResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response. -// This function processes the complete Gemini CLI response and transforms it into a single OpenAI-compatible -// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all -// the information into a single response that matches the OpenAI API format. -// -// Parameters: -// - ctx: The context for the request, used for cancellation and timeout handling -// - modelName: The name of the model being used for the response -// - rawJSON: The raw JSON response from the Gemini CLI API -// - param: A pointer to a parameter object for the conversion -// -// Returns: -// - []byte: An OpenAI-compatible JSON response containing all message content and metadata -func ConvertCliResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte { - responseResult := gjson.GetBytes(rawJSON, "response") - if responseResult.Exists() { - return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param) - } - return []byte{} -} diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response_test.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response_test.go deleted file mode 100644 index fad60e352b..0000000000 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package chat_completions - -import ( - "context" - "testing" - - "github.com/tidwall/gjson" -) - -func TestCliFinishReasonOnlyOnFinalChunk(t *testing.T) { - ctx := context.Background() - var param any - - chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"C:/"}}}]}}],"usageMetadata":{"trafficType":"ON_DEMAND"}}}`) - result1 := ConvertCliResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m) - if len(result1) != 1 { - t.Fatalf("expected 1 result from chunk1, got %d", len(result1)) - } - fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason") - if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" { - t.Fatalf("expected null finish_reason on tool chunk, got %v", fr1.String()) - } - - chunk2 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"D:/"}}}]}}],"usageMetadata":{"trafficType":"ON_DEMAND"}}}`) - ConvertCliResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m) - - chunk3 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}}`) - result3 := ConvertCliResponseToOpenAI(ctx, "model", nil, nil, chunk3, ¶m) - if len(result3) != 1 { - t.Fatalf("expected 1 result from chunk3, got %d", len(result3)) - } - fr3 := gjson.GetBytes(result3[0], "choices.0.finish_reason").String() - if fr3 != "tool_calls" { - t.Fatalf("expected finish_reason tool_calls, got %s", fr3) - } - nfr3 := gjson.GetBytes(result3[0], "choices.0.native_finish_reason").String() - if nfr3 != "stop" { - t.Fatalf("expected native_finish_reason stop, got %s", nfr3) - } -} diff --git a/internal/translator/gemini-cli/openai/chat-completions/init.go b/internal/translator/gemini-cli/openai/chat-completions/init.go deleted file mode 100644 index fcd85f2450..0000000000 --- a/internal/translator/gemini-cli/openai/chat-completions/init.go +++ /dev/null @@ -1,19 +0,0 @@ -package chat_completions - -import ( - . "github.com/router-for-me/CLIProxyAPI/v7/internal/constant" - "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" - "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/translator" -) - -func init() { - translator.Register( - OpenAI, - GeminiCLI, - ConvertOpenAIRequestToGeminiCLI, - interfaces.TranslateResponse{ - Stream: ConvertCliResponseToOpenAI, - NonStream: ConvertCliResponseToOpenAINonStream, - }, - ) -} diff --git a/internal/translator/gemini-cli/openai/responses/gemini-cli_openai-responses_request.go b/internal/translator/gemini-cli/openai/responses/gemini-cli_openai-responses_request.go deleted file mode 100644 index bea4b7a1fe..0000000000 --- a/internal/translator/gemini-cli/openai/responses/gemini-cli_openai-responses_request.go +++ /dev/null @@ -1,12 +0,0 @@ -package responses - -import ( - . "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini-cli/gemini" - . "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini/openai/responses" -) - -func ConvertOpenAIResponsesRequestToGeminiCLI(modelName string, inputRawJSON []byte, stream bool) []byte { - rawJSON := inputRawJSON - rawJSON = ConvertOpenAIResponsesRequestToGemini(modelName, rawJSON, stream) - return ConvertGeminiRequestToGeminiCLI(modelName, rawJSON, stream) -} diff --git a/internal/translator/gemini-cli/openai/responses/gemini-cli_openai-responses_response.go b/internal/translator/gemini-cli/openai/responses/gemini-cli_openai-responses_response.go deleted file mode 100644 index 29db8c19ef..0000000000 --- a/internal/translator/gemini-cli/openai/responses/gemini-cli_openai-responses_response.go +++ /dev/null @@ -1,35 +0,0 @@ -package responses - -import ( - "context" - - . "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini/openai/responses" - "github.com/tidwall/gjson" -) - -func ConvertGeminiCLIResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { - responseResult := gjson.GetBytes(rawJSON, "response") - if responseResult.Exists() { - rawJSON = []byte(responseResult.Raw) - } - return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) -} - -func ConvertGeminiCLIResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte { - responseResult := gjson.GetBytes(rawJSON, "response") - if responseResult.Exists() { - rawJSON = []byte(responseResult.Raw) - } - - requestResult := gjson.GetBytes(originalRequestRawJSON, "request") - if responseResult.Exists() { - originalRequestRawJSON = []byte(requestResult.Raw) - } - - requestResult = gjson.GetBytes(requestRawJSON, "request") - if responseResult.Exists() { - requestRawJSON = []byte(requestResult.Raw) - } - - return ConvertGeminiResponseToOpenAIResponsesNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) -} diff --git a/internal/translator/gemini-cli/openai/responses/init.go b/internal/translator/gemini-cli/openai/responses/init.go deleted file mode 100644 index e1d437715f..0000000000 --- a/internal/translator/gemini-cli/openai/responses/init.go +++ /dev/null @@ -1,19 +0,0 @@ -package responses - -import ( - . "github.com/router-for-me/CLIProxyAPI/v7/internal/constant" - "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" - "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/translator" -) - -func init() { - translator.Register( - OpenaiResponse, - GeminiCLI, - ConvertOpenAIResponsesRequestToGeminiCLI, - interfaces.TranslateResponse{ - Stream: ConvertGeminiCLIResponseToOpenAIResponses, - NonStream: ConvertGeminiCLIResponseToOpenAIResponsesNonStream, - }, - ) -} diff --git a/internal/translator/gemini/claude/gemini_claude_request.go b/internal/translator/gemini/claude/gemini_claude_request.go index 96d04a18e9..e248445a52 100644 --- a/internal/translator/gemini/claude/gemini_claude_request.go +++ b/internal/translator/gemini/claude/gemini_claude_request.go @@ -19,7 +19,7 @@ import ( const geminiClaudeThoughtSignature = "skip_thought_signature_validator" // ConvertClaudeRequestToGemini parses a Claude API request and returns a complete -// Gemini CLI request body (as JSON bytes) ready to be sent via SendRawMessageStream. +// Gemini request body (as JSON bytes) ready to be sent via SendRawMessageStream. // All JSON transformations are performed using gjson/sjson. // // Parameters: @@ -28,10 +28,10 @@ const geminiClaudeThoughtSignature = "skip_thought_signature_validator" // - stream: A boolean indicating if the request is for a streaming response. // // Returns: -// - []byte: The transformed request in Gemini CLI format. +// - []byte: The transformed request in Gemini format. func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte { rawJSON := inputRawJSON - // Build output Gemini CLI request JSON + // Build output Gemini request JSON out := []byte(`{"contents":[]}`) out, _ = sjson.SetBytes(out, "model", modelName) diff --git a/internal/translator/gemini/gemini-cli/gemini_gemini-cli_request.go b/internal/translator/gemini/gemini-cli/gemini_gemini-cli_request.go deleted file mode 100644 index 0d1da6c79a..0000000000 --- a/internal/translator/gemini/gemini-cli/gemini_gemini-cli_request.go +++ /dev/null @@ -1,52 +0,0 @@ -// Package gemini provides request translation functionality for Claude API. -// It handles parsing and transforming Claude API requests into the internal client format, -// extracting model information, system instructions, message contents, and tool declarations. -// The package also performs JSON data cleaning and transformation to ensure compatibility -// between Claude API format and the internal client's expected format. -package geminiCLI - -import ( - "fmt" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/signature" - "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini/common" - "github.com/router-for-me/CLIProxyAPI/v7/internal/util" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// PrepareClaudeRequest parses and transforms a Claude API request into internal client format. -// It extracts the model name, system instruction, message contents, and tool declarations -// from the raw JSON request and returns them in the format expected by the internal client. -func ConvertGeminiCLIRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte { - rawJSON := inputRawJSON - modelResult := gjson.GetBytes(rawJSON, "model") - rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw) - rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String()) - if gjson.GetBytes(rawJSON, "systemInstruction").Exists() { - rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw)) - rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction") - } - - toolsResult := gjson.GetBytes(rawJSON, "tools") - if toolsResult.Exists() && toolsResult.IsArray() { - toolResults := toolsResult.Array() - for i := 0; i < len(toolResults); i++ { - functionDeclarationsResult := gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.function_declarations", i)) - if functionDeclarationsResult.Exists() && functionDeclarationsResult.IsArray() { - functionDeclarationsResults := functionDeclarationsResult.Array() - for j := 0; j < len(functionDeclarationsResults); j++ { - parametersResult := gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.function_declarations.%d.parameters", i, j)) - if parametersResult.Exists() { - strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("tools.%d.function_declarations.%d.parameters", i, j), fmt.Sprintf("tools.%d.function_declarations.%d.parametersJsonSchema", i, j)) - rawJSON = []byte(strJson) - } - } - } - } - } - - rawJSON = signature.SanitizeGeminiRequestThoughtSignatures(rawJSON, "contents") - - return common.AttachDefaultSafetySettings(rawJSON, "safetySettings") -} diff --git a/internal/translator/gemini/gemini-cli/gemini_gemini-cli_response.go b/internal/translator/gemini/gemini-cli/gemini_gemini-cli_response.go deleted file mode 100644 index 36fa0d39b5..0000000000 --- a/internal/translator/gemini/gemini-cli/gemini_gemini-cli_response.go +++ /dev/null @@ -1,60 +0,0 @@ -// Package gemini_cli provides response translation functionality for Gemini API to Gemini CLI API. -// This package handles the conversion of Gemini API responses into Gemini CLI-compatible -// JSON format, transforming streaming events and non-streaming responses into the format -// expected by Gemini CLI API clients. -package geminiCLI - -import ( - "bytes" - "context" - - translatorcommon "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/common" - "github.com/tidwall/sjson" -) - -var dataTag = []byte("data:") - -// ConvertGeminiResponseToGeminiCLI converts Gemini streaming response format to Gemini CLI single-line JSON format. -// This function processes various Gemini event types and transforms them into Gemini CLI-compatible JSON responses. -// It handles thinking content, regular text content, and function calls, outputting single-line JSON -// that matches the Gemini CLI API response format. -// -// Parameters: -// - ctx: The context for the request. -// - modelName: The name of the model. -// - rawJSON: The raw JSON response from the Gemini API. -// - param: A pointer to a parameter object for the conversion (unused). -// -// Returns: -// - [][]byte: A slice of Gemini CLI-compatible JSON responses. -func ConvertGeminiResponseToGeminiCLI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) [][]byte { - if !bytes.HasPrefix(rawJSON, dataTag) { - return [][]byte{} - } - rawJSON = bytes.TrimSpace(rawJSON[5:]) - - if bytes.Equal(rawJSON, []byte("[DONE]")) { - return [][]byte{} - } - rawJSON, _ = sjson.SetRawBytes([]byte(`{"response":{}}`), "response", rawJSON) - return [][]byte{rawJSON} -} - -// ConvertGeminiResponseToGeminiCLINonStream converts a non-streaming Gemini response to a non-streaming Gemini CLI response. -// -// Parameters: -// - ctx: The context for the request. -// - modelName: The name of the model. -// - rawJSON: The raw JSON response from the Gemini API. -// - param: A pointer to a parameter object for the conversion (unused). -// -// Returns: -// - []byte: A Gemini CLI-compatible JSON response. -func ConvertGeminiResponseToGeminiCLINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte { - rawJSON, _ = sjson.SetRawBytes([]byte(`{"response":{}}`), "response", rawJSON) - return rawJSON -} - -func GeminiCLITokenCount(ctx context.Context, count int64) []byte { - return translatorcommon.GeminiTokenCountJSON(count) -} diff --git a/internal/translator/gemini/gemini-cli/init.go b/internal/translator/gemini/gemini-cli/init.go deleted file mode 100644 index ed18b5f0af..0000000000 --- a/internal/translator/gemini/gemini-cli/init.go +++ /dev/null @@ -1,20 +0,0 @@ -package geminiCLI - -import ( - . "github.com/router-for-me/CLIProxyAPI/v7/internal/constant" - "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" - "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/translator" -) - -func init() { - translator.Register( - GeminiCLI, - Gemini, - ConvertGeminiCLIRequestToGemini, - interfaces.TranslateResponse{ - Stream: ConvertGeminiResponseToGeminiCLI, - NonStream: ConvertGeminiResponseToGeminiCLINonStream, - TokenCount: GeminiCLITokenCount, - }, - ) -} diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go index 4a59c6ccd5..28086f5291 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go @@ -196,6 +196,18 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) p++ } } + case "video_url": + videoURL := item.Get("video_url.url").String() + if len(videoURL) > 5 { + pieces := strings.SplitN(videoURL[5:], ";", 2) + if len(pieces) == 2 && len(pieces[1]) > 7 { + mime := pieces[0] + data := pieces[1][7:] + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data) + p++ + } + } case "file": filename := item.Get("file.filename").String() fileData := item.Get("file.file_data").String() @@ -210,6 +222,14 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) } else { log.Warnf("Unknown file name extension '%s' in user message, skip", ext) } + case "input_audio": + audioData := item.Get("input_audio.data").String() + if audioData != "" { + mimeType := openAIInputAudioMimeType(item.Get("input_audio.format").String()) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", audioData) + p++ + } } } } @@ -355,6 +375,9 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) fnRawBytes := []byte(fnRaw) fnRawBytes, _ = sjson.SetBytes(fnRawBytes, "name", util.SanitizeFunctionName(fn.Get("name").String())) fnRaw = string(fnRawBytes) + if parameters := gjson.Get(fnRaw, "parametersJsonSchema"); parameters.Exists() { + fnRaw, _ = sjson.SetRaw(fnRaw, "parametersJsonSchema", util.CleanJSONSchemaForGemini(parameters.Raw)) + } fnRaw, _ = sjson.Delete(fnRaw, "strict") if !hasFunction { functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]")) @@ -438,3 +461,26 @@ func openAIToolCallGeminiThoughtSignature(toolCall gjson.Result) string { // itoa converts int to string without strconv import for few usages. func itoa(i int) string { return fmt.Sprintf("%d", i) } + +func openAIInputAudioMimeType(audioFormat string) string { + switch audioFormat { + case "", "wav": + return "audio/wav" + case "mp3": + return "audio/mpeg" + case "ogg": + return "audio/ogg" + case "flac": + return "audio/flac" + case "aac": + return "audio/aac" + case "webm": + return "audio/webm" + case "pcm16": + return "audio/pcm" + case "g711_ulaw", "g711_alaw": + return "audio/basic" + default: + return "audio/" + audioFormat + } +} diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_request_test.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_request_test.go index f9c0d272c4..ad79869cec 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_request_test.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_request_test.go @@ -26,3 +26,110 @@ func TestConvertOpenAIRequestToGemini_StripsTrailingAssistantPrefill(t *testing. t.Fatalf("final remaining role = %q, want %q", got, "user") } } + +func TestConvertOpenAIRequestToGeminiPreservesInputAudio(t *testing.T) { + inputJSON := `{ + "model": "gpt-5.5", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Transcribe this audio verbatim."}, + {"type": "input_audio", "input_audio": {"data": "SUQzBA==", "format": "mp3"}} + ] + } + ] + }` + + result := ConvertOpenAIRequestToGemini("gemini-3.1-pro-high", []byte(inputJSON), false) + resultJSON := gjson.ParseBytes(result) + parts := resultJSON.Get("contents.0.parts").Array() + + if len(parts) != 2 { + t.Fatalf("parts length = %d, want 2. parts=%s", len(parts), resultJSON.Get("contents.0.parts").Raw) + } + if got := parts[0].Get("text").String(); got != "Transcribe this audio verbatim." { + t.Fatalf("text part = %q, want prompt text", got) + } + if got := parts[1].Get("inlineData.mime_type").String(); got != "audio/mpeg" { + t.Fatalf("audio mime_type = %q, want %q", got, "audio/mpeg") + } + if got := parts[1].Get("inlineData.data").String(); got != "SUQzBA==" { + t.Fatalf("audio data = %q, want %q", got, "SUQzBA==") + } +} + +func TestConvertOpenAIRequestToGeminiPreservesVideoURL(t *testing.T) { + inputJSON := `{ + "model": "gemini-3-flash", + "messages": [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": "data:video/mp4;base64,AAAAIGZ0eXBtcDQy"}}, + {"type": "text", "text": "Describe the video"} + ] + } + ] + }` + + result := ConvertOpenAIRequestToGemini("gemini-3-flash", []byte(inputJSON), false) + resultJSON := gjson.ParseBytes(result) + parts := resultJSON.Get("contents.0.parts").Array() + + if len(parts) != 2 { + t.Fatalf("parts length = %d, want 2. parts=%s", len(parts), resultJSON.Get("contents.0.parts").Raw) + } + if got := parts[0].Get("inlineData.mime_type").String(); got != "video/mp4" { + t.Fatalf("video mime_type = %q, want %q", got, "video/mp4") + } + if got := parts[0].Get("inlineData.data").String(); got != "AAAAIGZ0eXBtcDQy" { + t.Fatalf("video data = %q, want %q", got, "AAAAIGZ0eXBtcDQy") + } + if got := parts[1].Get("text").String(); got != "Describe the video" { + t.Fatalf("text part = %q, want prompt text", got) + } +} + +func TestConvertOpenAIRequestToGeminiCleansToolSchemaRequiredFields(t *testing.T) { + inputJSON := `{ + "model": "gemini-2.0-flash", + "messages": [{"role": "user", "content": "hi"}], + "tools": [{ + "type": "function", + "function": { + "name": "search_company", + "description": "Search", + "parameters": { + "type": "object", + "title": "SearchCompany", + "properties": { + "country": {"type": "string"}, + "industry": {"type": "string"} + }, + "required": ["country", "industry", "stale_field", "another_stale"] + } + } + }] + }` + + output := ConvertOpenAIRequestToGemini("gemini-2.0-flash", []byte(inputJSON), false) + schema := gjson.GetBytes(output, "tools.0.functionDeclarations.0.parametersJsonSchema") + + if !schema.Exists() { + t.Fatalf("parametersJsonSchema missing. Output: %s", output) + } + if schema.Get("title").Exists() { + t.Fatalf("schema title should be removed. Output: %s", output) + } + required := schema.Get("required").Array() + if len(required) != 2 { + t.Fatalf("required length = %d, want 2. Schema: %s", len(required), schema.Raw) + } + if got := required[0].String(); got != "country" { + t.Fatalf("required[0] = %q, want country. Schema: %s", got, schema.Raw) + } + if got := required[1].String(); got != "industry" { + t.Fatalf("required[1] = %q, want industry. Schema: %s", got, schema.Raw) + } +} diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go index 0907519b3b..b9a6efe618 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go @@ -395,7 +395,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte funcDecl, _ = sjson.SetBytes(funcDecl, "description", desc.String()) } if params := tool.Get("parameters"); params.Exists() { - funcDecl, _ = sjson.SetRawBytes(funcDecl, "parametersJsonSchema", []byte(params.Raw)) + funcDecl, _ = sjson.SetRawBytes(funcDecl, "parametersJsonSchema", []byte(util.CleanJSONSchemaForGemini(params.Raw))) } geminiTools, _ = sjson.SetRawBytes(geminiTools, "0.functionDeclarations.-1", funcDecl) @@ -445,6 +445,8 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte out, _ = sjson.SetBytes(out, "generationConfig.stopSequences", sequences) } + out = applyOpenAIResponsesTextFormatToGemini(out, root) + // Apply thinking configuration: convert OpenAI Responses API reasoning.effort to Gemini thinkingConfig. // Inline translation-only mapping; capability checks happen later in ApplyThinking. re := root.Get("reasoning.effort") @@ -470,3 +472,38 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte func openAIResponsesGeminiThoughtSignature(rawSignature string) string { return sigcompat.GeminiReplaySignatureOrBypass(rawSignature, sigcompat.SignatureBlockKindGeminiModelPart) } + +func applyOpenAIResponsesTextFormatToGemini(out []byte, root gjson.Result) []byte { + textFormat := root.Get("text.format") + if !textFormat.Exists() { + return out + } + + formatType := strings.ToLower(strings.TrimSpace(textFormat.Get("type").String())) + switch formatType { + case "json_object": + out = ensureGeminiGenerationConfig(out) + out, _ = sjson.SetBytes(out, "generationConfig.responseMimeType", "application/json") + case "json_schema": + out = ensureGeminiGenerationConfig(out) + out, _ = sjson.SetBytes(out, "generationConfig.responseMimeType", "application/json") + out, _ = sjson.DeleteBytes(out, "generationConfig.responseSchema") + + schema := textFormat.Get("schema") + if !schema.Exists() { + schema = textFormat.Get("json_schema.schema") + } + if schema.Exists() { + out, _ = sjson.SetRawBytes(out, "generationConfig.responseJsonSchema", []byte(schema.Raw)) + } + } + + return out +} + +func ensureGeminiGenerationConfig(out []byte) []byte { + if !gjson.GetBytes(out, "generationConfig").Exists() { + out, _ = sjson.SetRawBytes(out, "generationConfig", []byte(`{}`)) + } + return out +} diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_request_test.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_request_test.go index 071fadc8b0..446ee753be 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_request_test.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_request_test.go @@ -38,6 +38,93 @@ func TestConvertOpenAIResponsesRequestToGemini_StripsTrailingAssistantPrefill(t } } +func TestConvertOpenAIResponsesRequestToGemini_TextFormatJSONSchema(t *testing.T) { + inputJSON := `{ + "model": "gemini-flash-lite", + "temperature": 0.2, + "input": [ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": "Return structured JSON." + } + ] + } + ], + "text": { + "format": { + "type": "json_schema", + "strict": true, + "name": "response", + "schema": { + "type": "object", + "properties": { + "cleanedContent": { + "type": "string" + } + }, + "required": [ + "cleanedContent" + ], + "additionalProperties": false + } + } + } + }` + + output := ConvertOpenAIResponsesRequestToGemini("gemini-3.1-flash-lite", []byte(inputJSON), false) + result := gjson.ParseBytes(output) + genConfig := result.Get("generationConfig") + + if got := genConfig.Get("responseMimeType").String(); got != "application/json" { + t.Fatalf("responseMimeType = %q, want application/json. Output: %s", got, output) + } + schema := genConfig.Get("responseJsonSchema") + if !schema.Exists() { + t.Fatalf("responseJsonSchema missing. Output: %s", output) + } + if genConfig.Get("responseSchema").Exists() { + t.Fatalf("responseSchema should not be set with responseJsonSchema. Output: %s", output) + } + if got := schema.Get("type").String(); got != "object" { + t.Fatalf("schema type = %q, want object. Output: %s", got, output) + } + if got := schema.Get("properties.cleanedContent.type").String(); got != "string" { + t.Fatalf("cleanedContent type = %q, want string. Output: %s", got, output) + } + if additionalProperties := schema.Get("additionalProperties"); !additionalProperties.Exists() || additionalProperties.Bool() { + t.Fatalf("additionalProperties = %s, want false. Output: %s", additionalProperties.Raw, output) + } + if got := genConfig.Get("temperature").Float(); got != 0.2 { + t.Fatalf("temperature = %v, want 0.2. Output: %s", got, output) + } +} + +func TestConvertOpenAIResponsesRequestToGemini_TextFormatJSONObject(t *testing.T) { + inputJSON := `{ + "model": "gemini-flash-lite", + "input": "Return a JSON object.", + "text": { + "format": { + "type": "json_object" + } + } + }` + + output := ConvertOpenAIResponsesRequestToGemini("gemini-3.1-flash-lite", []byte(inputJSON), false) + result := gjson.ParseBytes(output) + genConfig := result.Get("generationConfig") + + if got := genConfig.Get("responseMimeType").String(); got != "application/json" { + t.Fatalf("responseMimeType = %q, want application/json. Output: %s", got, output) + } + if genConfig.Get("responseJsonSchema").Exists() { + t.Fatalf("responseJsonSchema should not be set for json_object. Output: %s", output) + } +} + func TestConvertOpenAIResponsesRequestToGemini_ReasoningSignatureCompatibility(t *testing.T) { tests := []struct { name string @@ -158,6 +245,47 @@ func TestConvertOpenAIResponsesRequestToGemini_SystemAndDeveloperRoles(t *testin } } +func TestConvertOpenAIResponsesRequestToGeminiCleansToolSchemaRequiredFields(t *testing.T) { + inputJSON := `{ + "model": "gemini-2.0-flash", + "input": "hi", + "tools": [{ + "type": "function", + "name": "search_company", + "description": "Search", + "parameters": { + "type": "object", + "title": "SearchCompany", + "properties": { + "country": {"type": "string"}, + "industry": {"type": "string"} + }, + "required": ["country", "industry", "stale_field", "another_stale"] + } + }] + }` + + output := ConvertOpenAIResponsesRequestToGemini("gemini-2.0-flash", []byte(inputJSON), false) + schema := gjson.GetBytes(output, "tools.0.functionDeclarations.0.parametersJsonSchema") + + if !schema.Exists() { + t.Fatalf("parametersJsonSchema missing. Output: %s", output) + } + if schema.Get("title").Exists() { + t.Fatalf("schema title should be removed. Output: %s", output) + } + required := schema.Get("required").Array() + if len(required) != 2 { + t.Fatalf("required length = %d, want 2. Schema: %s", len(required), schema.Raw) + } + if got := required[0].String(); got != "country" { + t.Fatalf("required[0] = %q, want country. Schema: %s", got, schema.Raw) + } + if got := required[1].String(); got != "industry" { + t.Fatalf("required[1] = %q, want industry. Schema: %s", got, schema.Raw) + } +} + func validResponsesGPTReasoningSignature() string { raw := make([]byte, 1+8+16+16+32) raw[0] = 0x80 diff --git a/internal/translator/init.go b/internal/translator/init.go index 5f88a400ec..c0cccc9cdd 100644 --- a/internal/translator/init.go +++ b/internal/translator/init.go @@ -2,30 +2,21 @@ package translator import ( _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/claude/gemini" - _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/claude/gemini-cli" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/claude/openai/chat-completions" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/claude/openai/responses" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/codex/claude" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/codex/gemini" - _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/codex/gemini-cli" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/codex/openai/chat-completions" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/codex/openai/responses" - _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini-cli/claude" - _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini-cli/gemini" - _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini-cli/openai/chat-completions" - _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini-cli/openai/responses" - _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini/claude" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini/gemini" - _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini/gemini-cli" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini/openai/chat-completions" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/gemini/openai/responses" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/openai/claude" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/openai/gemini" - _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/openai/gemini-cli" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/openai/openai/chat-completions" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/openai/openai/responses" diff --git a/internal/translator/openai/gemini-cli/init.go b/internal/translator/openai/gemini-cli/init.go deleted file mode 100644 index 7b52d06dc0..0000000000 --- a/internal/translator/openai/gemini-cli/init.go +++ /dev/null @@ -1,20 +0,0 @@ -package geminiCLI - -import ( - . "github.com/router-for-me/CLIProxyAPI/v7/internal/constant" - "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" - "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/translator" -) - -func init() { - translator.Register( - GeminiCLI, - OpenAI, - ConvertGeminiCLIRequestToOpenAI, - interfaces.TranslateResponse{ - Stream: ConvertOpenAIResponseToGeminiCLI, - NonStream: ConvertOpenAIResponseToGeminiCLINonStream, - TokenCount: GeminiCLITokenCount, - }, - ) -} diff --git a/internal/translator/openai/gemini-cli/openai_gemini_request.go b/internal/translator/openai/gemini-cli/openai_gemini_request.go deleted file mode 100644 index c651826669..0000000000 --- a/internal/translator/openai/gemini-cli/openai_gemini_request.go +++ /dev/null @@ -1,27 +0,0 @@ -// Package geminiCLI provides request translation functionality for Gemini to OpenAI API. -// It handles parsing and transforming Gemini API requests into OpenAI Chat Completions API format, -// extracting model information, generation config, message contents, and tool declarations. -// The package performs JSON data transformation to ensure compatibility -// between Gemini API format and OpenAI API's expected format. -package geminiCLI - -import ( - . "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/openai/gemini" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// ConvertGeminiCLIRequestToOpenAI parses and transforms a Gemini API request into OpenAI Chat Completions API format. -// It extracts the model name, generation config, message contents, and tool declarations -// from the raw JSON request and returns them in the format expected by the OpenAI API. -func ConvertGeminiCLIRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte { - rawJSON := inputRawJSON - rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw) - rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName) - if gjson.GetBytes(rawJSON, "systemInstruction").Exists() { - rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw)) - rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction") - } - - return ConvertGeminiRequestToOpenAI(modelName, rawJSON, stream) -} diff --git a/internal/translator/openai/gemini-cli/openai_gemini_response.go b/internal/translator/openai/gemini-cli/openai_gemini_response.go deleted file mode 100644 index e54e08fc27..0000000000 --- a/internal/translator/openai/gemini-cli/openai_gemini_response.go +++ /dev/null @@ -1,53 +0,0 @@ -// Package geminiCLI provides response translation functionality for OpenAI to Gemini API. -// This package handles the conversion of OpenAI Chat Completions API responses into Gemini API-compatible -// JSON format, transforming streaming events and non-streaming responses into the format -// expected by Gemini API clients. It supports both streaming and non-streaming modes, -// handling text content, tool calls, and usage metadata appropriately. -package geminiCLI - -import ( - "context" - - translatorcommon "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/common" - . "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/openai/gemini" -) - -// ConvertOpenAIResponseToGeminiCLI converts OpenAI Chat Completions streaming response format to Gemini API format. -// This function processes OpenAI streaming chunks and transforms them into Gemini-compatible JSON responses. -// It handles text content, tool calls, and usage metadata, outputting responses that match the Gemini API format. -// -// Parameters: -// - ctx: The context for the request. -// - modelName: The name of the model. -// - rawJSON: The raw JSON response from the OpenAI API. -// - param: A pointer to a parameter object for the conversion. -// -// Returns: -// - [][]byte: A slice of Gemini-compatible JSON responses. -func ConvertOpenAIResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { - outputs := ConvertOpenAIResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) - newOutputs := make([][]byte, 0, len(outputs)) - for i := 0; i < len(outputs); i++ { - newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i])) - } - return newOutputs -} - -// ConvertOpenAIResponseToGeminiCLINonStream converts a non-streaming OpenAI response to a non-streaming Gemini CLI response. -// -// Parameters: -// - ctx: The context for the request. -// - modelName: The name of the model. -// - rawJSON: The raw JSON response from the OpenAI API. -// - param: A pointer to a parameter object for the conversion. -// -// Returns: -// - []byte: A Gemini-compatible JSON response. -func ConvertOpenAIResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte { - out := ConvertOpenAIResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) - return translatorcommon.WrapGeminiCLIResponse(out) -} - -func GeminiCLITokenCount(ctx context.Context, count int64) []byte { - return translatorcommon.GeminiTokenCountJSON(count) -} diff --git a/internal/translator/openai/gemini/openai_gemini_request.go b/internal/translator/openai/gemini/openai_gemini_request.go index 7369de88df..53773806d0 100644 --- a/internal/translator/openai/gemini/openai_gemini_request.go +++ b/internal/translator/openai/gemini/openai_gemini_request.go @@ -113,6 +113,7 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream // Process contents (Gemini messages) -> OpenAI messages var toolCallIDs []string // Track tool call IDs for matching with tool results + toolCallConsumeIdx := 0 // System instruction -> OpenAI system message // Gemini may provide `systemInstruction` or `system_instruction`; support both keys. @@ -241,12 +242,9 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream } } - // Try to match with previous tool call ID - _ = functionResponse.Get("name").String() // functionName not used for now - if len(toolCallIDs) > 0 { - // Use the last tool call ID (simple matching by function name) - // In a real implementation, you might want more sophisticated matching - toolMsg, _ = sjson.SetBytes(toolMsg, "tool_call_id", toolCallIDs[len(toolCallIDs)-1]) + if toolCallConsumeIdx < len(toolCallIDs) { + toolMsg, _ = sjson.SetBytes(toolMsg, "tool_call_id", toolCallIDs[toolCallConsumeIdx]) + toolCallConsumeIdx++ } else { // Generate a tool call ID if none available toolMsg, _ = sjson.SetBytes(toolMsg, "tool_call_id", genToolCallID()) diff --git a/internal/translator/openai/gemini/openai_gemini_request_test.go b/internal/translator/openai/gemini/openai_gemini_request_test.go new file mode 100644 index 0000000000..7bfbaad54e --- /dev/null +++ b/internal/translator/openai/gemini/openai_gemini_request_test.go @@ -0,0 +1,106 @@ +package gemini + +import ( + "strings" + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertGeminiRequestToOpenAI_FunctionResponsesConsumeToolCallIDsFIFO(t *testing.T) { + inputJSON := []byte(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "read_file", "args": {"path": "a.txt"}}}, + {"functionCall": {"name": "grep", "args": {"pattern": "needle"}}}, + {"functionCall": {"name": "list_dir", "args": {"path": "."}}} + ] + }, + { + "role": "function", + "parts": [ + {"functionResponse": {"name": "read_file", "response": {"result": "a"}}}, + {"functionResponse": {"name": "grep", "response": {"result": "b"}}}, + {"functionResponse": {"name": "list_dir", "response": {"result": "c"}}} + ] + } + ] + }`) + + out := ConvertGeminiRequestToOpenAI("test-model", inputJSON, false) + firstID := gjson.GetBytes(out, "messages.0.tool_calls.0.id").String() + secondID := gjson.GetBytes(out, "messages.0.tool_calls.1.id").String() + thirdID := gjson.GetBytes(out, "messages.0.tool_calls.2.id").String() + + if firstID == "" || secondID == "" || thirdID == "" { + t.Fatalf("expected all assistant tool call IDs to be set. Output: %s", string(out)) + } + if firstID == secondID || secondID == thirdID || firstID == thirdID { + t.Fatalf("expected distinct assistant tool call IDs, got %q, %q, %q", firstID, secondID, thirdID) + } + if got := gjson.GetBytes(out, "messages.1.tool_call_id").String(); got != firstID { + t.Fatalf("messages.1.tool_call_id = %q, want %q. Output: %s", got, firstID, string(out)) + } + if got := gjson.GetBytes(out, "messages.2.tool_call_id").String(); got != secondID { + t.Fatalf("messages.2.tool_call_id = %q, want %q. Output: %s", got, secondID, string(out)) + } + if got := gjson.GetBytes(out, "messages.3.tool_call_id").String(); got != thirdID { + t.Fatalf("messages.3.tool_call_id = %q, want %q. Output: %s", got, thirdID, string(out)) + } +} + +func TestConvertGeminiRequestToOpenAI_FunctionResponseWithoutPriorCallGetsFallbackID(t *testing.T) { + inputJSON := []byte(`{ + "contents": [ + { + "role": "function", + "parts": [ + {"functionResponse": {"name": "read_file", "response": {"result": "ok"}}} + ] + } + ] + }`) + + out := ConvertGeminiRequestToOpenAI("test-model", inputJSON, false) + toolCallID := gjson.GetBytes(out, "messages.0.tool_call_id").String() + if !strings.HasPrefix(toolCallID, "call_") { + t.Fatalf("fallback tool_call_id = %q, want call_ prefix. Output: %s", toolCallID, string(out)) + } +} + +func TestConvertGeminiRequestToOpenAI_ExtraFunctionResponsesUseFallbackID(t *testing.T) { + inputJSON := []byte(`{ + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "read_file", "args": {"path": "a.txt"}}} + ] + }, + { + "role": "function", + "parts": [ + {"functionResponse": {"name": "read_file", "response": {"result": "a"}}}, + {"functionResponse": {"name": "read_file", "response": {"result": "extra"}}} + ] + } + ] + }`) + + out := ConvertGeminiRequestToOpenAI("test-model", inputJSON, false) + callID := gjson.GetBytes(out, "messages.0.tool_calls.0.id").String() + firstResponseID := gjson.GetBytes(out, "messages.1.tool_call_id").String() + extraResponseID := gjson.GetBytes(out, "messages.2.tool_call_id").String() + + if firstResponseID != callID { + t.Fatalf("messages.1.tool_call_id = %q, want %q. Output: %s", firstResponseID, callID, string(out)) + } + if !strings.HasPrefix(extraResponseID, "call_") { + t.Fatalf("extra response fallback tool_call_id = %q, want call_ prefix. Output: %s", extraResponseID, string(out)) + } + if extraResponseID == callID { + t.Fatalf("extra response reused consumed tool_call_id %q. Output: %s", extraResponseID, string(out)) + } +} diff --git a/internal/translator/openai/openai/chat-completions/openai_openai_request.go b/internal/translator/openai/openai/chat-completions/openai_openai_request.go index a74cded6c7..f2e6fadc80 100644 --- a/internal/translator/openai/openai/chat-completions/openai_openai_request.go +++ b/internal/translator/openai/openai/chat-completions/openai_openai_request.go @@ -1,5 +1,5 @@ -// Package openai provides request translation functionality for OpenAI to Gemini CLI API compatibility. -// It converts OpenAI Chat Completions requests into Gemini CLI compatible JSON using gjson/sjson only. +// Package openai provides request translation functionality for OpenAI to OpenAI API compatibility. +// It converts OpenAI Chat Completions requests into OpenAI-compatible JSON using gjson/sjson only. package chat_completions import ( @@ -7,7 +7,7 @@ import ( ) // ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON) -// into a complete Gemini CLI request JSON. All JSON construction uses sjson and lookups use gjson. +// into a complete OpenAI request JSON. All JSON construction uses sjson and lookups use gjson. // // Parameters: // - modelName: The name of the model to use for the request @@ -15,7 +15,7 @@ import ( // - stream: A boolean indicating if the request is for a streaming response (unused in current implementation) // // Returns: -// - []byte: The transformed request data in Gemini CLI API format +// - []byte: The transformed request data in OpenAI API format func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte { // Update the "model" field in the JSON payload with the provided modelName // The sjson.SetBytes function returns a new byte slice with the updated JSON. diff --git a/internal/translator/openai/openai/chat-completions/openai_openai_response.go b/internal/translator/openai/openai/chat-completions/openai_openai_response.go index 9320a3ded4..0ecc96bffd 100644 --- a/internal/translator/openai/openai/chat-completions/openai_openai_response.go +++ b/internal/translator/openai/openai/chat-completions/openai_openai_response.go @@ -14,7 +14,7 @@ import ( // Parameters: // - ctx: The context for the request, used for cancellation and timeout handling // - modelName: The name of the model being used for the response (unused in current implementation) -// - rawJSON: The raw JSON response from the Gemini CLI API +// - rawJSON: The raw JSON response from the OpenAI API // - param: A pointer to a parameter object for maintaining state between calls // // Returns: @@ -34,7 +34,7 @@ func ConvertOpenAIResponseToOpenAI(_ context.Context, _ string, originalRequestR // Parameters: // - ctx: The context for the request, used for cancellation and timeout handling // - modelName: The name of the model being used for the response -// - rawJSON: The raw JSON response from the Gemini CLI API +// - rawJSON: The raw JSON response from the OpenAI API // - param: A pointer to a parameter object for the conversion // // Returns: diff --git a/internal/translator/openai/openai/responses/openai_openai-responses_request.go b/internal/translator/openai/openai/responses/openai_openai-responses_request.go index 1cd77c4d0e..c071076df2 100644 --- a/internal/translator/openai/openai/responses/openai_openai-responses_request.go +++ b/internal/translator/openai/openai/responses/openai_openai-responses_request.go @@ -72,15 +72,24 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu pendingToolCalls := make([]interface{}, 0) pendingToolCallIDs := make([]string, 0) + pendingReasoningContent := "" awaitingToolOutputs := make(map[string]struct{}) deferredMessages := make([][]byte, 0) + takePendingReasoningContent := func() string { + reasoningContent := pendingReasoningContent + pendingReasoningContent = "" + return reasoningContent + } flushPendingToolCalls := func() { if len(pendingToolCalls) == 0 { return } assistantMessage := []byte(`{"role":"assistant","tool_calls":[]}`) assistantMessage, _ = sjson.SetBytes(assistantMessage, "tool_calls", pendingToolCalls) + if reasoningContent := takePendingReasoningContent(); reasoningContent != "" { + assistantMessage, _ = sjson.SetBytes(assistantMessage, "reasoning_content", reasoningContent) + } out, _ = sjson.SetRawBytes(out, "messages.-1", assistantMessage) for _, id := range pendingToolCallIDs { if strings.TrimSpace(id) == "" { @@ -114,6 +123,15 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu } out, _ = sjson.SetRawBytes(out, "messages.-1", message) } + appendPendingReasoningMessage := func() { + reasoningContent := takePendingReasoningContent() + if reasoningContent == "" { + return + } + message := []byte(`{"role":"assistant","content":"","reasoning_content":""}`) + message, _ = sjson.SetBytes(message, "reasoning_content", reasoningContent) + appendRegularMessage(message) + } for _, item := range inputItems { itemType := item.Get("type").String() @@ -131,6 +149,9 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu if role == "developer" { role = "user" } + if role != "assistant" { + appendPendingReasoningMessage() + } message := []byte(`{"role":"","content":[]}`) message, _ = sjson.SetBytes(message, "role", role) @@ -154,6 +175,9 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu imageURL := contentItem.Get("image_url").String() contentPart := []byte(`{"type":"image_url","image_url":{"url":""}}`) contentPart, _ = sjson.SetBytes(contentPart, "image_url.url", imageURL) + if detail := contentItem.Get("detail"); detail.Exists() { + contentPart, _ = sjson.SetBytes(contentPart, "image_url.detail", detail.String()) + } message, _ = sjson.SetRawBytes(message, "content.-1", contentPart) } return true @@ -170,8 +194,28 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu message, _ = sjson.SetBytes(message, "content", content.String()) } + if role == "assistant" { + reasoningContent := item.Get("reasoning_content").String() + if reasoningContent == "" { + reasoningContent = takePendingReasoningContent() + } else { + pendingReasoningContent = "" + } + if reasoningContent != "" { + message, _ = sjson.SetBytes(message, "reasoning_content", reasoningContent) + } + } + appendRegularMessage(message) + case "reasoning": + reasoningContent := collectOpenAIResponsesReasoningContent(item) + if pendingReasoningContent == "" { + pendingReasoningContent = reasoningContent + } else { + pendingReasoningContent += reasoningContent + } + case "function_call": // Buffer consecutive function calls and emit them as one assistant message. toolCall := []byte(`{"id":"","type":"function","function":{"name":"","arguments":""}}`) @@ -217,6 +261,7 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu } flushPendingToolCalls() + appendPendingReasoningMessage() flushDeferredMessages() } else if input.Type == gjson.String { msg := []byte(`{}`) @@ -250,8 +295,25 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu // Convert tool_choice if present if toolChoice := root.Get("tool_choice"); toolChoice.Exists() { - out, _ = sjson.SetBytes(out, "tool_choice", toolChoice.String()) + out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(toolChoice.Raw)) } return out } + +func collectOpenAIResponsesReasoningContent(item gjson.Result) string { + var reasoningText strings.Builder + if summary := item.Get("summary"); summary.Exists() && summary.IsArray() { + summary.ForEach(func(_, summaryItem gjson.Result) bool { + if summaryItem.Get("type").String() != "summary_text" { + return true + } + reasoningText.WriteString(summaryItem.Get("text").String()) + return true + }) + } + if reasoningText.Len() == 0 { + return "[reasoning unavailable]" + } + return reasoningText.String() +} diff --git a/internal/translator/openai/openai/responses/openai_openai-responses_request_test.go b/internal/translator/openai/openai/responses/openai_openai-responses_request_test.go index c35f0e2be7..26a7fc0d3e 100644 --- a/internal/translator/openai/openai/responses/openai_openai-responses_request_test.go +++ b/internal/translator/openai/openai/responses/openai_openai-responses_request_test.go @@ -123,6 +123,107 @@ func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_DefersMessageUntil } } +func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_AttachesReasoningToAssistantMessage(t *testing.T) { + raw := []byte(`{ + "input": [ + { + "type": "reasoning", + "id": "rs_1", + "summary": [ + {"type": "summary_text", "text": "first line\n"}, + {"type": "summary_text", "text": "second line"} + ] + }, + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "answer"}] + }, + {"type": "message", "role": "user", "content": "next"} + ] + }`) + t.Logf("input json:\n%s", prettyJSONForTest(raw)) + + out := ConvertOpenAIResponsesRequestToOpenAIChatCompletions("deepseek-v4-flash", raw, false) + t.Logf("output json:\n%s", prettyJSONForTest(out)) + + if got := gjson.GetBytes(out, "messages.#").Int(); got != 2 { + t.Fatalf("messages count = %d, want 2; output=%s", got, out) + } + if got := gjson.GetBytes(out, "messages.0.role").String(); got != "assistant" { + t.Fatalf("messages.0.role = %q, want assistant; output=%s", got, out) + } + if got := gjson.GetBytes(out, "messages.0.reasoning_content").String(); got != "first line\nsecond line" { + t.Fatalf("messages.0.reasoning_content = %q, want %q; output=%s", got, "first line\nsecond line", out) + } + if got := gjson.GetBytes(out, "messages.0.content.0.text").String(); got != "answer" { + t.Fatalf("messages.0.content.0.text = %q, want answer; output=%s", got, out) + } + if got := gjson.GetBytes(out, "messages.1.role").String(); got != "user" { + t.Fatalf("messages.1.role = %q, want user; output=%s", got, out) + } +} + +func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_AttachesReasoningToToolCallMessage(t *testing.T) { + raw := []byte(`{ + "input": [ + { + "type": "reasoning", + "id": "rs_tool", + "summary": [{"type": "summary_text", "text": "tool reasoning"}] + }, + {"type":"function_call","call_id":"call_1","name":"exec_command","arguments":"{\"cmd\":\"pwd\"}"}, + {"type":"function_call_output","call_id":"call_1","output":"ok"} + ] + }`) + t.Logf("input json:\n%s", prettyJSONForTest(raw)) + + out := ConvertOpenAIResponsesRequestToOpenAIChatCompletions("deepseek-v4-flash", raw, true) + t.Logf("output json:\n%s", prettyJSONForTest(out)) + + if got := gjson.GetBytes(out, "messages.#").Int(); got != 2 { + t.Fatalf("messages count = %d, want 2; output=%s", got, out) + } + if got := gjson.GetBytes(out, "messages.0.role").String(); got != "assistant" { + t.Fatalf("messages.0.role = %q, want assistant; output=%s", got, out) + } + if got := gjson.GetBytes(out, "messages.0.reasoning_content").String(); got != "tool reasoning" { + t.Fatalf("messages.0.reasoning_content = %q, want tool reasoning; output=%s", got, out) + } + if got := gjson.GetBytes(out, "messages.0.tool_calls.0.id").String(); got != "call_1" { + t.Fatalf("messages.0.tool_calls.0.id = %q, want call_1; output=%s", got, out) + } + if got := gjson.GetBytes(out, "messages.1.role").String(); got != "tool" { + t.Fatalf("messages.1.role = %q, want tool; output=%s", got, out) + } +} + +func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_KeepsReasoningBeforeUserMessage(t *testing.T) { + raw := []byte(`{ + "input": [ + {"type": "reasoning", "id": "rs_empty", "summary": []}, + {"type": "message", "role": "user", "content": "continue"} + ] + }`) + t.Logf("input json:\n%s", prettyJSONForTest(raw)) + + out := ConvertOpenAIResponsesRequestToOpenAIChatCompletions("deepseek-v4-flash", raw, false) + t.Logf("output json:\n%s", prettyJSONForTest(out)) + + if got := gjson.GetBytes(out, "messages.#").Int(); got != 2 { + t.Fatalf("messages count = %d, want 2; output=%s", got, out) + } + if got := gjson.GetBytes(out, "messages.0.role").String(); got != "assistant" { + t.Fatalf("messages.0.role = %q, want assistant; output=%s", got, out) + } + if got := gjson.GetBytes(out, "messages.0.reasoning_content").String(); got != "[reasoning unavailable]" { + t.Fatalf("messages.0.reasoning_content = %q, want placeholder; output=%s", got, out) + } + if got := gjson.GetBytes(out, "messages.1.role").String(); got != "user" { + t.Fatalf("messages.1.role = %q, want user; output=%s", got, out) + } +} + func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_FlattensNamespaceTools(t *testing.T) { raw := []byte(`{ "input": [ @@ -173,3 +274,56 @@ func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_FlattensNamespaceT t.Fatalf("tools.0.function.parameters.required.0 = %q, want a; output=%s", got, out) } } + +func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_PreservesStructuredToolChoice(t *testing.T) { + raw := []byte(`{ + "input": [ + {"role":"user","content":"Run command."} + ], + "tool_choice": { + "type": "function", + "function": { + "name": "run_command" + } + } + }`) + t.Logf("input json:\n%s", prettyJSONForTest(raw)) + + out := ConvertOpenAIResponsesRequestToOpenAIChatCompletions("gpt-5.4", raw, false) + t.Logf("output json:\n%s", prettyJSONForTest(out)) + + if got := gjson.GetBytes(out, "tool_choice.type").String(); got != "function" { + t.Fatalf("tool_choice.type = %q, want function; output=%s", got, out) + } + if got := gjson.GetBytes(out, "tool_choice.function.name").String(); got != "run_command" { + t.Fatalf("tool_choice.function.name = %q, want run_command; output=%s", got, out) + } +} + +func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_PreservesInputImageDetail(t *testing.T) { + raw := []byte(`{ + "input": [ + { + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": "https://example.com/image.png", + "detail": "high" + } + ] + } + ] + }`) + t.Logf("input json:\n%s", prettyJSONForTest(raw)) + + out := ConvertOpenAIResponsesRequestToOpenAIChatCompletions("gpt-5.4", raw, false) + t.Logf("output json:\n%s", prettyJSONForTest(out)) + + if got := gjson.GetBytes(out, "messages.0.content.0.image_url.url").String(); got != "https://example.com/image.png" { + t.Fatalf("messages.0.content.0.image_url.url = %q, want https://example.com/image.png; output=%s", got, out) + } + if got := gjson.GetBytes(out, "messages.0.content.0.image_url.detail").String(); got != "high" { + t.Fatalf("messages.0.content.0.image_url.detail = %q, want high; output=%s", got, out) + } +} diff --git a/internal/tui/oauth_tab.go b/internal/tui/oauth_tab.go index bd3aac3f68..1cfe1a1a6b 100644 --- a/internal/tui/oauth_tab.go +++ b/internal/tui/oauth_tab.go @@ -19,7 +19,6 @@ type oauthProvider struct { } var oauthProviders = []oauthProvider{ - {"Gemini CLI", "gemini-cli-auth-url", "🟦"}, {"Claude (Anthropic)", "anthropic-auth-url", "🟧"}, {"Codex (OpenAI)", "codex-auth-url", "🟩"}, {"Antigravity", "antigravity-auth-url", "🟪"}, @@ -271,8 +270,6 @@ func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd { if p.name == m.providerName { // Map provider name to the canonical key the API expects switch p.apiPath { - case "gemini-cli-auth-url": - providerKey = "gemini" case "anthropic-auth-url": providerKey = "anthropic" case "codex-auth-url": diff --git a/internal/util/gemini_schema.go b/internal/util/gemini_schema.go index 4cc946d5f3..010669a811 100644 --- a/internal/util/gemini_schema.go +++ b/internal/util/gemini_schema.go @@ -440,7 +440,7 @@ func removeUnsupportedKeywords(jsonStr string) string { keywords := append(unsupportedConstraints, "$schema", "$defs", "definitions", "const", "$ref", "$id", "additionalProperties", "propertyNames", "patternProperties", // Gemini doesn't support these schema keywords - "enumTitles", "prefill", "deprecated", // Schema metadata fields unsupported by Gemini + "$comment", "enumDescriptions", "enumTitles", "prefill", "deprecated", // Schema metadata fields unsupported by Gemini ) deletePaths := make([]string, 0) diff --git a/internal/util/gemini_schema_test.go b/internal/util/gemini_schema_test.go index 92bce013f6..bb581cdcd3 100644 --- a/internal/util/gemini_schema_test.go +++ b/internal/util/gemini_schema_test.go @@ -874,15 +874,18 @@ func TestCleanJSONSchemaForGemini_RemovesGeminiUnsupportedMetadataFields(t *test input := `{ "$schema": "http://json-schema.org/draft-07/schema#", "$id": "root-schema", + "$comment": "root comment should be removed", "type": "object", "properties": { "payload": { "type": "object", + "$comment": "nested comment should be removed", "prefill": "hello", "properties": { "mode": { "type": "string", "enum": ["a", "b"], + "enumDescriptions": ["Alpha", "Beta"], "enumTitles": ["A", "B"] } }, @@ -893,6 +896,14 @@ func TestCleanJSONSchemaForGemini_RemovesGeminiUnsupportedMetadataFields(t *test "$id": { "type": "string", "description": "property name should not be removed" + }, + "$comment": { + "type": "string", + "description": "property name should not be removed" + }, + "enumDescriptions": { + "type": "array", + "description": "property name should not be removed" } } }` @@ -913,6 +924,14 @@ func TestCleanJSONSchemaForGemini_RemovesGeminiUnsupportedMetadataFields(t *test "$id": { "type": "string", "description": "property name should not be removed" + }, + "$comment": { + "type": "string", + "description": "property name should not be removed" + }, + "enumDescriptions": { + "type": "array", + "description": "property name should not be removed" } } }` diff --git a/internal/util/provider.go b/internal/util/provider.go index 6313f58e32..ae25a63148 100644 --- a/internal/util/provider.go +++ b/internal/util/provider.go @@ -12,6 +12,20 @@ import ( log "github.com/sirupsen/logrus" ) +const openAICompatibleProviderPrefix = "openai-compatible-" + +// OpenAICompatibleProviderKey returns the internal provider key for an OpenAI-compatible provider. +func OpenAICompatibleProviderKey(name string) string { + name = strings.ToLower(strings.TrimSpace(name)) + if name == "" || name == "openai-compatibility" || strings.HasPrefix(name, openAICompatibleProviderPrefix) { + if name == "" { + return "openai-compatibility" + } + return name + } + return openAICompatibleProviderPrefix + name +} + // GetProviderName determines all AI service providers capable of serving a registered model. // It first queries the global model registry to retrieve the providers backing the supplied model name. // When the model has not been registered yet, it falls back to legacy string heuristics to infer diff --git a/internal/watcher/diff/config_diff.go b/internal/watcher/diff/config_diff.go index 903526951b..80cc44ddc5 100644 --- a/internal/watcher/diff/config_diff.go +++ b/internal/watcher/diff/config_diff.go @@ -185,6 +185,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if oldExcluded.hash != newExcluded.hash { changes = append(changes, fmt.Sprintf("claude[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count)) } + if o.RebuildMidSystemMessage != n.RebuildMidSystemMessage { + changes = append(changes, fmt.Sprintf("claude[%d].rebuild-mid-system-message: %t -> %t", i, o.RebuildMidSystemMessage, n.RebuildMidSystemMessage)) + } if o.Cloak != nil && n.Cloak != nil { if strings.TrimSpace(o.Cloak.Mode) != strings.TrimSpace(n.Cloak.Mode) { changes = append(changes, fmt.Sprintf("claude[%d].cloak.mode: %s -> %s", i, o.Cloak.Mode, n.Cloak.Mode)) diff --git a/internal/watcher/synthesizer/config.go b/internal/watcher/synthesizer/config.go index 1eea3dc112..82a75cf789 100644 --- a/internal/watcher/synthesizer/config.go +++ b/internal/watcher/synthesizer/config.go @@ -5,6 +5,7 @@ import ( "strconv" "strings" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" "github.com/router-for-me/CLIProxyAPI/v7/internal/watcher/diff" coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" ) @@ -125,6 +126,9 @@ func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*corea if base != "" { attrs["base_url"] = base } + if ck.RebuildMidSystemMessage { + attrs["rebuild_mid_system_message"] = "true" + } if hash := diff.ComputeClaudeModelsHash(ck.Models); hash != "" { attrs["models_hash"] = hash } @@ -226,6 +230,7 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor if providerName == "" { providerName = "openai-compatibility" } + internalProviderKey := util.OpenAICompatibleProviderKey(providerName) base := strings.TrimSpace(compat.BaseURL) disableCooling := compat.DisableCooling @@ -241,7 +246,7 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor "source": fmt.Sprintf("config:%s[%s]", providerName, token), "base_url": base, "compat_name": compat.Name, - "provider_key": providerName, + "provider_key": internalProviderKey, } metadata := map[string]any{} if disableCooling { @@ -259,7 +264,7 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor addConfigHeadersToAttrs(compat.Headers, attrs) a := &coreauth.Auth{ ID: id, - Provider: providerName, + Provider: internalProviderKey, Label: compat.Name, Prefix: prefix, Status: coreauth.StatusActive, @@ -283,7 +288,7 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor "source": fmt.Sprintf("config:%s[%s]", providerName, token), "base_url": base, "compat_name": compat.Name, - "provider_key": providerName, + "provider_key": internalProviderKey, } metadata := map[string]any{} if disableCooling { @@ -298,7 +303,7 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor addConfigHeadersToAttrs(compat.Headers, attrs) a := &coreauth.Auth{ ID: id, - Provider: providerName, + Provider: internalProviderKey, Label: compat.Name, Prefix: prefix, Status: coreauth.StatusActive, diff --git a/internal/watcher/synthesizer/config_test.go b/internal/watcher/synthesizer/config_test.go index c8526a654a..5646ef871e 100644 --- a/internal/watcher/synthesizer/config_test.go +++ b/internal/watcher/synthesizer/config_test.go @@ -175,10 +175,11 @@ func TestConfigSynthesizer_ClaudeKeys(t *testing.T) { Config: &config.Config{ ClaudeKey: []config.ClaudeKey{ { - APIKey: "sk-ant-api-xxx", - Prefix: "main", - BaseURL: "https://api.anthropic.com", - DisableCooling: true, + APIKey: "sk-ant-api-xxx", + Prefix: "main", + BaseURL: "https://api.anthropic.com", + DisableCooling: true, + RebuildMidSystemMessage: true, Models: []config.ClaudeModel{ {Name: "claude-3-opus"}, {Name: "claude-3-sonnet"}, @@ -213,6 +214,9 @@ func TestConfigSynthesizer_ClaudeKeys(t *testing.T) { if _, ok := auths[0].Attributes["models_hash"]; !ok { t.Error("expected models_hash in attributes") } + if got := auths[0].Attributes["rebuild_mid_system_message"]; got != "true" { + t.Errorf("expected rebuild_mid_system_message=true, got %s", got) + } if v, ok := auths[0].Metadata["disable_cooling"].(bool); !ok || !v { t.Errorf("expected disable_cooling=true, got %v", auths[0].Metadata["disable_cooling"]) } @@ -400,6 +404,43 @@ func TestConfigSynthesizer_OpenAICompat(t *testing.T) { } } +func TestConfigSynthesizer_OpenAICompat_UsesNamespacedProviderKey(t *testing.T) { + synth := NewConfigSynthesizer() + ctx := &SynthesisContext{ + Config: &config.Config{ + OpenAICompatibility: []config.OpenAICompatibility{ + { + Name: "kimi", + BaseURL: "https://kimi-compatible.example.com/v1", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{ + {APIKey: "test-key"}, + }, + }, + }, + }, + Now: time.Now(), + IDGenerator: NewStableIDGenerator(), + } + + auths, err := synth.Synthesize(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(auths) != 1 { + t.Fatalf("expected 1 auth, got %d", len(auths)) + } + auth := auths[0] + if auth.Provider != "openai-compatible-kimi" { + t.Fatalf("provider = %q, want openai-compatible-kimi", auth.Provider) + } + if auth.Attributes["provider_key"] != "openai-compatible-kimi" { + t.Fatalf("provider_key = %q, want openai-compatible-kimi", auth.Attributes["provider_key"]) + } + if auth.Attributes["compat_name"] != "kimi" { + t.Fatalf("compat_name = %q, want kimi", auth.Attributes["compat_name"]) + } +} + func TestConfigSynthesizer_VertexCompat(t *testing.T) { synth := NewConfigSynthesizer() ctx := &SynthesisContext{ @@ -639,7 +680,7 @@ func TestConfigSynthesizer_AllProviders(t *testing.T) { providers[a.Provider] = true } - expected := []string{"gemini", "claude", "codex", "compat", "vertex"} + expected := []string{"gemini", "claude", "codex", "openai-compatible-compat", "vertex"} for _, p := range expected { if !providers[p] { t.Errorf("expected provider %s not found", p) diff --git a/internal/watcher/synthesizer/context.go b/internal/watcher/synthesizer/context.go index 4572f8bb8f..dce219c47c 100644 --- a/internal/watcher/synthesizer/context.go +++ b/internal/watcher/synthesizer/context.go @@ -14,6 +14,12 @@ type PluginAuthParser interface { ParseAuth(context.Context, pluginapi.AuthParseRequest) (*coreauth.Auth, bool, error) } +// PluginMultiAuthParser expands one auth JSON payload into multiple plugin auth records. +// Returning handled=true with an empty slice means the plugin intentionally suppresses built-in parsing. +type PluginMultiAuthParser interface { + ParseAuths(context.Context, pluginapi.AuthParseRequest) ([]*coreauth.Auth, bool, error) +} + // SynthesisContext provides the context needed for auth synthesis. type SynthesisContext struct { // Config is the current configuration diff --git a/internal/watcher/synthesizer/file.go b/internal/watcher/synthesizer/file.go index 1712670577..03233562e7 100644 --- a/internal/watcher/synthesizer/file.go +++ b/internal/watcher/synthesizer/file.go @@ -3,22 +3,20 @@ package synthesizer import ( "context" "encoding/json" - "fmt" "os" "path/filepath" "runtime" "strconv" "strings" - "time" "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex" - "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/geminicli" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi" ) // FileSynthesizer generates Auth entries from OAuth JSON files. -// It handles file-based authentication and Gemini virtual auth generation. +// It handles file-based authentication. type FileSynthesizer struct{} // NewFileSynthesizer creates a new FileSynthesizer instance. @@ -79,33 +77,47 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) [] } t, _ := metadata["type"].(string) provider := strings.ToLower(strings.TrimSpace(t)) + if provider == "gemini" { + provider = "gemini-cli" + } if ctx.PluginAuthParser != nil { - auth, handled, errParse := ctx.PluginAuthParser.ParseAuth(context.Background(), pluginapi.AuthParseRequest{ + auths, handled, errParse := parsePluginFileAuths(ctx.PluginAuthParser, pluginapi.AuthParseRequest{ Provider: provider, Path: fullPath, FileName: filepath.Base(fullPath), RawJSON: data, }) - if errParse == nil && handled && auth != nil { - auth.CreatedAt = now - auth.UpdatedAt = now - if auth.Attributes == nil { - auth.Attributes = make(map[string]string) + if errParse == nil && handled { + auths = compactPluginAuths(auths) + if len(auths) == 0 { + return nil } - auth.Attributes["path"] = fullPath - auth.Attributes["source"] = fullPath perAccountExcluded := extractExcludedModelsFromMetadata(metadata) - ApplyAuthExcludedModelsMeta(auth, cfg, perAccountExcluded, "oauth") - coreauth.ApplyCustomHeadersFromMetadata(auth) - return []*coreauth.Auth{auth} + perAccountModelAliases := extractOAuthModelAliasesFromMetadata(metadata) + for index, auth := range auths { + if auth == nil { + continue + } + if len(auths) > 1 { + coreauth.MarkPluginVirtualAuth(auth, fullPath, index) + } + auth.CreatedAt = now + auth.UpdatedAt = now + if auth.Attributes == nil { + auth.Attributes = make(map[string]string) + } + auth.Attributes["path"] = fullPath + auth.Attributes["source"] = fullPath + coreauth.SetOAuthModelAliasesAttribute(auth, perAccountModelAliases) + ApplyAuthExcludedModelsMeta(auth, cfg, perAccountExcluded, "oauth") + coreauth.ApplyCustomHeadersFromMetadata(auth) + } + return auths } } - if provider == "" { + if provider == "" || provider == "gemini-cli" { return nil } - if provider == "gemini" { - provider = "gemini-cli" - } label := provider if email, _ := metadata["email"].(string); email != "" { label = email @@ -143,6 +155,7 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) [] // Read per-account excluded models from the OAuth JSON file. perAccountExcluded := extractExcludedModelsFromMetadata(metadata) + perAccountModelAliases := extractOAuthModelAliasesFromMetadata(metadata) a := &coreauth.Auth{ ID: id, @@ -181,6 +194,7 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) [] } } coreauth.ApplyCustomHeadersFromMetadata(a) + coreauth.SetOAuthModelAliasesAttribute(a, perAccountModelAliases) ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth") // For codex auth files, extract plan_type from the JWT id_token. if provider == "codex" { @@ -192,147 +206,65 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) [] } } } - if provider == "gemini-cli" { - if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { - for _, v := range virtuals { - ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth") - } - out := make([]*coreauth.Auth, 0, 1+len(virtuals)) - out = append(out, a) - out = append(out, virtuals...) - return out - } - } return []*coreauth.Auth{a} } -// SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials. -// It disables the primary auth and creates one virtual auth per project. -func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth { - if primary == nil || metadata == nil { - return nil +func parsePluginFileAuths(parser PluginAuthParser, req pluginapi.AuthParseRequest) ([]*coreauth.Auth, bool, error) { + if parser == nil { + return nil, false, nil } - projects := splitGeminiProjectIDs(metadata) - if len(projects) <= 1 { - return nil + if multiParser, ok := parser.(PluginMultiAuthParser); ok { + return multiParser.ParseAuths(context.Background(), req) } - email, _ := metadata["email"].(string) - shared := geminicli.NewSharedCredential(primary.ID, email, metadata, projects) - primary.Disabled = true - primary.Status = coreauth.StatusDisabled - primary.Runtime = shared - if primary.Attributes == nil { - primary.Attributes = make(map[string]string) - } - primary.Attributes["gemini_virtual_primary"] = "true" - primary.Attributes["virtual_children"] = strings.Join(projects, ",") - source := primary.Attributes["source"] - authPath := primary.Attributes["path"] - originalProvider := primary.Provider - if originalProvider == "" { - originalProvider = "gemini-cli" - } - label := primary.Label - if label == "" { - label = originalProvider - } - virtuals := make([]*coreauth.Auth, 0, len(projects)) - for _, projectID := range projects { - attrs := map[string]string{ - "runtime_only": "true", - "gemini_virtual_parent": primary.ID, - "gemini_virtual_project": projectID, - } - if source != "" { - attrs["source"] = source - } - if authPath != "" { - attrs["path"] = authPath - } - // Propagate priority from primary auth to virtual auths - if priorityVal, hasPriority := primary.Attributes["priority"]; hasPriority && priorityVal != "" { - attrs["priority"] = priorityVal - } - // Propagate note from primary auth to virtual auths - if noteVal, hasNote := primary.Attributes["note"]; hasNote && noteVal != "" { - attrs["note"] = noteVal - } - for k, v := range primary.Attributes { - if strings.HasPrefix(k, "header:") && strings.TrimSpace(v) != "" { - attrs[k] = v - } - } - metadataCopy := map[string]any{ - "email": email, - "project_id": projectID, - "virtual": true, - "virtual_parent_id": primary.ID, - "type": metadata["type"], - } - if v, ok := metadata["disable_cooling"]; ok { - metadataCopy["disable_cooling"] = v - } else if v, ok := metadata["disable-cooling"]; ok { - metadataCopy["disable_cooling"] = v - } - if v, ok := metadata["request_retry"]; ok { - metadataCopy["request_retry"] = v - } else if v, ok := metadata["request-retry"]; ok { - metadataCopy["request_retry"] = v - } - proxy := strings.TrimSpace(primary.ProxyURL) - if proxy != "" { - metadataCopy["proxy_url"] = proxy - } - virtual := &coreauth.Auth{ - ID: buildGeminiVirtualID(primary.ID, projectID), - Provider: originalProvider, - Label: fmt.Sprintf("%s [%s]", label, projectID), - Status: coreauth.StatusActive, - Attributes: attrs, - Metadata: metadataCopy, - ProxyURL: primary.ProxyURL, - Prefix: primary.Prefix, - CreatedAt: primary.CreatedAt, - UpdatedAt: primary.UpdatedAt, - Runtime: geminicli.NewVirtualCredential(projectID, shared), - } - virtuals = append(virtuals, virtual) + auth, handled, errParse := parser.ParseAuth(context.Background(), req) + if errParse != nil || !handled || auth == nil { + return nil, handled, errParse } - return virtuals + return []*coreauth.Auth{auth}, true, nil } -// splitGeminiProjectIDs extracts and deduplicates project IDs from metadata. -func splitGeminiProjectIDs(metadata map[string]any) []string { - raw, _ := metadata["project_id"].(string) - trimmed := strings.TrimSpace(raw) - if trimmed == "" { +func compactPluginAuths(auths []*coreauth.Auth) []*coreauth.Auth { + if len(auths) == 0 { return nil } - parts := strings.Split(trimmed, ",") - result := make([]string, 0, len(parts)) - seen := make(map[string]struct{}, len(parts)) - for _, part := range parts { - id := strings.TrimSpace(part) - if id == "" { + out := auths[:0] + for _, auth := range auths { + if auth == nil { continue } - if _, ok := seen[id]; ok { - continue - } - seen[id] = struct{}{} - result = append(result, id) + out = append(out, auth) } - return result + return out } -// buildGeminiVirtualID constructs a virtual auth ID from base ID and project ID. -func buildGeminiVirtualID(baseID, projectID string) string { - project := strings.TrimSpace(projectID) - if project == "" { - project = "project" - } - replacer := strings.NewReplacer("/", "_", "\\", "_", " ", "_") - return fmt.Sprintf("%s::%s", baseID, replacer.Replace(project)) +// extractOAuthModelAliasesFromMetadata reads per-account model aliases from OAuth JSON metadata. +// Supports both "model_aliases" and "model-aliases" keys. +func extractOAuthModelAliasesFromMetadata(metadata map[string]any) []config.OAuthModelAlias { + if metadata == nil { + return nil + } + raw, ok := metadata["model_aliases"] + if !ok { + raw, ok = metadata["model-aliases"] + } + if !ok || raw == nil { + return nil + } + data, errMarshal := json.Marshal(raw) + if errMarshal != nil { + return nil + } + var aliases []config.OAuthModelAlias + if errUnmarshal := json.Unmarshal(data, &aliases); errUnmarshal != nil { + return nil + } + cfg := config.Config{ + OAuthModelAlias: map[string][]config.OAuthModelAlias{ + "auth": aliases, + }, + } + cfg.SanitizeOAuthModelAlias() + return cfg.OAuthModelAlias["auth"] } // extractExcludedModelsFromMetadata reads per-account excluded models from the OAuth JSON metadata. diff --git a/internal/watcher/synthesizer/file_test.go b/internal/watcher/synthesizer/file_test.go index 63b394aaf5..b52385549c 100644 --- a/internal/watcher/synthesizer/file_test.go +++ b/internal/watcher/synthesizer/file_test.go @@ -1,15 +1,16 @@ package synthesizer import ( + "context" "encoding/json" "os" "path/filepath" - "strings" "testing" "time" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi" ) func TestNewFileSynthesizer(t *testing.T) { @@ -131,10 +132,9 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) { } } -func TestFileSynthesizer_Synthesize_GeminiProviderMapping(t *testing.T) { +func TestFileSynthesizer_Synthesize_IgnoresGeminiProviderFile(t *testing.T) { tempDir := t.TempDir() - // Gemini type should be mapped to gemini-cli authData := map[string]any{ "type": "gemini", "email": "gemini@example.com", @@ -157,15 +157,110 @@ func TestFileSynthesizer_Synthesize_GeminiProviderMapping(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if len(auths) != 1 { - t.Fatalf("expected 1 auth, got %d", len(auths)) + if len(auths) != 0 { + t.Fatalf("expected Gemini auth file to be ignored, got %d auths", len(auths)) } +} - if auths[0].Provider != "gemini-cli" { - t.Errorf("gemini should be mapped to gemini-cli, got %s", auths[0].Provider) +func TestSynthesizeAuthFileExpandsPluginMultiAuths(t *testing.T) { + tempDir := t.TempDir() + fullPath := filepath.Join(tempDir, "geminicli.json") + raw := []byte(`{"type":"gemini-cli","excluded_models":["model-a"],"headers":{"X-Test":"value"}}`) + + ctx := &SynthesisContext{ + Config: &config.Config{}, + AuthDir: tempDir, + Now: time.Date(2026, 6, 21, 0, 0, 0, 0, time.UTC), + PluginAuthParser: multiAuthParserFunc(func(ctx context.Context, req pluginapi.AuthParseRequest) ([]*coreauth.Auth, bool, error) { + if req.Provider != "gemini-cli" || req.Path != fullPath || req.FileName != "geminicli.json" { + t.Fatalf("ParseAuths request = %#v, want file context", req) + } + return []*coreauth.Auth{ + { + ID: "geminicli.json", + Provider: "gemini-cli", + Metadata: map[string]any{ + "type": "gemini-cli", + "headers": map[string]any{ + "X-Test": "value", + }, + }, + }, + nil, + { + ID: "geminicli-project-a.json", + Provider: "gemini-cli", + Metadata: map[string]any{ + "type": "gemini-cli", + "project_id": "project-a", + "headers": map[string]any{ + "X-Test": "value", + }, + }, + }, + }, true, nil + }), + } + + auths := SynthesizeAuthFile(ctx, fullPath, raw) + if len(auths) != 2 { + t.Fatalf("SynthesizeAuthFile() len = %d, want two plugin auths", len(auths)) + } + if firstIndex, secondIndex := auths[0].EnsureIndex(), auths[1].EnsureIndex(); firstIndex == "" || firstIndex == secondIndex { + t.Fatalf("auth indexes = %q/%q, want distinct non-empty indexes", firstIndex, secondIndex) + } + for _, auth := range auths { + if !coreauth.IsPluginVirtualAuth(auth) { + t.Fatalf("auth attributes = %#v, want plugin virtual marker", auth.Attributes) + } + if auth.Attributes[coreauth.AttributeVirtualSource] != fullPath { + t.Fatalf("virtual_source = %q, want %q", auth.Attributes[coreauth.AttributeVirtualSource], fullPath) + } + if auth.Attributes["path"] != fullPath || auth.Attributes["source"] != fullPath { + t.Fatalf("auth attributes = %#v, want source path", auth.Attributes) + } + if gotHeader := auth.Attributes["header:X-Test"]; gotHeader != "value" { + t.Fatalf("header:X-Test = %q, want value", gotHeader) + } + if gotKind := auth.Attributes["auth_kind"]; gotKind != "oauth" { + t.Fatalf("auth_kind = %q, want oauth", gotKind) + } + } + if gotProject := auths[1].Metadata["project_id"]; gotProject != "project-a" { + t.Fatalf("project_id = %#v, want project-a", gotProject) } } +func TestSynthesizeAuthFilePluginHandledEmptySuppressesBuiltin(t *testing.T) { + tempDir := t.TempDir() + fullPath := filepath.Join(tempDir, "codex.json") + raw := []byte(`{"type":"codex","access_token":"token"}`) + + ctx := &SynthesisContext{ + Config: &config.Config{}, + AuthDir: tempDir, + Now: time.Date(2026, 6, 21, 0, 0, 0, 0, time.UTC), + PluginAuthParser: multiAuthParserFunc(func(context.Context, pluginapi.AuthParseRequest) ([]*coreauth.Auth, bool, error) { + return nil, true, nil + }), + } + + auths := SynthesizeAuthFile(ctx, fullPath, raw) + if len(auths) != 0 { + t.Fatalf("SynthesizeAuthFile() len = %d, want plugin-handled empty result", len(auths)) + } +} + +type multiAuthParserFunc func(context.Context, pluginapi.AuthParseRequest) ([]*coreauth.Auth, bool, error) + +func (f multiAuthParserFunc) ParseAuth(context.Context, pluginapi.AuthParseRequest) (*coreauth.Auth, bool, error) { + return nil, false, nil +} + +func (f multiAuthParserFunc) ParseAuths(ctx context.Context, req pluginapi.AuthParseRequest) ([]*coreauth.Auth, bool, error) { + return f(ctx, req) +} + func TestFileSynthesizer_Synthesize_SkipsInvalidFiles(t *testing.T) { tempDir := t.TempDir() @@ -418,242 +513,50 @@ func TestFileSynthesizer_Synthesize_OAuthExcludedModelsMerged(t *testing.T) { } } -func TestSynthesizeGeminiVirtualAuths_NilInputs(t *testing.T) { - now := time.Now() - - if SynthesizeGeminiVirtualAuths(nil, nil, now) != nil { - t.Error("expected nil for nil primary") - } - if SynthesizeGeminiVirtualAuths(&coreauth.Auth{}, nil, now) != nil { - t.Error("expected nil for nil metadata") - } - if SynthesizeGeminiVirtualAuths(nil, map[string]any{}, now) != nil { - t.Error("expected nil for nil primary with metadata") - } -} - -func TestSynthesizeGeminiVirtualAuths_SingleProject(t *testing.T) { - now := time.Now() - primary := &coreauth.Auth{ - ID: "test-id", - Provider: "gemini-cli", - Label: "test@example.com", - } - metadata := map[string]any{ - "project_id": "single-project", - "email": "test@example.com", - "type": "gemini", - } - - virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now) - if virtuals != nil { - t.Error("single project should not create virtuals") - } -} - -func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) { - now := time.Now() - primary := &coreauth.Auth{ - ID: "primary-id", - Provider: "gemini-cli", - Label: "test@example.com", - Prefix: "test-prefix", - ProxyURL: "http://proxy.local", - Attributes: map[string]string{ - "source": "test-source", - "path": "/path/to/auth", - "header:X-Tra": "value", +func TestFileSynthesizer_Synthesize_OAuthModelAliases(t *testing.T) { + tempDir := t.TempDir() + authData := map[string]any{ + "type": "codex", + "email": "codex@example.com", + "model-aliases": []map[string]any{ + {"name": " gpt-5.3-codex-spark ", "alias": " gpt-5.5 "}, + {"name": "gpt-5.3-codex-spark", "alias": "gpt-5.4", "fork": true}, + {"name": "gpt-5.3-codex-spark", "alias": "gpt-5.5"}, + {"name": "", "alias": "ignored"}, }, } - metadata := map[string]any{ - "project_id": "project-a, project-b, project-c", - "email": "test@example.com", - "type": "gemini", - "request_retry": 2, - "disable_cooling": true, - } - - virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now) - - if len(virtuals) != 3 { - t.Fatalf("expected 3 virtuals, got %d", len(virtuals)) - } - - // Check primary is disabled - if !primary.Disabled { - t.Error("expected primary to be disabled") - } - if primary.Status != coreauth.StatusDisabled { - t.Errorf("expected primary status disabled, got %s", primary.Status) - } - if primary.Attributes["gemini_virtual_primary"] != "true" { - t.Error("expected gemini_virtual_primary=true") - } - if !strings.Contains(primary.Attributes["virtual_children"], "project-a") { - t.Error("expected virtual_children to contain project-a") - } - - // Check virtuals - projectIDs := []string{"project-a", "project-b", "project-c"} - for i, v := range virtuals { - if v.Provider != "gemini-cli" { - t.Errorf("expected provider gemini-cli, got %s", v.Provider) - } - if v.Status != coreauth.StatusActive { - t.Errorf("expected status active, got %s", v.Status) - } - if v.Prefix != "test-prefix" { - t.Errorf("expected prefix test-prefix, got %s", v.Prefix) - } - if v.ProxyURL != "http://proxy.local" { - t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL) - } - if vv, ok := v.Metadata["disable_cooling"].(bool); !ok || !vv { - t.Errorf("expected disable_cooling true, got %v", v.Metadata["disable_cooling"]) - } - if vv, ok := v.Metadata["request_retry"].(int); !ok || vv != 2 { - t.Errorf("expected request_retry 2, got %v", v.Metadata["request_retry"]) - } - if v.Attributes["runtime_only"] != "true" { - t.Error("expected runtime_only=true") - } - if got := v.Attributes["header:X-Tra"]; got != "value" { - t.Errorf("expected virtual %d header:X-Tra %q, got %q", i, "value", got) - } - if v.Attributes["gemini_virtual_parent"] != "primary-id" { - t.Errorf("expected gemini_virtual_parent=primary-id, got %s", v.Attributes["gemini_virtual_parent"]) - } - if v.Attributes["gemini_virtual_project"] != projectIDs[i] { - t.Errorf("expected gemini_virtual_project=%s, got %s", projectIDs[i], v.Attributes["gemini_virtual_project"]) - } - if !strings.Contains(v.Label, "["+projectIDs[i]+"]") { - t.Errorf("expected label to contain [%s], got %s", projectIDs[i], v.Label) - } - } -} - -func TestSynthesizeGeminiVirtualAuths_EmptyProviderAndLabel(t *testing.T) { - now := time.Now() - // Test with empty Provider and Label to cover fallback branches - primary := &coreauth.Auth{ - ID: "primary-id", - Provider: "", // empty provider - should default to gemini-cli - Label: "", // empty label - should default to provider - Attributes: map[string]string{}, - } - metadata := map[string]any{ - "project_id": "proj-a, proj-b", - "email": "user@example.com", - "type": "gemini", - } - - virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now) - - if len(virtuals) != 2 { - t.Fatalf("expected 2 virtuals, got %d", len(virtuals)) - } - - // Check that empty provider defaults to gemini-cli - if virtuals[0].Provider != "gemini-cli" { - t.Errorf("expected provider gemini-cli (default), got %s", virtuals[0].Provider) - } - // Check that empty label defaults to provider - if !strings.Contains(virtuals[0].Label, "gemini-cli") { - t.Errorf("expected label to contain gemini-cli, got %s", virtuals[0].Label) + data, _ := json.Marshal(authData) + errWriteFile := os.WriteFile(filepath.Join(tempDir, "codex-auth.json"), data, 0644) + if errWriteFile != nil { + t.Fatalf("failed to write auth file: %v", errWriteFile) } -} -func TestSynthesizeGeminiVirtualAuths_NilPrimaryAttributes(t *testing.T) { - now := time.Now() - primary := &coreauth.Auth{ - ID: "primary-id", - Provider: "gemini-cli", - Label: "test@example.com", - Attributes: nil, // nil attributes - } - metadata := map[string]any{ - "project_id": "proj-a, proj-b", - "email": "test@example.com", - "type": "gemini", + synth := NewFileSynthesizer() + ctx := &SynthesisContext{ + Config: &config.Config{}, + AuthDir: tempDir, + Now: time.Now(), + IDGenerator: NewStableIDGenerator(), } - virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now) - - if len(virtuals) != 2 { - t.Fatalf("expected 2 virtuals, got %d", len(virtuals)) - } - // Nil attributes should be initialized - if primary.Attributes == nil { - t.Error("expected primary.Attributes to be initialized") - } - if primary.Attributes["gemini_virtual_primary"] != "true" { - t.Error("expected gemini_virtual_primary=true") + auths, errSynthesize := synth.Synthesize(ctx) + if errSynthesize != nil { + t.Fatalf("unexpected error: %v", errSynthesize) } -} - -func TestSplitGeminiProjectIDs(t *testing.T) { - tests := []struct { - name string - metadata map[string]any - want []string - }{ - { - name: "single project", - metadata: map[string]any{"project_id": "proj-a"}, - want: []string{"proj-a"}, - }, - { - name: "multiple projects", - metadata: map[string]any{"project_id": "proj-a, proj-b, proj-c"}, - want: []string{"proj-a", "proj-b", "proj-c"}, - }, - { - name: "with duplicates", - metadata: map[string]any{"project_id": "proj-a, proj-b, proj-a"}, - want: []string{"proj-a", "proj-b"}, - }, - { - name: "with empty parts", - metadata: map[string]any{"project_id": "proj-a, , proj-b, "}, - want: []string{"proj-a", "proj-b"}, - }, - { - name: "empty project_id", - metadata: map[string]any{"project_id": ""}, - want: nil, - }, - { - name: "no project_id", - metadata: map[string]any{}, - want: nil, - }, - { - name: "whitespace only", - metadata: map[string]any{"project_id": " "}, - want: nil, - }, + if len(auths) != 1 { + t.Fatalf("expected 1 auth, got %d", len(auths)) } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := splitGeminiProjectIDs(tt.metadata) - if len(got) != len(tt.want) { - t.Fatalf("expected %v, got %v", tt.want, got) - } - for i := range got { - if got[i] != tt.want[i] { - t.Errorf("expected %v, got %v", tt.want, got) - break - } - } - }) + got := auths[0].Attributes["model_aliases"] + want := `[{"name":"gpt-5.3-codex-spark","alias":"gpt-5.5"},{"name":"gpt-5.3-codex-spark","alias":"gpt-5.4","fork":true}]` + if got != want { + t.Fatalf("expected model_aliases %q, got %q", want, got) } } -func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) { +func TestFileSynthesizer_Synthesize_IgnoresGeminiOAuthFile(t *testing.T) { tempDir := t.TempDir() - // Create a gemini auth file with multiple projects authData := map[string]any{ "type": "gemini", "email": "multi@example.com", @@ -678,149 +581,8 @@ func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - // Should have 4 auths: 1 primary (disabled) + 3 virtuals - if len(auths) != 4 { - t.Fatalf("expected 4 auths (1 primary + 3 virtuals), got %d", len(auths)) - } - - // First auth should be the primary (disabled) - primary := auths[0] - if !primary.Disabled { - t.Error("expected primary to be disabled") - } - if primary.Status != coreauth.StatusDisabled { - t.Errorf("expected primary status disabled, got %s", primary.Status) - } - if gotPriority := primary.Attributes["priority"]; gotPriority != "10" { - t.Errorf("expected primary priority 10, got %q", gotPriority) - } - - // Remaining auths should be virtuals - for i := 1; i < 4; i++ { - v := auths[i] - if v.Status != coreauth.StatusActive { - t.Errorf("expected virtual %d to be active, got %s", i, v.Status) - } - if v.Attributes["gemini_virtual_parent"] != primary.ID { - t.Errorf("expected virtual %d parent to be %s, got %s", i, primary.ID, v.Attributes["gemini_virtual_parent"]) - } - if gotPriority := v.Attributes["priority"]; gotPriority != "10" { - t.Errorf("expected virtual %d priority 10, got %q", i, gotPriority) - } - } -} - -func TestBuildGeminiVirtualID(t *testing.T) { - tests := []struct { - name string - baseID string - projectID string - want string - }{ - { - name: "basic", - baseID: "auth.json", - projectID: "my-project", - want: "auth.json::my-project", - }, - { - name: "with slashes", - baseID: "path/to/auth.json", - projectID: "project/with/slashes", - want: "path/to/auth.json::project_with_slashes", - }, - { - name: "with spaces", - baseID: "auth.json", - projectID: "my project", - want: "auth.json::my_project", - }, - { - name: "empty project", - baseID: "auth.json", - projectID: "", - want: "auth.json::project", - }, - { - name: "whitespace project", - baseID: "auth.json", - projectID: " ", - want: "auth.json::project", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := buildGeminiVirtualID(tt.baseID, tt.projectID) - if got != tt.want { - t.Errorf("expected %q, got %q", tt.want, got) - } - }) - } -} - -func TestSynthesizeGeminiVirtualAuths_NotePropagated(t *testing.T) { - now := time.Now() - primary := &coreauth.Auth{ - ID: "primary-id", - Provider: "gemini-cli", - Label: "test@example.com", - Attributes: map[string]string{ - "source": "test-source", - "path": "/path/to/auth", - "priority": "5", - "note": "my test note", - }, - } - metadata := map[string]any{ - "project_id": "proj-a, proj-b", - "email": "test@example.com", - "type": "gemini", - } - - virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now) - - if len(virtuals) != 2 { - t.Fatalf("expected 2 virtuals, got %d", len(virtuals)) - } - - for i, v := range virtuals { - if got := v.Attributes["note"]; got != "my test note" { - t.Errorf("virtual %d: expected note %q, got %q", i, "my test note", got) - } - if got := v.Attributes["priority"]; got != "5" { - t.Errorf("virtual %d: expected priority %q, got %q", i, "5", got) - } - } -} - -func TestSynthesizeGeminiVirtualAuths_NoteAbsentWhenEmpty(t *testing.T) { - now := time.Now() - primary := &coreauth.Auth{ - ID: "primary-id", - Provider: "gemini-cli", - Label: "test@example.com", - Attributes: map[string]string{ - "source": "test-source", - "path": "/path/to/auth", - }, - } - metadata := map[string]any{ - "project_id": "proj-a, proj-b", - "email": "test@example.com", - "type": "gemini", - } - - virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now) - - if len(virtuals) != 2 { - t.Fatalf("expected 2 virtuals, got %d", len(virtuals)) - } - - for i, v := range virtuals { - if _, hasNote := v.Attributes["note"]; hasNote { - t.Errorf("virtual %d: expected no note attribute when primary has no note", i) - } + if len(auths) != 0 { + t.Fatalf("expected Gemini auth file to be ignored, got %d auths", len(auths)) } } @@ -905,53 +667,3 @@ func TestFileSynthesizer_Synthesize_NoteParsing(t *testing.T) { }) } } - -func TestFileSynthesizer_Synthesize_MultiProjectGeminiWithNote(t *testing.T) { - tempDir := t.TempDir() - - authData := map[string]any{ - "type": "gemini", - "email": "multi@example.com", - "project_id": "project-a, project-b", - "priority": 5, - "note": "production keys", - } - data, _ := json.Marshal(authData) - err := os.WriteFile(filepath.Join(tempDir, "gemini-multi.json"), data, 0644) - if err != nil { - t.Fatalf("failed to write auth file: %v", err) - } - - synth := NewFileSynthesizer() - ctx := &SynthesisContext{ - Config: &config.Config{}, - AuthDir: tempDir, - Now: time.Now(), - IDGenerator: NewStableIDGenerator(), - } - - auths, err := synth.Synthesize(ctx) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - // Should have 3 auths: 1 primary (disabled) + 2 virtuals - if len(auths) != 3 { - t.Fatalf("expected 3 auths (1 primary + 2 virtuals), got %d", len(auths)) - } - - primary := auths[0] - if gotNote := primary.Attributes["note"]; gotNote != "production keys" { - t.Errorf("expected primary note %q, got %q", "production keys", gotNote) - } - - // Verify virtuals inherit note - for i := 1; i < len(auths); i++ { - v := auths[i] - if gotNote := v.Attributes["note"]; gotNote != "production keys" { - t.Errorf("expected virtual %d note %q, got %q", i, "production keys", gotNote) - } - if gotPriority := v.Attributes["priority"]; gotPriority != "5" { - t.Errorf("expected virtual %d priority %q, got %q", i, "5", gotPriority) - } - } -} diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go index 98740df2e2..319aa5ab9c 100644 --- a/internal/watcher/watcher_test.go +++ b/internal/watcher/watcher_test.go @@ -141,30 +141,20 @@ func TestSnapshotCoreAuths_ConfigAndAuthFiles(t *testing.T) { Headers: map[string]string{"X-Req": "1"}, }, }, - OAuthExcludedModels: map[string][]string{ - "gemini-cli": {"Foo", "bar"}, - }, } w := &Watcher{authDir: authDir} w.SetConfig(cfg) auths := w.SnapshotCoreAuths() - if len(auths) != 4 { - t.Fatalf("expected 4 auth entries (1 config + 1 primary + 2 virtual), got %d", len(auths)) + if len(auths) != 1 { + t.Fatalf("expected 1 config auth entry, got %d", len(auths)) } var geminiAPIKeyAuth *coreauth.Auth - var geminiPrimary *coreauth.Auth - virtuals := make([]*coreauth.Auth, 0) for _, a := range auths { - switch { - case a.Provider == "gemini" && a.Attributes["api_key"] == "g-key": + if a.Provider == "gemini" && a.Attributes["api_key"] == "g-key" { geminiAPIKeyAuth = a - case a.Attributes["gemini_virtual_primary"] == "true": - geminiPrimary = a - case strings.TrimSpace(a.Attributes["gemini_virtual_parent"]) != "": - virtuals = append(virtuals, a) } } if geminiAPIKeyAuth == nil { @@ -177,35 +167,6 @@ func TestSnapshotCoreAuths_ConfigAndAuthFiles(t *testing.T) { if geminiAPIKeyAuth.Attributes["auth_kind"] != "apikey" { t.Fatalf("expected auth_kind=apikey, got %s", geminiAPIKeyAuth.Attributes["auth_kind"]) } - - if geminiPrimary == nil { - t.Fatal("expected primary gemini-cli auth from file") - } - if !geminiPrimary.Disabled || geminiPrimary.Status != coreauth.StatusDisabled { - t.Fatal("expected primary gemini-cli auth to be disabled when virtual auths are synthesized") - } - expectedOAuthHash := diff.ComputeExcludedModelsHash([]string{"Foo", "bar"}) - if geminiPrimary.Attributes["excluded_models_hash"] != expectedOAuthHash { - t.Fatalf("expected OAuth excluded hash %s, got %s", expectedOAuthHash, geminiPrimary.Attributes["excluded_models_hash"]) - } - if geminiPrimary.Attributes["auth_kind"] != "oauth" { - t.Fatalf("expected auth_kind=oauth, got %s", geminiPrimary.Attributes["auth_kind"]) - } - - if len(virtuals) != 2 { - t.Fatalf("expected 2 virtual auths, got %d", len(virtuals)) - } - for _, v := range virtuals { - if v.Attributes["gemini_virtual_parent"] != geminiPrimary.ID { - t.Fatalf("virtual auth missing parent link to %s", geminiPrimary.ID) - } - if v.Attributes["excluded_models_hash"] != expectedOAuthHash { - t.Fatalf("expected virtual excluded hash %s, got %s", expectedOAuthHash, v.Attributes["excluded_models_hash"]) - } - if v.Status != coreauth.StatusActive { - t.Fatalf("expected virtual auth to be active, got %s", v.Status) - } - } } func TestReloadConfigIfChanged_TriggersOnChangeAndSkipsUnchanged(t *testing.T) { diff --git a/sdk/api/handlers/gemini/gemini-cli_handlers.go b/sdk/api/handlers/gemini/gemini-cli_handlers.go deleted file mode 100644 index de79f05b7c..0000000000 --- a/sdk/api/handlers/gemini/gemini-cli_handlers.go +++ /dev/null @@ -1,248 +0,0 @@ -// Package gemini provides HTTP handlers for Gemini CLI API functionality. -// This package implements handlers that process CLI-specific requests for Gemini API operations, -// including content generation and streaming content generation endpoints. -// The handlers restrict access to localhost only and manage communication with the backend service. -package gemini - -import ( - "bytes" - "context" - "fmt" - "io" - "net" - "net/http" - "strings" - "time" - - "github.com/gin-gonic/gin" - . "github.com/router-for-me/CLIProxyAPI/v7/internal/constant" - "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" - "github.com/router-for-me/CLIProxyAPI/v7/internal/util" - "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" -) - -// GeminiCLIAPIHandler contains the handlers for Gemini CLI API endpoints. -// It holds a pool of clients to interact with the backend service. -type GeminiCLIAPIHandler struct { - *handlers.BaseAPIHandler -} - -// NewGeminiCLIAPIHandler creates a new Gemini CLI API handlers instance. -// It takes an BaseAPIHandler instance as input and returns a GeminiCLIAPIHandler. -func NewGeminiCLIAPIHandler(apiHandlers *handlers.BaseAPIHandler) *GeminiCLIAPIHandler { - return &GeminiCLIAPIHandler{ - BaseAPIHandler: apiHandlers, - } -} - -// HandlerType returns the type of this handler. -func (h *GeminiCLIAPIHandler) HandlerType() string { - return GeminiCLI -} - -// Models returns a list of models supported by this handler. -func (h *GeminiCLIAPIHandler) Models() []map[string]any { - return make([]map[string]any, 0) -} - -// CLIHandler handles CLI-specific requests for Gemini API operations. -// It restricts access to localhost only and routes requests to appropriate internal handlers. -func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) { - if h.Cfg == nil || !h.Cfg.EnableGeminiCLIEndpoint { - c.JSON(http.StatusForbidden, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: "Gemini CLI endpoint is disabled", - Type: "forbidden", - }, - }) - return - } - - requestHost := c.Request.Host - requestHostname := requestHost - if hostname, _, errSplitHostPort := net.SplitHostPort(requestHost); errSplitHostPort == nil { - requestHostname = hostname - } - - if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") || requestHostname != "127.0.0.1" { - c.JSON(http.StatusForbidden, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: "CLI reply only allow local access", - Type: "forbidden", - }, - }) - return - } - - rawJSON, _ := c.GetRawData() - requestRawURI := c.Request.URL.Path - - if requestRawURI == "/v1internal:generateContent" { - h.handleInternalGenerateContent(c, rawJSON) - } else if requestRawURI == "/v1internal:streamGenerateContent" { - h.handleInternalStreamGenerateContent(c, rawJSON) - } else { - reqBody := bytes.NewBuffer(rawJSON) - req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody) - if err != nil { - c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: fmt.Sprintf("Invalid request: %v", err), - Type: "invalid_request_error", - }, - }) - return - } - for key, value := range c.Request.Header { - req.Header[key] = value - } - - httpClient := util.SetProxy(h.Cfg, &http.Client{}) - - resp, err := httpClient.Do(req) - if err != nil { - c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: fmt.Sprintf("Invalid request: %v", err), - Type: "invalid_request_error", - }, - }) - return - } - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - defer func() { - if err = resp.Body.Close(); err != nil { - log.Printf("warn: failed to close response body: %v", err) - } - }() - bodyBytes, _ := io.ReadAll(resp.Body) - - c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: string(bodyBytes), - Type: "invalid_request_error", - }, - }) - return - } - - defer func() { - _ = resp.Body.Close() - }() - - for key, value := range resp.Header { - c.Header(key, value[0]) - } - output, err := io.ReadAll(resp.Body) - if err != nil { - log.Errorf("Failed to read response body: %v", err) - return - } - c.Set("API_RESPONSE_TIMESTAMP", time.Now()) - _, _ = c.Writer.Write(output) - c.Set("API_RESPONSE", output) - } -} - -// handleInternalStreamGenerateContent handles streaming content generation requests. -// It sets up a server-sent event stream and forwards the request to the backend client. -// The function continuously proxies response chunks from the backend to the client. -func (h *GeminiCLIAPIHandler) handleInternalStreamGenerateContent(c *gin.Context, rawJSON []byte) { - alt := h.GetAlt(c) - - if alt == "" { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - } - - // Get the http.Flusher interface to manually flush the response. - flusher, ok := c.Writer.(http.Flusher) - if !ok { - c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: "Streaming not supported", - Type: "server_error", - }, - }) - return - } - - modelResult := gjson.GetBytes(rawJSON, "model") - modelName := modelResult.String() - - cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") - handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) - h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan) - return -} - -// handleInternalGenerateContent handles non-streaming content generation requests. -// It sends a request to the backend client and proxies the entire response back to the client at once. -func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "application/json") - modelResult := gjson.GetBytes(rawJSON, "model") - modelName := modelResult.String() - - cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - cliCancel(errMsg.Error) - return - } - handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) - _, _ = c.Writer.Write(resp) - cliCancel() -} - -func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - var keepAliveInterval *time.Duration - if alt != "" { - keepAliveInterval = new(time.Duration(0)) - } - - h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ - KeepAliveInterval: keepAliveInterval, - WriteChunk: func(chunk []byte) { - if alt == "" { - if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) { - return - } - - if !bytes.HasPrefix(chunk, []byte("data:")) { - _, _ = c.Writer.Write([]byte("data: ")) - } - - _, _ = c.Writer.Write(chunk) - _, _ = c.Writer.Write([]byte("\n\n")) - } else { - _, _ = c.Writer.Write(chunk) - } - }, - WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { - if errMsg == nil { - return - } - status := http.StatusInternalServerError - if errMsg.StatusCode > 0 { - status = errMsg.StatusCode - } - errText := http.StatusText(status) - if errMsg.Error != nil && errMsg.Error.Error() != "" { - errText = errMsg.Error.Error() - } - body := handlers.BuildErrorResponseBody(status, errText) - if alt == "" { - _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body)) - } else { - _, _ = c.Writer.Write(body) - } - }, - }) -} diff --git a/sdk/api/handlers/openai/codex_client_models.go b/sdk/api/handlers/openai/codex_client_models.go index 7d05db22ec..41d8e120d2 100644 --- a/sdk/api/handlers/openai/codex_client_models.go +++ b/sdk/api/handlers/openai/codex_client_models.go @@ -26,6 +26,7 @@ var codexClientAllowedReasoningLevels = map[string]struct{}{ "medium": {}, "high": {}, "xhigh": {}, + "max": {}, } func (h *OpenAIAPIHandler) codexClientModelsResponse() map[string]any { @@ -66,6 +67,8 @@ func buildCodexClientModels(models []map[string]any) []map[string]any { result = append(result, entry) } + applyCodexClientNonTemplatePriorities(result, templates) + sort.SliceStable(result, func(i, j int) bool { return codexClientModelPriority(result[i]) < codexClientModelPriority(result[j]) }) @@ -73,6 +76,60 @@ func buildCodexClientModels(models []map[string]any) []map[string]any { return result } +func maxCodexClientTemplatePriority(templates map[string]map[string]any) int { + maxPriority := 0 + for _, template := range templates { + priority := codexClientModelPriority(template) + if priority > maxPriority { + maxPriority = priority + } + } + return maxPriority +} + +func applyCodexClientNonTemplatePriorities(result []map[string]any, templates map[string]map[string]any) { + if len(result) == 0 { + return + } + + basePriority := maxCodexClientTemplatePriority(templates) + type nonTemplateEntry struct { + index int + displayName string + slug string + } + + pending := make([]nonTemplateEntry, 0) + for index, entry := range result { + slug := stringModelValue(entry, "slug") + if _, ok := templates[slug]; ok { + continue + } + displayName := stringModelValue(entry, "display_name") + if displayName == "" { + displayName = slug + } + pending = append(pending, nonTemplateEntry{ + index: index, + displayName: displayName, + slug: slug, + }) + } + + sort.SliceStable(pending, func(i, j int) bool { + left := strings.ToLower(pending[i].displayName) + right := strings.ToLower(pending[j].displayName) + if left == right { + return pending[i].slug < pending[j].slug + } + return left < right + }) + + for rank, entry := range pending { + result[entry.index]["priority"] = basePriority + 100*(rank+1) + } +} + func loadCodexClientModelTemplates() (map[string]map[string]any, map[string]any, error) { codexClientModelTemplatesOnce.Do(func() { var payload codexClientModelsPayload @@ -130,8 +187,8 @@ func applyCodexClientModelMetadata(entry map[string]any, id string, model map[st entry["slug"] = id entry["display_name"] = displayName entry["description"] = description - entry["priority"] = 100 entry["prefer_websockets"] = false + entry["service_tiers"] = []any{} delete(entry, "apply_patch_tool_type") delete(entry, "upgrade") delete(entry, "availability_nux") @@ -249,6 +306,8 @@ func codexClientReasoningDescription(level string) string { return "Greater reasoning depth for complex problems" case "xhigh": return "Extra high reasoning depth for complex problems" + case "max": + return "Maximum available reasoning depth for complex problems" default: return level } diff --git a/sdk/api/management.go b/sdk/api/management.go index 689cda3dca..8a03909af4 100644 --- a/sdk/api/management.go +++ b/sdk/api/management.go @@ -19,7 +19,6 @@ type Handler = internalmanagement.Handler // ManagementTokenRequester exposes a limited subset of management endpoints for requesting tokens. type ManagementTokenRequester interface { RequestAnthropicToken(*gin.Context) - RequestGeminiCLIToken(*gin.Context) RequestCodexToken(*gin.Context) RequestAntigravityToken(*gin.Context) RequestKimiToken(*gin.Context) @@ -52,10 +51,6 @@ func (m *managementTokenRequester) RequestAnthropicToken(c *gin.Context) { m.handler.RequestAnthropicToken(c) } -func (m *managementTokenRequester) RequestGeminiCLIToken(c *gin.Context) { - m.handler.RequestGeminiCLIToken(c) -} - func (m *managementTokenRequester) RequestCodexToken(c *gin.Context) { m.handler.RequestCodexToken(c) } diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index 73743df4ef..ee41cbdbd2 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -172,7 +172,7 @@ waitForCallback: return nil, fmt.Errorf("antigravity: empty email returned from user info") } - // Fetch project ID via loadCodeAssist (same approach as Gemini CLI) + // Fetch project ID via loadCodeAssist. projectID := "" if accessToken != "" { fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken) diff --git a/sdk/auth/errors.go b/sdk/auth/errors.go index f950e925ff..eee4019f31 100644 --- a/sdk/auth/errors.go +++ b/sdk/auth/errors.go @@ -1,32 +1,5 @@ package auth -import ( - "fmt" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" -) - -// ProjectSelectionError indicates that the user must choose a specific project ID. -type ProjectSelectionError struct { - Email string - Projects []interfaces.GCPProjectProjects -} - -func (e *ProjectSelectionError) Error() string { - if e == nil { - return "cliproxy auth: project selection required" - } - return fmt.Sprintf("cliproxy auth: project selection required for %s", e.Email) -} - -// ProjectsDisplay returns the projects list for caller presentation. -func (e *ProjectSelectionError) ProjectsDisplay() []interfaces.GCPProjectProjects { - if e == nil { - return nil - } - return e.Projects -} - // EmailRequiredError indicates that the calling context must provide an email or alias. type EmailRequiredError struct { Prompt string diff --git a/sdk/auth/filestore.go b/sdk/auth/filestore.go index 584481ad3e..bc89b32238 100644 --- a/sdk/auth/filestore.go +++ b/sdk/auth/filestore.go @@ -4,10 +4,8 @@ import ( "context" "encoding/json" "fmt" - "io" "io/fs" "net/http" - "net/url" "os" "path/filepath" "runtime" @@ -25,6 +23,12 @@ type PluginAuthParser interface { ParseAuth(context.Context, pluginapi.AuthParseRequest) (*cliproxyauth.Auth, bool, error) } +// PluginMultiAuthParser expands one auth JSON payload into multiple plugin auth records. +// Returning handled=true with an empty slice means the plugin intentionally suppresses built-in parsing. +type PluginMultiAuthParser interface { + ParseAuths(context.Context, pluginapi.AuthParseRequest) ([]*cliproxyauth.Auth, bool, error) +} + type pluginAuthParserHolder struct { parser PluginAuthParser } @@ -173,12 +177,12 @@ func (s *FileTokenStore) List(ctx context.Context) ([]*cliproxyauth.Auth, error) if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") { return nil } - auth, err := s.readAuthFile(path, dir) - if err != nil { + auths, errReadAuths := s.readAuthFiles(path, dir) + if errReadAuths != nil { return nil } - if auth != nil { - entries = append(entries, auth) + if len(auths) > 0 { + entries = append(entries, auths...) } return nil }) @@ -215,7 +219,7 @@ func (s *FileTokenStore) resolveDeletePath(id string) (string, error) { return filepath.Join(dir, id), nil } -func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, error) { +func (s *FileTokenStore) readAuthFiles(path, baseDir string) ([]*cliproxyauth.Auth, error) { data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("read file: %w", err) @@ -229,48 +233,54 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, } provider, _ := metadata["type"].(string) provider = strings.TrimSpace(provider) + if strings.EqualFold(provider, "gemini") { + return nil, nil + } info, errStat := os.Stat(path) if errStat != nil { return nil, fmt.Errorf("stat file: %w", errStat) } if parser := currentPluginAuthParser(); parser != nil { - auth, handled, errParse := parser.ParseAuth(context.Background(), pluginapi.AuthParseRequest{ + auths, handled, errParse := parsePluginAuthFile(parser, pluginapi.AuthParseRequest{ Provider: provider, Path: path, FileName: s.idFor(path, baseDir), RawJSON: data, }) - if errParse == nil && handled && auth != nil { - auth.CreatedAt = info.ModTime() - auth.UpdatedAt = info.ModTime() - if auth.Attributes == nil { - auth.Attributes = make(map[string]string) + if errParse == nil && handled { + auths = compactPluginAuths(auths) + if len(auths) == 0 { + return nil, nil } - auth.Attributes["path"] = path - auth.Attributes["source"] = path - cliproxyauth.ApplyCustomHeadersFromMetadata(auth) - return auth, nil + for index, auth := range auths { + if auth == nil { + continue + } + if len(auths) > 1 { + cliproxyauth.MarkPluginVirtualAuth(auth, path, index) + } + auth.CreatedAt = info.ModTime() + auth.UpdatedAt = info.ModTime() + if auth.Attributes == nil { + auth.Attributes = make(map[string]string) + } + auth.Attributes["path"] = path + auth.Attributes["source"] = path + cliproxyauth.ApplyCustomHeadersFromMetadata(auth) + } + return auths, nil } } if provider == "" { provider = "unknown" } - if provider == "antigravity" || provider == "gemini" { + if provider == "antigravity" { projectID := "" if pid, ok := metadata["project_id"].(string); ok { projectID = strings.TrimSpace(pid) } if projectID == "" { accessToken := extractAccessToken(metadata) - // For gemini type, the stored access_token is likely expired (~1h lifetime). - // Refresh it using the long-lived refresh_token before querying. - if provider == "gemini" { - if tokenMap, ok := metadata["token"].(map[string]any); ok { - if refreshed, errRefresh := refreshGeminiAccessToken(tokenMap, http.DefaultClient); errRefresh == nil { - accessToken = refreshed - } - } - } if accessToken != "" { fetchedProjectID, errFetch := FetchAntigravityProjectID(context.Background(), accessToken, http.DefaultClient) if errFetch == nil && strings.TrimSpace(fetchedProjectID) != "" { @@ -313,7 +323,43 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, auth.Attributes["email"] = email } cliproxyauth.ApplyCustomHeadersFromMetadata(auth) - return auth, nil + return []*cliproxyauth.Auth{auth}, nil +} + +func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, error) { + auths, errReadAuths := s.readAuthFiles(path, baseDir) + if errReadAuths != nil || len(auths) == 0 { + return nil, errReadAuths + } + return auths[0], nil +} + +func parsePluginAuthFile(parser PluginAuthParser, req pluginapi.AuthParseRequest) ([]*cliproxyauth.Auth, bool, error) { + if parser == nil { + return nil, false, nil + } + if multiParser, ok := parser.(PluginMultiAuthParser); ok { + return multiParser.ParseAuths(context.Background(), req) + } + auth, handled, errParse := parser.ParseAuth(context.Background(), req) + if errParse != nil || !handled || auth == nil { + return nil, handled, errParse + } + return []*cliproxyauth.Auth{auth}, true, nil +} + +func compactPluginAuths(auths []*cliproxyauth.Auth) []*cliproxyauth.Auth { + if len(auths) == 0 { + return nil + } + out := auths[:0] + for _, auth := range auths { + if auth == nil { + continue + } + out = append(out, auth) + } + return out } func (s *FileTokenStore) idFor(path, baseDir string) string { @@ -399,51 +445,6 @@ func extractAccessToken(metadata map[string]any) string { return "" } -func refreshGeminiAccessToken(tokenMap map[string]any, httpClient *http.Client) (string, error) { - refreshToken, _ := tokenMap["refresh_token"].(string) - clientID, _ := tokenMap["client_id"].(string) - clientSecret, _ := tokenMap["client_secret"].(string) - tokenURI, _ := tokenMap["token_uri"].(string) - - if refreshToken == "" || clientID == "" || clientSecret == "" { - return "", fmt.Errorf("missing refresh credentials") - } - if tokenURI == "" { - tokenURI = "https://oauth2.googleapis.com/token" - } - - data := url.Values{ - "grant_type": {"refresh_token"}, - "refresh_token": {refreshToken}, - "client_id": {clientID}, - "client_secret": {clientSecret}, - } - - resp, err := httpClient.PostForm(tokenURI, data) - if err != nil { - return "", fmt.Errorf("refresh request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("refresh failed: status %d", resp.StatusCode) - } - - var result map[string]any - if errUnmarshal := json.Unmarshal(body, &result); errUnmarshal != nil { - return "", fmt.Errorf("decode refresh response: %w", errUnmarshal) - } - - newAccessToken, _ := result["access_token"].(string) - if newAccessToken == "" { - return "", fmt.Errorf("no access_token in refresh response") - } - - tokenMap["access_token"] = newAccessToken - return newAccessToken, nil -} - // jsonEqual compares two JSON blobs by parsing them into Go objects and deep comparing. func jsonEqual(a, b []byte) bool { var objA any diff --git a/sdk/auth/filestore_test.go b/sdk/auth/filestore_test.go index 9e135ad4c9..32164bed16 100644 --- a/sdk/auth/filestore_test.go +++ b/sdk/auth/filestore_test.go @@ -1,6 +1,14 @@ package auth -import "testing" +import ( + "context" + "os" + "path/filepath" + "testing" + + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi" +) func TestExtractAccessToken(t *testing.T) { t.Parallel() @@ -78,3 +86,109 @@ func TestExtractAccessToken(t *testing.T) { }) } } + +func TestFileTokenStoreListExpandsPluginMultiAuths(t *testing.T) { + baseDir := t.TempDir() + path := filepath.Join(baseDir, "geminicli.json") + if errWrite := os.WriteFile(path, []byte(`{"type":"gemini-cli","headers":{"X-Test":"value"}}`), 0o600); errWrite != nil { + t.Fatalf("write auth file: %v", errWrite) + } + + RegisterPluginAuthParser(fileStoreMultiAuthParserFunc(func(ctx context.Context, req pluginapi.AuthParseRequest) ([]*cliproxyauth.Auth, bool, error) { + if req.Provider != "gemini-cli" || req.Path != path || req.FileName != "geminicli.json" { + t.Fatalf("ParseAuths request = %#v, want file context", req) + } + return []*cliproxyauth.Auth{ + { + ID: "geminicli.json", + Provider: "gemini-cli", + Metadata: map[string]any{ + "type": "gemini-cli", + "headers": map[string]any{ + "X-Test": "value", + }, + }, + }, + nil, + { + ID: "geminicli-project-a.json", + Provider: "gemini-cli", + Metadata: map[string]any{ + "type": "gemini-cli", + "project_id": "project-a", + "headers": map[string]any{ + "X-Test": "value", + }, + }, + }, + }, true, nil + })) + t.Cleanup(func() { + RegisterPluginAuthParser(nil) + }) + + store := NewFileTokenStore() + store.SetBaseDir(baseDir) + auths, errList := store.List(context.Background()) + if errList != nil { + t.Fatalf("List() error = %v", errList) + } + if len(auths) != 2 { + t.Fatalf("List() len = %d, want two plugin auths", len(auths)) + } + if firstIndex, secondIndex := auths[0].EnsureIndex(), auths[1].EnsureIndex(); firstIndex == "" || firstIndex == secondIndex { + t.Fatalf("auth indexes = %q/%q, want distinct non-empty indexes", firstIndex, secondIndex) + } + for _, auth := range auths { + if !cliproxyauth.IsPluginVirtualAuth(auth) { + t.Fatalf("auth attributes = %#v, want plugin virtual marker", auth.Attributes) + } + if auth.Attributes[cliproxyauth.AttributeVirtualSource] != path { + t.Fatalf("virtual_source = %q, want %q", auth.Attributes[cliproxyauth.AttributeVirtualSource], path) + } + if auth.Attributes["path"] != path || auth.Attributes["source"] != path { + t.Fatalf("auth attributes = %#v, want source path", auth.Attributes) + } + if gotHeader := auth.Attributes["header:X-Test"]; gotHeader != "value" { + t.Fatalf("header:X-Test = %q, want value", gotHeader) + } + } + if gotProject := auths[1].Metadata["project_id"]; gotProject != "project-a" { + t.Fatalf("project_id = %#v, want project-a", gotProject) + } +} + +func TestFileTokenStoreListPluginHandledEmptySuppressesBuiltin(t *testing.T) { + baseDir := t.TempDir() + path := filepath.Join(baseDir, "codex.json") + if errWrite := os.WriteFile(path, []byte(`{"type":"codex","access_token":"token"}`), 0o600); errWrite != nil { + t.Fatalf("write auth file: %v", errWrite) + } + + RegisterPluginAuthParser(fileStoreMultiAuthParserFunc(func(context.Context, pluginapi.AuthParseRequest) ([]*cliproxyauth.Auth, bool, error) { + return nil, true, nil + })) + t.Cleanup(func() { + RegisterPluginAuthParser(nil) + }) + + store := NewFileTokenStore() + store.SetBaseDir(baseDir) + auths, errList := store.List(context.Background()) + if errList != nil { + t.Fatalf("List() error = %v", errList) + } + if len(auths) != 0 { + t.Fatalf("List() len = %d, want plugin-handled empty result", len(auths)) + } +} + +type fileStoreMultiAuthParserFunc func(context.Context, pluginapi.AuthParseRequest) ([]*cliproxyauth.Auth, bool, error) + +func (f fileStoreMultiAuthParserFunc) ParseAuth(context.Context, pluginapi.AuthParseRequest) (*cliproxyauth.Auth, bool, error) { + return nil, false, nil +} + +func (f fileStoreMultiAuthParserFunc) ParseAuths(ctx context.Context, req pluginapi.AuthParseRequest) ([]*cliproxyauth.Auth, bool, error) { + return f(ctx, req) +} diff --git a/sdk/auth/gemini.go b/sdk/auth/gemini.go deleted file mode 100644 index ba7c7728ad..0000000000 --- a/sdk/auth/gemini.go +++ /dev/null @@ -1,73 +0,0 @@ -package auth - -import ( - "context" - "fmt" - "time" - - "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini" - // legacy client removed - "github.com/router-for-me/CLIProxyAPI/v7/internal/config" - coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" -) - -// GeminiAuthenticator implements the login flow for Google Gemini CLI accounts. -type GeminiAuthenticator struct{} - -// NewGeminiAuthenticator constructs a Gemini authenticator. -func NewGeminiAuthenticator() *GeminiAuthenticator { - return &GeminiAuthenticator{} -} - -func (a *GeminiAuthenticator) Provider() string { - return "gemini" -} - -func (a *GeminiAuthenticator) RefreshLead() *time.Duration { - return nil -} - -func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { - if cfg == nil { - return nil, fmt.Errorf("cliproxy auth: configuration is required") - } - if ctx == nil { - ctx = context.Background() - } - if opts == nil { - opts = &LoginOptions{} - } - - var ts gemini.GeminiTokenStorage - if opts.ProjectID != "" { - ts.ProjectID = opts.ProjectID - } - - geminiAuth := gemini.NewGeminiAuth() - _, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{ - NoBrowser: opts.NoBrowser, - CallbackPort: opts.CallbackPort, - Prompt: opts.Prompt, - }) - if err != nil { - return nil, fmt.Errorf("gemini authentication failed: %w", err) - } - - // Skip onboarding here; rely on upstream configuration - - fileName := fmt.Sprintf("%s-%s.json", ts.Email, ts.ProjectID) - metadata := map[string]any{ - "email": ts.Email, - "project_id": ts.ProjectID, - } - - fmt.Println("Gemini authentication successful") - - return &coreauth.Auth{ - ID: fileName, - Provider: a.Provider(), - FileName: fileName, - Storage: &ts, - Metadata: metadata, - }, nil -} diff --git a/sdk/auth/refresh_registry.go b/sdk/auth/refresh_registry.go index 634c69d3e5..e2c0aba9e6 100644 --- a/sdk/auth/refresh_registry.go +++ b/sdk/auth/refresh_registry.go @@ -9,8 +9,6 @@ import ( func init() { registerRefreshLead("codex", func() Authenticator { return NewCodexAuthenticator() }) registerRefreshLead("claude", func() Authenticator { return NewClaudeAuthenticator() }) - registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() }) - registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() }) registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() }) registerRefreshLead("xai", func() Authenticator { return NewXAIAuthenticator() }) diff --git a/sdk/cliproxy/auth/antigravity_credits_test.go b/sdk/cliproxy/auth/antigravity_credits_test.go index 540a4ef056..52754095cc 100644 --- a/sdk/cliproxy/auth/antigravity_credits_test.go +++ b/sdk/cliproxy/auth/antigravity_credits_test.go @@ -263,6 +263,6 @@ func TestIsAuthBlockedForModel_KeepsGeminiBlockedWithoutCreditsBypass(t *testing blocked, reason, _ := isAuthBlockedForModel(auth, "gemini-3-flash", time.Now()) if !blocked || reason != blockReasonCooldown { - t.Fatalf("expected gemini auth to remain blocked, got blocked=%v reason=%v", blocked, reason) + t.Fatalf("expected gemini model to remain blocked, got blocked=%v reason=%v", blocked, reason) } } diff --git a/sdk/cliproxy/auth/api_key_model_alias_test.go b/sdk/cliproxy/auth/api_key_model_alias_test.go index 25da4df4ed..7f0e49c06d 100644 --- a/sdk/cliproxy/auth/api_key_model_alias_test.go +++ b/sdk/cliproxy/auth/api_key_model_alias_test.go @@ -145,7 +145,7 @@ func TestApplyAPIKeyModelAlias(t *testing.T) { ctx := context.Background() apiKeyAuth := &Auth{ID: "a1", Provider: "gemini", Attributes: map[string]string{"api_key": "k"}} - oauthAuth := &Auth{ID: "oauth-auth", Provider: "gemini", Attributes: map[string]string{"auth_kind": "oauth"}} + oauthAuth := &Auth{ID: "oauth-auth", Provider: "claude", Attributes: map[string]string{"auth_kind": "oauth"}} _, _ = mgr.Register(ctx, apiKeyAuth) tests := []struct { diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 993bd4724c..54a52559fe 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net/http" "path/filepath" @@ -686,6 +687,119 @@ func clearCooldownStateForAuth(auth *Auth, now time.Time) bool { return changed } +func dedupeStrings(values []string) []string { + if len(values) < 2 { + return values + } + seen := make(map[string]struct{}, len(values)) + out := values[:0] + for _, value := range values { + value = strings.TrimSpace(value) + if value == "" { + continue + } + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + out = append(out, value) + } + return out +} + +// ResetQuota clears quota/cooldown state for an auth and resumes registry routing. +func (m *Manager) ResetQuota(ctx context.Context, authID string) (*Auth, []string, error) { + if m == nil { + return nil, nil, nil + } + authID = strings.TrimSpace(authID) + if authID == "" { + return nil, nil, fmt.Errorf("auth id is required") + } + + now := time.Now() + var snapshot *Auth + models := make([]string, 0) + registeredModels := modelsForRegisteredAuth(authID) + cooldownStateChanged := false + + m.mu.Lock() + auth, ok := m.auths[authID] + if !ok || auth == nil { + m.mu.Unlock() + return nil, nil, nil + } + + var cooldownRecordsBefore []CooldownStateRecord + trackCooldownState := m.cooldownStore != nil + if trackCooldownState { + cooldownRecordsBefore = m.cooldownStateRecordsForAuthLocked(auth, now) + } + + for modelKey, state := range auth.ModelStates { + if strings.TrimSpace(modelKey) == "" { + continue + } + models = append(models, modelKey) + if state != nil { + resetModelState(state, now) + } + } + if clearCooldownStateForAuth(auth, now) { + if len(models) == 0 { + models = append(models, registeredModels...) + } + } else if len(auth.ModelStates) > 0 { + updateAggregatedAvailability(auth, now) + } + + if len(models) == 0 { + models = append(models, registeredModels...) + } + models = dedupeStrings(models) + + if !auth.Disabled && auth.Status != StatusDisabled && !hasModelError(auth, now) { + auth.LastError = nil + auth.StatusMessage = "" + auth.Status = StatusActive + } + auth.UpdatedAt = now + if errPersist := m.persist(ctx, auth); errPersist != nil { + m.mu.Unlock() + return nil, nil, errPersist + } + snapshot = auth.Clone() + if trackCooldownState { + cooldownRecordsAfter := m.cooldownStateRecordsForAuthLocked(auth, now) + cooldownStateChanged = !cooldownStateRecordsEqual(cooldownRecordsBefore, cooldownRecordsAfter) + } + m.mu.Unlock() + + for _, modelKey := range models { + registry.GetGlobalRegistry().ClearModelQuotaExceeded(authID, modelKey) + registry.GetGlobalRegistry().ResumeClientModel(authID, modelKey) + } + if m.scheduler != nil && snapshot != nil { + m.scheduler.upsertAuth(snapshot) + } + if snapshot != nil && cooldownStateChanged { + m.persistCooldownStates(ctx) + } + return snapshot, models, nil +} + +func modelsForRegisteredAuth(authID string) []string { + supportedModels := registry.GetGlobalRegistry().GetModelsForClient(authID) + models := make([]string, 0, len(supportedModels)) + for _, supportedModel := range supportedModels { + if supportedModel == nil || strings.TrimSpace(supportedModel.ID) == "" { + continue + } + models = append(models, supportedModel.ID) + } + return models +} + func (m *Manager) persistCooldownStates(ctx context.Context) { if m == nil { return @@ -913,13 +1027,13 @@ func openAICompatProviderKey(auth *Auth) string { } if auth.Attributes != nil { if providerKey := strings.TrimSpace(auth.Attributes["provider_key"]); providerKey != "" { - return strings.ToLower(providerKey) + return util.OpenAICompatibleProviderKey(providerKey) } if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" { - return strings.ToLower(compatName) + return util.OpenAICompatibleProviderKey(compatName) } } - return strings.ToLower(strings.TrimSpace(auth.Provider)) + return util.OpenAICompatibleProviderKey(auth.Provider) } func openAICompatModelPoolKey(auth *Auth, requestedModel string) string { @@ -2126,8 +2240,6 @@ func requestToFormat(provider string, executor ProviderExecutor, req cliproxyexe return sdktranslator.FormatClaude case "gemini", "vertex", "aistudio": return sdktranslator.FormatGemini - case "gemini-cli": - return sdktranslator.FormatGeminiCLI case "kimi": return sdktranslator.FormatOpenAI case "antigravity": @@ -3065,7 +3177,7 @@ func (m *Manager) closestCooldownWait(providers []string, model string, attempt if auth == nil { continue } - providerKey := strings.TrimSpace(strings.ToLower(auth.Provider)) + providerKey := executorKeyFromAuth(auth) if _, ok := providerSet[providerKey]; !ok { continue } @@ -3125,7 +3237,7 @@ func (m *Manager) retryAllowed(attempt int, providers []string) bool { if auth == nil { continue } - providerKey := strings.TrimSpace(strings.ToLower(auth.Provider)) + providerKey := executorKeyFromAuth(auth) if _, ok := providerSet[providerKey]; !ok { continue } @@ -4019,7 +4131,7 @@ func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, op } registryRef := registry.GetGlobalRegistry() for _, candidate := range m.auths { - if candidate.Provider != provider || candidate.Disabled { + if candidate == nil || executorKeyFromAuth(candidate) != provider || candidate.Disabled { continue } if pinnedAuthID != "" && candidate.ID != pinnedAuthID { @@ -4085,7 +4197,7 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli if strings.TrimSpace(model) != "" { m.mu.RLock() for _, candidate := range m.auths { - if candidate == nil || candidate.Provider != provider || candidate.Disabled { + if candidate == nil || executorKeyFromAuth(candidate) != provider || candidate.Disabled { continue } if _, used := tried[candidate.ID]; used { @@ -4178,7 +4290,7 @@ func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, m if disallowFreeAuth && isFreeCodexAuth(candidate) { continue } - providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider)) + providerKey := executorKeyFromAuth(candidate) if providerKey == "" { continue } @@ -4221,7 +4333,7 @@ func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, m if selected == nil { return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"} } - providerKey := strings.TrimSpace(strings.ToLower(selected.Provider)) + providerKey := executorKeyFromAuth(selected) executor, okExecutor := m.Executor(providerKey) if !okExecutor { return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"} @@ -4276,7 +4388,7 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s if candidate == nil || candidate.Disabled { continue } - if _, ok := providerSet[strings.TrimSpace(strings.ToLower(candidate.Provider))]; !ok { + if _, ok := providerSet[executorKeyFromAuth(candidate)]; !ok { continue } if _, used := tried[candidate.ID]; used { @@ -4556,7 +4668,7 @@ func (m *Manager) homeRuntimeAuthByID(sessionID string, authID string) (*Auth, P if auth == nil || !authWebsocketsEnabled(auth) { return nil, nil, "", false } - providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) + providerKey := executorKeyFromAuth(auth) if providerKey == "" { return nil, nil, "", false } @@ -4654,7 +4766,7 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro if homeAuthAlreadyTried(tried, auth.ID) { return nil, nil, "", repeatedHomeAuthError() } - providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) + providerKey := executorKeyFromAuth(&auth) if providerKey == "" { return nil, nil, "", &Error{Code: "invalid_auth", Message: "home returned auth without provider", HTTPStatus: http.StatusBadGateway} } @@ -4727,7 +4839,7 @@ func (m *Manager) findAllAntigravityCreditsCandidateAuths(ctx context.Context, r if !strings.Contains(strings.ToLower(strings.TrimSpace(routeModel)), "claude") { continue } - providerKey := strings.TrimSpace(strings.ToLower(auth.Provider)) + providerKey := executorKeyFromAuth(auth) executor, ok := m.executors[providerKey] if !ok { continue @@ -4920,6 +5032,9 @@ func (m *Manager) persist(ctx context.Context, auth *Auth) error { return nil } } + if IsPluginVirtualAuth(auth) { + return nil + } // Skip persistence when metadata is absent (e.g., runtime-only auths). if auth.Metadata == nil { return nil @@ -5340,8 +5455,15 @@ func executorKeyFromAuth(auth *Auth) string { if providerKey == "" { providerKey = compatName } - return strings.ToLower(providerKey) + return util.OpenAICompatibleProviderKey(providerKey) + } + } + if strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") { + providerKey := strings.TrimSpace(auth.Label) + if providerKey == "" { + providerKey = "openai-compatibility" } + return util.OpenAICompatibleProviderKey(providerKey) } return strings.ToLower(strings.TrimSpace(auth.Provider)) } diff --git a/sdk/cliproxy/auth/conductor_availability_test.go b/sdk/cliproxy/auth/conductor_availability_test.go index 831df3b023..7e07cc0714 100644 --- a/sdk/cliproxy/auth/conductor_availability_test.go +++ b/sdk/cliproxy/auth/conductor_availability_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" "time" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" ) func TestUpdateAggregatedAvailability_UnavailableWithoutNextRetryDoesNotBlockAuth(t *testing.T) { @@ -102,3 +104,75 @@ func TestManager_AvailableProvidersAndHasProviderAuth_ExcludeDisabled(t *testing t.Errorf("HasProviderAuth(codex) = true, want false (only StatusDisabled auth registered)") } } + +func TestManager_ResetQuotaClearsRuntimeAndRegistryState(t *testing.T) { + manager := NewManager(nil, nil, nil) + ctx := context.Background() + authID := "reset-quota-auth" + model := "reset-quota-model" + next := time.Now().Add(time.Hour) + + reg := registry.GetGlobalRegistry() + reg.RegisterClient(authID, "claude", []*registry.ModelInfo{{ID: model}}) + t.Cleanup(func() { + reg.UnregisterClient(authID) + }) + + if _, errRegister := manager.Register(ctx, &Auth{ + ID: authID, + Provider: "claude", + Status: StatusError, + StatusMessage: "quota exhausted", + Unavailable: true, + NextRetryAfter: next, + Quota: QuotaState{Exceeded: true, Reason: "quota", NextRecoverAt: next, BackoffLevel: 2}, + ModelStates: map[string]*ModelState{ + model: { + Status: StatusError, + StatusMessage: "quota exhausted", + Unavailable: true, + NextRetryAfter: next, + Quota: QuotaState{Exceeded: true, Reason: "quota", NextRecoverAt: next, BackoffLevel: 2}, + UpdatedAt: next, + }, + }, + }); errRegister != nil { + t.Fatalf("register auth: %v", errRegister) + } + + reg.SetModelQuotaExceeded(authID, model) + reg.SuspendClientModel(authID, model, "quota") + if count := reg.GetModelCount(model); count != 0 { + t.Fatalf("registry model count before reset = %d, want 0", count) + } + + updated, models, errReset := manager.ResetQuota(ctx, authID) + if errReset != nil { + t.Fatalf("ResetQuota() error = %v", errReset) + } + if updated == nil { + t.Fatalf("ResetQuota() updated auth is nil") + } + if len(models) != 1 || models[0] != model { + t.Fatalf("ResetQuota() models = %v, want [%s]", models, model) + } + if updated.Status != StatusActive || updated.StatusMessage != "" || updated.Unavailable || !updated.NextRetryAfter.IsZero() { + t.Fatalf("updated auth state = status %q message %q unavailable %v next %v", updated.Status, updated.StatusMessage, updated.Unavailable, updated.NextRetryAfter) + } + if updated.Quota.Exceeded || updated.Quota.Reason != "" || !updated.Quota.NextRecoverAt.IsZero() || updated.Quota.BackoffLevel != 0 { + t.Fatalf("updated auth quota = %+v, want cleared", updated.Quota) + } + state := updated.ModelStates[model] + if state == nil { + t.Fatalf("updated model state missing") + } + if state.Status != StatusActive || state.StatusMessage != "" || state.Unavailable || !state.NextRetryAfter.IsZero() { + t.Fatalf("updated model state = status %q message %q unavailable %v next %v", state.Status, state.StatusMessage, state.Unavailable, state.NextRetryAfter) + } + if state.Quota.Exceeded || state.Quota.Reason != "" || !state.Quota.NextRecoverAt.IsZero() || state.Quota.BackoffLevel != 0 { + t.Fatalf("updated model quota = %+v, want cleared", state.Quota) + } + if count := reg.GetModelCount(model); count != 1 { + t.Fatalf("registry model count after reset = %d, want 1", count) + } +} diff --git a/sdk/cliproxy/auth/oauth_model_alias.go b/sdk/cliproxy/auth/oauth_model_alias.go index 1de65afd2a..f936fa5a68 100644 --- a/sdk/cliproxy/auth/oauth_model_alias.go +++ b/sdk/cliproxy/auth/oauth_model_alias.go @@ -1,12 +1,15 @@ package auth import ( + "encoding/json" "strings" internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config" "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" ) +const oauthModelAliasesAttributeKey = "model_aliases" + type modelAliasEntry interface { GetName() string GetAlias() string @@ -183,7 +186,105 @@ func resolveModelAliasFromConfigModels(requestedModel string, models []modelAlia // the suffix is preserved in the returned model name. However, if the alias's // original name already contains a suffix, the config suffix takes priority. func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string { - return resolveUpstreamModelFromAliasTable(m, auth, requestedModel, modelAliasChannel(auth)) + channel := modelAliasChannel(auth) + if channel == "" { + return "" + } + if resolved := resolveUpstreamModelFromAliases(OAuthModelAliasesFromAttributes(authAttributes(auth)), requestedModel); resolved != "" { + return resolved + } + return resolveUpstreamModelFromAliasTable(m, auth, requestedModel, channel) +} + +func authAttributes(auth *Auth) map[string]string { + if auth == nil { + return nil + } + return auth.Attributes +} + +// SetOAuthModelAliasesAttribute stores sanitized per-auth OAuth model aliases on an auth entry. +func SetOAuthModelAliasesAttribute(auth *Auth, aliases []internalconfig.OAuthModelAlias) { + if auth == nil { + return + } + aliases = sanitizeOAuthModelAliases(aliases) + if len(aliases) == 0 { + return + } + data, errMarshal := json.Marshal(aliases) + if errMarshal != nil { + return + } + if auth.Attributes == nil { + auth.Attributes = make(map[string]string) + } + auth.Attributes[oauthModelAliasesAttributeKey] = string(data) +} + +// OAuthModelAliasesFromAttributes returns sanitized per-auth OAuth model aliases from auth attributes. +func OAuthModelAliasesFromAttributes(attributes map[string]string) []internalconfig.OAuthModelAlias { + if len(attributes) == 0 { + return nil + } + raw := strings.TrimSpace(attributes[oauthModelAliasesAttributeKey]) + if raw == "" { + return nil + } + var aliases []internalconfig.OAuthModelAlias + if errUnmarshal := json.Unmarshal([]byte(raw), &aliases); errUnmarshal != nil { + return nil + } + return sanitizeOAuthModelAliases(aliases) +} + +func sanitizeOAuthModelAliases(aliases []internalconfig.OAuthModelAlias) []internalconfig.OAuthModelAlias { + if len(aliases) == 0 { + return nil + } + cfg := internalconfig.Config{ + OAuthModelAlias: map[string][]internalconfig.OAuthModelAlias{ + "auth": aliases, + }, + } + cfg.SanitizeOAuthModelAlias() + clean := cfg.OAuthModelAlias["auth"] + if len(clean) == 0 { + return nil + } + return append([]internalconfig.OAuthModelAlias(nil), clean...) +} + +func resolveUpstreamModelFromAliases(aliases []internalconfig.OAuthModelAlias, requestedModel string) string { + if len(aliases) == 0 { + return "" + } + requestResult, candidates := modelAliasLookupCandidates(requestedModel) + if len(candidates) == 0 { + return "" + } + baseModel := requestResult.ModelName + if baseModel == "" { + baseModel = strings.TrimSpace(requestedModel) + } + for _, entry := range aliases { + original := strings.TrimSpace(entry.Name) + alias := strings.TrimSpace(entry.Alias) + if original == "" || alias == "" { + continue + } + for _, candidate := range candidates { + key := strings.TrimSpace(candidate) + if key == "" || !strings.EqualFold(alias, key) { + continue + } + if strings.EqualFold(original, baseModel) { + return "" + } + return preserveResolvedModelSuffix(original, requestResult) + } + } + return "" } func resolveUpstreamModelFromAliasTable(m *Manager, auth *Auth, requestedModel, channel string) string { @@ -265,7 +366,7 @@ func modelAliasChannel(auth *Auth) string { // and auth kind. Returns empty string if the provider/authKind combination doesn't support // OAuth model alias (e.g., API key authentication). // -// Built-in channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi. +// Built-in channels: vertex, aistudio, antigravity, claude, codex, kimi. // Plugin OAuth providers use their normalized provider key as the channel. func OAuthModelAliasChannel(provider, authKind string) string { provider = strings.ToLower(strings.TrimSpace(provider)) @@ -275,8 +376,6 @@ func OAuthModelAliasChannel(provider, authKind string) string { } switch provider { case "gemini": - // gemini provider uses gemini-api-key config, not oauth-model-alias. - // OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer. return "" case "vertex": return "vertex" @@ -284,7 +383,7 @@ func OAuthModelAliasChannel(provider, authKind string) string { return "claude" case "codex": return "codex" - case "gemini-cli", "aistudio", "antigravity", "kimi": + case "aistudio", "antigravity", "kimi": return provider default: return provider diff --git a/sdk/cliproxy/auth/oauth_model_alias_test.go b/sdk/cliproxy/auth/oauth_model_alias_test.go index 7f6e2325d6..3504a62297 100644 --- a/sdk/cliproxy/auth/oauth_model_alias_test.go +++ b/sdk/cliproxy/auth/oauth_model_alias_test.go @@ -19,9 +19,9 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) { { name: "numeric suffix preserved", aliases: map[string][]internalconfig.OAuthModelAlias{ - "gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, + "antigravity": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, }, - channel: "gemini-cli", + channel: "antigravity", input: "gemini-2.5-pro(8192)", want: "gemini-2.5-pro-exp-03-25(8192)", }, @@ -37,9 +37,9 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) { { name: "no suffix unchanged", aliases: map[string][]internalconfig.OAuthModelAlias{ - "gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, + "antigravity": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, }, - channel: "gemini-cli", + channel: "antigravity", input: "gemini-2.5-pro", want: "gemini-2.5-pro-exp-03-25", }, @@ -55,18 +55,18 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) { { name: "auto suffix preserved", aliases: map[string][]internalconfig.OAuthModelAlias{ - "gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, + "antigravity": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, }, - channel: "gemini-cli", + channel: "antigravity", input: "gemini-2.5-pro(auto)", want: "gemini-2.5-pro-exp-03-25(auto)", }, { name: "none suffix preserved", aliases: map[string][]internalconfig.OAuthModelAlias{ - "gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, + "antigravity": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, }, - channel: "gemini-cli", + channel: "antigravity", input: "gemini-2.5-pro(none)", want: "gemini-2.5-pro-exp-03-25(none)", }, @@ -82,25 +82,25 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) { { name: "case insensitive alias lookup with suffix", aliases: map[string][]internalconfig.OAuthModelAlias{ - "gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "Gemini-2.5-Pro"}}, + "antigravity": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "Gemini-2.5-Pro"}}, }, - channel: "gemini-cli", + channel: "antigravity", input: "gemini-2.5-pro(high)", want: "gemini-2.5-pro-exp-03-25(high)", }, { name: "no alias returns empty", aliases: map[string][]internalconfig.OAuthModelAlias{ - "gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, + "antigravity": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, }, - channel: "gemini-cli", + channel: "antigravity", input: "unknown-model(high)", want: "", }, { name: "wrong channel returns empty", aliases: map[string][]internalconfig.OAuthModelAlias{ - "gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, + "antigravity": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, }, channel: "claude", input: "gemini-2.5-pro(high)", @@ -109,18 +109,18 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) { { name: "empty suffix filtered out", aliases: map[string][]internalconfig.OAuthModelAlias{ - "gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, + "antigravity": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, }, - channel: "gemini-cli", + channel: "antigravity", input: "gemini-2.5-pro()", want: "gemini-2.5-pro-exp-03-25", }, { name: "incomplete suffix treated as no suffix", aliases: map[string][]internalconfig.OAuthModelAlias{ - "gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro(high"}}, + "antigravity": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro(high"}}, }, - channel: "gemini-cli", + channel: "antigravity", input: "gemini-2.5-pro(high", want: "gemini-2.5-pro-exp-03-25", }, @@ -145,8 +145,8 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) { func createAuthForChannel(channel string) *Auth { switch channel { - case "gemini-cli": - return &Auth{Provider: "gemini-cli"} + case "antigravity": + return &Auth{Provider: "antigravity", Attributes: map[string]string{"auth_kind": "oauth"}} case "claude": return &Auth{Provider: "claude", Attributes: map[string]string{"auth_kind": "oauth"}} case "vertex": @@ -155,8 +155,6 @@ func createAuthForChannel(channel string) *Auth { return &Auth{Provider: "codex", Attributes: map[string]string{"auth_kind": "oauth"}} case "aistudio": return &Auth{Provider: "aistudio"} - case "antigravity": - return &Auth{Provider: "antigravity"} case "kimi": return &Auth{Provider: "kimi"} default: @@ -164,6 +162,14 @@ func createAuthForChannel(channel string) *Auth { } } +func TestOAuthModelAliasChannel_APIKeyOnlyProviderUnsupported(t *testing.T) { + t.Parallel() + + if got := OAuthModelAliasChannel("gemini", "oauth"); got != "" { + t.Fatalf("OAuthModelAliasChannel() = %q, want empty channel for API-key-only provider", got) + } +} + func TestOAuthModelAliasChannel_Kimi(t *testing.T) { t.Parallel() @@ -187,14 +193,14 @@ func TestApplyOAuthModelAlias_SuffixPreservation(t *testing.T) { t.Parallel() aliases := map[string][]internalconfig.OAuthModelAlias{ - "gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, + "antigravity": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}}, } mgr := NewManager(nil, nil, nil) mgr.SetConfig(&internalconfig.Config{}) mgr.SetOAuthModelAlias(aliases) - auth := &Auth{ID: "test-auth-id", Provider: "gemini-cli"} + auth := &Auth{ID: "test-auth-id", Provider: "antigravity"} resolvedModel := mgr.applyOAuthModelAlias(auth, "gemini-2.5-pro(8192)") if resolvedModel != "gemini-2.5-pro-exp-03-25(8192)" { @@ -202,6 +208,53 @@ func TestApplyOAuthModelAlias_SuffixPreservation(t *testing.T) { } } +func TestApplyOAuthModelAlias_PerAuthOverridesGlobalAlias(t *testing.T) { + t.Parallel() + + globalAliases := map[string][]internalconfig.OAuthModelAlias{ + "codex": {{Name: "gpt-5-global", Alias: "gpt-5.5"}}, + } + + mgr := NewManager(nil, nil, nil) + mgr.SetConfig(&internalconfig.Config{}) + mgr.SetOAuthModelAlias(globalAliases) + + auth := &Auth{ + ID: "codex-auth-id", + Provider: "codex", + Attributes: map[string]string{ + "auth_kind": "oauth", + "model_aliases": `[{"name":"gpt-5.3-codex-spark","alias":"gpt-5.5"}]`, + }, + } + + resolvedModel := mgr.applyOAuthModelAlias(auth, "gpt-5.5(high)") + if resolvedModel != "gpt-5.3-codex-spark(high)" { + t.Errorf("applyOAuthModelAlias() model = %q, want %q", resolvedModel, "gpt-5.3-codex-spark(high)") + } +} + +func TestApplyOAuthModelAlias_PerAuthAliasSkipsAPIKey(t *testing.T) { + t.Parallel() + + mgr := NewManager(nil, nil, nil) + mgr.SetConfig(&internalconfig.Config{}) + + auth := &Auth{ + ID: "codex-api-key-auth", + Provider: "codex", + Attributes: map[string]string{ + "auth_kind": "api_key", + "model_aliases": `[{"name":"gpt-5.3-codex-spark","alias":"gpt-5.5"}]`, + }, + } + + resolvedModel := mgr.applyOAuthModelAlias(auth, "gpt-5.5") + if resolvedModel != "gpt-5.5" { + t.Errorf("applyOAuthModelAlias() model = %q, want %q", resolvedModel, "gpt-5.5") + } +} + func TestApplyOAuthModelAlias_PluginProvider(t *testing.T) { t.Parallel() diff --git a/sdk/cliproxy/auth/openai_compat_pool_test.go b/sdk/cliproxy/auth/openai_compat_pool_test.go index f052c486f4..33e40e57ea 100644 --- a/sdk/cliproxy/auth/openai_compat_pool_test.go +++ b/sdk/cliproxy/auth/openai_compat_pool_test.go @@ -12,6 +12,8 @@ import ( cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" ) +const openAICompatPoolProviderKey = "openai-compatible-pool" + type openAICompatPoolExecutor struct { id string @@ -169,18 +171,18 @@ func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []interna m := NewManager(nil, nil, nil) m.SetConfig(cfg) if executor == nil { - executor = &openAICompatPoolExecutor{id: "pool"} + executor = &openAICompatPoolExecutor{id: openAICompatPoolProviderKey} } m.RegisterExecutor(executor) auth := &Auth{ ID: "pool-auth-" + t.Name(), - Provider: "pool", + Provider: openAICompatPoolProviderKey, Status: StatusActive, Attributes: map[string]string{ "api_key": "test-key", "compat_name": "pool", - "provider_key": "pool", + "provider_key": openAICompatPoolProviderKey, }, } if _, err := m.Register(context.Background(), auth); err != nil { @@ -188,7 +190,7 @@ func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []interna } reg := registry.GetGlobalRegistry() - reg.RegisterClient(auth.ID, "pool", []*registry.ModelInfo{{ID: alias}}) + reg.RegisterClient(auth.ID, openAICompatPoolProviderKey, []*registry.ModelInfo{{ID: alias}}) t.Cleanup(func() { reg.UnregisterClient(auth.ID) }) @@ -214,7 +216,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testi alias := "claude-opus-4.66" invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"} executor := &openAICompatPoolExecutor{ - id: "pool", + id: openAICompatPoolProviderKey, countErrors: map[string]error{"deepseek-v3.1": invalidErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ @@ -222,7 +224,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testi {Name: "glm-5", Alias: alias}, }, executor) - _, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + _, err := m.ExecuteCount(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err == nil || err.Error() != invalidErr.Error() { t.Fatalf("execute count error = %v, want %v", err, invalidErr) } @@ -251,14 +253,14 @@ func TestResolveModelAliasPoolFromConfigModels(t *testing.T) { func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) { alias := "claude-opus-4.66" - executor := &openAICompatPoolExecutor{id: "pool"} + executor := &openAICompatPoolExecutor{id: openAICompatPoolProviderKey} m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) for i := 0; i < 3; i++ { - resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + resp, err := m.Execute(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err != nil { t.Fatalf("execute %d: %v", i, err) } @@ -283,7 +285,7 @@ func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) { alias := "claude-opus-4.66" invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"} executor := &openAICompatPoolExecutor{ - id: "pool", + id: openAICompatPoolProviderKey, executeErrors: map[string]error{"deepseek-v3.1": invalidErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ @@ -291,7 +293,7 @@ func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) { {Name: "glm-5", Alias: alias}, }, executor) - _, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + _, err := m.Execute(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err == nil || err.Error() != invalidErr.Error() { t.Fatalf("execute error = %v, want %v", err, invalidErr) } @@ -308,7 +310,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t Message: "invalid_request_error: The requested model is not supported.", } executor := &openAICompatPoolExecutor{ - id: "pool", + id: openAICompatPoolProviderKey, executeErrors: map[string]error{"deepseek-v3.1": modelSupportErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ @@ -316,7 +318,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportBadRequest(t {Name: "glm-5", Alias: alias}, }, executor) - resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + resp, err := m.Execute(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err != nil { t.Fatalf("execute error = %v, want fallback success", err) } @@ -354,7 +356,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportUnprocessabl Message: "The requested model is not supported.", } executor := &openAICompatPoolExecutor{ - id: "pool", + id: openAICompatPoolProviderKey, executeErrors: map[string]error{"deepseek-v3.1": modelSupportErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ @@ -362,7 +364,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportUnprocessabl {Name: "glm-5", Alias: alias}, }, executor) - resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + resp, err := m.Execute(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err != nil { t.Fatalf("execute error = %v, want fallback success", err) } @@ -384,7 +386,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackOnModelSupportUnprocessabl func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) { alias := "claude-opus-4.66" executor := &openAICompatPoolExecutor{ - id: "pool", + id: openAICompatPoolProviderKey, executeErrors: map[string]error{"deepseek-v3.1": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ @@ -392,7 +394,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing. {Name: "glm-5", Alias: alias}, }, executor) - resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + resp, err := m.Execute(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err != nil { t.Fatalf("execute: %v", err) } @@ -411,7 +413,7 @@ func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing. func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *testing.T) { alias := "claude-opus-4.66" executor := &openAICompatPoolExecutor{ - id: "pool", + id: openAICompatPoolProviderKey, streamPayloads: map[string][]cliproxyexecutor.StreamChunk{ "deepseek-v3.1": {}, }, @@ -421,7 +423,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *te {Name: "glm-5", Alias: alias}, }, executor) - streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + streamResult, err := m.ExecuteStream(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err != nil { t.Fatalf("execute stream: %v", err) } @@ -447,7 +449,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *te func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *testing.T) { alias := "claude-opus-4.66" executor := &openAICompatPoolExecutor{ - id: "pool", + id: openAICompatPoolProviderKey, streamFirstErrors: map[string]error{"deepseek-v3.1": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ @@ -455,7 +457,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *t {Name: "glm-5", Alias: alias}, }, executor) - streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + streamResult, err := m.ExecuteStream(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err != nil { t.Fatalf("execute stream: %v", err) } @@ -485,7 +487,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *test alias := "claude-opus-4.66" invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"} executor := &openAICompatPoolExecutor{ - id: "pool", + id: openAICompatPoolProviderKey, streamFirstErrors: map[string]error{"deepseek-v3.1": invalidErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ @@ -493,7 +495,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *test {Name: "glm-5", Alias: alias}, }, executor) - _, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + _, err := m.ExecuteStream(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err == nil || err.Error() != invalidErr.Error() { t.Fatalf("execute stream error = %v, want %v", err, invalidErr) } @@ -510,7 +512,7 @@ func TestManagerExecute_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterReques Message: "invalid_request_error: The requested model is not supported.", } executor := &openAICompatPoolExecutor{ - id: "pool", + id: openAICompatPoolProviderKey, executeErrors: map[string]error{"deepseek-v3.1": modelSupportErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ @@ -519,7 +521,7 @@ func TestManagerExecute_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterReques }, executor) for i := 0; i < 3; i++ { - resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + resp, err := m.Execute(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err != nil { t.Fatalf("execute %d: %v", i, err) } @@ -547,7 +549,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLater Message: "The requested model is not supported.", } executor := &openAICompatPoolExecutor{ - id: "pool", + id: openAICompatPoolProviderKey, streamFirstErrors: map[string]error{"deepseek-v3.1": modelSupportErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ @@ -556,7 +558,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLater }, executor) for i := 0; i < 3; i++ { - streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + streamResult, err := m.ExecuteStream(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err != nil { t.Fatalf("execute stream %d: %v", i, err) } @@ -582,14 +584,14 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLater func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) { alias := "claude-opus-4.66" - executor := &openAICompatPoolExecutor{id: "pool"} + executor := &openAICompatPoolExecutor{id: openAICompatPoolProviderKey} m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ {Name: "deepseek-v3.1", Alias: alias}, {Name: "glm-5", Alias: alias}, }, executor) for i := 0; i < 2; i++ { - resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + resp, err := m.ExecuteCount(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err != nil { t.Fatalf("execute count %d: %v", i, err) } @@ -614,7 +616,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterR Message: "invalid_request_error: The requested model is unsupported.", } executor := &openAICompatPoolExecutor{ - id: "pool", + id: openAICompatPoolProviderKey, countErrors: map[string]error{"deepseek-v3.1": modelSupportErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ @@ -623,7 +625,7 @@ func TestManagerExecuteCount_OpenAICompatAliasPoolSkipsSuspendedUpstreamOnLaterR }, executor) for i := 0; i < 3; i++ { - resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + resp, err := m.ExecuteCount(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err != nil { t.Fatalf("execute count %d: %v", i, err) } @@ -659,27 +661,27 @@ func TestManagerExecute_OpenAICompatAliasPoolBlockedAuthDoesNotConsumeRetryBudge m.SetConfig(cfg) m.SetRetryConfig(0, 0, 1) - executor := &authScopedOpenAICompatPoolExecutor{id: "pool"} + executor := &authScopedOpenAICompatPoolExecutor{id: openAICompatPoolProviderKey} m.RegisterExecutor(executor) badAuth := &Auth{ ID: "aa-blocked-auth", - Provider: "pool", + Provider: openAICompatPoolProviderKey, Status: StatusActive, Attributes: map[string]string{ "api_key": "bad-key", "compat_name": "pool", - "provider_key": "pool", + "provider_key": openAICompatPoolProviderKey, }, } goodAuth := &Auth{ ID: "bb-good-auth", - Provider: "pool", + Provider: openAICompatPoolProviderKey, Status: StatusActive, Attributes: map[string]string{ "api_key": "good-key", "compat_name": "pool", - "provider_key": "pool", + "provider_key": openAICompatPoolProviderKey, }, } if _, err := m.Register(context.Background(), badAuth); err != nil { @@ -690,8 +692,8 @@ func TestManagerExecute_OpenAICompatAliasPoolBlockedAuthDoesNotConsumeRetryBudge } reg := registry.GetGlobalRegistry() - reg.RegisterClient(badAuth.ID, "pool", []*registry.ModelInfo{{ID: alias}}) - reg.RegisterClient(goodAuth.ID, "pool", []*registry.ModelInfo{{ID: alias}}) + reg.RegisterClient(badAuth.ID, openAICompatPoolProviderKey, []*registry.ModelInfo{{ID: alias}}) + reg.RegisterClient(goodAuth.ID, openAICompatPoolProviderKey, []*registry.ModelInfo{{ID: alias}}) t.Cleanup(func() { reg.UnregisterClient(badAuth.ID) reg.UnregisterClient(goodAuth.ID) @@ -704,14 +706,14 @@ func TestManagerExecute_OpenAICompatAliasPoolBlockedAuthDoesNotConsumeRetryBudge for _, upstreamModel := range []string{"deepseek-v3.1", "glm-5"} { m.MarkResult(context.Background(), Result{ AuthID: badAuth.ID, - Provider: "pool", + Provider: openAICompatPoolProviderKey, Model: upstreamModel, Success: false, Error: modelSupportErr, }) } - resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + resp, err := m.Execute(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err != nil { t.Fatalf("execute error = %v, want success via fallback auth", err) } @@ -732,7 +734,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *te alias := "claude-opus-4.66" invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"} executor := &openAICompatPoolExecutor{ - id: "pool", + id: openAICompatPoolProviderKey, streamFirstErrors: map[string]error{"deepseek-v3.1": invalidErr}, } m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ @@ -740,7 +742,7 @@ func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *te {Name: "glm-5", Alias: alias}, }, executor) - streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + streamResult, err := m.ExecuteStream(context.Background(), []string{openAICompatPoolProviderKey}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) if err == nil { t.Fatal("expected invalid request error") } diff --git a/sdk/cliproxy/auth/scheduler.go b/sdk/cliproxy/auth/scheduler.go index b3b61534f6..8c86422117 100644 --- a/sdk/cliproxy/auth/scheduler.go +++ b/sdk/cliproxy/auth/scheduler.go @@ -52,7 +52,6 @@ type scheduledAuthMeta struct { auth *Auth providerKey string priority int - virtualParent string websocketEnabled bool supportedModelSet map[string]struct{} } @@ -80,18 +79,9 @@ type readyBucket struct { ws readyView } -// readyView holds the selection order for flat or grouped round-robin traversal. +// readyView holds the selection order for flat round-robin traversal. type readyView struct { - flat []*scheduledAuth - cursor int - parentOrder []string - parentCursor int - children map[string]*childBucket -} - -// childBucket keeps the per-parent rotation state for grouped Gemini virtual auths. -type childBucket struct { - items []*scheduledAuth + flat []*scheduledAuth cursor int } @@ -99,9 +89,7 @@ type childBucket struct { type cooldownQueue []*scheduledAuth type readyViewCursorState struct { - cursor int - parentCursor int - childCursors map[string]int + cursor int } type readyBucketCursorState struct { @@ -110,21 +98,7 @@ type readyBucketCursorState struct { } func snapshotReadyViewCursors(view readyView) readyViewCursorState { - state := readyViewCursorState{ - cursor: view.cursor, - parentCursor: view.parentCursor, - } - if len(view.children) == 0 { - return state - } - state.childCursors = make(map[string]int, len(view.children)) - for parent, child := range view.children { - if child == nil { - continue - } - state.childCursors[parent] = child.cursor - } - return state + return readyViewCursorState{cursor: view.cursor} } func restoreReadyViewCursors(view *readyView, state readyViewCursorState) { @@ -134,23 +108,6 @@ func restoreReadyViewCursors(view *readyView, state readyViewCursorState) { if len(view.flat) > 0 { view.cursor = normalizeCursor(state.cursor, len(view.flat)) } - if len(view.parentOrder) == 0 || len(view.children) == 0 { - return - } - view.parentCursor = normalizeCursor(state.parentCursor, len(view.parentOrder)) - if len(state.childCursors) == 0 { - return - } - for parent, child := range view.children { - if child == nil || len(child.items) == 0 { - continue - } - cursor, ok := state.childCursors[parent] - if !ok { - continue - } - child.cursor = normalizeCursor(cursor, len(child.items)) - } } func normalizeCursor(cursor, size int) int { @@ -534,7 +491,7 @@ func (s *authScheduler) upsertAuthLocked(auth *Auth, now time.Time) { return } authID := strings.TrimSpace(auth.ID) - providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) + providerKey := executorKeyFromAuth(auth) if authID == "" || providerKey == "" || auth.Disabled { s.removeAuthLocked(authID) return @@ -581,16 +538,11 @@ func (s *authScheduler) ensureProviderLocked(providerKey string) *providerSchedu // buildScheduledAuthMeta extracts the scheduling metadata needed for shard bookkeeping. func buildScheduledAuthMeta(auth *Auth) *scheduledAuthMeta { - providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) - virtualParent := "" - if auth.Attributes != nil { - virtualParent = strings.TrimSpace(auth.Attributes["gemini_virtual_parent"]) - } + providerKey := executorKeyFromAuth(auth) return &scheduledAuthMeta{ auth: auth, providerKey: providerKey, priority: authPriority(auth), - virtualParent: virtualParent, websocketEnabled: authWebsocketsEnabled(auth), supportedModelSet: supportedModelSetForAuth(auth.ID), } @@ -702,11 +654,9 @@ func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Tim previousState := entry.state previousNextRetryAt := entry.nextRetryAt previousPriority := 0 - previousParent := "" previousWebsocketEnabled := false if entry.meta != nil { previousPriority = entry.meta.priority - previousParent = entry.meta.virtualParent previousWebsocketEnabled = entry.meta.websocketEnabled } @@ -727,7 +677,7 @@ func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Tim entry.nextRetryAt = next } - if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousParent == meta.virtualParent && previousWebsocketEnabled == meta.websocketEnabled { + if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousWebsocketEnabled == meta.websocketEnabled { return } m.rebuildIndexesLocked() @@ -989,32 +939,9 @@ func buildReadyBucket(entries []*scheduledAuth) *readyBucket { return bucket } -// buildReadyView creates either a flat view or a grouped parent/child view for rotation. +// buildReadyView creates a flat view for rotation. func buildReadyView(entries []*scheduledAuth) readyView { - view := readyView{flat: append([]*scheduledAuth(nil), entries...)} - if len(entries) == 0 { - return view - } - groups := make(map[string][]*scheduledAuth) - for _, entry := range entries { - if entry == nil || entry.meta == nil || entry.meta.virtualParent == "" { - return view - } - groups[entry.meta.virtualParent] = append(groups[entry.meta.virtualParent], entry) - } - if len(groups) <= 1 { - return view - } - view.children = make(map[string]*childBucket, len(groups)) - view.parentOrder = make([]string, 0, len(groups)) - for parent := range groups { - view.parentOrder = append(view.parentOrder, parent) - } - sort.Strings(view.parentOrder) - for _, parent := range view.parentOrder { - view.children[parent] = &childBucket{items: append([]*scheduledAuth(nil), groups[parent]...)} - } - return view + return readyView{flat: append([]*scheduledAuth(nil), entries...)} } // pickFirst returns the first ready entry that satisfies predicate without advancing cursors. @@ -1027,11 +954,8 @@ func (v *readyView) pickFirst(predicate func(*scheduledAuth) bool) *scheduledAut return nil } -// pickRoundRobin returns the next ready entry using flat or grouped round-robin traversal. +// pickRoundRobin returns the next ready entry using flat round-robin traversal. func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth { - if len(v.parentOrder) > 1 && len(v.children) > 0 { - return v.pickGroupedRoundRobin(predicate) - } if len(v.flat) == 0 { return nil } @@ -1050,31 +974,3 @@ func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *schedul } return nil } - -// pickGroupedRoundRobin rotates across parents first and then within the selected parent. -func (v *readyView) pickGroupedRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth { - start := 0 - if len(v.parentOrder) > 0 { - start = v.parentCursor % len(v.parentOrder) - } - for offset := 0; offset < len(v.parentOrder); offset++ { - parentIndex := (start + offset) % len(v.parentOrder) - parent := v.parentOrder[parentIndex] - child := v.children[parent] - if child == nil || len(child.items) == 0 { - continue - } - itemStart := child.cursor % len(child.items) - for itemOffset := 0; itemOffset < len(child.items); itemOffset++ { - itemIndex := (itemStart + itemOffset) % len(child.items) - entry := child.items[itemIndex] - if predicate != nil && !predicate(entry) { - continue - } - child.cursor = itemIndex + 1 - v.parentCursor = parentIndex + 1 - return entry - } - } - return nil -} diff --git a/sdk/cliproxy/auth/scheduler_test.go b/sdk/cliproxy/auth/scheduler_test.go index 5843eaed33..99f4f9dc77 100644 --- a/sdk/cliproxy/auth/scheduler_test.go +++ b/sdk/cliproxy/auth/scheduler_test.go @@ -180,37 +180,6 @@ func TestSchedulerPick_PromotesExpiredCooldownBeforePick(t *testing.T) { } } -func TestSchedulerPick_GeminiVirtualParentUsesTwoLevelRotation(t *testing.T) { - t.Parallel() - - registerSchedulerModels(t, "gemini-cli", "gemini-2.5-pro", "cred-a::proj-1", "cred-a::proj-2", "cred-b::proj-1", "cred-b::proj-2") - scheduler := newSchedulerForTest( - &RoundRobinSelector{}, - &Auth{ID: "cred-a::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}}, - &Auth{ID: "cred-a::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}}, - &Auth{ID: "cred-b::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}}, - &Auth{ID: "cred-b::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}}, - ) - - wantParents := []string{"cred-a", "cred-b", "cred-a", "cred-b"} - wantIDs := []string{"cred-a::proj-1", "cred-b::proj-1", "cred-a::proj-2", "cred-b::proj-2"} - for index := range wantIDs { - got, errPick := scheduler.pickSingle(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, nil) - if errPick != nil { - t.Fatalf("pickSingle() #%d error = %v", index, errPick) - } - if got == nil { - t.Fatalf("pickSingle() #%d auth = nil", index) - } - if got.ID != wantIDs[index] { - t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index]) - } - if got.Attributes["gemini_virtual_parent"] != wantParents[index] { - t.Fatalf("pickSingle() #%d parent = %q, want %q", index, got.Attributes["gemini_virtual_parent"], wantParents[index]) - } - } -} - func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) { t.Parallel() diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index 0dcb32d938..b761086533 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -6,7 +6,6 @@ import ( "fmt" "hash/fnv" "math" - "math/rand/v2" "net/http" "regexp" "sort" @@ -255,9 +254,6 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([] } // Pick selects the next available auth for the provider in a round-robin manner. -// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute), -// a two-level round-robin is used: first cycling across credential groups (parent -// accounts), then cycling within each group's project auths. func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { _ = opts now := time.Now() @@ -276,39 +272,6 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o limit = 4096 } - // Check if any available auth has gemini_virtual_parent attribute, - // indicating gemini-cli virtual auths that should use credential-level polling. - groups, parentOrder := groupByVirtualParent(available) - if len(parentOrder) > 1 { - // Two-level round-robin: first select a credential group, then pick within it. - groupKey := key + "::group" - s.ensureCursorKey(groupKey, limit) - if _, exists := s.cursors[groupKey]; !exists { - // Seed with a random initial offset so the starting credential is randomized. - s.cursors[groupKey] = rand.IntN(len(parentOrder)) - } - groupIndex := s.cursors[groupKey] - if groupIndex >= 2_147_483_640 { - groupIndex = 0 - } - s.cursors[groupKey] = groupIndex + 1 - - selectedParent := parentOrder[groupIndex%len(parentOrder)] - group := groups[selectedParent] - - // Second level: round-robin within the selected credential group. - innerKey := key + "::cred:" + selectedParent - s.ensureCursorKey(innerKey, limit) - innerIndex := s.cursors[innerKey] - if innerIndex >= 2_147_483_640 { - innerIndex = 0 - } - s.cursors[innerKey] = innerIndex + 1 - s.mu.Unlock() - return group[innerIndex%len(group)], nil - } - - // Flat round-robin for non-grouped auths (original behavior). s.ensureCursorKey(key, limit) index := s.cursors[key] if index >= 2_147_483_640 { @@ -327,35 +290,6 @@ func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) { } } -// groupByVirtualParent groups auths by their gemini_virtual_parent attribute. -// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration. -// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks -// this attribute, nil/nil is returned so the caller falls back to flat round-robin. -func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) { - if len(auths) == 0 { - return nil, nil - } - groups := make(map[string][]*Auth) - for _, a := range auths { - parent := "" - if a.Attributes != nil { - parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"]) - } - if parent == "" { - // Non-virtual auth present; fall back to flat round-robin. - return nil, nil - } - groups[parent] = append(groups[parent], a) - } - // Collect parent IDs in sorted order for stable cursor indexing. - parentOrder := make([]string, 0, len(groups)) - for p := range groups { - parentOrder = append(parentOrder, p) - } - sort.Strings(parentOrder) - return groups, parentOrder -} - // Pick selects the first available auth for the provider in a deterministic manner. func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { _ = opts diff --git a/sdk/cliproxy/auth/selector_test.go b/sdk/cliproxy/auth/selector_test.go index c2d752a49a..4896422b4f 100644 --- a/sdk/cliproxy/auth/selector_test.go +++ b/sdk/cliproxy/auth/selector_test.go @@ -405,61 +405,6 @@ func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) { } } -func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) { - t.Parallel() - - selector := &RoundRobinSelector{} - - // Simulate two gemini-cli credentials, each with multiple projects: - // Credential A (parent = "cred-a.json") has 3 projects - // Credential B (parent = "cred-b.json") has 2 projects - auths := []*Auth{ - {ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, - {ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, - {ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, - {ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}}, - {ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}}, - } - - // Two-level round-robin: consecutive picks must alternate between credentials. - // Credential group order is randomized, but within each call the group cursor - // advances by 1, so consecutive picks should cycle through different parents. - picks := make([]string, 6) - parents := make([]string, 6) - for i := 0; i < 6; i++ { - got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths) - if err != nil { - t.Fatalf("Pick() #%d error = %v", i, err) - } - if got == nil { - t.Fatalf("Pick() #%d auth = nil", i) - } - picks[i] = got.ID - parents[i] = got.Attributes["gemini_virtual_parent"] - } - - // Verify property: consecutive picks must alternate between credential groups. - for i := 1; i < len(parents); i++ { - if parents[i] == parents[i-1] { - t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials", - i-1, i, parents[i], picks[i-1], picks[i]) - } - } - - // Verify property: each credential's projects are picked in sequence (round-robin within group). - credPicks := map[string][]string{} - for i, id := range picks { - credPicks[parents[i]] = append(credPicks[parents[i]], id) - } - for parent, ids := range credPicks { - for i := 1; i < len(ids); i++ { - if ids[i] == ids[i-1] { - t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i]) - } - } - } -} - func TestExtractSessionID(t *testing.T) { t.Parallel() @@ -613,42 +558,6 @@ func TestSessionAffinitySelector_DifferentSessionsDifferentAuths(t *testing.T) { } } -func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) { - t.Parallel() - - selector := &RoundRobinSelector{} - - // All auths from the same parent - should fall back to flat round-robin - // because there's only one credential group (no benefit from two-level). - auths := []*Auth{ - {ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, - {ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, - {ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, - } - - // With single parent group, parentOrder has length 1, so it uses flat round-robin. - // Sorted by ID: proj-a1, proj-a2, proj-a3 - want := []string{ - "cred-a.json::proj-a1", - "cred-a.json::proj-a2", - "cred-a.json::proj-a3", - "cred-a.json::proj-a1", - } - - for i, expectedID := range want { - got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths) - if err != nil { - t.Fatalf("Pick() #%d error = %v", i, err) - } - if got == nil { - t.Fatalf("Pick() #%d auth = nil", i) - } - if got.ID != expectedID { - t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID) - } - } -} - func TestSessionAffinitySelector_FailoverWhenAuthUnavailable(t *testing.T) { t.Parallel() @@ -700,39 +609,6 @@ func TestSessionAffinitySelector_FailoverWhenAuthUnavailable(t *testing.T) { } } -func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) { - t.Parallel() - - selector := &RoundRobinSelector{} - - // Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects - // alongside virtual ones). Should fall back to flat round-robin. - auths := []*Auth{ - {ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, - {ID: "cred-regular.json"}, // no gemini_virtual_parent - } - - // groupByVirtualParent returns nil when any auth lacks the attribute, - // so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json - want := []string{ - "cred-a.json::proj-a1", - "cred-regular.json", - "cred-a.json::proj-a1", - } - - for i, expectedID := range want { - got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths) - if err != nil { - t.Fatalf("Pick() #%d error = %v", i, err) - } - if got == nil { - t.Fatalf("Pick() #%d auth = nil", i) - } - if got.ID != expectedID { - t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID) - } - } -} func TestExtractSessionID_ClaudeCodePriorityOverHeader(t *testing.T) { t.Parallel() diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index 882c25eabd..8c90095117 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -100,6 +100,49 @@ type Auth struct { indexAssigned bool `json:"-"` } +const ( + AttributeAuthIndexSeed = "auth_index_seed" + AttributePluginVirtual = "plugin_virtual" + AttributeVirtualSource = "virtual_source" + pluginVirtualAttrEnabled = "true" +) + +// MarkPluginVirtualAuth marks an auth that was expanded from a plugin-owned source file. +func MarkPluginVirtualAuth(auth *Auth, sourcePath string, ordinal int) { + if auth == nil { + return + } + if auth.Attributes == nil { + auth.Attributes = make(map[string]string) + } + auth.Attributes[AttributePluginVirtual] = pluginVirtualAttrEnabled + sourcePath = strings.TrimSpace(sourcePath) + if sourcePath != "" { + auth.Attributes[AttributeVirtualSource] = sourcePath + } + seedID := strings.TrimSpace(auth.ID) + if seedID == "" { + seedID = strings.TrimSpace(auth.FileName) + } + if seedID == "" { + seedID = strconv.Itoa(ordinal) + } + auth.Attributes[AttributeAuthIndexSeed] = strings.Join([]string{ + strings.ToLower(strings.TrimSpace(auth.Provider)), + sourcePath, + seedID, + strconv.Itoa(ordinal), + }, "|") +} + +// IsPluginVirtualAuth reports whether an auth was expanded from a plugin-owned source file. +func IsPluginVirtualAuth(auth *Auth) bool { + if auth == nil || len(auth.Attributes) == 0 { + return false + } + return strings.EqualFold(strings.TrimSpace(auth.Attributes[AttributePluginVirtual]), pluginVirtualAttrEnabled) +} + const ( recentRequestBucketSeconds int64 = 10 * 60 recentRequestBucketCount = 20 @@ -257,6 +300,12 @@ func (a *Auth) indexSeed() string { return "" } + if a.Attributes != nil { + if seed := strings.TrimSpace(a.Attributes[AttributeAuthIndexSeed]); seed != "" { + return AttributeAuthIndexSeed + ":" + seed + } + } + provider := strings.ToLower(strings.TrimSpace(a.Provider)) compatName := "" baseURL := "" @@ -508,23 +557,6 @@ func (a *Auth) AccountInfo() (string, string) { if a == nil { return "", "" } - // For Gemini CLI, include project ID in the OAuth account info if present. - if strings.ToLower(a.Provider) == "gemini-cli" { - if a.Metadata != nil { - email, _ := a.Metadata["email"].(string) - email = strings.TrimSpace(email) - if email != "" { - if p, ok := a.Metadata["project_id"].(string); ok { - p = strings.TrimSpace(p) - if p != "" { - return "oauth", email + " (" + p + ")" - } - } - return "oauth", email - } - } - } - // Check metadata for email first (OAuth-style auth) if a.Metadata != nil { if v, ok := a.Metadata["email"].(string); ok { diff --git a/sdk/cliproxy/auth/types_test.go b/sdk/cliproxy/auth/types_test.go index f579bfda2e..83f3392444 100644 --- a/sdk/cliproxy/auth/types_test.go +++ b/sdk/cliproxy/auth/types_test.go @@ -113,16 +113,16 @@ func TestEnsureIndexUsesOAuthTypeAndAbsolutePath(t *testing.T) { relPath := "test-oauth.json" absPath := filepath.Join(wd, relPath) - expectedSeed := "gemini:" + filepath.Clean(absPath) + expectedSeed := "antigravity:" + filepath.Clean(absPath) expectedIndex := stableAuthIndex(expectedSeed) a := &Auth{ - Provider: "gemini-cli", + Provider: "antigravity", Attributes: map[string]string{ "path": relPath, }, Metadata: map[string]any{ - "type": "gemini", + "type": "antigravity", }, } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 7912c0a5b7..6f2d996735 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -413,11 +413,10 @@ func (s *Service) registerModelRefreshCallback() { }) } -// newDefaultAuthManager creates a default authentication manager with all supported providers. +// newDefaultAuthManager creates a default authentication manager with supported OAuth providers. func newDefaultAuthManager() *sdkAuth.Manager { return sdkAuth.NewManager( sdkAuth.GetTokenStore(), - sdkAuth.NewGeminiAuthenticator(), sdkAuth.NewCodexAuthenticator(), sdkAuth.NewClaudeAuthenticator(), sdkAuth.NewXAIAuthenticator(), @@ -788,11 +787,16 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName if providerKey == "" { providerKey = compatName } - return strings.ToLower(providerKey), compatName, true + return util.OpenAICompatibleProviderKey(providerKey), compatName, true } } if strings.EqualFold(strings.TrimSpace(a.Provider), "openai-compatibility") { - return "openai-compatibility", strings.TrimSpace(a.Label), true + compatName = strings.TrimSpace(a.Label) + providerKey = compatName + if providerKey == "" { + providerKey = "openai-compatibility" + } + return util.OpenAICompatibleProviderKey(providerKey), compatName, true } return "", "", false } @@ -848,6 +852,24 @@ func (s *Service) hasNativeOpenAICompatExecutorConfig(a *coreauth.Auth, provider return false } +func (s *Service) unregisterOpenAICompatExecutor(providerKey string) { + if s == nil || s.coreManager == nil { + return + } + providerKey = strings.ToLower(strings.TrimSpace(providerKey)) + if providerKey == "" { + return + } + existing, okExecutor := s.coreManager.Executor(providerKey) + if !okExecutor || existing == nil { + return + } + if _, okOpenAICompat := existing.(*executor.OpenAICompatExecutor); !okOpenAICompat { + return + } + s.coreManager.UnregisterExecutor(providerKey) +} + func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.ensureExecutorsForAuthWithMode(a, false) } @@ -888,7 +910,6 @@ func baselineExecutorAuths() []*coreauth.Auth { "claude", "gemini", "vertex", - "gemini-cli", "aistudio", "antigravity", "kimi", @@ -960,8 +981,6 @@ func (s *Service) registerExecutorForAuth(a *coreauth.Auth, forceReplace bool) { s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg)) case "vertex": s.coreManager.RegisterExecutor(executor.NewGeminiVertexExecutor(s.cfg)) - case "gemini-cli": - s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg)) case "aistudio": if s.wsGateway != nil { s.coreManager.RegisterExecutor(executor.NewAIStudioExecutor(s.cfg, a.ID, s.wsGateway)) @@ -983,6 +1002,7 @@ func (s *Service) registerExecutorForAuth(a *coreauth.Auth, forceReplace bool) { if s.pluginHost != nil && s.pluginHost.HasExecutorCandidateProvider(providerKey) && !s.hasNativeOpenAICompatExecutorConfig(a, providerKey) { + s.unregisterOpenAICompatExecutor(providerKey) return } s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(providerKey, s.cfg)) @@ -1120,7 +1140,7 @@ func (s *Service) tryRegisterPluginModelsForAuth(ctx context.Context, a *coreaut } } models := applyExcludedModels(result.Models, activeExcluded) - models = applyOAuthModelAlias(s.cfg, providerKey, activeAuthKind, models) + models = applyOAuthModelAliasForAuth(s.cfg, providerKey, activeAuthKind, activeAuth.Attributes, models) if len(models) > 0 { s.registerResolvedModelsForAuth(activeAuth, providerKey, applyModelPrefixes(models, activeAuth.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix)) return true @@ -1222,6 +1242,8 @@ func (s *Service) applyConfigUpdateWithAuthSynthesis(newCfg *config.Config, synt s.coreManager.SetConfig(newCfg) s.coreManager.SetOAuthModelAlias(newCfg.OAuthModelAlias) } + ctx := coreauth.WithSkipPersist(context.Background()) + s.syncPluginRuntimeConfig(ctx) var auths []*coreauth.Auth if s.coreManager != nil { auths = s.coreManager.List() @@ -1231,7 +1253,6 @@ func (s *Service) applyConfigUpdateWithAuthSynthesis(newCfg *config.Config, synt forceReplaceAuths: true, auths: auths, }) - ctx := coreauth.WithSkipPersist(context.Background()) if synthesizeConfigAuths { s.registerConfigAPIKeyAuths(ctx, newCfg) } @@ -1240,7 +1261,7 @@ func (s *Service) applyConfigUpdateWithAuthSynthesis(newCfg *config.Config, synt log.Warnf("failed to restore cooldown state after config update: %v", errRestoreCooldown) } } - s.syncPluginRuntime(ctx) + s.syncPluginModelRuntime(ctx) } func (s *Service) reloadConfigFromWatcher() bool { @@ -1298,7 +1319,6 @@ func forceHomeRuntimeConfig(cfg *config.Config) { cfg.DisableCooling = true cfg.SaveCooldownStatus = false cfg.WebsocketAuth = false - cfg.EnableGeminiCLIEndpoint = false cfg.RemoteManagement.AllowRemote = false cfg.RemoteManagement.DisableControlPanel = true } @@ -1776,12 +1796,6 @@ func (s *Service) registerModelsForAuth(ctx context.Context, a *coreauth.Auth) { authKind = "apikey" } } - if a.Attributes != nil { - if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") { - GlobalModelRegistry().UnregisterClient(a.ID) - return - } - } // Unregister legacy client ID (if present) to avoid double counting if a.Runtime != nil { if idGetter, ok := a.Runtime.(interface{ GetClientID() string }); ok { @@ -1831,9 +1845,6 @@ func (s *Service) registerModelsForAuth(ctx context.Context, a *coreauth.Auth) { } } models = applyExcludedModels(models, excluded) - case "gemini-cli": - models = registry.GetGeminiCLIModels() - models = applyExcludedModels(models, excluded) case "aistudio": models = registry.GetAIStudioModels() models = applyExcludedModels(models, excluded) @@ -1962,7 +1973,7 @@ func (s *Service) registerModelsForAuth(ctx context.Context, a *coreauth.Auth) { } } } - models = applyOAuthModelAlias(s.cfg, provider, authKind, models) + models = applyOAuthModelAliasForAuth(s.cfg, provider, authKind, a.Attributes, models) key := provider if key == "" { key = strings.ToLower(strings.TrimSpace(a.Provider)) @@ -2427,18 +2438,58 @@ func rewriteModelInfoName(name, oldID, newID string) string { } func applyOAuthModelAlias(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo { - if cfg == nil || len(models) == 0 { + return applyOAuthModelAliasForAuth(cfg, provider, authKind, nil, models) +} + +func applyOAuthModelAliasForAuth(cfg *config.Config, provider, authKind string, attributes map[string]string, models []*ModelInfo) []*ModelInfo { + if len(models) == 0 { return models } channel := coreauth.OAuthModelAliasChannel(provider, authKind) - if channel == "" || len(cfg.OAuthModelAlias) == 0 { + if channel == "" { return models } - aliases := cfg.OAuthModelAlias[channel] + aliases := oauthModelAliasesForAuth(cfg, channel, attributes) if len(aliases) == 0 { return models } + return applyOAuthModelAliasEntries(aliases, models) +} + +func oauthModelAliasesForAuth(cfg *config.Config, channel string, attributes map[string]string) []config.OAuthModelAlias { + perAuthAliases := coreauth.OAuthModelAliasesFromAttributes(attributes) + if cfg == nil || len(cfg.OAuthModelAlias) == 0 { + return perAuthAliases + } + globalAliases := cfg.OAuthModelAlias[channel] + if len(perAuthAliases) == 0 { + return globalAliases + } + if len(globalAliases) == 0 { + return perAuthAliases + } + out := make([]config.OAuthModelAlias, 0, len(perAuthAliases)+len(globalAliases)) + seenAlias := make(map[string]struct{}, len(perAuthAliases)+len(globalAliases)) + add := func(aliases []config.OAuthModelAlias) { + for _, entry := range aliases { + alias := strings.TrimSpace(entry.Alias) + if alias == "" { + continue + } + key := strings.ToLower(alias) + if _, exists := seenAlias[key]; exists { + continue + } + seenAlias[key] = struct{}{} + out = append(out, entry) + } + } + add(perAuthAliases) + add(globalAliases) + return out +} +func applyOAuthModelAliasEntries(aliases []config.OAuthModelAlias, models []*ModelInfo) []*ModelInfo { type aliasEntry struct { alias string fork bool diff --git a/sdk/cliproxy/service_excluded_models_test.go b/sdk/cliproxy/service_excluded_models_test.go index fd44436fac..96490743b1 100644 --- a/sdk/cliproxy/service_excluded_models_test.go +++ b/sdk/cliproxy/service_excluded_models_test.go @@ -16,13 +16,13 @@ func TestRegisterModelsForAuth_UsesPreMergedExcludedModelsAttribute(t *testing.T service := &Service{ cfg: &config.Config{ OAuthExcludedModels: map[string][]string{ - "gemini-cli": {"gemini-2.5-pro"}, + "gemini": {"gemini-2.5-pro"}, }, }, } auth := &coreauth.Auth{ - ID: "auth-gemini-cli", - Provider: "gemini-cli", + ID: "auth-gemini", + Provider: "gemini", Status: coreauth.StatusActive, Attributes: map[string]string{ "auth_kind": "oauth", @@ -38,9 +38,9 @@ func TestRegisterModelsForAuth_UsesPreMergedExcludedModelsAttribute(t *testing.T service.registerModelsForAuth(context.Background(), auth) - models := registry.GetAvailableModelsByProvider("gemini-cli") + models := registry.GetAvailableModelsByProvider("gemini") if len(models) == 0 { - t.Fatal("expected gemini-cli models to be registered") + t.Fatal("expected gemini models to be registered") } for _, model := range models { diff --git a/sdk/cliproxy/service_executor_registration_test.go b/sdk/cliproxy/service_executor_registration_test.go index d3867987d3..5366fa09ab 100644 --- a/sdk/cliproxy/service_executor_registration_test.go +++ b/sdk/cliproxy/service_executor_registration_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/router-for-me/CLIProxyAPI/v7/internal/pluginhost" + runtimeexecutor "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor" coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" "github.com/router-for-me/CLIProxyAPI/v7/sdk/config" @@ -79,7 +80,6 @@ func TestRegisterAvailableExecutors(t *testing.T) { "claude", "gemini", "vertex", - "gemini-cli", "aistudio", "antigravity", "kimi", @@ -99,3 +99,64 @@ func TestRegisterAvailableExecutors(t *testing.T) { t.Fatalf("executor type = %T, want serviceTestPluginExecutor", resolved) } } + +func TestRegisterExecutorForAuth_OpenAICompatUsesNamespacedProviderKey(t *testing.T) { + testCases := []struct { + name string + auths []*coreauth.Auth + }{ + { + name: "native first", + auths: []*coreauth.Auth{ + {ID: "native-kimi", Provider: "kimi"}, + openAICompatKimiAuth(), + }, + }, + { + name: "compat first", + auths: []*coreauth.Auth{ + openAICompatKimiAuth(), + {ID: "native-kimi", Provider: "kimi"}, + }, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + service := &Service{ + cfg: &config.Config{}, + coreManager: coreauth.NewManager(nil, nil, nil), + } + + service.registerExecutorsForAuths(tt.auths, true) + + nativeExecutor, okNative := service.coreManager.Executor("kimi") + if !okNative { + t.Fatal("expected native kimi executor") + } + if _, okKimi := nativeExecutor.(*runtimeexecutor.KimiExecutor); !okKimi { + t.Fatalf("native executor type = %T, want *executor.KimiExecutor", nativeExecutor) + } + + compatExecutor, okCompat := service.coreManager.Executor("openai-compatible-kimi") + if !okCompat { + t.Fatal("expected namespaced OpenAI-compatible executor") + } + if _, okOpenAICompat := compatExecutor.(*runtimeexecutor.OpenAICompatExecutor); !okOpenAICompat { + t.Fatalf("compat executor type = %T, want *executor.OpenAICompatExecutor", compatExecutor) + } + }) + } +} + +func openAICompatKimiAuth() *coreauth.Auth { + return &coreauth.Auth{ + ID: "compat-kimi", + Provider: "openai-compatibility", + Label: "kimi", + Attributes: map[string]string{ + "compat_name": "kimi", + "provider_key": "kimi", + }, + } +} diff --git a/sdk/cliproxy/service_oauth_model_alias_test.go b/sdk/cliproxy/service_oauth_model_alias_test.go index c39fbb7b11..df77cfa4aa 100644 --- a/sdk/cliproxy/service_oauth_model_alias_test.go +++ b/sdk/cliproxy/service_oauth_model_alias_test.go @@ -132,3 +132,23 @@ func TestApplyOAuthModelAlias_PluginProviderSkipsAPIKey(t *testing.T) { t.Fatalf("expected API key plugin model to remain unchanged, got %#v", out) } } + +func TestApplyOAuthModelAlias_PerAuthAlias(t *testing.T) { + models := []*ModelInfo{ + {ID: "gpt-5.3-codex-spark", Name: "models/gpt-5.3-codex-spark"}, + } + attributes := map[string]string{ + "model_aliases": `[{"name":"gpt-5.3-codex-spark","alias":"gpt-5.5"}]`, + } + + out := applyOAuthModelAliasForAuth(nil, "codex", "oauth", attributes, models) + if len(out) != 1 { + t.Fatalf("expected 1 model, got %d", len(out)) + } + if out[0].ID != "gpt-5.5" { + t.Fatalf("expected per-auth alias id %q, got %q", "gpt-5.5", out[0].ID) + } + if out[0].Name != "models/gpt-5.5" { + t.Fatalf("expected per-auth alias name %q, got %q", "models/gpt-5.5", out[0].Name) + } +} diff --git a/sdk/cliproxy/types.go b/sdk/cliproxy/types.go index 719e03095d..d6c2b39909 100644 --- a/sdk/cliproxy/types.go +++ b/sdk/cliproxy/types.go @@ -86,6 +86,12 @@ type PluginAuthParser interface { ParseAuth(context.Context, pluginapi.AuthParseRequest) (*coreauth.Auth, bool, error) } +// PluginMultiAuthParser expands one auth JSON payload into multiple plugin auth records. +// Returning handled=true with an empty slice means the plugin intentionally suppresses built-in parsing. +type PluginMultiAuthParser interface { + ParseAuths(context.Context, pluginapi.AuthParseRequest) ([]*coreauth.Auth, bool, error) +} + // WatcherWrapper exposes the subset of watcher methods required by the SDK. type WatcherWrapper struct { start func(ctx context.Context) error diff --git a/sdk/pluginabi/types.go b/sdk/pluginabi/types.go index a1ab574663..5db85b0d66 100644 --- a/sdk/pluginabi/types.go +++ b/sdk/pluginabi/types.go @@ -86,7 +86,8 @@ type Envelope struct { } type Error struct { - Code string `json:"code"` - Message string `json:"message"` - Retryable bool `json:"retryable,omitempty"` + Code string `json:"code"` + Message string `json:"message"` + Retryable bool `json:"retryable,omitempty"` + HTTPStatus int `json:"http_status,omitempty"` } diff --git a/sdk/pluginapi/types.go b/sdk/pluginapi/types.go index 6f9f53f756..5bd97508b2 100644 --- a/sdk/pluginapi/types.go +++ b/sdk/pluginapi/types.go @@ -253,6 +253,8 @@ type AuthParseResponse struct { Handled bool // Auth is the parsed auth record when Handled is true. Auth AuthData + // Auths contains multiple parsed auth records when one auth material expands into several runtime auths. + Auths []AuthData } // AuthProvider parses, logs in, polls, and refreshes plugin provider auths. @@ -326,6 +328,8 @@ type AuthLoginPollResponse struct { Message string // Auth is the completed auth record when Status is success. Auth AuthData + // Auths contains multiple completed auth records when one login flow expands into several runtime auths. + Auths []AuthData } // AuthRefreshRequest asks a plugin to refresh provider auth data. diff --git a/sdk/pluginapi/types_test.go b/sdk/pluginapi/types_test.go index 6a5556efdc..de0d5c4e1d 100644 --- a/sdk/pluginapi/types_test.go +++ b/sdk/pluginapi/types_test.go @@ -51,6 +51,61 @@ func TestMetadataConfigFieldsExposePluginSchema(t *testing.T) { } } +func TestAuthParseResponseSupportsMultipleAuths(t *testing.T) { + resp := AuthParseResponse{ + Handled: true, + Auth: AuthData{ + Provider: "gemini-cli", + ID: "primary.json", + }, + Auths: []AuthData{ + {Provider: "gemini-cli", ID: "primary.json"}, + {Provider: "gemini-cli", ID: "primary-project-a.json"}, + }, + } + + raw, errMarshal := json.Marshal(resp) + if errMarshal != nil { + t.Fatalf("Marshal() error = %v", errMarshal) + } + var decoded AuthParseResponse + if errUnmarshal := json.Unmarshal(raw, &decoded); errUnmarshal != nil { + t.Fatalf("Unmarshal() error = %v", errUnmarshal) + } + if !decoded.Handled || len(decoded.Auths) != 2 || decoded.Auths[1].ID != "primary-project-a.json" { + t.Fatalf("decoded response = %#v, want two auths", decoded) + } + if decoded.Auth.ID != "primary.json" { + t.Fatalf("decoded Auth.ID = %q, want primary.json", decoded.Auth.ID) + } +} + +func TestAuthLoginPollResponseSupportsMultipleAuths(t *testing.T) { + resp := AuthLoginPollResponse{ + Status: AuthLoginStatusSuccess, + Auth: AuthData{ + Provider: "gemini-cli", + ID: "primary.json", + }, + Auths: []AuthData{ + {Provider: "gemini-cli", ID: "primary.json"}, + {Provider: "gemini-cli", ID: "primary-project-a.json"}, + }, + } + + raw, errMarshal := json.Marshal(resp) + if errMarshal != nil { + t.Fatalf("Marshal() error = %v", errMarshal) + } + var decoded AuthLoginPollResponse + if errUnmarshal := json.Unmarshal(raw, &decoded); errUnmarshal != nil { + t.Fatalf("Unmarshal() error = %v", errUnmarshal) + } + if decoded.Status != AuthLoginStatusSuccess || len(decoded.Auths) != 2 { + t.Fatalf("decoded response = %#v, want success with two auths", decoded) + } +} + func TestResourceRouteMenuFieldsExposeManagementUIHints(t *testing.T) { route := ResourceRoute{ Path: "/status", diff --git a/sdk/translator/formats.go b/sdk/translator/formats.go index aafe9e056c..d03bbf74d8 100644 --- a/sdk/translator/formats.go +++ b/sdk/translator/formats.go @@ -6,7 +6,6 @@ const ( FormatOpenAIResponse Format = "openai-response" FormatClaude Format = "claude" FormatGemini Format = "gemini" - FormatGeminiCLI Format = "gemini-cli" FormatCodex Format = "codex" FormatAntigravity Format = "antigravity" ) diff --git a/test/thinking_conversion_test.go b/test/thinking_conversion_test.go index 8e000be6cb..fa0e3313f1 100644 --- a/test/thinking_conversion_test.go +++ b/test/thinking_conversion_test.go @@ -12,7 +12,6 @@ import ( _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/claude" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/codex" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/gemini" - _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/geminicli" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/kimi" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/openai" _ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/xai" @@ -1041,10 +1040,10 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) { expectValue: "128000", expectErr: false, }, - // Case 88: Gemini-CLI to Antigravity, budget 8192 → passthrough thinkingBudget + // Case 88: Antigravity to Antigravity, budget 8192 → passthrough thinkingBudget { name: "88", - from: "gemini-cli", + from: "antigravity", to: "antigravity", model: "antigravity-budget-model(8192)", inputJSON: `{"model":"antigravity-budget-model(8192)","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`, @@ -1053,10 +1052,10 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 89: Gemini-CLI to Antigravity, budget 64000 → clamped to Max + // Case 89: Antigravity to Antigravity, budget 64000 → clamped to Max { name: "89", - from: "gemini-cli", + from: "antigravity", to: "antigravity", model: "antigravity-budget-model(64000)", inputJSON: `{"model":"antigravity-budget-model(64000)","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`, @@ -1067,7 +1066,7 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) { }, // Gemini Family Cross-Channel Consistency (Cases 90-95) - // Tests that gemini/gemini-cli/antigravity as same API family should have consistent validation behavior + // Tests that gemini/antigravity as same API family should have consistent validation behavior // Case 90: Gemini to Antigravity, budget 64000 (suffix) → clamped to Max { @@ -1081,42 +1080,6 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 91: Gemini to Gemini-CLI, budget 64000 (suffix) → clamped to Max - { - name: "91", - from: "gemini", - to: "gemini-cli", - model: "gemini-budget-model(64000)", - inputJSON: `{"model":"gemini-budget-model(64000)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`, - expectField: "request.generationConfig.thinkingConfig.thinkingBudget", - expectValue: "20000", - includeThoughts: "true", - expectErr: false, - }, - // Case 92: Gemini-CLI to Antigravity, budget 64000 (suffix) → clamped to Max - { - name: "92", - from: "gemini-cli", - to: "antigravity", - model: "gemini-budget-model(64000)", - inputJSON: `{"model":"gemini-budget-model(64000)","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`, - expectField: "request.generationConfig.thinkingConfig.thinkingBudget", - expectValue: "20000", - includeThoughts: "true", - expectErr: false, - }, - // Case 93: Gemini-CLI to Gemini, budget 64000 (suffix) → clamped to Max - { - name: "93", - from: "gemini-cli", - to: "gemini", - model: "gemini-budget-model(64000)", - inputJSON: `{"model":"gemini-budget-model(64000)","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`, - expectField: "generationConfig.thinkingConfig.thinkingBudget", - expectValue: "20000", - includeThoughts: "true", - expectErr: false, - }, // Case 94: Gemini to Antigravity, budget 8192 → passthrough (normal value) { name: "94", @@ -1129,18 +1092,6 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 95: Gemini-CLI to Antigravity, budget 8192 → passthrough (normal value) - { - name: "95", - from: "gemini-cli", - to: "antigravity", - model: "gemini-budget-model(8192)", - inputJSON: `{"model":"gemini-budget-model(8192)","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`, - expectField: "request.generationConfig.thinkingConfig.thinkingBudget", - expectValue: "8192", - includeThoughts: "true", - expectErr: false, - }, } runThinkingTests(t, cases) @@ -1525,16 +1476,6 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { expectField: "", expectErr: false, }, - // Case 31B: reasoning_effort=none with zero allowed to Gemini CLI → delete thinkingConfig - { - name: "31B", - from: "openai", - to: "gemini-cli", - model: "gemini-zero-mixed-model", - inputJSON: `{"model":"gemini-zero-mixed-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`, - expectField: "", - expectErr: false, - }, // Case 31C: reasoning_effort=none with zero allowed to Antigravity → delete thinkingConfig { name: "31C", @@ -1555,16 +1496,6 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { expectField: "", expectErr: false, }, - // Case 31E: reasoning.effort=none with zero allowed to Gemini CLI → delete thinkingConfig - { - name: "31E", - from: "openai-response", - to: "gemini-cli", - model: "gemini-zero-mixed-model", - inputJSON: `{"model":"gemini-zero-mixed-model","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"none"}}`, - expectField: "", - expectErr: false, - }, // Case 31F: reasoning.effort=none with zero allowed to Antigravity → delete thinkingConfig { name: "31F", @@ -2204,10 +2135,10 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { expectField: "", expectErr: true, }, - // Case 88: Gemini-CLI to Antigravity, thinkingBudget=8192 → passthrough + // Case 88: Antigravity to Antigravity, thinkingBudget=8192 → passthrough { name: "88", - from: "gemini-cli", + from: "antigravity", to: "antigravity", model: "antigravity-budget-model", inputJSON: `{"model":"antigravity-budget-model","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}}`, @@ -2216,10 +2147,10 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 89: Gemini-CLI to Antigravity, thinkingBudget=64000 → exceeds Max error + // Case 89: Antigravity to Antigravity, thinkingBudget=64000 → exceeds Max error { name: "89", - from: "gemini-cli", + from: "antigravity", to: "antigravity", model: "antigravity-budget-model", inputJSON: `{"model":"antigravity-budget-model","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":64000}}}}`, @@ -2228,7 +2159,7 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { }, // Gemini Family Cross-Channel Consistency (Cases 90-95) - // Tests that gemini/gemini-cli/antigravity as same API family should have consistent validation behavior + // Tests that gemini/antigravity as same API family should have consistent validation behavior // Case 90: Gemini to Antigravity, thinkingBudget=64000 → exceeds Max error (same family strict validation) { @@ -2240,36 +2171,6 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { expectField: "", expectErr: true, }, - // Case 91: Gemini to Gemini-CLI, thinkingBudget=64000 → exceeds Max error (same family strict validation) - { - name: "91", - from: "gemini", - to: "gemini-cli", - model: "gemini-budget-model", - inputJSON: `{"model":"gemini-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":64000}}}`, - expectField: "", - expectErr: true, - }, - // Case 92: Gemini-CLI to Antigravity, thinkingBudget=64000 → exceeds Max error (same family strict validation) - { - name: "92", - from: "gemini-cli", - to: "antigravity", - model: "gemini-budget-model", - inputJSON: `{"model":"gemini-budget-model","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":64000}}}}`, - expectField: "", - expectErr: true, - }, - // Case 93: Gemini-CLI to Gemini, thinkingBudget=64000 → exceeds Max error (same family strict validation) - { - name: "93", - from: "gemini-cli", - to: "gemini", - model: "gemini-budget-model", - inputJSON: `{"model":"gemini-budget-model","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":64000}}}}`, - expectField: "", - expectErr: true, - }, // Case 94: Gemini to Antigravity, thinkingBudget=8192 → passthrough (normal value) { name: "94", @@ -2282,18 +2183,6 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 95: Gemini-CLI to Antigravity, thinkingBudget=8192 → passthrough (normal value) - { - name: "95", - from: "gemini-cli", - to: "antigravity", - model: "gemini-budget-model", - inputJSON: `{"model":"gemini-budget-model","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}}`, - expectField: "request.generationConfig.thinkingConfig.thinkingBudget", - expectValue: "8192", - includeThoughts: "true", - expectErr: false, - }, } runThinkingTests(t, cases) @@ -3063,7 +2952,7 @@ func getTestModels() []*registry.ModelInfo { Object: "model", Created: 1700000000, OwnedBy: "test", - Type: "gemini-cli", + Type: "antigravity", DisplayName: "Antigravity Budget Model", Thinking: ®istry.ThinkingSupport{Min: 128, Max: 20000, ZeroAllowed: true, DynamicAllowed: true}, }, @@ -3153,8 +3042,6 @@ func runThinkingTests(t *testing.T, cases []thinkingTestCase) { switch tc.to { case "gemini": hasThinking = gjson.GetBytes(body, "generationConfig.thinkingConfig").Exists() - case "gemini-cli": - hasThinking = gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() case "antigravity": hasThinking = gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() case "claude": @@ -3189,9 +3076,9 @@ func runThinkingTests(t *testing.T, cases []thinkingTestCase) { assertField(tc.expectField2, tc.expectValue2) } - if tc.includeThoughts != "" && (tc.to == "gemini" || tc.to == "gemini-cli" || tc.to == "antigravity") { + if tc.includeThoughts != "" && (tc.to == "gemini" || tc.to == "antigravity") { path := "generationConfig.thinkingConfig.includeThoughts" - if tc.to == "gemini-cli" || tc.to == "antigravity" { + if tc.to == "antigravity" { path = "request.generationConfig.thinkingConfig.includeThoughts" } itVal := gjson.GetBytes(body, path)