Skip to content

Commit edf020d

Browse files
committed
retry error
1 parent d109761 commit edf020d

2 files changed

Lines changed: 137 additions & 3 deletions

File tree

pkg/cmd/register/register_test.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,3 +499,125 @@ Peers count: 0/0 Connected`
499499
})
500500
}
501501
}
502+
503+
func TestIsSSHConnectionError(t *testing.T) {
504+
t.Run("nil", func(t *testing.T) {
505+
if IsSSHConnectionError(nil) {
506+
t.Error("IsSSHConnectionError(nil) should be false")
507+
}
508+
})
509+
t.Run("plain_error", func(t *testing.T) {
510+
if IsSSHConnectionError(fmt.Errorf("some error")) {
511+
t.Error("IsSSHConnectionError(plain error) should be false")
512+
}
513+
})
514+
t.Run("connection_error_type", func(t *testing.T) {
515+
err := &sshConnectionError{err: fmt.Errorf("transient")}
516+
if !IsSSHConnectionError(err) {
517+
t.Error("IsSSHConnectionError(sshConnectionError) should be true")
518+
}
519+
})
520+
t.Run("wrapped_connection_error", func(t *testing.T) {
521+
err := fmt.Errorf("wrapped: %w", &sshConnectionError{err: fmt.Errorf("transient")})
522+
if !IsSSHConnectionError(err) {
523+
t.Error("IsSSHConnectionError(wrapped sshConnectionError) should be true")
524+
}
525+
})
526+
}
527+
528+
func Test_runRegister_GrantSSH_retries_on_connection_error_then_succeeds(t *testing.T) {
529+
regStore := &mockRegistrationStore{}
530+
531+
store := &mockRegisterStore{
532+
user: &entity.User{ID: "user_1"},
533+
org: &entity.Organization{ID: "org_123", Name: "TestOrg"},
534+
home: "/home/testuser/.brev",
535+
token: "tok",
536+
}
537+
538+
var grantCalls int
539+
svc := &fakeNodeService{
540+
addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) {
541+
return &nodev1.AddNodeResponse{
542+
ExternalNode: &nodev1.ExternalNode{
543+
ExternalNodeId: "unode_abc",
544+
OrganizationId: "org_123",
545+
Name: req.GetName(),
546+
DeviceId: req.GetDeviceId(),
547+
ConnectivityInfo: &nodev1.ConnectivityInfo{
548+
RegistrationCommand: "netbird up --key abc",
549+
},
550+
},
551+
}, nil
552+
},
553+
grantNodeSSHAccessFn: func(_ *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) {
554+
grantCalls++
555+
if grantCalls < 2 {
556+
return nil, connect.NewError(connect.CodeInternal, nil)
557+
}
558+
return &nodev1.GrantNodeSSHAccessResponse{}, nil
559+
},
560+
}
561+
562+
deps, server := testRegisterDeps(t, svc, regStore)
563+
defer server.Close()
564+
565+
deps.prompter = mockConfirmer{confirm: true}
566+
567+
term := terminal.New()
568+
err := runRegister(context.Background(), term, store, "My Spark", deps)
569+
if err != nil {
570+
t.Fatalf("runRegister failed: %v", err)
571+
}
572+
573+
if grantCalls != 2 {
574+
t.Errorf("expected GrantNodeSSHAccess to be called 2 times (retry once), got %d", grantCalls)
575+
}
576+
}
577+
578+
func Test_runRegister_GrantSSH_no_retry_on_permanent_error(t *testing.T) {
579+
regStore := &mockRegistrationStore{}
580+
581+
store := &mockRegisterStore{
582+
user: &entity.User{ID: "user_1"},
583+
org: &entity.Organization{ID: "org_123", Name: "TestOrg"},
584+
home: "/home/testuser/.brev",
585+
token: "tok",
586+
}
587+
588+
var grantCalls int
589+
svc := &fakeNodeService{
590+
addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) {
591+
return &nodev1.AddNodeResponse{
592+
ExternalNode: &nodev1.ExternalNode{
593+
ExternalNodeId: "unode_abc",
594+
OrganizationId: "org_123",
595+
Name: req.GetName(),
596+
DeviceId: req.GetDeviceId(),
597+
ConnectivityInfo: &nodev1.ConnectivityInfo{
598+
RegistrationCommand: "netbird up --key abc",
599+
},
600+
},
601+
}, nil
602+
},
603+
grantNodeSSHAccessFn: func(_ *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) {
604+
grantCalls++
605+
return nil, connect.NewError(connect.CodePermissionDenied, nil)
606+
},
607+
}
608+
609+
deps, server := testRegisterDeps(t, svc, regStore)
610+
defer server.Close()
611+
612+
deps.prompter = mockConfirmer{confirm: true}
613+
614+
term := terminal.New()
615+
err := runRegister(context.Background(), term, store, "My Spark", deps)
616+
if err != nil {
617+
t.Fatalf("runRegister should not fail the overall flow when SSH grant fails: %v", err)
618+
}
619+
620+
if grantCalls != 1 {
621+
t.Errorf("expected GrantNodeSSHAccess to be called once (no retry on permanent error), got %d", grantCalls)
622+
}
623+
}

pkg/cmd/register/rpcclient_test.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,10 @@ func Test_toProtoNodeSpec_MinimalFields(t *testing.T) {
143143
// fakeNodeService implements the server side of ExternalNodeService for testing.
144144
type fakeNodeService struct {
145145
nodev1connect.UnimplementedExternalNodeServiceHandler
146-
addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error)
147-
removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error)
148-
getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error)
146+
addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error)
147+
removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error)
148+
getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error)
149+
grantNodeSSHAccessFn func(*nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error)
149150
}
150151

151152
func (f *fakeNodeService) AddNode(_ context.Context, req *connect.Request[nodev1.AddNodeRequest]) (*connect.Response[nodev1.AddNodeResponse], error) {
@@ -175,6 +176,17 @@ func (f *fakeNodeService) GetNode(_ context.Context, req *connect.Request[nodev1
175176
return connect.NewResponse(resp), nil
176177
}
177178

179+
func (f *fakeNodeService) GrantNodeSSHAccess(_ context.Context, req *connect.Request[nodev1.GrantNodeSSHAccessRequest]) (*connect.Response[nodev1.GrantNodeSSHAccessResponse], error) {
180+
if f.grantNodeSSHAccessFn == nil {
181+
return nil, connect.NewError(connect.CodeUnimplemented, nil)
182+
}
183+
resp, err := f.grantNodeSSHAccessFn(req.Msg)
184+
if err != nil {
185+
return nil, err
186+
}
187+
return connect.NewResponse(resp), nil
188+
}
189+
178190
func Test_NewNodeServiceClient_AddNode(t *testing.T) {
179191
svc := &fakeNodeService{
180192
addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) {

0 commit comments

Comments
 (0)