@@ -499,3 +499,125 @@ Peers count: 0/0 Connected`
499499 })
500500 }
501501}
502+
503+ func TestIsSSHConnectionError (t * testing.T ) {
504+ t .Run ("nil" , func (t * testing.T ) {
505+ if IsSSHConnectionError (nil ) {
506+ t .Error ("IsSSHConnectionError(nil) should be false" )
507+ }
508+ })
509+ t .Run ("plain_error" , func (t * testing.T ) {
510+ if IsSSHConnectionError (fmt .Errorf ("some error" )) {
511+ t .Error ("IsSSHConnectionError(plain error) should be false" )
512+ }
513+ })
514+ t .Run ("connection_error_type" , func (t * testing.T ) {
515+ err := & sshConnectionError {err : fmt .Errorf ("transient" )}
516+ if ! IsSSHConnectionError (err ) {
517+ t .Error ("IsSSHConnectionError(sshConnectionError) should be true" )
518+ }
519+ })
520+ t .Run ("wrapped_connection_error" , func (t * testing.T ) {
521+ err := fmt .Errorf ("wrapped: %w" , & sshConnectionError {err : fmt .Errorf ("transient" )})
522+ if ! IsSSHConnectionError (err ) {
523+ t .Error ("IsSSHConnectionError(wrapped sshConnectionError) should be true" )
524+ }
525+ })
526+ }
527+
528+ func Test_runRegister_GrantSSH_retries_on_connection_error_then_succeeds (t * testing.T ) {
529+ regStore := & mockRegistrationStore {}
530+
531+ store := & mockRegisterStore {
532+ user : & entity.User {ID : "user_1" },
533+ org : & entity.Organization {ID : "org_123" , Name : "TestOrg" },
534+ home : "/home/testuser/.brev" ,
535+ token : "tok" ,
536+ }
537+
538+ var grantCalls int
539+ svc := & fakeNodeService {
540+ addNodeFn : func (req * nodev1.AddNodeRequest ) (* nodev1.AddNodeResponse , error ) {
541+ return & nodev1.AddNodeResponse {
542+ ExternalNode : & nodev1.ExternalNode {
543+ ExternalNodeId : "unode_abc" ,
544+ OrganizationId : "org_123" ,
545+ Name : req .GetName (),
546+ DeviceId : req .GetDeviceId (),
547+ ConnectivityInfo : & nodev1.ConnectivityInfo {
548+ RegistrationCommand : "netbird up --key abc" ,
549+ },
550+ },
551+ }, nil
552+ },
553+ grantNodeSSHAccessFn : func (_ * nodev1.GrantNodeSSHAccessRequest ) (* nodev1.GrantNodeSSHAccessResponse , error ) {
554+ grantCalls ++
555+ if grantCalls < 2 {
556+ return nil , connect .NewError (connect .CodeInternal , nil )
557+ }
558+ return & nodev1.GrantNodeSSHAccessResponse {}, nil
559+ },
560+ }
561+
562+ deps , server := testRegisterDeps (t , svc , regStore )
563+ defer server .Close ()
564+
565+ deps .prompter = mockConfirmer {confirm : true }
566+
567+ term := terminal .New ()
568+ err := runRegister (context .Background (), term , store , "My Spark" , deps )
569+ if err != nil {
570+ t .Fatalf ("runRegister failed: %v" , err )
571+ }
572+
573+ if grantCalls != 2 {
574+ t .Errorf ("expected GrantNodeSSHAccess to be called 2 times (retry once), got %d" , grantCalls )
575+ }
576+ }
577+
578+ func Test_runRegister_GrantSSH_no_retry_on_permanent_error (t * testing.T ) {
579+ regStore := & mockRegistrationStore {}
580+
581+ store := & mockRegisterStore {
582+ user : & entity.User {ID : "user_1" },
583+ org : & entity.Organization {ID : "org_123" , Name : "TestOrg" },
584+ home : "/home/testuser/.brev" ,
585+ token : "tok" ,
586+ }
587+
588+ var grantCalls int
589+ svc := & fakeNodeService {
590+ addNodeFn : func (req * nodev1.AddNodeRequest ) (* nodev1.AddNodeResponse , error ) {
591+ return & nodev1.AddNodeResponse {
592+ ExternalNode : & nodev1.ExternalNode {
593+ ExternalNodeId : "unode_abc" ,
594+ OrganizationId : "org_123" ,
595+ Name : req .GetName (),
596+ DeviceId : req .GetDeviceId (),
597+ ConnectivityInfo : & nodev1.ConnectivityInfo {
598+ RegistrationCommand : "netbird up --key abc" ,
599+ },
600+ },
601+ }, nil
602+ },
603+ grantNodeSSHAccessFn : func (_ * nodev1.GrantNodeSSHAccessRequest ) (* nodev1.GrantNodeSSHAccessResponse , error ) {
604+ grantCalls ++
605+ return nil , connect .NewError (connect .CodePermissionDenied , nil )
606+ },
607+ }
608+
609+ deps , server := testRegisterDeps (t , svc , regStore )
610+ defer server .Close ()
611+
612+ deps .prompter = mockConfirmer {confirm : true }
613+
614+ term := terminal .New ()
615+ err := runRegister (context .Background (), term , store , "My Spark" , deps )
616+ if err != nil {
617+ t .Fatalf ("runRegister should not fail the overall flow when SSH grant fails: %v" , err )
618+ }
619+
620+ if grantCalls != 1 {
621+ t .Errorf ("expected GrantNodeSSHAccess to be called once (no retry on permanent error), got %d" , grantCalls )
622+ }
623+ }
0 commit comments