Skip to content

Commit 8cbe47b

Browse files
authored
fix(BRE2-766): Backoff on ssh Attempt During Registration (#307)
* backoff on ssh attempt during registration * retry cleanup * retry error * lint, mod tidy, fmt * cleanup
1 parent 88dc335 commit 8cbe47b

7 files changed

Lines changed: 191 additions & 46 deletions

File tree

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ require (
99
github.com/alessio/shellescape v1.4.1
1010
github.com/brevdev/parse v0.0.11
1111
github.com/briandowns/spinner v1.16.0
12+
github.com/cenkalti/backoff/v4 v4.3.0
1213
github.com/fatih/color v1.13.0
1314
github.com/getsentry/sentry-go v0.14.0
1415
github.com/gin-gonic/gin v1.10.0

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc
7373
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
7474
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
7575
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
76+
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
77+
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
7678
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
7779
github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE=
7880
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=

pkg/cmd/enablessh/enablessh_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func readAuthorizedKeys(t *testing.T, u *user.User) string {
3131
func Test_InstallAuthorizedKey_TagsKeyWithBrevComment(t *testing.T) {
3232
u := tempUser(t)
3333

34-
if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
34+
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
3535
t.Fatalf("InstallAuthorizedKey: %v", err)
3636
}
3737

@@ -44,10 +44,10 @@ func Test_InstallAuthorizedKey_TagsKeyWithBrevComment(t *testing.T) {
4444
func Test_InstallAuthorizedKey_SkipsDuplicate(t *testing.T) {
4545
u := tempUser(t)
4646

47-
if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
47+
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
4848
t.Fatalf("first install: %v", err)
4949
}
50-
if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
50+
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
5151
t.Fatalf("second install: %v", err)
5252
}
5353

@@ -70,7 +70,7 @@ func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) {
7070
t.Fatal(err)
7171
}
7272

73-
if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
73+
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
7474
t.Fatalf("InstallAuthorizedKey: %v", err)
7575
}
7676

@@ -84,10 +84,10 @@ func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) {
8484
func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) {
8585
u := tempUser(t)
8686

87-
if err := register.InstallAuthorizedKey(u, ""); err != nil {
87+
if _, err := register.InstallAuthorizedKey(u, ""); err != nil {
8888
t.Fatalf("InstallAuthorizedKey: %v", err)
8989
}
90-
if err := register.InstallAuthorizedKey(u, " "); err != nil {
90+
if _, err := register.InstallAuthorizedKey(u, " "); err != nil {
9191
t.Fatalf("InstallAuthorizedKey (whitespace): %v", err)
9292
}
9393

@@ -101,7 +101,7 @@ func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) {
101101
func Test_InstallAuthorizedKey_CreatesSSHDir(t *testing.T) {
102102
u := tempUser(t)
103103

104-
if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
104+
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
105105
t.Fatalf("InstallAuthorizedKey: %v", err)
106106
}
107107

@@ -126,7 +126,7 @@ func Test_InstallAuthorizedKey_PreservesExistingKeys(t *testing.T) {
126126
t.Fatal(err)
127127
}
128128

129-
if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
129+
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
130130
t.Fatalf("InstallAuthorizedKey: %v", err)
131131
}
132132

@@ -152,7 +152,7 @@ func Test_InstallAuthorizedKey_TagsExistingUntaggedKey(t *testing.T) {
152152
}
153153

154154
// InstallAuthorizedKey should tag the existing key rather than adding a duplicate.
155-
if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
155+
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
156156
t.Fatalf("InstallAuthorizedKey: %v", err)
157157
}
158158

@@ -366,10 +366,10 @@ func Test_InstallThenRemove_RoundTrip(t *testing.T) {
366366
}
367367

368368
// Install two brev keys.
369-
if err := register.InstallAuthorizedKey(u, "ssh-rsa KEY1"); err != nil {
369+
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY1"); err != nil {
370370
t.Fatal(err)
371371
}
372-
if err := register.InstallAuthorizedKey(u, "ssh-rsa KEY2"); err != nil {
372+
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY2"); err != nil {
373373
t.Fatal(err)
374374
}
375375

@@ -393,10 +393,10 @@ func Test_InstallThenRemoveSpecificKey_RollbackScenario(t *testing.T) {
393393
u := tempUser(t)
394394

395395
// Install two brev keys (simulating two users granted access).
396-
if err := register.InstallAuthorizedKey(u, "ssh-rsa ALICE"); err != nil {
396+
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa ALICE"); err != nil {
397397
t.Fatal(err)
398398
}
399-
if err := register.InstallAuthorizedKey(u, "ssh-rsa BOB"); err != nil {
399+
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa BOB"); err != nil {
400400
t.Fatal(err)
401401
}
402402

pkg/cmd/register/register.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam
208208
runSetup(node, t, deps)
209209

210210
if deps.prompter.ConfirmYesNo("Would you like to enable SSH access to this device?") {
211-
grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser)
211+
if err := grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser); err != nil {
212+
t.Vprintf(" Warning: SSH access not granted: %v\n", err)
213+
}
212214
}
213215

214216
return nil
@@ -330,7 +332,7 @@ func runSetup(node *nodev1.ExternalNode, t *terminal.Terminal, deps registerDeps
330332
}
331333
}
332334

333-
func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) {
335+
func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) error {
334336
t.Vprint("")
335337
t.Vprint(t.Green("Enabling SSH access on this device"))
336338
t.Vprint("")
@@ -341,14 +343,9 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps
341343

342344
err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser)
343345
if err != nil {
344-
t.Vprint(" Retrying in 3 seconds...")
345-
time.Sleep(3 * time.Second)
346-
err = GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser)
347-
}
348-
if err != nil {
349-
t.Vprintf(" Warning: %v\n", err)
350-
return
346+
return fmt.Errorf("grant SSH failed: %w", err)
351347
}
352348

353349
t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName)))
350+
return nil
354351
}

pkg/cmd/register/register_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,3 +499,100 @@ Peers count: 0/0 Connected`
499499
})
500500
}
501501
}
502+
503+
func Test_runRegister_GrantSSH_retries_on_connection_error_then_succeeds(t *testing.T) {
504+
regStore := &mockRegistrationStore{}
505+
506+
store := &mockRegisterStore{
507+
user: &entity.User{ID: "user_1"},
508+
org: &entity.Organization{ID: "org_123", Name: "TestOrg"},
509+
home: "/home/testuser/.brev",
510+
token: "tok",
511+
}
512+
513+
var grantCalls int
514+
svc := &fakeNodeService{
515+
addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) {
516+
return &nodev1.AddNodeResponse{
517+
ExternalNode: &nodev1.ExternalNode{
518+
ExternalNodeId: "unode_abc",
519+
OrganizationId: "org_123",
520+
Name: req.GetName(),
521+
DeviceId: req.GetDeviceId(),
522+
ConnectivityInfo: &nodev1.ConnectivityInfo{
523+
RegistrationCommand: "netbird up --key abc",
524+
},
525+
},
526+
}, nil
527+
},
528+
grantNodeSSHAccessFn: func(_ *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) {
529+
grantCalls++
530+
if grantCalls < 2 {
531+
return nil, connect.NewError(connect.CodeInternal, nil)
532+
}
533+
return &nodev1.GrantNodeSSHAccessResponse{}, nil
534+
},
535+
}
536+
537+
deps, server := testRegisterDeps(t, svc, regStore)
538+
defer server.Close()
539+
540+
deps.prompter = mockConfirmer{confirm: true}
541+
542+
term := terminal.New()
543+
err := runRegister(context.Background(), term, store, "My Spark", deps)
544+
if err != nil {
545+
t.Fatalf("runRegister failed: %v", err)
546+
}
547+
548+
if grantCalls != 2 {
549+
t.Errorf("expected GrantNodeSSHAccess to be called 2 times (retry once), got %d", grantCalls)
550+
}
551+
}
552+
553+
func Test_runRegister_GrantSSH_no_retry_on_permanent_error(t *testing.T) {
554+
regStore := &mockRegistrationStore{}
555+
556+
store := &mockRegisterStore{
557+
user: &entity.User{ID: "user_1"},
558+
org: &entity.Organization{ID: "org_123", Name: "TestOrg"},
559+
home: "/home/testuser/.brev",
560+
token: "tok",
561+
}
562+
563+
var grantCalls int
564+
svc := &fakeNodeService{
565+
addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) {
566+
return &nodev1.AddNodeResponse{
567+
ExternalNode: &nodev1.ExternalNode{
568+
ExternalNodeId: "unode_abc",
569+
OrganizationId: "org_123",
570+
Name: req.GetName(),
571+
DeviceId: req.GetDeviceId(),
572+
ConnectivityInfo: &nodev1.ConnectivityInfo{
573+
RegistrationCommand: "netbird up --key abc",
574+
},
575+
},
576+
}, nil
577+
},
578+
grantNodeSSHAccessFn: func(_ *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) {
579+
grantCalls++
580+
return nil, connect.NewError(connect.CodePermissionDenied, nil)
581+
},
582+
}
583+
584+
deps, server := testRegisterDeps(t, svc, regStore)
585+
defer server.Close()
586+
587+
deps.prompter = mockConfirmer{confirm: true}
588+
589+
term := terminal.New()
590+
err := runRegister(context.Background(), term, store, "My Spark", deps)
591+
if err != nil {
592+
t.Fatalf("runRegister should not fail the overall flow when SSH grant fails: %v", err)
593+
}
594+
595+
if grantCalls != 1 {
596+
t.Errorf("expected GrantNodeSSHAccess to be called once (no retry on permanent error), got %d", grantCalls)
597+
}
598+
}

pkg/cmd/register/rpcclient_test.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,10 @@ func Test_toProtoNodeSpec_MinimalFields(t *testing.T) {
143143
// fakeNodeService implements the server side of ExternalNodeService for testing.
144144
type fakeNodeService struct {
145145
nodev1connect.UnimplementedExternalNodeServiceHandler
146-
addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error)
147-
removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error)
148-
getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error)
146+
addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error)
147+
removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error)
148+
getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error)
149+
grantNodeSSHAccessFn func(*nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error)
149150
}
150151

151152
func (f *fakeNodeService) AddNode(_ context.Context, req *connect.Request[nodev1.AddNodeRequest]) (*connect.Response[nodev1.AddNodeResponse], error) {
@@ -175,6 +176,17 @@ func (f *fakeNodeService) GetNode(_ context.Context, req *connect.Request[nodev1
175176
return connect.NewResponse(resp), nil
176177
}
177178

179+
func (f *fakeNodeService) GrantNodeSSHAccess(_ context.Context, req *connect.Request[nodev1.GrantNodeSSHAccessRequest]) (*connect.Response[nodev1.GrantNodeSSHAccessResponse], error) {
180+
if f.grantNodeSSHAccessFn == nil {
181+
return nil, connect.NewError(connect.CodeUnimplemented, nil)
182+
}
183+
resp, err := f.grantNodeSSHAccessFn(req.Msg)
184+
if err != nil {
185+
return nil, err
186+
}
187+
return connect.NewResponse(resp), nil
188+
}
189+
178190
func Test_NewNodeServiceClient_AddNode(t *testing.T) {
179191
svc := &fakeNodeService{
180192
addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) {

0 commit comments

Comments
 (0)