Skip to content

Commit 7a6bdc1

Browse files
committed
fix enable-ssh port collision
1 parent ab796b3 commit 7a6bdc1

2 files changed

Lines changed: 149 additions & 5 deletions

File tree

pkg/cmd/enablessh/enablessh.go

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ import (
88
"os/exec"
99
"os/user"
1010

11+
nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1"
12+
"connectrpc.com/connect"
13+
1114
"github.com/brevdev/brev-cli/pkg/cmd/register"
15+
"github.com/brevdev/brev-cli/pkg/config"
1216
"github.com/brevdev/brev-cli/pkg/entity"
1317
breverrors "github.com/brevdev/brev-cli/pkg/errors"
1418
"github.com/brevdev/brev-cli/pkg/externalnode"
@@ -99,14 +103,24 @@ func enableSSH(
99103
t.Vprintf(" Linux user: %s\n", linuxUsername)
100104
t.Vprint("")
101105

102-
t.Vprint("")
103-
port, err := register.PromptSSHPort(t)
106+
// Check if the node already has an SSH port allocated (e.g. for another linux user)
107+
port, err := existingSSHPort(ctx, deps, tokenProvider, reg)
104108
if err != nil {
105-
return fmt.Errorf("SSH port: %w", err)
109+
t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: could not check for existing ports: %v", err)))
106110
}
107111

108-
if err := register.OpenSSHPort(ctx, t, deps.nodeClients, tokenProvider, reg, port); err != nil {
109-
return fmt.Errorf("enable SSH failed: %w", err)
112+
if port != 0 {
113+
t.Vprintf(" Using existing SSH port %d.\n", port)
114+
} else {
115+
t.Vprint("")
116+
port, err = register.PromptSSHPort(t)
117+
if err != nil {
118+
return fmt.Errorf("SSH port: %w", err)
119+
}
120+
121+
if err := register.OpenSSHPort(ctx, t, deps.nodeClients, tokenProvider, reg, port); err != nil {
122+
return fmt.Errorf("enable SSH failed: %w", err)
123+
}
110124
}
111125

112126
if err := register.SetupAndRegisterNodeSSHAccess(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, linuxUsername); err != nil {
@@ -117,6 +131,27 @@ func enableSSH(
117131
return nil
118132
}
119133

134+
// existingSSHPort calls GetNode and returns the PortNumber of an already-allocated
135+
// SSH port, or 0 if none exists
136+
func existingSSHPort(ctx context.Context, deps enableSSHDeps, tokenProvider externalnode.TokenProvider, reg *register.DeviceRegistration) (int32, error) {
137+
client := deps.nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL())
138+
resp, err := client.GetNode(ctx, connect.NewRequest(&nodev1.GetNodeRequest{
139+
ExternalNodeId: reg.ExternalNodeID,
140+
OrganizationId: reg.OrgID,
141+
}))
142+
if err != nil {
143+
return 0, fmt.Errorf("error retrieving node: %w", err)
144+
}
145+
146+
for _, p := range resp.Msg.GetExternalNode().GetPorts() {
147+
// TODO if we ever allow more than one SSH port, this should be modified
148+
if p.GetProtocol() == nodev1.PortProtocol_PORT_PROTOCOL_SSH {
149+
return p.GetPortNumber(), nil
150+
}
151+
}
152+
return 0, nil
153+
}
154+
120155
// checkSSHDaemon prints a warning if neither "ssh" nor "sshd" systemd services
121156
// appear to be active. It never returns an error — it is best-effort.
122157
func checkSSHDaemon(t *terminal.Terminal) {

pkg/cmd/enablessh/enablessh_test.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
package enablessh
22

33
import (
4+
"context"
5+
"fmt"
6+
"net/http/httptest"
47
"os"
58
"os/user"
69
"path/filepath"
710
"strings"
811
"testing"
912

13+
nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect"
14+
nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1"
15+
"connectrpc.com/connect"
16+
1017
"github.com/brevdev/brev-cli/pkg/cmd/register"
18+
"github.com/brevdev/brev-cli/pkg/externalnode"
1119
)
1220

1321
// tempUser returns a *user.User whose HomeDir points to a temporary directory.
@@ -413,3 +421,104 @@ func Test_InstallThenRemoveSpecificKey_RollbackScenario(t *testing.T) {
413421
t.Errorf("Alice's key was removed during Bob's rollback:\n%s", result)
414422
}
415423
}
424+
425+
// --- existingSSHPort ---
426+
427+
type mockNodeClientFactory struct{ serverURL string }
428+
429+
func (m mockNodeClientFactory) NewNodeClient(provider externalnode.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient {
430+
return register.NewNodeServiceClient(provider, m.serverURL)
431+
}
432+
433+
type mockEnableSSHStore struct {
434+
token string
435+
}
436+
437+
func (m *mockEnableSSHStore) GetCurrentUser() (interface{}, error) { return nil, nil }
438+
func (m *mockEnableSSHStore) GetAccessToken() (string, error) { return m.token, nil }
439+
440+
// fakeNodeService implements the server side of ExternalNodeService for testing.
441+
type fakeNodeService struct {
442+
nodev1connect.UnimplementedExternalNodeServiceHandler
443+
getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error)
444+
}
445+
446+
func (f *fakeNodeService) GetNode(_ context.Context, req *connect.Request[nodev1.GetNodeRequest]) (*connect.Response[nodev1.GetNodeResponse], error) {
447+
resp, err := f.getNodeFn(req.Msg)
448+
if err != nil {
449+
return nil, err
450+
}
451+
return connect.NewResponse(resp), nil
452+
}
453+
454+
func startFakeServer(t *testing.T, svc *fakeNodeService) (enableSSHDeps, *httptest.Server) {
455+
t.Helper()
456+
_, handler := nodev1connect.NewExternalNodeServiceHandler(svc)
457+
server := httptest.NewServer(handler)
458+
t.Cleanup(server.Close)
459+
return enableSSHDeps{
460+
nodeClients: mockNodeClientFactory{serverURL: server.URL},
461+
}, server
462+
}
463+
464+
func Test_existingSSHPort(t *testing.T) {
465+
ssh := nodev1.PortProtocol_PORT_PROTOCOL_SSH
466+
tcp := nodev1.PortProtocol_PORT_PROTOCOL_TCP
467+
468+
tests := []struct {
469+
name string
470+
resp *nodev1.GetNodeResponse
471+
rpcErr error
472+
wantPort int32
473+
wantErr bool
474+
}{
475+
{
476+
name: "ReturnsExistingPort",
477+
resp: &nodev1.GetNodeResponse{ExternalNode: &nodev1.ExternalNode{
478+
Ports: []*nodev1.Port{{Protocol: ssh, PortNumber: 2222}},
479+
}},
480+
wantPort: 2222,
481+
},
482+
{
483+
name: "ReturnsZeroWhenNoPorts",
484+
resp: &nodev1.GetNodeResponse{ExternalNode: &nodev1.ExternalNode{}},
485+
wantPort: 0,
486+
},
487+
{
488+
name: "ReturnsErrorOnRPCFailure",
489+
rpcErr: connect.NewError(connect.CodeInternal, fmt.Errorf("server error")),
490+
wantErr: true,
491+
},
492+
{
493+
name: "IgnoresNonSSHPorts",
494+
resp: &nodev1.GetNodeResponse{ExternalNode: &nodev1.ExternalNode{
495+
Ports: []*nodev1.Port{
496+
{Protocol: tcp, PortNumber: 8080},
497+
{Protocol: ssh, PortNumber: 3333},
498+
},
499+
}},
500+
wantPort: 3333,
501+
},
502+
}
503+
504+
for _, tt := range tests {
505+
t.Run(tt.name, func(t *testing.T) {
506+
svc := &fakeNodeService{
507+
getNodeFn: func(_ *nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) {
508+
return tt.resp, tt.rpcErr
509+
},
510+
}
511+
deps, _ := startFakeServer(t, svc)
512+
store := &mockEnableSSHStore{token: "tok"}
513+
reg := &register.DeviceRegistration{ExternalNodeID: "unode_abc"}
514+
515+
port, err := existingSSHPort(context.Background(), deps, store, reg)
516+
if (err != nil) != tt.wantErr {
517+
t.Fatalf("err = %v, wantErr = %v", err, tt.wantErr)
518+
}
519+
if port != tt.wantPort {
520+
t.Errorf("port = %d, want %d", port, tt.wantPort)
521+
}
522+
})
523+
}
524+
}

0 commit comments

Comments
 (0)