Skip to content

Commit 80b7284

Browse files
committed
caching email
1 parent 39ff195 commit 80b7284

7 files changed

Lines changed: 265 additions & 19 deletions

File tree

pkg/cmd/cmd.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ import (
5858
"github.com/brevdev/brev-cli/pkg/cmd/workspacegroups"
5959
"github.com/brevdev/brev-cli/pkg/cmd/writeconnectionevent"
6060
"github.com/brevdev/brev-cli/pkg/config"
61+
"github.com/brevdev/brev-cli/pkg/entity"
6162
"github.com/brevdev/brev-cli/pkg/featureflag"
6263
"github.com/brevdev/brev-cli/pkg/files"
6364
"github.com/brevdev/brev-cli/pkg/remoteversion"
@@ -257,8 +258,20 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin
257258
cmds.SetUsageTemplate(usageTemplate)
258259

259260
// In-memory auth for external node commands — never touches credentials.json.
260-
memAuthStore := store.NewMemoryAuthStore()
261-
memAuthenticator := auth.StandardLogin("", "", nil)
261+
// Pre-fill the cached email so the user sees a confirmation prompt instead of
262+
// having to type it from scratch every time.
263+
cachedEmail, _ := fsStore.GetCachedEmail()
264+
memAuthenticator := auth.StandardLogin("", cachedEmail, nil)
265+
if cachedEmail != "" {
266+
if kas, ok := memAuthenticator.(auth.KasAuthenticator); ok {
267+
kas.ShouldPromptEmail = true
268+
memAuthenticator = kas
269+
}
270+
}
271+
memAuthStore := &emailCachingAuthStore{
272+
MemoryAuthStore: store.NewMemoryAuthStore(),
273+
fileStore: fsStore,
274+
}
262275
memLoginAuth := auth.NewLoginAuth(memAuthStore, memAuthenticator)
263276
memLoginAuth.WithShouldLogin(func() (bool, error) { return true, nil })
264277

@@ -555,4 +568,22 @@ var (
555568
_ store.Auth = auth.NoLoginAuth{}
556569
_ auth.AuthStore = store.FileStore{}
557570
_ auth.AuthStore = &store.MemoryAuthStore{}
571+
_ auth.AuthStore = &emailCachingAuthStore{}
558572
)
573+
574+
// emailCachingAuthStore wraps MemoryAuthStore and persists the login email
575+
// to ~/.brev/cached-email after each successful authentication.
576+
type emailCachingAuthStore struct {
577+
*store.MemoryAuthStore
578+
fileStore *store.FileStore
579+
}
580+
581+
func (e *emailCachingAuthStore) SaveAuthTokens(tokens entity.AuthTokens) error {
582+
if err := e.MemoryAuthStore.SaveAuthTokens(tokens); err != nil {
583+
return breverrors.WrapAndTrace(err)
584+
}
585+
if email := auth.GetEmailFromToken(tokens.AccessToken); email != "" {
586+
_ = e.fileStore.SaveCachedEmail(email)
587+
}
588+
return nil
589+
}

pkg/cmd/register/device_registration_store.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func (s *FileRegistrationStore) Load() (*DeviceRegistration, error) {
8686
return nil, breverrors.WrapAndTrace(err)
8787
}
8888
if reg.ExternalNodeID == "" || reg.OrgID == "" {
89-
return nil, breverrors.New("corrupt registration file: missing external_node_id or org_id")
89+
return nil, breverrors.New("malformed registration")
9090
}
9191
return &reg, nil
9292
}

pkg/cmd/register/register.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
type RegisterStore interface {
2828
GetCurrentUser() (*entity.User, error)
2929
GetActiveOrganizationOrDefault() (*entity.Organization, error)
30+
GetOrganizationsByName(name string) ([]entity.Organization, error)
3031
GetAccessToken() (string, error)
3132
}
3233

@@ -91,6 +92,8 @@ This command sets up network connectivity and registers this machine with Brev.`
9192
)
9293

9394
func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command {
95+
var orgFlag string
96+
9497
cmd := &cobra.Command{
9598
Annotations: map[string]string{"configuration": ""},
9699
Use: "register [name]",
@@ -104,14 +107,16 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command {
104107
if len(args) > 0 {
105108
name = args[0]
106109
}
107-
return runRegister(cmd.Context(), t, store, name, defaultRegisterDeps())
110+
return runRegister(cmd.Context(), t, store, name, orgFlag, defaultRegisterDeps())
108111
},
109112
}
110113

114+
cmd.Flags().StringVarP(&orgFlag, "org", "o", "", "organization name (overrides active org)")
115+
111116
return cmd
112117
}
113118

114-
func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, deps registerDeps) error { //nolint:funlen // registration flow
119+
func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, orgName string, deps registerDeps) error { //nolint:funlen // registration flow
115120
if !deps.platform.IsCompatible() {
116121
return breverrors.New("brev register is only supported on Linux")
117122
}
@@ -132,7 +137,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam
132137
if err != nil {
133138
return breverrors.WrapAndTrace(err)
134139
}
135-
org, err := getOrgToRegisterFor(s)
140+
org, err := getOrgToRegisterFor(s, orgName)
136141
if err != nil {
137142
return breverrors.WrapAndTrace(err)
138143
}
@@ -219,15 +224,28 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam
219224
return nil
220225
}
221226

222-
func getOrgToRegisterFor(s RegisterStore) (*entity.Organization, error) {
227+
func getOrgToRegisterFor(s RegisterStore, orgName string) (*entity.Organization, error) {
228+
if orgName != "" {
229+
orgs, err := s.GetOrganizationsByName(orgName)
230+
if err != nil {
231+
return nil, breverrors.WrapAndTrace(err)
232+
}
233+
if len(orgs) == 0 {
234+
return nil, fmt.Errorf("no organization found with name %q", orgName)
235+
}
236+
if len(orgs) > 1 {
237+
return nil, fmt.Errorf("multiple organizations found with name %q", orgName)
238+
}
239+
return &orgs[0], nil
240+
}
241+
223242
org, err := s.GetActiveOrganizationOrDefault()
224243
if err != nil {
225244
return nil, breverrors.WrapAndTrace(err)
226245
}
227246
if org == nil {
228247
return nil, fmt.Errorf("no organization found; please create or join an organization first")
229248
}
230-
231249
return org, nil
232250
}
233251

pkg/cmd/register/register_test.go

Lines changed: 100 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
type mockRegisterStore struct {
2121
user *entity.User
2222
org *entity.Organization
23+
orgs []entity.Organization
2324
token string
2425
err error
2526
}
@@ -35,6 +36,16 @@ func (m *mockRegisterStore) GetActiveOrganizationOrDefault() (*entity.Organizati
3536
return m.org, nil
3637
}
3738

39+
func (m *mockRegisterStore) GetOrganizationsByName(name string) ([]entity.Organization, error) {
40+
var matched []entity.Organization
41+
for _, o := range m.orgs {
42+
if o.Name == name {
43+
matched = append(matched, o)
44+
}
45+
}
46+
return matched, nil
47+
}
48+
3849
func (m *mockRegisterStore) GetAccessToken() (string, error) { return m.token, nil }
3950

4051
// mockRegistrationStore satisfies RegistrationStore for orchestration tests.
@@ -173,7 +184,7 @@ func Test_runRegister_HappyPath(t *testing.T) {
173184
defer ClearTestSSHPort()
174185

175186
term := terminal.New()
176-
err := runRegister(context.Background(), term, store, "my-spark", deps)
187+
err := runRegister(context.Background(), term, store, "my-spark", "", deps)
177188
if err != nil {
178189
t.Fatalf("runRegister failed: %v", err)
179190
}
@@ -224,7 +235,7 @@ func Test_runRegister_UserCancels(t *testing.T) {
224235
deps.prompter = mockConfirmer{confirm: false}
225236

226237
term := terminal.New()
227-
err := runRegister(context.Background(), term, store, "my-spark", deps)
238+
err := runRegister(context.Background(), term, store, "my-spark", "", deps)
228239
if err != nil {
229240
t.Fatalf("expected nil error on cancel, got: %v", err)
230241
}
@@ -312,7 +323,7 @@ func Test_runRegister_AlreadyRegistered(t *testing.T) {
312323
term := terminal.New()
313324
// Pass the same name as the existing registration so we go through
314325
// the checkExistingRegistration path (not the different-name path).
315-
err := runRegister(context.Background(), term, store, "Existing", deps)
326+
err := runRegister(context.Background(), term, store, "Existing", "", deps)
316327
if err != nil {
317328
t.Fatalf("expected nil error, got: %v", err)
318329
}
@@ -340,12 +351,90 @@ func Test_runRegister_NoOrganization(t *testing.T) {
340351
defer server.Close()
341352

342353
term := terminal.New()
343-
err := runRegister(context.Background(), term, store, "my-spark", deps)
354+
err := runRegister(context.Background(), term, store, "my-spark", "", deps)
344355
if err == nil {
345356
t.Fatal("expected error when no org exists")
346357
}
347358
}
348359

360+
func Test_runRegister_WithOrgFlag(t *testing.T) {
361+
regStore := &mockRegistrationStore{}
362+
363+
store := &mockRegisterStore{
364+
user: &entity.User{ID: "user_1"},
365+
org: &entity.Organization{ID: "org_default", Name: "DefaultOrg"},
366+
orgs: []entity.Organization{
367+
{ID: "org_456", Name: "SpecificOrg"},
368+
},
369+
token: "tok",
370+
}
371+
372+
var capturedOrgID string
373+
svc := &fakeNodeService{
374+
addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) {
375+
capturedOrgID = req.GetOrganizationId()
376+
return &nodev1.AddNodeResponse{
377+
ExternalNode: &nodev1.ExternalNode{
378+
ExternalNodeId: "unode_abc",
379+
OrganizationId: req.GetOrganizationId(),
380+
Name: req.GetName(),
381+
DeviceId: req.GetDeviceId(),
382+
},
383+
}, nil
384+
},
385+
}
386+
387+
setupRunner := &mockSetupRunner{}
388+
deps, server := testRegisterDeps(t, svc, regStore)
389+
defer server.Close()
390+
deps.setupRunner = setupRunner
391+
392+
SetTestSSHPort(22)
393+
defer ClearTestSSHPort()
394+
395+
term := terminal.New()
396+
err := runRegister(context.Background(), term, store, "my-spark", "SpecificOrg", deps)
397+
if err != nil {
398+
t.Fatalf("runRegister with --org failed: %v", err)
399+
}
400+
401+
if capturedOrgID != "org_456" {
402+
t.Errorf("expected org_456, got %s", capturedOrgID)
403+
}
404+
405+
reg, err := regStore.Load()
406+
if err != nil {
407+
t.Fatalf("Load failed: %v", err)
408+
}
409+
if reg.OrgID != "org_456" {
410+
t.Errorf("expected registration org org_456, got %s", reg.OrgID)
411+
}
412+
}
413+
414+
func Test_runRegister_WithOrgFlag_NotFound(t *testing.T) {
415+
regStore := &mockRegistrationStore{}
416+
417+
store := &mockRegisterStore{
418+
user: &entity.User{ID: "user_1"},
419+
org: &entity.Organization{ID: "org_default", Name: "DefaultOrg"},
420+
orgs: []entity.Organization{},
421+
token: "tok",
422+
}
423+
424+
svc := &fakeNodeService{}
425+
deps, server := testRegisterDeps(t, svc, regStore)
426+
defer server.Close()
427+
428+
term := terminal.New()
429+
err := runRegister(context.Background(), term, store, "my-spark", "NonexistentOrg", deps)
430+
if err == nil {
431+
t.Fatal("expected error when org not found")
432+
}
433+
if !strings.Contains(err.Error(), "no organization found") {
434+
t.Errorf("expected 'no organization found' error, got: %v", err)
435+
}
436+
}
437+
349438
func Test_runRegister_AddNodeFails(t *testing.T) {
350439
regStore := &mockRegistrationStore{}
351440

@@ -366,7 +455,7 @@ func Test_runRegister_AddNodeFails(t *testing.T) {
366455
defer server.Close()
367456

368457
term := terminal.New()
369-
err := runRegister(context.Background(), term, store, "my-spark", deps)
458+
err := runRegister(context.Background(), term, store, "my-spark", "", deps)
370459
if err == nil {
371460
t.Fatal("expected error when AddNode fails")
372461
}
@@ -416,7 +505,7 @@ func Test_runRegister_NoSetupCommand(t *testing.T) {
416505
defer ClearTestSSHPort()
417506

418507
term := terminal.New()
419-
err := runRegister(context.Background(), term, store, "my-spark", deps)
508+
err := runRegister(context.Background(), term, store, "my-spark", "", deps)
420509
if err != nil {
421510
t.Fatalf("runRegister failed: %v", err)
422511
}
@@ -549,7 +638,7 @@ func Test_runRegister_GrantSSH_retries_on_connection_error_then_succeeds(t *test
549638
defer ClearTestSSHPort()
550639

551640
term := terminal.New()
552-
err := runRegister(context.Background(), term, store, "my-spark", deps)
641+
err := runRegister(context.Background(), term, store, "my-spark", "", deps)
553642
if err != nil {
554643
t.Fatalf("runRegister failed: %v", err)
555644
}
@@ -598,7 +687,7 @@ func Test_runRegister_GrantSSH_no_retry_on_permanent_error(t *testing.T) {
598687
defer ClearTestSSHPort()
599688

600689
term := terminal.New()
601-
err := runRegister(context.Background(), term, store, "my-spark", deps)
690+
err := runRegister(context.Background(), term, store, "my-spark", "", deps)
602691
if err != nil {
603692
t.Fatalf("runRegister should not fail the overall flow when SSH grant fails: %v", err)
604693
}
@@ -658,7 +747,7 @@ func Test_runRegister_NameValidation(t *testing.T) {
658747
defer ClearTestSSHPort()
659748

660749
term := terminal.New()
661-
err := runRegister(context.Background(), term, store, tt.input, deps)
750+
err := runRegister(context.Background(), term, store, tt.input, "", deps)
662751
if tt.wantErr {
663752
if err == nil {
664753
t.Fatal("expected error, got nil")
@@ -687,7 +776,7 @@ func Test_runRegister_NoNameNotRegistered(t *testing.T) {
687776
defer server.Close()
688777

689778
term := terminal.New()
690-
err := runRegister(context.Background(), term, store, "", deps)
779+
err := runRegister(context.Background(), term, store, "", "", deps)
691780
if err == nil {
692781
t.Fatal("expected error when no name provided and not registered")
693782
}
@@ -728,7 +817,7 @@ func Test_runRegister_NoNameAlreadyRegistered(t *testing.T) {
728817
defer server.Close()
729818

730819
term := terminal.New()
731-
err := runRegister(context.Background(), term, store, "", deps)
820+
err := runRegister(context.Background(), term, store, "", "", deps)
732821
if err != nil {
733822
t.Fatalf("expected nil error when already registered with no name, got: %v", err)
734823
}

pkg/store/email_cache.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package store
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"strings"
7+
8+
breverrors "github.com/brevdev/brev-cli/pkg/errors"
9+
"github.com/spf13/afero"
10+
)
11+
12+
const cachedEmailFile = "cached-email"
13+
14+
// GetCachedEmail returns the previously cached login email, or "" if none exists.
15+
func (f FileStore) GetCachedEmail() (string, error) {
16+
brevHome, err := f.GetBrevHomePath()
17+
if err != nil {
18+
return "", breverrors.WrapAndTrace(err)
19+
}
20+
data, err := f.fs.Open(filepath.Join(brevHome, cachedEmailFile))
21+
if err != nil {
22+
if os.IsNotExist(err) {
23+
return "", nil
24+
}
25+
return "", breverrors.WrapAndTrace(err)
26+
}
27+
defer data.Close() //nolint:errcheck // best-effort close
28+
buf := make([]byte, 512)
29+
n, err := data.Read(buf)
30+
if err != nil && n == 0 {
31+
return "", breverrors.WrapAndTrace(err)
32+
}
33+
return strings.TrimSpace(string(buf[:n])), nil
34+
}
35+
36+
// SaveCachedEmail writes the login email to ~/.brev/cached-email (0600).
37+
func (f FileStore) SaveCachedEmail(email string) error {
38+
brevHome, err := f.GetBrevHomePath()
39+
if err != nil {
40+
return breverrors.WrapAndTrace(err)
41+
}
42+
path := filepath.Join(brevHome, cachedEmailFile)
43+
err = f.fs.MkdirAll(brevHome, 0o755)
44+
if err != nil {
45+
return breverrors.WrapAndTrace(err)
46+
}
47+
err = afero.WriteFile(f.fs, path, []byte(email), 0o600)
48+
if err != nil {
49+
return breverrors.WrapAndTrace(err)
50+
}
51+
return nil
52+
}

0 commit comments

Comments
 (0)