@@ -913,3 +913,134 @@ func Test_runRegister_NoNameAlreadyRegistered(t *testing.T) {
913913 t .Error ("expected registration to still exist" )
914914 }
915915}
916+
917+ func Test_runRegister_OpenSSHPort (t * testing.T ) { // nolint:funlen // test
918+ tests := []struct {
919+ name string
920+ port int32
921+ openFn func (* nodev1.OpenPortRequest ) (* nodev1.OpenPortResponse , error )
922+ verify func (t * testing.T , openReq * nodev1.OpenPortRequest , grantReq * nodev1.GrantNodeSSHAccessRequest , reg * mockRegistrationStore , err error )
923+ }{
924+ {
925+ name : "SendsCorrectArgs" ,
926+ port : 2222 ,
927+ openFn : func (req * nodev1.OpenPortRequest ) (* nodev1.OpenPortResponse , error ) {
928+ return & nodev1.OpenPortResponse {
929+ Port : & nodev1.Port {
930+ PortId : "port_ssh" ,
931+ Protocol : req .GetProtocol (),
932+ PortNumber : req .GetPortNumber (),
933+ },
934+ }, nil
935+ },
936+ verify : func (t * testing.T , openReq * nodev1.OpenPortRequest , _ * nodev1.GrantNodeSSHAccessRequest , _ * mockRegistrationStore , err error ) {
937+ t .Helper ()
938+ if err != nil {
939+ t .Fatalf ("runRegister failed: %v" , err )
940+ }
941+ if openReq == nil {
942+ t .Fatal ("expected OpenPort to be called" )
943+ }
944+ if openReq .GetExternalNodeId () != "unode_abc" {
945+ t .Errorf ("expected node ID unode_abc, got %s" , openReq .GetExternalNodeId ())
946+ }
947+ if openReq .GetProtocol () != nodev1 .PortProtocol_PORT_PROTOCOL_SSH {
948+ t .Errorf ("expected PORT_PROTOCOL_SSH, got %s" , openReq .GetProtocol ())
949+ }
950+ if openReq .GetPortNumber () != 2222 {
951+ t .Errorf ("expected port 2222, got %d" , openReq .GetPortNumber ())
952+ }
953+ },
954+ },
955+ {
956+ name : "FailureIsSoftError" ,
957+ port : 22 ,
958+ openFn : func (_ * nodev1.OpenPortRequest ) (* nodev1.OpenPortResponse , error ) {
959+ return nil , connect .NewError (connect .CodeInternal , fmt .Errorf ("skybridge unavailable" ))
960+ },
961+ verify : func (t * testing.T , _ * nodev1.OpenPortRequest , _ * nodev1.GrantNodeSSHAccessRequest , regStore * mockRegistrationStore , err error ) {
962+ t .Helper ()
963+ if err != nil {
964+ t .Fatalf ("registration should succeed even when OpenSSHPort fails (soft error), got: %v" , err )
965+ }
966+ exists , _ := regStore .Exists ()
967+ if ! exists {
968+ t .Error ("expected registration to still exist after OpenSSHPort failure" )
969+ }
970+ },
971+ },
972+ {
973+ name : "GrantRequestHasNoPort" ,
974+ port : 22 ,
975+ verify : func (t * testing.T , _ * nodev1.OpenPortRequest , grantReq * nodev1.GrantNodeSSHAccessRequest , _ * mockRegistrationStore , err error ) {
976+ t .Helper ()
977+ if err != nil {
978+ t .Fatalf ("runRegister failed: %v" , err )
979+ }
980+ if grantReq == nil {
981+ t .Fatal ("expected GrantNodeSSHAccess to be called" )
982+ }
983+ if grantReq .GetExternalNodeId () != "unode_abc" {
984+ t .Errorf ("expected node ID unode_abc, got %s" , grantReq .GetExternalNodeId ())
985+ }
986+ if grantReq .GetUserId () != "user_1" {
987+ t .Errorf ("expected user ID user_1, got %s" , grantReq .GetUserId ())
988+ }
989+ },
990+ },
991+ }
992+
993+ for _ , tt := range tests {
994+ t .Run (tt .name , func (t * testing.T ) {
995+ regStore := & mockRegistrationStore {}
996+ store := & mockRegisterStore {
997+ user : & entity.User {ID : "user_1" },
998+ org : & entity.Organization {ID : "org_123" , Name : "TestOrg" },
999+ token : "tok" ,
1000+ }
1001+
1002+ var gotOpenReq * nodev1.OpenPortRequest
1003+ var gotGrantReq * nodev1.GrantNodeSSHAccessRequest
1004+ svc := & fakeNodeService {
1005+ addNodeFn : func (req * nodev1.AddNodeRequest ) (* nodev1.AddNodeResponse , error ) {
1006+ return & nodev1.AddNodeResponse {
1007+ ExternalNode : & nodev1.ExternalNode {
1008+ ExternalNodeId : "unode_abc" ,
1009+ OrganizationId : "org_123" ,
1010+ Name : req .GetName (),
1011+ DeviceId : req .GetDeviceId (),
1012+ ConnectivityInfo : & nodev1.ConnectivityInfo {
1013+ RegistrationCommand : "netbird up --key abc" ,
1014+ },
1015+ },
1016+ }, nil
1017+ },
1018+ openPortFn : func (req * nodev1.OpenPortRequest ) (* nodev1.OpenPortResponse , error ) {
1019+ gotOpenReq = req
1020+ if tt .openFn != nil {
1021+ return tt .openFn (req )
1022+ }
1023+ return & nodev1.OpenPortResponse {
1024+ Port : & nodev1.Port {PortId : "port_ssh" , Protocol : req .GetProtocol (), PortNumber : req .GetPortNumber ()},
1025+ }, nil
1026+ },
1027+ grantNodeSSHAccessFn : func (_ * nodev1.GrantNodeSSHAccessRequest ) (* nodev1.GrantNodeSSHAccessResponse , error ) {
1028+ return & nodev1.GrantNodeSSHAccessResponse {}, nil
1029+ },
1030+ }
1031+
1032+ deps , server := testRegisterDeps (t , svc , regStore )
1033+ defer server .Close ()
1034+
1035+ deps .prompter = mockConfirmer {confirm : true }
1036+
1037+ SetTestSSHPort (tt .port )
1038+ defer ClearTestSSHPort ()
1039+
1040+ term := terminal .New ()
1041+ err := runRegister (context .Background (), term , store , "my-spark" , "" , deps )
1042+
1043+ tt .verify (t , gotOpenReq , gotGrantReq , regStore , err )
1044+ })
1045+ }
1046+ }
0 commit comments