Skip to content

Commit 67058c5

Browse files
committed
review
1 parent 0cd362b commit 67058c5

19 files changed

Lines changed: 274 additions & 92 deletions

File tree

pkg/auth/auth.go

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,37 +103,38 @@ type Auth struct {
103103

104104
const BrevAPIKeyPrefix = "bak-"
105105

106-
type TokenProvider interface {
107-
GetAccessToken() (string, error)
108-
}
106+
const MissingAPIKeyOrgIDMessage = "api key auth requires an org id; run brev login --api-key <api-key> --org-id <org-id>"
109107

110-
type APIKeyOrgProvider interface {
108+
type APIKeyAuthStore interface {
111109
GetAuthTokens() (*entity.AuthTokens, error)
112110
}
113111

114112
func IsBrevAPIKey(token string) bool {
115113
return strings.HasPrefix(strings.TrimSpace(token), BrevAPIKeyPrefix)
116114
}
117115

118-
func IsAPIKeyAuthStore(tokenProvider TokenProvider) bool {
119-
token, err := tokenProvider.GetAccessToken()
116+
func IsAPIKeyAuthStore(authTokensProvider APIKeyAuthStore) bool {
117+
tokens, err := authTokensProvider.GetAuthTokens()
120118
if err != nil {
121119
return false
122120
}
123-
return IsBrevAPIKey(token)
121+
if tokens == nil {
122+
return false
123+
}
124+
return IsBrevAPIKey(tokens.APIKey)
124125
}
125126

126-
func GetAPIKeyOrgID(authTokensProvider APIKeyOrgProvider) (string, error) {
127+
func GetAPIKeyOrgID(authTokensProvider APIKeyAuthStore) (string, error) {
127128
tokens, err := authTokensProvider.GetAuthTokens()
128129
if err != nil {
129130
return "", breverrors.WrapAndTrace(err)
130131
}
131132
if tokens == nil {
132-
return "", breverrors.NewValidationError("api key auth requires an org id; run brev login --api-key <api-key> --org-id <org-id>")
133+
return "", breverrors.NewValidationError(MissingAPIKeyOrgIDMessage)
133134
}
134135
orgID := strings.TrimSpace(tokens.APIKeyOrgID)
135136
if orgID == "" {
136-
return "", breverrors.NewValidationError("api key auth requires an org id; run brev login --api-key <api-key> --org-id <org-id>")
137+
return "", breverrors.NewValidationError(MissingAPIKeyOrgIDMessage)
137138
}
138139
return orgID, nil
139140
}
@@ -183,11 +184,8 @@ func (t Auth) GetFreshAccessTokenOrNil() (string, error) {
183184
return "", nil
184185
}
185186

186-
if tokens.APIKey != "" {
187-
apiKey := strings.TrimSpace(tokens.APIKey)
188-
if apiKey == "" {
189-
return "", breverrors.NewValidationError("api key is empty")
190-
}
187+
apiKey := strings.TrimSpace(tokens.APIKey)
188+
if apiKey != "" {
191189
return apiKey, nil
192190
}
193191

@@ -277,7 +275,7 @@ func (t Auth) LoginWithAPIKey(apiKey string, orgID string) error {
277275
}
278276
orgID = strings.TrimSpace(orgID)
279277
if orgID == "" {
280-
return breverrors.NewValidationError("org-id is required with api-key")
278+
return breverrors.NewValidationError(MissingAPIKeyOrgIDMessage)
281279
}
282280

283281
tokens, err := t.getSavedTokensOrNil()

pkg/auth/auth_test.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/brevdev/brev-cli/pkg/entity"
99
breverrors "github.com/brevdev/brev-cli/pkg/errors"
1010
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
1112
"github.com/stretchr/testify/suite"
1213
)
1314

@@ -94,6 +95,41 @@ func TestIsBrevAPIKey(t *testing.T) {
9495
assert.False(t, IsBrevAPIKey(""))
9596
}
9697

98+
type sideEffectingTokenStore struct {
99+
tokens *entity.AuthTokens
100+
getAccessTokenCalled bool
101+
}
102+
103+
func (s *sideEffectingTokenStore) GetAuthTokens() (*entity.AuthTokens, error) {
104+
return s.tokens, nil
105+
}
106+
107+
func (s *sideEffectingTokenStore) GetAccessToken() (string, error) {
108+
s.getAccessTokenCalled = true
109+
return testAPIKey, nil
110+
}
111+
112+
func TestIsAPIKeyAuthStore_ReadsSavedTokensWithoutAccessTokenSideEffects(t *testing.T) {
113+
s := &sideEffectingTokenStore{
114+
tokens: &entity.AuthTokens{APIKey: testAPIKey},
115+
}
116+
117+
assert.True(t, IsAPIKeyAuthStore(s))
118+
assert.False(t, s.getAccessTokenCalled)
119+
}
120+
121+
func TestIsAPIKeyAuthStore_LegacyCredentialsAreNotAPIKeyAuth(t *testing.T) {
122+
s := &sideEffectingTokenStore{
123+
tokens: &entity.AuthTokens{
124+
AccessToken: validToken,
125+
RefreshToken: "refresh",
126+
},
127+
}
128+
129+
assert.False(t, IsAPIKeyAuthStore(s))
130+
assert.False(t, s.getAccessTokenCalled)
131+
}
132+
97133
func TestGetFreshAccessTokenOrNil_APIKeySkipsJWTValidationAndRefresh(t *testing.T) {
98134
s := MockAuthStore{authTokens: &entity.AuthTokens{
99135
AccessToken: "expired-jwt",
@@ -202,8 +238,11 @@ func TestLoginWithAPIKey_EmptyOrgIDReturnsError(t *testing.T) {
202238

203239
func TestStandardLogin_APIKeyCredentialDoesNotProbeOAuthProviders(t *testing.T) {
204240
oldStdout := os.Stdout
241+
t.Cleanup(func() {
242+
os.Stdout = oldStdout
243+
})
205244
readPipe, writePipe, err := os.Pipe()
206-
assert.NoError(t, err)
245+
require.NoError(t, err)
207246
os.Stdout = writePipe
208247

209248
_ = StandardLogin("", "", &entity.AuthTokens{

pkg/cmd/completions/completions.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
)
1010

1111
type CompletionStore interface {
12+
auth.APIKeyAuthStore
1213
GetWorkspaces(organizationID string, options *store.GetWorkspacesOptions) ([]entity.Workspace, error)
1314
GetActiveOrganizationOrDefault() (*entity.Organization, error)
1415
GetCurrentUser() (*entity.User, error)
@@ -29,7 +30,7 @@ func GetAllWorkspaceNameCompletionHandler(completionStore CompletionStore, t *te
2930
}
3031

3132
var options *store.GetWorkspacesOptions
32-
if tokenProvider, ok := completionStore.(auth.TokenProvider); !ok || !auth.IsAPIKeyAuthStore(tokenProvider) {
33+
if !auth.IsAPIKeyAuthStore(completionStore) {
3334
user, err := completionStore.GetCurrentUser()
3435
if err != nil {
3536
t.Errprint(err, "")
@@ -55,7 +56,7 @@ func GetAllWorkspaceNameCompletionHandler(completionStore CompletionStore, t *te
5556

5657
func GetOrgsNameCompletionHandler(completionStore CompletionStore, t *terminal.Terminal) CompletionHandler {
5758
return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
58-
if tokenProvider, ok := completionStore.(auth.TokenProvider); ok && auth.IsAPIKeyAuthStore(tokenProvider) {
59+
if auth.IsAPIKeyAuthStore(completionStore) {
5960
return []string{}, cobra.ShellCompDirectiveNoFileComp
6061
}
6162

pkg/cmd/delete/delete.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func deleteWorkspace(workspaceName string, t *terminal.Terminal, deleteStore Del
100100

101101
func handleAdminUser(err error, deleteStore DeleteStore, piped bool) error {
102102
if strings.Contains(err.Error(), "not found") {
103-
if tokenProvider, ok := deleteStore.(auth.TokenProvider); ok && auth.IsAPIKeyAuthStore(tokenProvider) {
103+
if auth.IsAPIKeyAuthStore(deleteStore) {
104104
return breverrors.WrapAndTrace(err)
105105
}
106106
user, err1 := deleteStore.GetCurrentUser()

pkg/cmd/gpucreate/gpucreate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ type GPUCreateStore interface {
9898
util.GetWorkspaceByNameOrIDErrStore
9999
gpusearch.GPUSearchStore
100100
GetActiveOrganizationOrDefault() (*entity.Organization, error)
101-
GetAccessToken() (string, error)
101+
GetAuthTokens() (*entity.AuthTokens, error)
102102
GetCurrentUser() (*entity.User, error)
103103
GetWorkspace(workspaceID string) (*entity.Workspace, error)
104104
CreateWorkspace(organizationID string, options *store.CreateWorkspacesOptions) (*entity.Workspace, error)

pkg/cmd/gpucreate/gpucreate_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ func (m *MockGPUCreateStore) GetCurrentUser() (*entity.User, error) {
4545
return m.User, nil
4646
}
4747

48-
func (m *MockGPUCreateStore) GetAccessToken() (string, error) {
49-
return "", nil
48+
func (m *MockGPUCreateStore) GetAuthTokens() (*entity.AuthTokens, error) {
49+
return nil, nil
5050
}
5151

5252
func (m *MockGPUCreateStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) {

pkg/cmd/invite/invite.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,9 @@ import (
1717
)
1818

1919
type InviteStore interface {
20-
GetWorkspaces(organizationID string, options *store.GetWorkspacesOptions) ([]entity.Workspace, error)
21-
GetActiveOrganizationOrDefault() (*entity.Organization, error)
22-
GetCurrentUser() (*entity.User, error)
20+
completions.CompletionStore
2321
GetUsers(queryParams map[string]string) ([]entity.User, error)
2422
GetWorkspace(workspaceID string) (*entity.Workspace, error)
25-
GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error)
2623
CreateInviteLink(organizationID string) (string, error)
2724
}
2825

pkg/cmd/login/login.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra.
6464
var skipBrowser bool
6565
var emailFlag string
6666
var authProviderFlag string
67-
apiKeyLogin := false
6867

6968
cmd := &cobra.Command{
7069
Annotations: map[string]string{"configuration": ""},
@@ -75,7 +74,7 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra.
7574
Example: "brev login",
7675
Args: cmderrors.TransformToValidationError(cobra.NoArgs),
7776
RunE: func(cmd *cobra.Command, args []string) error {
78-
apiKeyLogin = strings.TrimSpace(apiKey) != ""
77+
apiKeyLogin := strings.TrimSpace(apiKey) != ""
7978
err := opts.RunLogin(t, loginToken, apiKey, apiKeyOrgID, skipBrowser, emailFlag, authProviderFlag)
8079
if err != nil {
8180
// if err is ImportIDEConfigError, log err with sentry but continue
@@ -216,7 +215,7 @@ func (o LoginOptions) doApiKeyLogin(t *terminal.Terminal, loginToken string, api
216215
apiKey = strings.TrimSpace(apiKey)
217216
orgID := strings.TrimSpace(apiKeyOrgID)
218217
if orgID == "" {
219-
return breverrors.NewValidationError("org-id is required with api-key")
218+
return breverrors.NewValidationError(auth.MissingAPIKeyOrgIDMessage)
220219
}
221220
if err := o.Auth.LoginWithAPIKey(apiKey, orgID); err != nil {
222221
return breverrors.WrapAndTrace(err)

pkg/cmd/ls/ls.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,11 @@ import (
3838
type LsStore interface {
3939
GetWorkspaces(organizationID string, options *store.GetWorkspacesOptions) ([]entity.Workspace, error)
4040
GetActiveOrganizationOrDefault() (*entity.Organization, error)
41-
GetCachedActiveOrganizationOrNil() (*entity.Organization, error)
4241
GetCurrentUser() (*entity.User, error)
4342
GetUsers(queryParams map[string]string) ([]entity.User, error)
4443
GetWorkspace(workspaceID string) (*entity.Workspace, error)
4544
GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error)
46-
GetAccessToken() (string, error)
45+
externalnode.TokenProvider
4746
GetAuthTokens() (*entity.AuthTokens, error)
4847
GetInstanceTypes(includeCPU bool) (*gpusearch.InstanceTypesResponse, error)
4948
hello.HelloStore
@@ -139,11 +138,14 @@ func getOrgForRunLs(lsStore LsStore, orgflag string, apiKeyAuth bool) (*entity.O
139138
if orgflag != "" {
140139
return nil, breverrors.NewValidationError("api key auth is scoped to the org saved during login; --org is not supported")
141140
}
142-
orgID, err := auth.GetAPIKeyOrgID(lsStore)
141+
org, err := lsStore.GetActiveOrganizationOrDefault()
143142
if err != nil {
144143
return nil, breverrors.WrapAndTrace(err)
145144
}
146-
return &entity.Organization{ID: orgID, Name: orgID}, nil
145+
if org == nil {
146+
return nil, breverrors.NewValidationError("no orgs exist")
147+
}
148+
return org, nil
147149
}
148150

149151
if orgflag != "" {

pkg/cmd/ls/ls_test.go

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ type mockLsStore struct {
2424
org *entity.Organization
2525
orgs []entity.Organization
2626
workspaces []entity.Workspace
27-
accessToken string
2827
authTokens *entity.AuthTokens
2928
workspaceOrgID string
3029
currentUserCalls int
@@ -36,26 +35,26 @@ func (m *mockLsStore) GetCurrentUser() (*entity.User, error) {
3635
return m.user, nil
3736
}
3837

39-
func (m *mockLsStore) GetAccessToken() (string, error) {
40-
if m.accessToken != "" {
41-
return m.accessToken, nil
42-
}
43-
return "tok", nil
44-
}
45-
4638
func (m *mockLsStore) GetAuthTokens() (*entity.AuthTokens, error) {
4739
return m.authTokens, nil
4840
}
4941

42+
func (m *mockLsStore) GetAccessToken() (string, error) {
43+
return "tok", nil
44+
}
45+
5046
func (m *mockLsStore) GetWorkspace(_ string) (*entity.Workspace, error) {
5147
return nil, nil
5248
}
5349

5450
func (m *mockLsStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) {
55-
return m.org, nil
56-
}
57-
58-
func (m *mockLsStore) GetCachedActiveOrganizationOrNil() (*entity.Organization, error) {
51+
if m.authTokens != nil && authpkg.IsBrevAPIKey(m.authTokens.APIKey) {
52+
orgID, err := authpkg.GetAPIKeyOrgID(m)
53+
if err != nil {
54+
return nil, err
55+
}
56+
return &entity.Organization{ID: orgID, Name: m.org.Name}, nil
57+
}
5958
return m.org, nil
6059
}
6160

@@ -99,7 +98,6 @@ func newTestStore() *mockLsStore {
9998

10099
func TestRunLs_APIKeyJSONSkipsUserAndOrgList(t *testing.T) {
101100
s := newTestStore()
102-
s.accessToken = testAPIKey
103101
s.authTokens = &entity.AuthTokens{APIKey: testAPIKey, APIKeyOrgID: "org1"}
104102
s.workspaces = []entity.Workspace{
105103
{
@@ -143,7 +141,6 @@ func TestRunLs_APIKeyJSONSkipsUserAndOrgList(t *testing.T) {
143141

144142
func TestRunLs_APIKeyUsesCredentialOrgNotCachedActiveOrg(t *testing.T) {
145143
s := newTestStore()
146-
s.accessToken = testAPIKey
147144
s.authTokens = &entity.AuthTokens{APIKey: testAPIKey, APIKeyOrgID: "org-login"}
148145
s.org = &entity.Organization{ID: "org-set", Name: "set-org"}
149146
s.workspaces = []entity.Workspace{
@@ -171,9 +168,25 @@ func TestRunLs_APIKeyUsesCredentialOrgNotCachedActiveOrg(t *testing.T) {
171168
}
172169
}
173170

171+
func TestGetOrgForRunLs_APIKeyUsesActiveOrgDisplayName(t *testing.T) {
172+
s := newTestStore()
173+
s.authTokens = &entity.AuthTokens{APIKey: testAPIKey, APIKeyOrgID: "org-login"}
174+
s.org = &entity.Organization{ID: "org-login", Name: "friendly-org"}
175+
176+
org, err := getOrgForRunLs(s, "", true)
177+
if err != nil {
178+
t.Fatalf("getOrgForRunLs returned error: %v", err)
179+
}
180+
if org.ID != "org-login" {
181+
t.Fatalf("expected org-login, got %s", org.ID)
182+
}
183+
if org.Name != "friendly-org" {
184+
t.Fatalf("expected friendly org name, got %s", org.Name)
185+
}
186+
}
187+
174188
func TestRunLs_APIKeyRequiresCredentialOrg(t *testing.T) {
175189
s := newTestStore()
176-
s.accessToken = testAPIKey
177190
s.authTokens = &entity.AuthTokens{APIKey: testAPIKey}
178191
term := terminal.New()
179192

0 commit comments

Comments
 (0)