diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 41036a50666..acd010685dc 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -529,6 +529,15 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H { if claims := extractCodexIDTokenClaims(auth); claims != nil { entry["id_token"] = claims } + if subscription := extractCodexSubscriptionMetadata(auth); subscription != nil { + entry["codex_subscription"] = subscription + if v, ok := subscription["plan_type"]; ok { + entry["plan_type"] = v + } + if v, ok := subscription["subscription_active_until"]; ok { + entry["subscription_active_until"] = v + } + } // Expose priority from Attributes (set by synthesizer from JSON "priority" field). // Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer). if p := strings.TrimSpace(authAttribute(auth, "priority")); p != "" { @@ -659,6 +668,71 @@ func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H { return result } +func extractCodexSubscriptionMetadata(auth *coreauth.Auth) gin.H { + if auth == nil || auth.Metadata == nil { + return nil + } + if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") { + return nil + } + result := gin.H{} + copyMetadataValue(result, auth.Metadata, "account_id") + copyMetadataValue(result, auth.Metadata, "chatgpt_account_id") + copyMetadataValue(result, auth.Metadata, "plan_type") + copyMetadataValue(result, auth.Metadata, "subscription_active_until") + copyMetadataValue(result, auth.Metadata, "chatgpt_subscription_active_until") + // Derive the expired flag from the expiry at response time rather than + // exposing the cached boolean, which goes stale once the stored expiry + // passes without a reload/enrichment. Pass the raw metadata value (which may + // be a JSON number for Unix timestamps) so IsSubscriptionExpired can apply + // the same scalar normalization the enrichment uses, instead of a stringified + // float in scientific notation. + rawActiveUntil := metadataActiveUntilValue(auth.Metadata) + if rawActiveUntil != nil { + result["subscription_expired"] = codex.IsSubscriptionExpired(rawActiveUntil) + } else { + copyMetadataValue(result, auth.Metadata, "subscription_expired") + } + copyMetadataValue(result, auth.Metadata, "chatgpt_subscription_last_checked") + if len(result) == 0 { + return nil + } + return result +} + +// metadataActiveUntilValue returns the raw subscription expiry value (string or +// JSON number) from metadata, preferring subscription_active_until, or nil when +// neither key holds a usable value. +func metadataActiveUntilValue(metadata map[string]any) any { + for _, key := range []string{"subscription_active_until", "chatgpt_subscription_active_until"} { + value, ok := metadata[key] + if !ok || value == nil { + continue + } + if text, isString := value.(string); isString && strings.TrimSpace(text) == "" { + continue + } + return value + } + return nil +} + +func copyMetadataValue(dst gin.H, metadata map[string]any, key string) { + if dst == nil || metadata == nil { + return + } + value, ok := metadata[key] + if !ok || value == nil { + return + } + if text, isString := value.(string); isString { + if strings.TrimSpace(text) == "" { + return + } + } + dst[key] = value +} + func authEmail(auth *coreauth.Auth) string { if auth == nil { return "" @@ -2228,16 +2302,34 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { // Create token storage and persist tokenStorage := openaiAuth.CreateTokenStorage(bundle) fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true) + metadata := map[string]any{ + "email": tokenStorage.Email, + "account_id": tokenStorage.AccountID, + } + // Bound this best-effort lookup so a slow/unresponsive ChatGPT backend + // cannot block the token save and OAuth session completion. Mirrors the + // SDK device-flow path (sdk/auth/codex_device.go). + enrichCtx, cancelEnrich := context.WithTimeout(ctx, 20*time.Second) + if _, errEnrich := openaiAuth.EnrichSubscriptionMetadata( + enrichCtx, + metadata, + tokenStorage.IDToken, + tokenStorage.AccessToken, + tokenStorage.AccountID, + ); errEnrich != nil { + log.Warnf("Codex subscription metadata enrichment failed: %v", errEnrich) + } + cancelEnrich() record := &coreauth.Auth{ ID: fileName, Provider: "codex", FileName: fileName, Storage: tokenStorage, - Metadata: map[string]any{ - "email": tokenStorage.Email, - "account_id": tokenStorage.AccountID, - }, + Metadata: metadata, } + // Mirror the enriched subscription fields into attributes the runtime + // reads (Codex model-catalog selection keys off Attributes["plan_type"]). + coreauth.ApplyCodexSubscriptionAttributes(record) savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { SetOAuthSessionError(state, "Failed to save authentication tokens") diff --git a/internal/api/handlers/management/auth_files_subscription_expired_test.go b/internal/api/handlers/management/auth_files_subscription_expired_test.go new file mode 100644 index 00000000000..39fd08690b4 --- /dev/null +++ b/internal/api/handlers/management/auth_files_subscription_expired_test.go @@ -0,0 +1,58 @@ +package management + +import ( + "testing" + "time" + + coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" +) + +func TestExtractCodexSubscriptionMetadata_RecomputesExpired(t *testing.T) { + past := time.Now().UTC().Add(-24 * time.Hour).Format(time.RFC3339) + future := time.Now().UTC().Add(24 * time.Hour).Format(time.RFC3339) + + t.Run("stale cached false is recomputed to true once expiry passed", func(t *testing.T) { + auth := &coreauth.Auth{ + Provider: "codex", + Metadata: map[string]any{ + "subscription_active_until": past, + "subscription_expired": false, // stale cached value + }, + } + got := extractCodexSubscriptionMetadata(auth) + if got == nil { + t.Fatalf("expected result") + } + if v, _ := got["subscription_expired"].(bool); !v { + t.Fatalf("subscription_expired=%v, want true (recomputed from past expiry)", got["subscription_expired"]) + } + }) + + t.Run("future expiry yields not expired", func(t *testing.T) { + auth := &coreauth.Auth{ + Provider: "codex", + Metadata: map[string]any{ + "subscription_active_until": future, + "subscription_expired": true, // stale cached value + }, + } + got := extractCodexSubscriptionMetadata(auth) + if v, _ := got["subscription_expired"].(bool); v { + t.Fatalf("subscription_expired=%v, want false (recomputed from future expiry)", got["subscription_expired"]) + } + }) + t.Run("numeric unix-seconds expiry parses without scientific notation", func(t *testing.T) { + futureUnix := float64(time.Now().UTC().Add(24 * time.Hour).Unix()) + auth := &coreauth.Auth{ + Provider: "codex", + Metadata: map[string]any{ + "subscription_active_until": futureUnix, // JSON number, not string + "subscription_expired": true, // stale cached value + }, + } + got := extractCodexSubscriptionMetadata(auth) + if v, _ := got["subscription_expired"].(bool); v { + t.Fatalf("subscription_expired=%v, want false (numeric future expiry)", got["subscription_expired"]) + } + }) +} diff --git a/internal/auth/codex/subscription.go b/internal/auth/codex/subscription.go new file mode 100644 index 00000000000..92a55f422a3 --- /dev/null +++ b/internal/auth/codex/subscription.go @@ -0,0 +1,500 @@ +package codex + +import ( + "context" + "encoding/json" + "fmt" + "io" + "math" + "net/http" + "strconv" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + subscriptionAccountsCheckURL = "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27" + subscriptionsURL = "https://chatgpt.com/backend-api/subscriptions" + chatGPTWebReferer = "https://chatgpt.com/" + chatGPTWebUserAgent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/147.0.0.0 Safari/537.36" +) + +// SubscriptionSnapshot is the normalized ChatGPT subscription status for a Codex account. +type SubscriptionSnapshot struct { + AccountID string + PlanType string + ActiveUntil string +} + +type subscriptionAccountRecord struct { + key string + node any +} + +// EnrichSubscriptionMetadata adds ChatGPT subscription fields to metadata using values already in the auth JSON. +func EnrichSubscriptionMetadata(ctx context.Context, metadata map[string]any, client *http.Client) (bool, error) { + if metadata == nil { + return false, nil + } + return EnrichSubscriptionMetadataForTokens( + ctx, + metadata, + stringMetadata(metadata, "id_token"), + stringMetadata(metadata, "access_token"), + firstNonEmptyString( + stringMetadata(metadata, "chatgpt_account_id"), + stringMetadata(metadata, "account_id"), + ), + client, + ) +} + +// EnrichSubscriptionMetadata adds ChatGPT subscription fields using this auth service HTTP client. +func (o *CodexAuth) EnrichSubscriptionMetadata(ctx context.Context, metadata map[string]any, idToken, accessToken, accountID string) (bool, error) { + var client *http.Client + if o != nil { + client = o.httpClient + } + return EnrichSubscriptionMetadataForTokens(ctx, metadata, idToken, accessToken, accountID, client) +} + +// EnrichSubscriptionMetadataForTokens fills metadata from JWT claims first, then falls back to ChatGPT backend APIs. +func EnrichSubscriptionMetadataForTokens(ctx context.Context, metadata map[string]any, idToken, accessToken, accountID string, client *http.Client) (bool, error) { + if metadata == nil { + return false, nil + } + changed := false + log.Debugf("Codex subscription metadata enrichment attempt: has_id_token=%t has_access_token=%t account_id=%s", strings.TrimSpace(idToken) != "", strings.TrimSpace(accessToken) != "", strings.TrimSpace(accountID)) + + if claims, err := ParseJWTToken(strings.TrimSpace(idToken)); err == nil && claims != nil { + if setStringMetadata(metadata, "chatgpt_account_id", strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)) { + changed = true + } + if setStringMetadata(metadata, "account_id", strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)) { + changed = true + } + if setStringMetadata(metadata, "plan_type", normalizeSubscriptionPlan(claims.CodexAuthInfo.ChatgptPlanType)) { + changed = true + } + if activeUntil := normalizeSubscriptionScalar(claims.CodexAuthInfo.ChatgptSubscriptionActiveUntil); activeUntil != "" { + if setStringMetadata(metadata, "chatgpt_subscription_active_until", activeUntil) { + changed = true + } + if setStringMetadata(metadata, "subscription_active_until", activeUntil) { + changed = true + } + } + } + + currentActiveUntil := firstNonEmptyString( + stringMetadata(metadata, "subscription_active_until"), + stringMetadata(metadata, "chatgpt_subscription_active_until"), + ) + if !subscriptionMissingOrExpired(currentActiveUntil) { + log.Debugf("Codex subscription metadata enrichment using existing expiry: account_id=%s active_until=%s", firstNonEmptyString(accountID, stringMetadata(metadata, "account_id")), currentActiveUntil) + if updateSubscriptionExpiredMetadata(metadata, currentActiveUntil) { + changed = true + } + return changed, nil + } + + accessToken = strings.TrimSpace(accessToken) + if accessToken == "" { + log.Debugf("Codex subscription metadata enrichment skipped backend fallback: missing access token account_id=%s", firstNonEmptyString(accountID, stringMetadata(metadata, "account_id"))) + if currentActiveUntil != "" && updateSubscriptionExpiredMetadata(metadata, currentActiveUntil) { + changed = true + } + return changed, nil + } + + preferredAccountID := firstNonEmptyString( + strings.TrimSpace(accountID), + stringMetadata(metadata, "chatgpt_account_id"), + stringMetadata(metadata, "account_id"), + ) + snapshot, err := FetchSubscriptionStatus(ctx, accessToken, preferredAccountID, client) + if err != nil { + log.Debugf("Codex subscription metadata backend fallback failed: account_id=%s error=%v", preferredAccountID, err) + if currentActiveUntil != "" && updateSubscriptionExpiredMetadata(metadata, currentActiveUntil) { + changed = true + } + return changed, err + } + + if setStringMetadata(metadata, "chatgpt_account_id", snapshot.AccountID) { + changed = true + } + if setStringMetadata(metadata, "account_id", snapshot.AccountID) { + changed = true + } + if setStringMetadata(metadata, "plan_type", normalizeSubscriptionPlan(snapshot.PlanType)) { + changed = true + } + if setStringMetadata(metadata, "chatgpt_subscription_active_until", snapshot.ActiveUntil) { + changed = true + } + if setStringMetadata(metadata, "subscription_active_until", snapshot.ActiveUntil) { + changed = true + } + if setStringMetadata(metadata, "chatgpt_subscription_last_checked", time.Now().UTC().Format(time.RFC3339)) { + changed = true + } + if updateSubscriptionExpiredMetadata(metadata, snapshot.ActiveUntil) { + changed = true + } + + return changed, nil +} + +// FetchSubscriptionStatus returns ChatGPT subscription state using accounts/check with subscriptions fallback. +func FetchSubscriptionStatus(ctx context.Context, accessToken, preferredAccountID string, client *http.Client) (*SubscriptionSnapshot, error) { + if ctx == nil { + ctx = context.Background() + } + if client == nil { + client = http.DefaultClient + } + log.Debugf("Codex subscription status fetch attempt: account_id=%s", strings.TrimSpace(preferredAccountID)) + + snapshot, err := fetchAccountsCheckSnapshot(ctx, client, accessToken, preferredAccountID) + if err != nil { + return nil, err + } + if snapshot == nil { + return nil, fmt.Errorf("accounts/check returned no account records") + } + if !subscriptionMissingOrExpired(snapshot.ActiveUntil) { + log.Debugf("Codex subscription status fetched from accounts/check: account_id=%s active_until=%s", snapshot.AccountID, snapshot.ActiveUntil) + return snapshot, nil + } + + accountID := firstNonEmptyString(snapshot.AccountID, strings.TrimSpace(preferredAccountID)) + if accountID == "" { + log.Debug("Codex subscription subscriptions fallback skipped: missing account_id") + return snapshot, nil + } + log.Debugf("Codex subscription subscriptions fallback attempt: account_id=%s", accountID) + subscriptionSnapshot, err := fetchSubscriptionsSnapshot(ctx, client, accessToken, accountID) + if err != nil { + log.Debugf("Codex subscription subscriptions fallback failed: account_id=%s error=%v", accountID, err) + return snapshot, nil + } + if subscriptionSnapshot.PlanType != "" { + snapshot.PlanType = subscriptionSnapshot.PlanType + } + if subscriptionSnapshot.ActiveUntil != "" { + snapshot.ActiveUntil = subscriptionSnapshot.ActiveUntil + } + if subscriptionSnapshot.AccountID != "" { + snapshot.AccountID = subscriptionSnapshot.AccountID + } + return snapshot, nil +} + +func fetchAccountsCheckSnapshot(ctx context.Context, client *http.Client, accessToken, preferredAccountID string) (*SubscriptionSnapshot, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, subscriptionAccountsCheckURL, nil) + if err != nil { + return nil, fmt.Errorf("create accounts/check request: %w", err) + } + q := req.URL.Query() + q.Set("timezone_offset_min", "0") + req.URL.RawQuery = q.Encode() + setSubscriptionHeaders(req, accessToken, "/backend-api/accounts/check/v4-2023-04-27") + log.Debugf("Codex subscription accounts/check request attempt: account_id=%s", strings.TrimSpace(preferredAccountID)) + + payload, err := doSubscriptionJSON(client, req) + if err != nil { + return nil, err + } + return parseAccountsCheckSnapshot(payload, preferredAccountID), nil +} + +func fetchSubscriptionsSnapshot(ctx context.Context, client *http.Client, accessToken, accountID string) (*SubscriptionSnapshot, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, subscriptionsURL, nil) + if err != nil { + return nil, fmt.Errorf("create subscriptions request: %w", err) + } + q := req.URL.Query() + q.Set("account_id", accountID) + req.URL.RawQuery = q.Encode() + setSubscriptionHeaders(req, accessToken, "/backend-api/subscriptions") + log.Debugf("Codex subscription subscriptions request attempt: account_id=%s", strings.TrimSpace(accountID)) + + payload, err := doSubscriptionJSON(client, req) + if err != nil { + return nil, err + } + return &SubscriptionSnapshot{ + AccountID: strings.TrimSpace(accountID), + PlanType: firstJSONScalar(payload, "subscription_plan", "plan_type"), + ActiveUntil: firstJSONScalar(payload, "active_until", "expires_at"), + }, nil +} + +func setSubscriptionHeaders(req *http.Request, accessToken, targetPath string) { + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(accessToken)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Referer", chatGPTWebReferer) + req.Header.Set("User-Agent", chatGPTWebUserAgent) + req.Header.Set("x-openai-target-path", targetPath) + req.Header.Set("x-openai-target-route", targetPath) +} + +func doSubscriptionJSON(client *http.Client, req *http.Request) (map[string]any, error) { + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("subscription request failed: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read subscription response: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("subscription request failed with status %d body_len=%d", resp.StatusCode, len(body)) + } + var payload map[string]any + if err = json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("parse subscription response: %w", err) + } + return payload, nil +} + +func parseAccountsCheckSnapshot(payload map[string]any, preferredAccountID string) *SubscriptionSnapshot { + records := collectSubscriptionAccountRecords(payload) + if len(records) == 0 { + return nil + } + + preferredAccountID = strings.TrimSpace(preferredAccountID) + selected := records[0] + if preferredAccountID != "" { + for _, record := range records { + accountRecord := accountObjectFromRecord(record.node) + candidateID := firstJSONScalar(accountRecord, "account_id", "id", "chatgpt_account_id", "workspace_id") + // Object-keyed responses carry the account id as the map key, which + // may not be repeated inside the value; match on it too. + if candidateID == preferredAccountID || strings.TrimSpace(record.key) == preferredAccountID { + selected = record + break + } + } + } + + node, _ := selected.node.(map[string]any) + if node == nil { + return nil + } + accountRecord := accountObjectFromRecord(node) + entitlement, _ := node["entitlement"].(map[string]any) + + return &SubscriptionSnapshot{ + AccountID: firstNonEmptyString( + firstJSONScalar(accountRecord, "account_id", "id", "chatgpt_account_id", "workspace_id"), + strings.TrimSpace(selected.key), + ), + PlanType: firstNonEmptyString( + firstJSONScalar(entitlement, "subscription_plan"), + firstJSONScalar(accountRecord, "plan_type", "planType"), + ), + ActiveUntil: firstNonEmptyString( + firstJSONScalar(entitlement, "expires_at"), + firstJSONScalar(accountRecord, "expires_at"), + ), + } +} + +func collectSubscriptionAccountRecords(payload map[string]any) []subscriptionAccountRecord { + var records []subscriptionAccountRecord + for _, key := range []string{"accounts", "account_items", "items", "data"} { + value := payload[key] + switch typed := value.(type) { + case []any: + for _, item := range typed { + records = append(records, subscriptionAccountRecord{node: item}) + } + case map[string]any: + for recordKey, item := range typed { + records = append(records, subscriptionAccountRecord{key: recordKey, node: item}) + } + } + } + return records +} + +func accountObjectFromRecord(record any) map[string]any { + node, _ := record.(map[string]any) + if node == nil { + return nil + } + if account, ok := node["account"].(map[string]any); ok && account != nil { + return account + } + return node +} + +func firstJSONScalar(obj map[string]any, keys ...string) string { + if obj == nil { + return "" + } + for _, key := range keys { + if value := normalizeSubscriptionScalar(obj[key]); value != "" { + return value + } + } + return "" +} + +func normalizeSubscriptionScalar(value any) string { + switch typed := value.(type) { + case nil: + return "" + case string: + return strings.TrimSpace(typed) + case json.Number: + return strings.TrimSpace(typed.String()) + case float64: + if math.Trunc(typed) == typed { + return strconv.FormatInt(int64(typed), 10) + } + return strconv.FormatFloat(typed, 'f', -1, 64) + case float32: + asFloat := float64(typed) + if math.Trunc(asFloat) == asFloat { + return strconv.FormatInt(int64(asFloat), 10) + } + return strconv.FormatFloat(asFloat, 'f', -1, 64) + case int: + return strconv.Itoa(typed) + case int64: + return strconv.FormatInt(typed, 10) + case bool: + return strconv.FormatBool(typed) + default: + return "" + } +} + +func parseSubscriptionTime(value any) (time.Time, bool) { + raw := normalizeSubscriptionScalar(value) + if raw == "" { + return time.Time{}, false + } + if isDigits(raw) { + timestamp, err := strconv.ParseInt(raw, 10, 64) + if err != nil { + return time.Time{}, false + } + if timestamp > 1_000_000_000_000 { + timestamp /= 1000 + } + return time.Unix(timestamp, 0).UTC(), true + } + parsed, err := time.Parse(time.RFC3339, raw) + if err == nil { + return parsed.UTC(), true + } + return time.Time{}, false +} + +func subscriptionMissingOrExpired(value any) bool { + parsed, ok := parseSubscriptionTime(value) + return !ok || !parsed.After(time.Now().UTC()) +} + +// IsSubscriptionExpired reports whether a subscription expiry value is missing, +// unparseable, or already in the past relative to now (UTC). Callers can use it +// to derive the expired state at response time instead of trusting a cached +// boolean that may have gone stale since the last enrichment. +func IsSubscriptionExpired(activeUntil any) bool { + return subscriptionMissingOrExpired(activeUntil) +} + +func updateSubscriptionExpiredMetadata(metadata map[string]any, activeUntil string) bool { + parsed, ok := parseSubscriptionTime(activeUntil) + if !ok { + return false + } + expired := !parsed.After(time.Now().UTC()) + return setBoolMetadata(metadata, "subscription_expired", expired) +} + +func stringMetadata(metadata map[string]any, key string) string { + return normalizeSubscriptionScalar(metadata[key]) +} + +func setStringMetadata(metadata map[string]any, key, value string) bool { + value = strings.TrimSpace(value) + if value == "" { + return false + } + if stringMetadata(metadata, key) == value { + return false + } + metadata[key] = value + return true +} + +func setBoolMetadata(metadata map[string]any, key string, value bool) bool { + if current, ok := metadata[key].(bool); ok && current == value { + return false + } + metadata[key] = value + return true +} + +func firstNonEmptyString(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + +// normalizeSubscriptionPlan maps web subscription_plan values (e.g. +// "chatgptplusplan", "chatgpt_free_plan", "ChatGPT Pro") to the canonical plan +// tokens the rest of the app expects (e.g. "free", "plus", "pro"); values that +// are already canonical pass through unchanged. The free-plan check and model +// registration read this normalized form from Attributes["plan_type"]. +func normalizeSubscriptionPlan(plan string) string { + trimmed := strings.TrimSpace(plan) + if trimmed == "" { + return "" + } + // Collapse to lowercase alphanumerics so separators/casing don't matter. + collapsed := strings.Map(func(r rune) rune { + switch { + case r >= 'a' && r <= 'z': + return r + case r >= 'A' && r <= 'Z': + return r + ('a' - 'A') + case r >= '0' && r <= '9': + return r + default: + return -1 + } + }, trimmed) + stripped := strings.TrimSuffix(strings.TrimPrefix(collapsed, "chatgpt"), "plan") + if stripped == "" { + // Degenerate input like "chatgpt" or "plan"; keep the collapsed form. + return collapsed + } + return stripped +} + +func isDigits(value string) bool { + if value == "" { + return false + } + for _, ch := range value { + if ch < '0' || ch > '9' { + return false + } + } + return true +} diff --git a/internal/auth/codex/subscription_test.go b/internal/auth/codex/subscription_test.go new file mode 100644 index 00000000000..27e855b06bd --- /dev/null +++ b/internal/auth/codex/subscription_test.go @@ -0,0 +1,453 @@ +package codex + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + "time" +) + +func jsonResponse(req *http.Request, status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + Request: req, + } +} + +func TestParseSubscriptionTime(t *testing.T) { + rfc := "2030-01-02T03:04:05Z" + wantRFC, _ := time.Parse(time.RFC3339, rfc) + + cases := []struct { + name string + value any + ok bool + want time.Time + }{ + {name: "rfc3339 string", value: rfc, ok: true, want: wantRFC.UTC()}, + {name: "unix seconds", value: "1893553445", ok: true, want: time.Unix(1893553445, 0).UTC()}, + {name: "unix millis normalized to seconds", value: "1893553445000", ok: true, want: time.Unix(1893553445, 0).UTC()}, + {name: "numeric float seconds", value: float64(1893553445), ok: true, want: time.Unix(1893553445, 0).UTC()}, + {name: "empty", value: "", ok: false}, + {name: "garbage", value: "not-a-time", ok: false}, + {name: "nil", value: nil, ok: false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, ok := parseSubscriptionTime(tc.value) + if ok != tc.ok { + t.Fatalf("ok=%v, want %v", ok, tc.ok) + } + if tc.ok && !got.Equal(tc.want) { + t.Fatalf("time=%v, want %v", got, tc.want) + } + }) + } +} + +func TestNormalizeSubscriptionScalar(t *testing.T) { + cases := []struct { + name string + value any + want string + }{ + {name: "trimmed string", value: " hello ", want: "hello"}, + {name: "json number", value: json.Number("42"), want: "42"}, + {name: "integral float", value: float64(7), want: "7"}, + {name: "fractional float", value: float64(7.5), want: "7.5"}, + {name: "int", value: 9, want: "9"}, + {name: "int64", value: int64(11), want: "11"}, + {name: "bool", value: true, want: "true"}, + {name: "nil", value: nil, want: ""}, + {name: "unsupported", value: []string{"x"}, want: ""}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := normalizeSubscriptionScalar(tc.value); got != tc.want { + t.Fatalf("got %q, want %q", got, tc.want) + } + }) + } +} + +func TestFirstJSONScalarPrefersEarlierKeys(t *testing.T) { + obj := map[string]any{"id": "second", "account_id": "first"} + if got := firstJSONScalar(obj, "account_id", "id"); got != "first" { + t.Fatalf("got %q, want first", got) + } + if got := firstJSONScalar(obj, "missing", "id"); got != "second" { + t.Fatalf("got %q, want second", got) + } + if got := firstJSONScalar(nil, "id"); got != "" { + t.Fatalf("nil obj: got %q, want empty", got) + } +} + +func TestSubscriptionMissingOrExpired(t *testing.T) { + future := time.Now().UTC().Add(48 * time.Hour).Format(time.RFC3339) + past := time.Now().UTC().Add(-48 * time.Hour).Format(time.RFC3339) + + if subscriptionMissingOrExpired(future) { + t.Fatalf("future expiry should not be expired") + } + if !subscriptionMissingOrExpired(past) { + t.Fatalf("past expiry should be expired") + } + if !subscriptionMissingOrExpired("") { + t.Fatalf("missing expiry should be treated as expired") + } +} + +func TestUpdateSubscriptionExpiredMetadata(t *testing.T) { + future := time.Now().UTC().Add(72 * time.Hour).Format(time.RFC3339) + past := time.Now().UTC().Add(-72 * time.Hour).Format(time.RFC3339) + + t.Run("active sets expired false", func(t *testing.T) { + meta := map[string]any{} + if !updateSubscriptionExpiredMetadata(meta, future) { + t.Fatalf("expected changed=true on first write") + } + if v, _ := meta["subscription_expired"].(bool); v { + t.Fatalf("subscription_expired=%v, want false", meta["subscription_expired"]) + } + // Idempotent: second write with same value reports no change. + if updateSubscriptionExpiredMetadata(meta, future) { + t.Fatalf("expected changed=false when value unchanged") + } + }) + + t.Run("past sets expired true", func(t *testing.T) { + meta := map[string]any{} + if !updateSubscriptionExpiredMetadata(meta, past) { + t.Fatalf("expected changed=true") + } + if v, _ := meta["subscription_expired"].(bool); !v { + t.Fatalf("subscription_expired=%v, want true", meta["subscription_expired"]) + } + }) + + t.Run("unparseable leaves metadata untouched", func(t *testing.T) { + meta := map[string]any{} + if updateSubscriptionExpiredMetadata(meta, "nope") { + t.Fatalf("expected changed=false for unparseable expiry") + } + if _, ok := meta["subscription_expired"]; ok { + t.Fatalf("subscription_expired should not be set for unparseable expiry") + } + }) +} + +func TestSetStringAndBoolMetadata(t *testing.T) { + meta := map[string]any{} + if !setStringMetadata(meta, "k", " v ") { + t.Fatalf("expected changed=true on first set") + } + if got := stringMetadata(meta, "k"); got != "v" { + t.Fatalf("stringMetadata=%q, want v", got) + } + if setStringMetadata(meta, "k", "v") { + t.Fatalf("expected changed=false on identical set") + } + if setStringMetadata(meta, "blank", " ") { + t.Fatalf("expected changed=false for blank value") + } + if _, ok := meta["blank"]; ok { + t.Fatalf("blank key should not be written") + } + + if !setBoolMetadata(meta, "flag", true) { + t.Fatalf("expected changed=true on first bool set") + } + if setBoolMetadata(meta, "flag", true) { + t.Fatalf("expected changed=false on identical bool set") + } + if !setBoolMetadata(meta, "flag", false) { + t.Fatalf("expected changed=true when bool flips") + } +} + +func TestFirstNonEmptyStringAndIsDigits(t *testing.T) { + if got := firstNonEmptyString("", " ", "x", "y"); got != "x" { + t.Fatalf("got %q, want x", got) + } + if got := firstNonEmptyString("", " "); got != "" { + t.Fatalf("got %q, want empty", got) + } + if !isDigits("12345") { + t.Fatalf("12345 should be digits") + } + if isDigits("12a45") || isDigits("") { + t.Fatalf("non-digit / empty should be false") + } +} + +func TestParseAccountsCheckSnapshot(t *testing.T) { + t.Run("selects preferred account and reads entitlement", func(t *testing.T) { + payload := map[string]any{ + "accounts": []any{ + map[string]any{ + "account": map[string]any{"account_id": "acc-1", "plan_type": "plus"}, + "entitlement": map[string]any{"subscription_plan": "pro", "expires_at": "2030-01-01T00:00:00Z"}, + }, + map[string]any{ + "account": map[string]any{"account_id": "acc-2"}, + "entitlement": map[string]any{"subscription_plan": "team", "expires_at": "2031-01-01T00:00:00Z"}, + }, + }, + } + snap := parseAccountsCheckSnapshot(payload, "acc-2") + if snap == nil { + t.Fatalf("expected snapshot") + } + if snap.AccountID != "acc-2" { + t.Fatalf("AccountID=%q, want acc-2", snap.AccountID) + } + if snap.PlanType != "team" { + t.Fatalf("PlanType=%q, want team", snap.PlanType) + } + if snap.ActiveUntil != "2031-01-01T00:00:00Z" { + t.Fatalf("ActiveUntil=%q, want 2031-01-01T00:00:00Z", snap.ActiveUntil) + } + }) + + t.Run("defaults to first record when preferred missing", func(t *testing.T) { + payload := map[string]any{ + "accounts": []any{ + map[string]any{"account_id": "only", "plan_type": "plus", "expires_at": "2030-06-01T00:00:00Z"}, + }, + } + snap := parseAccountsCheckSnapshot(payload, "does-not-exist") + if snap == nil || snap.AccountID != "only" || snap.PlanType != "plus" { + t.Fatalf("unexpected snapshot: %#v", snap) + } + }) + + t.Run("no records returns nil", func(t *testing.T) { + if snap := parseAccountsCheckSnapshot(map[string]any{}, ""); snap != nil { + t.Fatalf("expected nil, got %#v", snap) + } + }) + + t.Run("selects object-keyed account by its map key", func(t *testing.T) { + // accounts is keyed by account id; the values do not repeat account_id, + // so selection must fall back to the map key. + payload := map[string]any{ + "accounts": map[string]any{ + "acc-1": map[string]any{"entitlement": map[string]any{"subscription_plan": "free", "expires_at": "2030-01-01T00:00:00Z"}}, + "acc-2": map[string]any{"entitlement": map[string]any{"subscription_plan": "pro", "expires_at": "2031-01-01T00:00:00Z"}}, + }, + } + snap := parseAccountsCheckSnapshot(payload, "acc-2") + if snap == nil { + t.Fatalf("expected snapshot") + } + if snap.PlanType != "pro" { + t.Fatalf("PlanType=%q, want pro (selected wrong keyed account)", snap.PlanType) + } + if snap.AccountID != "acc-2" { + t.Fatalf("AccountID=%q, want acc-2 (should fall back to map key)", snap.AccountID) + } + }) +} + +func TestFetchSubscriptionStatus_AccountsCheckPrimary(t *testing.T) { + future := time.Now().UTC().Add(30 * 24 * time.Hour).Format(time.RFC3339) + client := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if strings.Contains(req.URL.Path, "/accounts/check/") { + body := `{"accounts":[{"account":{"account_id":"acc-1"},"entitlement":{"subscription_plan":"pro","expires_at":"` + future + `"}}]}` + return jsonResponse(req, http.StatusOK, body), nil + } + t.Fatalf("unexpected request to %s", req.URL) + return nil, nil + })} + + snap, err := FetchSubscriptionStatus(context.Background(), "token", "acc-1", client) + if err != nil { + t.Fatalf("FetchSubscriptionStatus error: %v", err) + } + if snap.AccountID != "acc-1" || snap.PlanType != "pro" || snap.ActiveUntil != future { + t.Fatalf("snapshot=%#v", snap) + } +} + +func TestFetchSubscriptionStatus_FallsBackToSubscriptions(t *testing.T) { + past := time.Now().UTC().Add(-30 * 24 * time.Hour).Format(time.RFC3339) + future := time.Now().UTC().Add(30 * 24 * time.Hour).Format(time.RFC3339) + var hitSubscriptions bool + + client := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + switch { + case strings.Contains(req.URL.Path, "/accounts/check/"): + // Expired entitlement triggers the subscriptions fallback. + body := `{"accounts":[{"account":{"account_id":"acc-9"},"entitlement":{"subscription_plan":"free","expires_at":"` + past + `"}}]}` + return jsonResponse(req, http.StatusOK, body), nil + case strings.Contains(req.URL.Path, "/subscriptions"): + hitSubscriptions = true + body := `{"subscription_plan":"pro","active_until":"` + future + `"}` + return jsonResponse(req, http.StatusOK, body), nil + default: + t.Fatalf("unexpected request to %s", req.URL) + return nil, nil + } + })} + + snap, err := FetchSubscriptionStatus(context.Background(), "token", "acc-9", client) + if err != nil { + t.Fatalf("FetchSubscriptionStatus error: %v", err) + } + if !hitSubscriptions { + t.Fatalf("expected subscriptions fallback to be called") + } + if snap.PlanType != "pro" || snap.ActiveUntil != future { + t.Fatalf("snapshot=%#v, want plan pro / active %s", snap, future) + } +} + +func TestEnrichSubscriptionMetadataForTokens_UsesExistingExpiryWithoutBackend(t *testing.T) { + future := time.Now().UTC().Add(20 * 24 * time.Hour).Format(time.RFC3339) + meta := map[string]any{"subscription_active_until": future} + // A client that fails the test if any request is made: a still-valid expiry + // must short-circuit before any backend call. + client := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + t.Fatalf("unexpected backend call to %s", req.URL) + return nil, nil + })} + + changed, err := EnrichSubscriptionMetadataForTokens(context.Background(), meta, "", "", "", client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !changed { + t.Fatalf("expected changed=true (subscription_expired written)") + } + if v, _ := meta["subscription_expired"].(bool); v { + t.Fatalf("subscription_expired=%v, want false for active subscription", meta["subscription_expired"]) + } +} + +func TestEnrichSubscriptionMetadataForTokens_BackendFallbackPopulatesMetadata(t *testing.T) { + future := time.Now().UTC().Add(20 * 24 * time.Hour).Format(time.RFC3339) + client := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if strings.Contains(req.URL.Path, "/accounts/check/") { + body := `{"accounts":[{"account":{"account_id":"acc-7"},"entitlement":{"subscription_plan":"pro","expires_at":"` + future + `"}}]}` + return jsonResponse(req, http.StatusOK, body), nil + } + t.Fatalf("unexpected request to %s", req.URL) + return nil, nil + })} + + meta := map[string]any{} + changed, err := EnrichSubscriptionMetadataForTokens(context.Background(), meta, "", "access-token", "acc-7", client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !changed { + t.Fatalf("expected changed=true") + } + if got := stringMetadata(meta, "plan_type"); got != "pro" { + t.Fatalf("plan_type=%q, want pro", got) + } + if got := stringMetadata(meta, "subscription_active_until"); got != future { + t.Fatalf("subscription_active_until=%q, want %s", got, future) + } + if v, _ := meta["subscription_expired"].(bool); v { + t.Fatalf("subscription_expired=%v, want false", meta["subscription_expired"]) + } +} + +func TestNormalizeSubscriptionPlan(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"chatgptplusplan", "plus"}, + {"chatgptfreeplan", "free"}, + {"chatgptproplan", "pro"}, + {"chatgpt_team_plan", "team"}, + {"ChatGPT Pro", "pro"}, + {"plus", "plus"}, + {"free", "free"}, + {"enterprise", "enterprise"}, + {" Plus ", "plus"}, + {"", ""}, + } + for _, tc := range cases { + if got := normalizeSubscriptionPlan(tc.in); got != tc.want { + t.Fatalf("normalizeSubscriptionPlan(%q)=%q, want %q", tc.in, got, tc.want) + } + } +} + +func TestEnrichSubscriptionMetadataForTokens_NormalizesBackendPlan(t *testing.T) { + future := time.Now().UTC().Add(20 * 24 * time.Hour).Format(time.RFC3339) + client := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if strings.Contains(req.URL.Path, "/accounts/check/") { + // Backend returns a raw web plan value. + body := `{"accounts":[{"account":{"account_id":"acc-1"},"entitlement":{"subscription_plan":"chatgptplusplan","expires_at":"` + future + `"}}]}` + return jsonResponse(req, http.StatusOK, body), nil + } + t.Fatalf("unexpected request to %s", req.URL) + return nil, nil + })} + + meta := map[string]any{} + if _, err := EnrichSubscriptionMetadataForTokens(context.Background(), meta, "", "access-token", "acc-1", client); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := stringMetadata(meta, "plan_type"); got != "plus" { + t.Fatalf("plan_type=%q, want normalized plus", got) + } +} + +func TestEnrichSubscriptionMetadataForTokens_NoTokensNoChange(t *testing.T) { + meta := map[string]any{} + changed, err := EnrichSubscriptionMetadataForTokens(context.Background(), meta, "", "", "", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if changed { + t.Fatalf("expected changed=false when no tokens and no expiry present") + } + if len(meta) != 0 { + t.Fatalf("metadata should remain empty, got %#v", meta) + } +} + +func TestEnrichSubscriptionMetadataWrappers(t *testing.T) { + future := time.Now().UTC().Add(15 * 24 * time.Hour).Format(time.RFC3339) + + t.Run("package-level wrapper", func(t *testing.T) { + meta := map[string]any{"subscription_active_until": future} + changed, err := EnrichSubscriptionMetadata(context.Background(), meta, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !changed { + t.Fatalf("expected changed=true (subscription_expired written)") + } + }) + + t.Run("nil metadata is a no-op", func(t *testing.T) { + changed, err := EnrichSubscriptionMetadata(context.Background(), nil, nil) + if err != nil || changed { + t.Fatalf("nil metadata: changed=%v err=%v, want false/nil", changed, err) + } + }) + + t.Run("CodexAuth method wrapper", func(t *testing.T) { + meta := map[string]any{"subscription_active_until": future} + auth := &CodexAuth{} + changed, err := auth.EnrichSubscriptionMetadata(context.Background(), meta, "", "", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !changed { + t.Fatalf("expected changed=true") + } + }) +} diff --git a/internal/watcher/synthesizer/file.go b/internal/watcher/synthesizer/file.go index 17126705774..00eff19a937 100644 --- a/internal/watcher/synthesizer/file.go +++ b/internal/watcher/synthesizer/file.go @@ -191,6 +191,8 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) [] } } } + // Enriched metadata (when present) takes precedence over the JWT claim. + coreauth.ApplyCodexSubscriptionAttributes(a) } if provider == "gemini-cli" { if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { diff --git a/sdk/auth/codex.go b/sdk/auth/codex.go index be58c9c5a60..bcc8153579e 100644 --- a/sdk/auth/codex.go +++ b/sdk/auth/codex.go @@ -194,5 +194,5 @@ waitForCallback: return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err) } - return a.buildAuthRecord(authSvc, authBundle) + return a.buildAuthRecord(ctx, authSvc, authBundle) } diff --git a/sdk/auth/codex_device.go b/sdk/auth/codex_device.go index d7ea4e1fe93..6bf951ac0bd 100644 --- a/sdk/auth/codex_device.go +++ b/sdk/auth/codex_device.go @@ -122,7 +122,7 @@ func (a *CodexAuthenticator) loginWithDeviceFlow(ctx context.Context, cfg *confi return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err) } - return a.buildAuthRecord(authSvc, authBundle) + return a.buildAuthRecord(ctx, authSvc, authBundle) } func requestCodexDeviceUserCode(ctx context.Context, client *http.Client) (*codexDeviceUserCodeResponse, error) { @@ -251,7 +251,7 @@ func codexDeviceIsSuccessStatus(code int) bool { return code >= 200 && code < 300 } -func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundle *codex.CodexAuthBundle) (*coreauth.Auth, error) { +func (a *CodexAuthenticator) buildAuthRecord(ctx context.Context, authSvc *codex.CodexAuth, authBundle *codex.CodexAuthBundle) (*coreauth.Auth, error) { tokenStorage := authSvc.CreateTokenStorage(authBundle) if tokenStorage == nil || tokenStorage.Email == "" { @@ -273,22 +273,40 @@ func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundl fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true) metadata := map[string]any{ - "email": tokenStorage.Email, + "email": tokenStorage.Email, + "account_id": tokenStorage.AccountID, } + if ctx == nil { + ctx = context.Background() + } + subscriptionCtx, cancelSubscription := context.WithTimeout(ctx, 20*time.Second) + if _, errEnrich := authSvc.EnrichSubscriptionMetadata( + subscriptionCtx, + metadata, + tokenStorage.IDToken, + tokenStorage.AccessToken, + tokenStorage.AccountID, + ); errEnrich != nil { + log.Warnf("Codex subscription metadata enrichment failed: %v", errEnrich) + } + cancelSubscription() fmt.Println("Codex authentication successful") if authBundle.APIKey != "" { fmt.Println("Codex API key obtained and stored") } - return &coreauth.Auth{ + record := &coreauth.Auth{ ID: fileName, Provider: a.Provider(), FileName: fileName, Storage: tokenStorage, Metadata: metadata, - Attributes: map[string]string{ - "plan_type": planType, - }, - }, nil + // Seed with the pre-enrichment JWT plan; the shared helper overrides it + // with the enriched/normalized value from metadata when available, so + // SDK callers select the correct Codex catalog before a file reload. + Attributes: map[string]string{"plan_type": planType}, + } + coreauth.ApplyCodexSubscriptionAttributes(record) + return record, nil } diff --git a/sdk/auth/filestore.go b/sdk/auth/filestore.go index 584481ad3ea..4a5aab17b65 100644 --- a/sdk/auth/filestore.go +++ b/sdk/auth/filestore.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "time" + "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v7/sdk/pluginapi" ) @@ -173,7 +174,7 @@ func (s *FileTokenStore) List(ctx context.Context) ([]*cliproxyauth.Auth, error) if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") { return nil } - auth, err := s.readAuthFile(path, dir) + auth, err := s.readAuthFile(ctx, path, dir) if err != nil { return nil } @@ -215,7 +216,10 @@ func (s *FileTokenStore) resolveDeletePath(id string) (string, error) { return filepath.Join(dir, id), nil } -func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, error) { +func (s *FileTokenStore) readAuthFile(ctx context.Context, path, baseDir string) (*cliproxyauth.Auth, error) { + if ctx == nil { + ctx = context.Background() + } data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("read file: %w", err) @@ -285,12 +289,25 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, } } } + disabled, _ := metadata["disabled"].(bool) + // Skip the subscription lookup for disabled credentials: they are excluded + // from runtime use, so contacting ChatGPT here would only delay List by up + // to the 20s timeout per file for no benefit. + if provider == "codex" && !disabled { + // Derive the timeout from the caller's context so a cancelled load + // (startup/shutdown) aborts promptly instead of blocking ~20s per file. + enrichCtx, cancelEnrich := context.WithTimeout(ctx, 20*time.Second) + changed, _ := codex.EnrichSubscriptionMetadata(enrichCtx, metadata, http.DefaultClient) + cancelEnrich() + if changed { + s.persistCodexSubscriptionFields(path, metadata) + } + } info, errStat = os.Stat(path) if errStat != nil { return nil, fmt.Errorf("stat file: %w", errStat) } id := s.idFor(path, baseDir) - disabled, _ := metadata["disabled"].(bool) status := cliproxyauth.StatusActive if disabled { status = cliproxyauth.StatusDisabled @@ -312,10 +329,70 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, if email, ok := metadata["email"].(string); ok && email != "" { auth.Attributes["email"] = email } + // Mirror the Codex subscription fields into attributes the runtime reads + // (the model catalog selects on Attributes["plan_type"]). + if provider == "codex" { + cliproxyauth.ApplyCodexSubscriptionAttributes(auth) + } cliproxyauth.ApplyCustomHeadersFromMetadata(auth) return auth, nil } +// codexSubscriptionMetadataKeys are the only keys the List-path enrichment is +// allowed to write back, so a concurrent token Save is never rolled back. +var codexSubscriptionMetadataKeys = []string{ + "plan_type", + "subscription_active_until", + "subscription_expired", + "chatgpt_account_id", + "account_id", + "chatgpt_subscription_active_until", + "chatgpt_subscription_last_checked", +} + +// persistCodexSubscriptionFields writes the enriched subscription fields back +// to disk under the store mutex (the same lock Save uses). It re-reads the +// current file and merges only the subscription keys, so a token refresh/login +// Save racing the enrichment cannot have its fresh access/refresh tokens +// clobbered by a stale read taken before the network call. +func (s *FileTokenStore) persistCodexSubscriptionFields(path string, enriched map[string]any) { + s.mu.Lock() + defer s.mu.Unlock() + + data, errRead := os.ReadFile(path) + if errRead != nil || len(data) == 0 { + return + } + current := make(map[string]any) + if errUnmarshal := json.Unmarshal(data, ¤t); errUnmarshal != nil { + return + } + for _, key := range codexSubscriptionMetadataKeys { + if value, ok := enriched[key]; ok { + current[key] = value + } + } + _ = writeAuthMetadataFile(path, current) +} + +func writeAuthMetadataFile(path string, metadata map[string]any) error { + raw, errMarshal := json.Marshal(metadata) + if errMarshal != nil { + return errMarshal + } + // Write to a sibling temp file and atomically rename into place so a crash + // or concurrent read never observes a truncated/empty credential file. + tmpPath := path + ".tmp" + if errWrite := os.WriteFile(tmpPath, raw, 0o600); errWrite != nil { + return errWrite + } + if errRename := os.Rename(tmpPath, path); errRename != nil { + _ = os.Remove(tmpPath) + return errRename + } + return nil +} + func (s *FileTokenStore) idFor(path, baseDir string) string { id := path if baseDir != "" { diff --git a/sdk/auth/filestore_metadata_write_test.go b/sdk/auth/filestore_metadata_write_test.go new file mode 100644 index 00000000000..8cb250d5c58 --- /dev/null +++ b/sdk/auth/filestore_metadata_write_test.go @@ -0,0 +1,168 @@ +package auth + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" +) + +func TestWriteAuthMetadataFile_WritesAtomicallyWithoutTempLeftover(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "codex.json") + // Seed an existing file to ensure it is replaced, not appended to. + if err := os.WriteFile(path, []byte(`{"old":true}`), 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + + meta := map[string]any{"type": "codex", "access_token": "tok"} + if err := writeAuthMetadataFile(path, meta); err != nil { + t.Fatalf("writeAuthMetadataFile: %v", err) + } + + raw, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read: %v", err) + } + var got map[string]any + if err := json.Unmarshal(raw, &got); err != nil { + t.Fatalf("unmarshal %q: %v", string(raw), err) + } + if got["access_token"] != "tok" || got["type"] != "codex" { + t.Fatalf("unexpected content: %s", string(raw)) + } + if _, stale := got["old"]; stale { + t.Fatalf("old content not replaced: %s", string(raw)) + } + + // The atomic rename must not leave the sibling temp file behind. + if _, err := os.Stat(path + ".tmp"); !os.IsNotExist(err) { + t.Fatalf("temp file should not remain after atomic write (err=%v)", err) + } +} + +func TestList_HonorsContextCancellationDuringCodexEnrichment(t *testing.T) { + dir := t.TempDir() + // A codex auth with an expired subscription forces the backend enrichment + // path (which would otherwise wait up to 20s on a slow/unreachable host). + seed := `{"type":"codex","access_token":"a","id_token":"","subscription_active_until":"2000-01-01T00:00:00Z"}` + if err := os.WriteFile(filepath.Join(dir, "codex.json"), []byte(seed), 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + + store := NewFileTokenStore() + store.SetBaseDir(dir) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // already cancelled + + done := make(chan struct{}) + go func() { + _, _ = store.List(ctx) + close(done) + }() + + select { + case <-done: + // Returned promptly: the enrichment honored the cancelled context. + case <-time.After(5 * time.Second): + t.Fatal("List did not return promptly under a cancelled context (enrichment ignored ctx)") + } +} + +func TestList_CopiesCodexPlanTypeIntoAttributes(t *testing.T) { + dir := t.TempDir() + // Future expiry keeps enrichment offline (no backend call) while still + // exercising the attribute copy. + future := time.Now().UTC().Add(30 * 24 * time.Hour).Format("2006-01-02T15:04:05Z") + seed := `{"type":"codex","access_token":"a","plan_type":"plus","subscription_active_until":"` + future + `"}` + if err := os.WriteFile(filepath.Join(dir, "codex.json"), []byte(seed), 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + + store := NewFileTokenStore() + store.SetBaseDir(dir) + auths, err := store.List(context.Background()) + if err != nil { + t.Fatalf("List: %v", err) + } + if len(auths) != 1 { + t.Fatalf("want 1 auth, got %d", len(auths)) + } + if got := auths[0].Attributes["plan_type"]; got != "plus" { + t.Fatalf("Attributes[plan_type] = %q, want plus (runtime catalog selection)", got) + } + if got := auths[0].Attributes["subscription_active_until"]; got != future { + t.Fatalf("Attributes[subscription_active_until] = %q, want %s", got, future) + } +} + +func TestPersistCodexSubscriptionFields_DoesNotClobberTokens(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "codex.json") + // Simulate the latest on-disk state written by a concurrent token Save. + if err := os.WriteFile(path, []byte(`{"type":"codex","access_token":"fresh","refresh_token":"fresh-r"}`), 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + + store := NewFileTokenStore() + store.SetBaseDir(dir) + // Enrichment built from a STALE read (old tokens) plus new subscription info. + enriched := map[string]any{ + "access_token": "stale", + "refresh_token": "stale-r", + "plan_type": "plus", + "subscription_active_until": "2030-01-01T00:00:00Z", + "subscription_expired": false, + } + store.persistCodexSubscriptionFields(path, enriched) + + raw, _ := os.ReadFile(path) + var got map[string]any + if err := json.Unmarshal(raw, &got); err != nil { + t.Fatalf("unmarshal %q: %v", string(raw), err) + } + // Tokens must remain the fresh on-disk values, not be rolled back. + if got["access_token"] != "fresh" || got["refresh_token"] != "fresh-r" { + t.Fatalf("tokens were clobbered: %s", string(raw)) + } + // Subscription fields must be written. + if got["plan_type"] != "plus" { + t.Fatalf("plan_type not persisted: %s", string(raw)) + } + if got["subscription_active_until"] != "2030-01-01T00:00:00Z" { + t.Fatalf("expiry not persisted: %s", string(raw)) + } +} + +func TestList_SkipsEnrichmentForDisabledCodex(t *testing.T) { + dir := t.TempDir() + // Disabled codex auth with an expired/missing expiry would otherwise hit the + // network; List must return promptly without contacting ChatGPT. + seed := `{"type":"codex","access_token":"a","id_token":"","disabled":true,"subscription_active_until":"2000-01-01T00:00:00Z"}` + if err := os.WriteFile(filepath.Join(dir, "codex.json"), []byte(seed), 0o600); err != nil { + t.Fatalf("seed: %v", err) + } + + store := NewFileTokenStore() + store.SetBaseDir(dir) + + done := make(chan struct{}) + var auths []*cliproxyauth.Auth + go func() { + auths, _ = store.List(context.Background()) + close(done) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("List blocked on enrichment for a disabled codex auth") + } + if len(auths) != 1 || !auths[0].Disabled { + t.Fatalf("expected one disabled auth, got %#v", auths) + } +} diff --git a/sdk/cliproxy/auth/codex_subscription.go b/sdk/cliproxy/auth/codex_subscription.go new file mode 100644 index 00000000000..dc25e8ddfa6 --- /dev/null +++ b/sdk/cliproxy/auth/codex_subscription.go @@ -0,0 +1,33 @@ +package auth + +import "strings" + +// codexSubscriptionAttributeKeys are the metadata keys mirrored into runtime +// attributes for Codex auths. +var codexSubscriptionAttributeKeys = []string{"plan_type", "subscription_active_until"} + +// ApplyCodexSubscriptionAttributes mirrors the Codex subscription fields stored +// in metadata into the attributes the runtime reads. Codex model-catalog +// selection keys off Attributes["plan_type"], so every code path that builds a +// Codex Auth must keep this in sync. Routing them all through this single +// helper avoids the per-site drift that otherwise leaves Free/Plus/Team +// accounts defaulting to the Pro catalog until a file reload. +func ApplyCodexSubscriptionAttributes(auth *Auth) { + if auth == nil || auth.Metadata == nil { + return + } + for _, key := range codexSubscriptionAttributeKeys { + value, ok := auth.Metadata[key].(string) + if !ok { + continue + } + trimmed := strings.TrimSpace(value) + if trimmed == "" { + continue + } + if auth.Attributes == nil { + auth.Attributes = make(map[string]string) + } + auth.Attributes[key] = trimmed + } +} diff --git a/sdk/cliproxy/auth/codex_subscription_test.go b/sdk/cliproxy/auth/codex_subscription_test.go new file mode 100644 index 00000000000..7df701b03eb --- /dev/null +++ b/sdk/cliproxy/auth/codex_subscription_test.go @@ -0,0 +1,39 @@ +package auth + +import "testing" + +func TestApplyCodexSubscriptionAttributes(t *testing.T) { + t.Run("copies plan and expiry from metadata", func(t *testing.T) { + a := &Auth{Metadata: map[string]any{ + "plan_type": " plus ", + "subscription_active_until": "2030-01-01T00:00:00Z", + }} + ApplyCodexSubscriptionAttributes(a) + if a.Attributes["plan_type"] != "plus" { + t.Fatalf("plan_type=%q, want plus", a.Attributes["plan_type"]) + } + if a.Attributes["subscription_active_until"] != "2030-01-01T00:00:00Z" { + t.Fatalf("subscription_active_until=%q", a.Attributes["subscription_active_until"]) + } + }) + + t.Run("overrides an existing seed value", func(t *testing.T) { + a := &Auth{ + Attributes: map[string]string{"plan_type": "free"}, + Metadata: map[string]any{"plan_type": "pro"}, + } + ApplyCodexSubscriptionAttributes(a) + if a.Attributes["plan_type"] != "pro" { + t.Fatalf("plan_type=%q, want pro (metadata overrides seed)", a.Attributes["plan_type"]) + } + }) + + t.Run("no metadata is a safe no-op", func(t *testing.T) { + a := &Auth{Attributes: map[string]string{"plan_type": "keep"}} + ApplyCodexSubscriptionAttributes(a) + if a.Attributes["plan_type"] != "keep" { + t.Fatalf("plan_type=%q, want keep", a.Attributes["plan_type"]) + } + ApplyCodexSubscriptionAttributes(nil) // must not panic + }) +}