Skip to content

Commit 848737d

Browse files
committed
caching email
1 parent 39ff195 commit 848737d

8 files changed

Lines changed: 334 additions & 19 deletions

File tree

pkg/cmd/cmd.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ import (
5858
"github.com/brevdev/brev-cli/pkg/cmd/workspacegroups"
5959
"github.com/brevdev/brev-cli/pkg/cmd/writeconnectionevent"
6060
"github.com/brevdev/brev-cli/pkg/config"
61+
"github.com/brevdev/brev-cli/pkg/entity"
6162
"github.com/brevdev/brev-cli/pkg/featureflag"
6263
"github.com/brevdev/brev-cli/pkg/files"
6364
"github.com/brevdev/brev-cli/pkg/remoteversion"
@@ -257,8 +258,20 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin
257258
cmds.SetUsageTemplate(usageTemplate)
258259

259260
// In-memory auth for external node commands — never touches credentials.json.
260-
memAuthStore := store.NewMemoryAuthStore()
261-
memAuthenticator := auth.StandardLogin("", "", nil)
261+
// Pre-fill the cached email so the user sees a confirmation prompt instead of
262+
// having to type it from scratch every time.
263+
cachedEmail, _ := fsStore.GetCachedEmail()
264+
memAuthenticator := auth.StandardLogin("", cachedEmail, nil)
265+
if cachedEmail != "" {
266+
if kas, ok := memAuthenticator.(auth.KasAuthenticator); ok {
267+
kas.ShouldPromptEmail = true
268+
memAuthenticator = kas
269+
}
270+
}
271+
memAuthStore := &emailCachingAuthStore{
272+
MemoryAuthStore: store.NewMemoryAuthStore(),
273+
fileStore: fsStore,
274+
}
262275
memLoginAuth := auth.NewLoginAuth(memAuthStore, memAuthenticator)
263276
memLoginAuth.WithShouldLogin(func() (bool, error) { return true, nil })
264277

@@ -555,4 +568,22 @@ var (
555568
_ store.Auth = auth.NoLoginAuth{}
556569
_ auth.AuthStore = store.FileStore{}
557570
_ auth.AuthStore = &store.MemoryAuthStore{}
571+
_ auth.AuthStore = &emailCachingAuthStore{}
558572
)
573+
574+
// emailCachingAuthStore wraps MemoryAuthStore and persists the login email
575+
// to ~/.brev/cached-email after each successful authentication.
576+
type emailCachingAuthStore struct {
577+
*store.MemoryAuthStore
578+
fileStore *store.FileStore
579+
}
580+
581+
func (e *emailCachingAuthStore) SaveAuthTokens(tokens entity.AuthTokens) error {
582+
if err := e.MemoryAuthStore.SaveAuthTokens(tokens); err != nil {
583+
return breverrors.WrapAndTrace(err)
584+
}
585+
if email := auth.GetEmailFromToken(tokens.AccessToken); email != "" {
586+
_ = e.fileStore.SaveCachedEmail(email)
587+
}
588+
return nil
589+
}

pkg/cmd/cmd_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package cmd
2+
3+
import (
4+
"encoding/base64"
5+
"encoding/json"
6+
"testing"
7+
8+
"github.com/brevdev/brev-cli/pkg/entity"
9+
"github.com/brevdev/brev-cli/pkg/store"
10+
"github.com/spf13/afero"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
// fakeJWT builds an unsigned JWT with the given claims (header.payload.signature).
16+
func fakeJWT(t *testing.T, claims map[string]interface{}) string {
17+
t.Helper()
18+
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
19+
payload, err := json.Marshal(claims)
20+
require.NoError(t, err)
21+
return header + "." + base64.RawURLEncoding.EncodeToString(payload) + "."
22+
}
23+
24+
func newTestFileStore(t *testing.T) *store.FileStore {
25+
t.Helper()
26+
fs := afero.NewMemMapFs()
27+
err := fs.MkdirAll("/home/testuser/.brev", 0o755)
28+
require.NoError(t, err)
29+
return store.NewBasicStore().WithFileSystem(fs).WithUserHomeDirGetter(
30+
func() (string, error) { return "/home/testuser", nil },
31+
)
32+
}
33+
34+
func TestEmailCachingAuthStore_SaveCachesEmail(t *testing.T) {
35+
fs := newTestFileStore(t)
36+
s := &emailCachingAuthStore{
37+
MemoryAuthStore: store.NewMemoryAuthStore(),
38+
fileStore: fs,
39+
}
40+
41+
token := fakeJWT(t, map[string]interface{}{"email": "user@example.com"})
42+
err := s.SaveAuthTokens(entity.AuthTokens{AccessToken: token})
43+
require.NoError(t, err)
44+
45+
cached, err := fs.GetCachedEmail()
46+
require.NoError(t, err)
47+
assert.Equal(t, "user@example.com", cached)
48+
}
49+
50+
func TestEmailCachingAuthStore_NoEmailInToken(t *testing.T) {
51+
fs := newTestFileStore(t)
52+
s := &emailCachingAuthStore{
53+
MemoryAuthStore: store.NewMemoryAuthStore(),
54+
fileStore: fs,
55+
}
56+
57+
token := fakeJWT(t, map[string]interface{}{"sub": "12345"})
58+
err := s.SaveAuthTokens(entity.AuthTokens{AccessToken: token})
59+
require.NoError(t, err)
60+
61+
cached, err := fs.GetCachedEmail()
62+
require.NoError(t, err)
63+
assert.Equal(t, "", cached)
64+
}
65+
66+
func TestEmailCachingAuthStore_EmptyAccessToken(t *testing.T) {
67+
fs := newTestFileStore(t)
68+
s := &emailCachingAuthStore{
69+
MemoryAuthStore: store.NewMemoryAuthStore(),
70+
fileStore: fs,
71+
}
72+
73+
err := s.SaveAuthTokens(entity.AuthTokens{AccessToken: ""})
74+
require.Error(t, err)
75+
}

pkg/cmd/register/device_registration_store.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func (s *FileRegistrationStore) Load() (*DeviceRegistration, error) {
8686
return nil, breverrors.WrapAndTrace(err)
8787
}
8888
if reg.ExternalNodeID == "" || reg.OrgID == "" {
89-
return nil, breverrors.New("corrupt registration file: missing external_node_id or org_id")
89+
return nil, breverrors.New("malformed registration")
9090
}
9191
return &reg, nil
9292
}

pkg/cmd/register/register.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
type RegisterStore interface {
2828
GetCurrentUser() (*entity.User, error)
2929
GetActiveOrganizationOrDefault() (*entity.Organization, error)
30+
GetOrganizationsByName(name string) ([]entity.Organization, error)
3031
GetAccessToken() (string, error)
3132
}
3233

@@ -91,6 +92,8 @@ This command sets up network connectivity and registers this machine with Brev.`
9192
)
9293

9394
func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command {
95+
var orgFlag string
96+
9497
cmd := &cobra.Command{
9598
Annotations: map[string]string{"configuration": ""},
9699
Use: "register [name]",
@@ -104,14 +107,16 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command {
104107
if len(args) > 0 {
105108
name = args[0]
106109
}
107-
return runRegister(cmd.Context(), t, store, name, defaultRegisterDeps())
110+
return runRegister(cmd.Context(), t, store, name, orgFlag, defaultRegisterDeps())
108111
},
109112
}
110113

114+
cmd.Flags().StringVarP(&orgFlag, "org", "o", "", "organization name (overrides active org)")
115+
111116
return cmd
112117
}
113118

114-
func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, deps registerDeps) error { //nolint:funlen // registration flow
119+
func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, orgName string, deps registerDeps) error { //nolint:funlen // registration flow
115120
if !deps.platform.IsCompatible() {
116121
return breverrors.New("brev register is only supported on Linux")
117122
}
@@ -132,7 +137,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam
132137
if err != nil {
133138
return breverrors.WrapAndTrace(err)
134139
}
135-
org, err := getOrgToRegisterFor(s)
140+
org, err := getOrgToRegisterFor(s, orgName)
136141
if err != nil {
137142
return breverrors.WrapAndTrace(err)
138143
}
@@ -219,15 +224,28 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam
219224
return nil
220225
}
221226

222-
func getOrgToRegisterFor(s RegisterStore) (*entity.Organization, error) {
227+
func getOrgToRegisterFor(s RegisterStore, orgName string) (*entity.Organization, error) {
228+
if orgName != "" {
229+
orgs, err := s.GetOrganizationsByName(orgName)
230+
if err != nil {
231+
return nil, breverrors.WrapAndTrace(err)
232+
}
233+
if len(orgs) == 0 {
234+
return nil, fmt.Errorf("no organization found with name %q", orgName)
235+
}
236+
if len(orgs) > 1 {
237+
return nil, fmt.Errorf("multiple organizations found with name %q", orgName)
238+
}
239+
return &orgs[0], nil
240+
}
241+
223242
org, err := s.GetActiveOrganizationOrDefault()
224243
if err != nil {
225244
return nil, breverrors.WrapAndTrace(err)
226245
}
227246
if org == nil {
228247
return nil, fmt.Errorf("no organization found; please create or join an organization first")
229248
}
230-
231249
return org, nil
232250
}
233251

0 commit comments

Comments
 (0)