Skip to content

Commit ace5dda

Browse files
aprimakinaclaude
andcommitted
refactor(auth): address OAuth Bearer review feedback
- Identify OAuth users for analytics on login, mirroring the PAT path's ValidateAPIKey. planType is omitted until the gateway returns it on the oauth /auth/info branch. - Build the token-authenticated client once in loginWithOAuth and reuse it for project selection and analytics identification (was constructed twice). - Route OAuth project-list errors through common.ExitWithErrorFromStatusCode so backend messages and exit codes surface in the CLI. - Reuse the pooled getHTTPClient in the OAuth client so Bearer requests and token refreshes share its connection limits. - Export analytics.Enabled() to skip the extra /auth/info round-trip when analytics is disabled. - Drop the deprecated credential test seam: GetStoredCredentials is now the single accessor. Migrate test sites to mockTestPAT/mockNotLoggedIn helpers, remove config.GetCredentials, and add OAuth-credential coverage asserting the client sends a Bearer token. - Restore doc comments dropped during the initial refactor. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent caba91e commit ace5dda

12 files changed

Lines changed: 195 additions & 272 deletions

File tree

internal/tiger/analytics/analytics.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func (a *Analytics) Identify(options ...Option) {
137137
)
138138

139139
// Check if analytics is disabled
140-
if !a.enabled() {
140+
if !a.Enabled() {
141141
logger.Debug("Analytics identify skipped (analytics disabled)")
142142
return
143143
}
@@ -202,7 +202,7 @@ func (a *Analytics) Track(event string, options ...Option) {
202202
)
203203

204204
// Check if analytics is disabled
205-
if !a.enabled() {
205+
if !a.Enabled() {
206206
logger.Debug("Analytics event skipped (analytics disabled)")
207207
return
208208
}
@@ -239,7 +239,9 @@ func (a *Analytics) Track(event string, options ...Option) {
239239
logger.Debug("Analytics event sent", zap.String("status", *resp.JSON200.Status))
240240
}
241241

242-
func (a *Analytics) enabled() bool {
242+
// Enabled reports whether analytics events will actually be sent given the
243+
// current config and environment.
244+
func (a *Analytics) Enabled() bool {
243245
if envVarIsTrue("DO_NOT_TRACK") ||
244246
envVarIsTrue("NO_TELEMETRY") ||
245247
envVarIsTrue("DISABLE_TELEMETRY") {

internal/tiger/api/client_util.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,15 @@ func NewTigerClientWithToken(cfg *config.Config, token *oauth2.Token, persist fu
7070
},
7171
}
7272

73-
var src oauth2.TokenSource = oauthCfg.TokenSource(context.Background(), token)
73+
// Stash our pooled client in the context.
74+
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, getHTTPClient())
75+
76+
var src oauth2.TokenSource = oauthCfg.TokenSource(ctx, token)
7477
if persist != nil {
7578
src = &persistingTokenSource{base: src, persist: persist, last: token.AccessToken}
7679
}
7780

78-
httpClient := oauth2.NewClient(context.Background(), src)
81+
httpClient := oauth2.NewClient(ctx, src)
7982
httpClient.Timeout = 30 * time.Second
8083

8184
client, err := NewClientWithResponses(cfg.APIURL, WithHTTPClient(httpClient), WithRequestEditorFn(func(_ context.Context, req *http.Request) error {

internal/tiger/cmd/auth.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,15 @@ Examples:
9595
out: cmd.OutOrStdout(),
9696
}
9797

98-
token, projectID, err := l.loginWithOAuth(cmd.Context())
98+
token, client, projectID, err := l.loginWithOAuth(cmd.Context())
9999
if err != nil {
100100
return err
101101
}
102102
if err := config.StoreOAuthCredentials(token, projectID); err != nil {
103103
return fmt.Errorf("failed to store credentials: %w", err)
104104
}
105+
// Identify the user for analytics.
106+
common.IdentifyOAuthUser(cmd.Context(), cfg, client, projectID)
105107
finishLogin(cmd, projectID)
106108
return nil
107109
} else if creds.publicKey == "" || creds.secretKey == "" {

internal/tiger/cmd/auth_test.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,11 @@ func TestAuthLogin_KeyFlags(t *testing.T) {
114114
expectedAPIKey := "test-public-key:test-secret-key"
115115
expectedProjectID := "test-project-id" // Comes from mock validation function
116116

117-
apiKey, projectID, err := config.GetCredentials()
117+
creds, err := config.GetStoredCredentials()
118118
if err != nil {
119119
t.Fatalf("Credentials not stored in keyring or file: %v", err)
120120
}
121+
apiKey, projectID := creds.APIKey, creds.ProjectID
121122

122123
if apiKey != expectedAPIKey {
123124
t.Errorf("Expected API key '%s', got '%s'", expectedAPIKey, apiKey)
@@ -150,10 +151,11 @@ func TestAuthLogin_KeyEnvironmentVariables(t *testing.T) {
150151
// Verify credentials were stored
151152
expectedAPIKey := "env-public-key:env-secret-key"
152153
expectedProjectID := "test-project-id" // Auto-detected from mock
153-
storedKey, storedProjectID, err := config.GetCredentials()
154+
creds, err := config.GetStoredCredentials()
154155
if err != nil {
155156
t.Fatalf("Failed to get stored credentials: %v", err)
156157
}
158+
storedKey, storedProjectID := creds.APIKey, creds.ProjectID
157159
if storedKey != expectedAPIKey {
158160
t.Errorf("Expected API key '%s', got '%s'", expectedAPIKey, storedKey)
159161
}
@@ -529,10 +531,11 @@ func TestAuthLogin_KeyringFallback(t *testing.T) {
529531
}
530532

531533
// Verify file storage works
532-
storedKey, storedProjectID, err := config.GetCredentials()
534+
creds, err := config.GetStoredCredentials()
533535
if err != nil {
534536
t.Fatalf("Failed to get credentials from file fallback: %v", err)
535537
}
538+
storedKey, storedProjectID := creds.APIKey, creds.ProjectID
536539
if storedKey != expectedAPIKey {
537540
t.Errorf("Expected API key '%s', got '%s'", expectedAPIKey, storedKey)
538541
}
@@ -589,10 +592,11 @@ func TestAuthLogin_EnvironmentVariable_FileOnly(t *testing.T) {
589592
}
590593

591594
// Verify getCredentials works with file-only storage
592-
storedKey, storedProjectID, err := config.GetCredentials()
595+
creds, err := config.GetStoredCredentials()
593596
if err != nil {
594597
t.Fatalf("Failed to get credentials from file: %v", err)
595598
}
599+
storedKey, storedProjectID := creds.APIKey, creds.ProjectID
596600
if storedKey != expectedAPIKey {
597601
t.Errorf("Expected API key '%s', got '%s'", expectedAPIKey, storedKey)
598602
}
@@ -684,7 +688,7 @@ func TestAuthLogout_Success(t *testing.T) {
684688
}
685689

686690
// Verify credentials are stored
687-
_, _, err = config.GetCredentials()
691+
_, err = config.GetStoredCredentials()
688692
if err != nil {
689693
t.Fatalf("Credentials should be stored: %v", err)
690694
}
@@ -700,7 +704,7 @@ func TestAuthLogout_Success(t *testing.T) {
700704
}
701705

702706
// Verify credentials are removed
703-
_, _, err = config.GetCredentials()
707+
_, err = config.GetStoredCredentials()
704708
if err == nil {
705709
t.Fatal("Credentials should be removed after logout")
706710
}

internal/tiger/cmd/auth_validation_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func TestAuthLogin_APIKeyValidationFailure(t *testing.T) {
6161
}
6262

6363
// Verify that no credentials were stored
64-
if _, _, err := config.GetCredentials(); err == nil {
64+
if _, err := config.GetStoredCredentials(); err == nil {
6565
t.Error("Credentials should not be stored when validation fails")
6666
}
6767
}
@@ -114,10 +114,11 @@ func TestAuthLogin_APIKeyValidationSuccess(t *testing.T) {
114114
// Verify that credentials were stored
115115
expectedAPIKey := "valid-public:valid-secret"
116116
expectedProjectID := "test-project-valid"
117-
apiKey, projectID, err := config.GetCredentials()
117+
creds, err := config.GetStoredCredentials()
118118
if err != nil {
119119
t.Fatalf("Credentials not stored in keyring or file: %v", err)
120120
}
121+
apiKey, projectID := creds.APIKey, creds.ProjectID
121122
if apiKey != expectedAPIKey {
122123
t.Errorf("Expected API key '%s', got '%s'", expectedAPIKey, apiKey)
123124
}

internal/tiger/cmd/db_test.go

Lines changed: 40 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,30 @@ import (
1919
"github.com/timescale/tiger-cli/internal/tiger/util"
2020
)
2121

22+
// mockStoredCredentials overrides the common.GetStoredCredentials seam for the
23+
// duration of the test, restoring the original automatically via t.Cleanup.
24+
func mockStoredCredentials(t *testing.T, creds *config.Credentials, err error) {
25+
t.Helper()
26+
original := common.GetStoredCredentials
27+
common.GetStoredCredentials = func() (*config.Credentials, error) {
28+
return creds, err
29+
}
30+
t.Cleanup(func() { common.GetStoredCredentials = original })
31+
}
32+
33+
// mockTestPAT injects a fixed PAT credential.
34+
func mockTestPAT(t *testing.T) {
35+
mockStoredCredentials(t, &config.Credentials{
36+
APIKey: "test-api-key",
37+
ProjectID: "test-project-123",
38+
}, nil)
39+
}
40+
41+
// mockNotLoggedIn simulates the absence of stored credentials.
42+
func mockNotLoggedIn(t *testing.T) {
43+
mockStoredCredentials(t, nil, config.ErrNotLoggedIn)
44+
}
45+
2246
func setupDBTest(t *testing.T) string {
2347
t.Helper()
2448

@@ -81,11 +105,7 @@ func TestDBConnectionString_NoServiceID(t *testing.T) {
81105
}
82106

83107
// Mock authentication
84-
originalGetCredentials := common.GetCredentials
85-
common.GetCredentials = func() (string, string, error) {
86-
return "test-api-key", "test-project-123", nil
87-
}
88-
defer func() { common.GetCredentials = originalGetCredentials }()
108+
mockTestPAT(t)
89109

90110
// Execute db connection-string command without service ID
91111
_, err = executeDBCommand(t.Context(), "db", "connection-string")
@@ -111,11 +131,7 @@ func TestDBConnectionString_NoAuth(t *testing.T) {
111131
}
112132

113133
// Mock authentication failure
114-
originalGetCredentials := common.GetCredentials
115-
common.GetCredentials = func() (string, string, error) {
116-
return "", "", fmt.Errorf("not logged in")
117-
}
118-
defer func() { common.GetCredentials = originalGetCredentials }()
134+
mockNotLoggedIn(t)
119135

120136
// Execute db connection-string command
121137
_, err = executeDBCommand(t.Context(), "db", "connection-string")
@@ -174,11 +190,7 @@ func TestDBConnect_NoServiceID(t *testing.T) {
174190
}
175191

176192
// Mock authentication
177-
originalGetCredentials := common.GetCredentials
178-
common.GetCredentials = func() (string, string, error) {
179-
return "test-api-key", "test-project-123", nil
180-
}
181-
defer func() { common.GetCredentials = originalGetCredentials }()
193+
mockTestPAT(t)
182194

183195
// Execute db connect command without service ID
184196
_, err = executeDBCommand(t.Context(), "db", "connect")
@@ -204,11 +216,7 @@ func TestDBConnect_NoAuth(t *testing.T) {
204216
}
205217

206218
// Mock authentication failure
207-
originalGetCredentials := common.GetCredentials
208-
common.GetCredentials = func() (string, string, error) {
209-
return "", "", fmt.Errorf("not logged in")
210-
}
211-
defer func() { common.GetCredentials = originalGetCredentials }()
219+
mockNotLoggedIn(t)
212220

213221
// Execute db connect command
214222
_, err = executeDBCommand(t.Context(), "db", "connect")
@@ -234,11 +242,7 @@ func TestDBConnect_PsqlNotFound(t *testing.T) {
234242
}
235243

236244
// Mock authentication
237-
originalGetCredentials := common.GetCredentials
238-
common.GetCredentials = func() (string, string, error) {
239-
return "test-api-key", "test-project-123", nil
240-
}
241-
defer func() { common.GetCredentials = originalGetCredentials }()
245+
mockTestPAT(t)
242246

243247
// Test that psql alias works the same as connect
244248
_, err1 := executeDBCommand(t.Context(), "db", "connect")
@@ -540,11 +544,7 @@ func TestDBTestConnection_NoServiceID(t *testing.T) {
540544
}
541545

542546
// Mock authentication
543-
originalGetCredentials := common.GetCredentials
544-
common.GetCredentials = func() (string, string, error) {
545-
return "test-api-key", "test-project-123", nil
546-
}
547-
defer func() { common.GetCredentials = originalGetCredentials }()
547+
mockTestPAT(t)
548548

549549
// Execute db test-connection command without service ID
550550
_, err = executeDBCommand(t.Context(), "db", "test-connection")
@@ -570,11 +570,7 @@ func TestDBTestConnection_NoAuth(t *testing.T) {
570570
}
571571

572572
// Mock authentication failure
573-
originalGetCredentials := common.GetCredentials
574-
common.GetCredentials = func() (string, string, error) {
575-
return "", "", fmt.Errorf("not logged in")
576-
}
577-
defer func() { common.GetCredentials = originalGetCredentials }()
573+
mockNotLoggedIn(t)
578574

579575
// Execute db test-connection command
580576
_, err = executeDBCommand(t.Context(), "db", "test-connection")
@@ -830,11 +826,7 @@ func TestDBTestConnection_TimeoutParsing(t *testing.T) {
830826
}
831827

832828
// Mock authentication
833-
originalGetCredentials := common.GetCredentials
834-
common.GetCredentials = func() (string, string, error) {
835-
return "test-api-key", "test-project-123", nil
836-
}
837-
defer func() { common.GetCredentials = originalGetCredentials }()
829+
mockTestPAT(t)
838830

839831
// Execute db test-connection command with timeout flag
840832
_, err = executeDBCommand(t.Context(), "db", "test-connection", "--timeout", tc.timeoutFlag)
@@ -982,11 +974,7 @@ func TestDBSavePassword_ExplicitPassword(t *testing.T) {
982974
}
983975

984976
originalGetServiceDetails := getServiceDetailsFunc
985-
originalGetCredentials := common.GetCredentials
986-
common.GetCredentials = func() (string, string, error) {
987-
return "test-api-key", "test-project-123", nil
988-
}
989-
defer func() { common.GetCredentials = originalGetCredentials }()
977+
mockTestPAT(t)
990978
getServiceDetailsFunc = func(cmd *cobra.Command, cfg *common.Config, args []string) (api.Service, error) {
991979
return mockService, nil
992980
}
@@ -1056,11 +1044,7 @@ func TestDBSavePassword_EnvironmentVariable(t *testing.T) {
10561044
}
10571045

10581046
originalGetServiceDetails := getServiceDetailsFunc
1059-
originalGetCredentials := common.GetCredentials
1060-
common.GetCredentials = func() (string, string, error) {
1061-
return "test-api-key", "test-project-123", nil
1062-
}
1063-
defer func() { common.GetCredentials = originalGetCredentials }()
1047+
mockTestPAT(t)
10641048
getServiceDetailsFunc = func(cmd *cobra.Command, cfg *common.Config, args []string) (api.Service, error) {
10651049
return mockService, nil
10661050
}
@@ -1130,11 +1114,7 @@ func TestDBSavePassword_InteractivePrompt(t *testing.T) {
11301114
}
11311115

11321116
originalGetServiceDetails := getServiceDetailsFunc
1133-
originalGetCredentials := common.GetCredentials
1134-
common.GetCredentials = func() (string, string, error) {
1135-
return "test-api-key", "test-project-123", nil
1136-
}
1137-
defer func() { common.GetCredentials = originalGetCredentials }()
1117+
mockTestPAT(t)
11381118
getServiceDetailsFunc = func(cmd *cobra.Command, cfg *common.Config, args []string) (api.Service, error) {
11391119
return mockService, nil
11401120
}
@@ -1211,11 +1191,7 @@ func TestDBSavePassword_InteractivePromptEmpty(t *testing.T) {
12111191
}
12121192

12131193
originalGetServiceDetails := getServiceDetailsFunc
1214-
originalGetCredentials := common.GetCredentials
1215-
common.GetCredentials = func() (string, string, error) {
1216-
return "test-api-key", "test-project-123", nil
1217-
}
1218-
defer func() { common.GetCredentials = originalGetCredentials }()
1194+
mockTestPAT(t)
12191195
getServiceDetailsFunc = func(cmd *cobra.Command, cfg *common.Config, args []string) (api.Service, error) {
12201196
return mockService, nil
12211197
}
@@ -1285,11 +1261,7 @@ func TestDBSavePassword_CustomRole(t *testing.T) {
12851261
}
12861262

12871263
originalGetServiceDetails := getServiceDetailsFunc
1288-
originalGetCredentials := common.GetCredentials
1289-
common.GetCredentials = func() (string, string, error) {
1290-
return "test-api-key", "test-project-123", nil
1291-
}
1292-
defer func() { common.GetCredentials = originalGetCredentials }()
1264+
mockTestPAT(t)
12931265
getServiceDetailsFunc = func(cmd *cobra.Command, cfg *common.Config, args []string) (api.Service, error) {
12941266
return mockService, nil
12951267
}
@@ -1342,11 +1314,7 @@ func TestDBSavePassword_NoServiceID(t *testing.T) {
13421314
if err != nil {
13431315
t.Fatalf("Failed to save test config: %v", err)
13441316
}
1345-
originalGetCredentials := common.GetCredentials
1346-
common.GetCredentials = func() (string, string, error) {
1347-
return "test-api-key", "test-project-123", nil
1348-
}
1349-
defer func() { common.GetCredentials = originalGetCredentials }()
1317+
mockTestPAT(t)
13501318

13511319
// No need to mock service details since it should fail before reaching getServiceDetailsFunc
13521320

@@ -1375,11 +1343,7 @@ func TestDBSavePassword_NoAuth(t *testing.T) {
13751343
}
13761344

13771345
// Mock authentication failure
1378-
originalGetCredentials := common.GetCredentials
1379-
common.GetCredentials = func() (string, string, error) {
1380-
return "", "", fmt.Errorf("not logged in")
1381-
}
1382-
defer func() { common.GetCredentials = originalGetCredentials }()
1346+
mockNotLoggedIn(t)
13831347

13841348
// Execute save-password command
13851349
_, err = executeDBCommand(t.Context(), "db", "save-password", "--password=test-password")
@@ -1427,11 +1391,7 @@ func TestDBSavePassword_PgpassStorage(t *testing.T) {
14271391
}
14281392

14291393
originalGetServiceDetails := getServiceDetailsFunc
1430-
originalGetCredentials := common.GetCredentials
1431-
common.GetCredentials = func() (string, string, error) {
1432-
return "test-api-key", "test-project-123", nil
1433-
}
1434-
defer func() { common.GetCredentials = originalGetCredentials }()
1394+
mockTestPAT(t)
14351395
getServiceDetailsFunc = func(cmd *cobra.Command, cfg *common.Config, args []string) (api.Service, error) {
14361396
return mockService, nil
14371397
}

0 commit comments

Comments
 (0)