Skip to content

Commit 20d46a0

Browse files
committed
fix(BRE2-824): validate SSH port and allow re-entry
1 parent f5105a0 commit 20d46a0

5 files changed

Lines changed: 103 additions & 22 deletions

File tree

pkg/cmd/enablessh/enablessh.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func enableSSH(
115115
t.Vprint("")
116116
port, err = register.PromptSSHPort(t)
117117
if err != nil {
118-
return fmt.Errorf("SSH port: %w", err)
118+
return fmt.Errorf("invalid SSH port: %w", err)
119119
}
120120

121121
if err := register.OpenSSHPort(ctx, t, deps.nodeClients, tokenProvider, reg, port); err != nil {

pkg/cmd/register/register.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, opt
250250
return fmt.Errorf("failed to determine current Linux user: %w", err)
251251
}
252252
if err := grantSSHAccessWithPort(ctx, t, deps, s, reg, brevUser, osUser, sshPortForGrant, opts.interactive, opts.skipConfirm); err != nil {
253-
t.Vprintf(" Warning: SSH access not granted: %v\n", err)
253+
t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: %v", err)))
254254
}
255255
}
256256

@@ -465,7 +465,7 @@ func grantSSHAccessWithPort(ctx context.Context, t *terminal.Terminal, deps regi
465465
t.Vprint("")
466466
port, err = PromptSSHPort(t)
467467
if err != nil {
468-
return fmt.Errorf("SSH port: %w", err)
468+
return fmt.Errorf("invalid SSH port: %w", err)
469469
}
470470
} else {
471471
t.Vprintf(" %s %s\n", t.Green(fmt.Sprintf("%-14s", "SSH port:")), t.BoldBlue(fmt.Sprintf("%d", port)))

pkg/cmd/register/register_test.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ func Test_runRegister_NoNameAlreadyRegistered(t *testing.T) {
980980
}
981981
}
982982

983-
func Test_runRegister_OpenSSHPort(t *testing.T) { // nolint:funlen // test
983+
func Test_runRegister_OpenSSHPort(t *testing.T) { // nolint:funlen, gocyclo, gocognit // test
984984
tests := []struct {
985985
name string
986986
port int32
@@ -1035,6 +1035,23 @@ func Test_runRegister_OpenSSHPort(t *testing.T) { // nolint:funlen // test
10351035
}
10361036
},
10371037
},
1038+
{
1039+
name: "InvalidPortNoAPICall",
1040+
port: 99999,
1041+
verify: func(t *testing.T, openReq *nodev1.OpenPortRequest, _ *nodev1.GrantNodeSSHAccessRequest, regStore *mockRegistrationStore, err error) {
1042+
t.Helper()
1043+
if err != nil {
1044+
t.Fatalf("registration should succeed even when SSH port is invalid (soft error), got: %v", err)
1045+
}
1046+
if openReq != nil {
1047+
t.Error("expected OpenPort NOT to be called for invalid port")
1048+
}
1049+
exists, _ := regStore.Exists()
1050+
if !exists {
1051+
t.Error("expected registration to still exist after invalid port")
1052+
}
1053+
},
1054+
},
10381055
{
10391056
name: "GrantRequestHasNoPort",
10401057
port: 22,

pkg/cmd/register/sshkeys.go

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ func OpenSSHPort(
177177
reg *DeviceRegistration,
178178
port int32,
179179
) error {
180+
if port < 1 || port > 65535 {
181+
return fmt.Errorf("invalid SSH port %d: port must be between 1 and 65535", port)
182+
}
180183
client := nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL())
181184
_, err := client.OpenPort(ctx, connect.NewRequest(&nodev1.OpenPortRequest{
182185
ExternalNodeId: reg.ExternalNodeID,
@@ -275,30 +278,34 @@ func SetTestSSHPort(port int32) { testSSHPort = &port }
275278
func ClearTestSSHPort() { testSSHPort = nil }
276279

277280
// PromptSSHPort prompts the user for the target SSH port, defaulting to 22 if
278-
// they press Enter or leave it empty. Returns an error for invalid port numbers.
279-
// Uses a single print + stdin read to avoid promptui rendering the label twice.
281+
// they press Enter or leave it empty. Re-prompts on invalid input until a valid
282+
// port is provided. Only returns an error for unrecoverable I/O failures.
280283
func PromptSSHPort(t *terminal.Terminal) (int32, error) {
281284
if testSSHPort != nil {
282285
return *testSSHPort, nil
283286
}
284-
t.Vprintf(" %s ", t.Green("SSH port (default 22):"))
285287
reader := bufio.NewReader(os.Stdin)
286-
line, err := reader.ReadString('\n')
287-
if err != nil {
288-
return 0, fmt.Errorf("reading input: %w", err)
289-
}
290-
portStr := strings.TrimSpace(line)
291-
if portStr == "" {
292-
return defaultSSHPort, nil
293-
}
294-
n, err := strconv.ParseUint(portStr, 10, 16)
295-
if err != nil {
296-
return 0, fmt.Errorf("invalid port %q: %w", portStr, err)
297-
}
298-
if n < 1 || n > 65535 {
299-
return 0, fmt.Errorf("port must be between 1 and 65535, got %d", n)
288+
for {
289+
t.Vprintf(" %s ", t.Green("SSH port (default 22):"))
290+
line, err := reader.ReadString('\n')
291+
if err != nil {
292+
return 0, fmt.Errorf("reading input: %w", err)
293+
}
294+
portStr := strings.TrimSpace(line)
295+
if portStr == "" {
296+
return defaultSSHPort, nil
297+
}
298+
n, err := strconv.ParseUint(portStr, 10, 16)
299+
if err != nil {
300+
t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Invalid port %q: port must be a number between 1 and 65535", portStr)))
301+
continue
302+
}
303+
if n < 1 || n > 65535 { // explicit gate even though ParseUint will already fail values out of range
304+
t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Invalid port %q: port must be between 1 and 65535", portStr)))
305+
continue
306+
}
307+
return int32(n), nil
300308
}
301-
return int32(n), nil
302309
}
303310

304311
// InstallAuthorizedKey appends the given public key to the user's

pkg/cmd/register/sshkeys_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"path/filepath"
77
"strings"
88
"testing"
9+
10+
"github.com/brevdev/brev-cli/pkg/terminal"
911
)
1012

1113
func tempUser(t *testing.T) *user.User {
@@ -215,6 +217,61 @@ func TestInstallAuthorizedKey_IncludesUserID(t *testing.T) {
215217
}
216218
}
217219

220+
// --- PromptSSHPort ---
221+
222+
func promptSSHPortWithInput(t *testing.T, input string) (int32, error) {
223+
t.Helper()
224+
225+
r, w, err := os.Pipe()
226+
if err != nil {
227+
t.Fatalf("creating pipe: %v", err)
228+
}
229+
defer r.Close()
230+
231+
// Write input and close writer so ReadString sees EOF after newline.
232+
if _, err := w.WriteString(input); err != nil {
233+
t.Fatalf("writing to pipe: %v", err)
234+
}
235+
w.Close()
236+
237+
origStdin := os.Stdin
238+
os.Stdin = r
239+
defer func() { os.Stdin = origStdin }()
240+
241+
ClearTestSSHPort() // ensure we go through the real path
242+
term := terminal.New()
243+
return PromptSSHPort(term)
244+
}
245+
246+
func TestPromptSSHPort(t *testing.T) {
247+
tests := []struct {
248+
name string
249+
input string
250+
want int32
251+
}{
252+
{"Default", "\n", 22},
253+
{"CustomPort", "2222\n", 2222},
254+
{"MinPort", "1\n", 1},
255+
{"MaxPort", "65535\n", 65535},
256+
{"RetryAfterOutOfRange", "99999\n22\n", 22},
257+
{"RetryAfterZero", "0\n443\n", 443},
258+
{"RetryAfterNonNumeric", "abc\n8080\n", 8080},
259+
{"RetryAfterNegative", "-1\n22\n", 22},
260+
{"RetryMultipleThenValid", "foo\n99999\n22\n", 22},
261+
}
262+
for _, tt := range tests {
263+
t.Run(tt.name, func(t *testing.T) {
264+
port, err := promptSSHPortWithInput(t, tt.input)
265+
if err != nil {
266+
t.Fatalf("unexpected error: %v", err)
267+
}
268+
if port != tt.want {
269+
t.Errorf("expected port %d, got %d", tt.want, port)
270+
}
271+
})
272+
}
273+
}
274+
218275
func TestInstallAuthorizedKey_EmptyUserID_UsesPrefix(t *testing.T) {
219276
u := tempUser(t)
220277

0 commit comments

Comments
 (0)