Skip to content

Commit fef15d9

Browse files
committed
cleanup
1 parent ae75746 commit fef15d9

3 files changed

Lines changed: 49 additions & 86 deletions

File tree

pkg/cmd/register/register.go

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212

1313
nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1"
1414
"connectrpc.com/connect"
15-
"github.com/cenkalti/backoff/v4"
1615
"github.com/google/uuid"
1716

1817
"github.com/brevdev/brev-cli/pkg/config"
@@ -24,14 +23,6 @@ import (
2423
"github.com/spf13/cobra"
2524
)
2625

27-
const (
28-
backoffInitialInterval = 1 * time.Second
29-
backoffMaxInterval = 10 * time.Second
30-
backoffMaxElapsedTime = 1 * time.Minute
31-
32-
backoffPrintRound = 500 * time.Millisecond
33-
)
34-
3526
// RegisterStore defines the store methods needed by the register command.
3627
type RegisterStore interface {
3728
GetCurrentUser() (*entity.User, error)
@@ -217,7 +208,9 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam
217208
runSetup(node, t, deps)
218209

219210
if deps.prompter.ConfirmYesNo("Would you like to enable SSH access to this device?") {
220-
grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser)
211+
if err := grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser); err != nil {
212+
t.Vprintf(" Warning: SSH access not granted: %v\n", err)
213+
}
221214
}
222215

223216
return nil
@@ -339,7 +332,7 @@ func runSetup(node *nodev1.ExternalNode, t *terminal.Terminal, deps registerDeps
339332
}
340333
}
341334

342-
func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) {
335+
func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) error {
343336
t.Vprint("")
344337
t.Vprint(t.Green("Enabling SSH access on this device"))
345338
t.Vprint("")
@@ -348,29 +341,11 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps
348341
t.Vprintf(" Linux user: %s\n", osUser.Username)
349342
t.Vprint("")
350343

351-
backoffCtx := backoff.WithContext(backoff.NewExponentialBackOff(
352-
backoff.WithInitialInterval(backoffInitialInterval),
353-
backoff.WithMaxInterval(backoffMaxInterval),
354-
backoff.WithMaxElapsedTime(backoffMaxElapsedTime),
355-
), ctx)
356-
357-
opToTry := func() error {
358-
err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser)
359-
if err != nil && !IsSSHConnectionError(err) {
360-
return backoff.Permanent(err)
361-
}
362-
return err
363-
}
364-
onOpErr := func(err error, d time.Duration) {
365-
t.Vprintf(" SSH access not yet granted; retrying in: %s...\n", d.Round(backoffPrintRound))
366-
}
367-
368-
// Retry until the operation succeeds or the context is canceled.
369-
err := backoff.RetryNotify(opToTry, backoffCtx, onOpErr)
344+
err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser)
370345
if err != nil {
371-
t.Vprintf(" Warning: SSH access not granted: %v\n", err)
372-
return
346+
return fmt.Errorf("grant SSH failed: %w", err)
373347
}
374348

375349
t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName)))
350+
return nil
376351
}

pkg/cmd/register/register_test.go

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -500,31 +500,6 @@ Peers count: 0/0 Connected`
500500
}
501501
}
502502

503-
func TestIsSSHConnectionError(t *testing.T) {
504-
t.Run("nil", func(t *testing.T) {
505-
if IsSSHConnectionError(nil) {
506-
t.Error("IsSSHConnectionError(nil) should be false")
507-
}
508-
})
509-
t.Run("plain_error", func(t *testing.T) {
510-
if IsSSHConnectionError(fmt.Errorf("some error")) {
511-
t.Error("IsSSHConnectionError(plain error) should be false")
512-
}
513-
})
514-
t.Run("connection_error_type", func(t *testing.T) {
515-
err := &sshConnectionError{err: fmt.Errorf("transient")}
516-
if !IsSSHConnectionError(err) {
517-
t.Error("IsSSHConnectionError(sshConnectionError) should be true")
518-
}
519-
})
520-
t.Run("wrapped_connection_error", func(t *testing.T) {
521-
err := fmt.Errorf("wrapped: %w", &sshConnectionError{err: fmt.Errorf("transient")})
522-
if !IsSSHConnectionError(err) {
523-
t.Error("IsSSHConnectionError(wrapped sshConnectionError) should be true")
524-
}
525-
})
526-
}
527-
528503
func Test_runRegister_GrantSSH_retries_on_connection_error_then_succeeds(t *testing.T) {
529504
regStore := &mockRegistrationStore{}
530505

pkg/cmd/register/sshkeys.go

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os/user"
99
"path/filepath"
1010
"strings"
11+
"time"
1112

1213
nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1"
1314
"connectrpc.com/connect"
@@ -16,20 +17,16 @@ import (
1617
"github.com/brevdev/brev-cli/pkg/entity"
1718
"github.com/brevdev/brev-cli/pkg/externalnode"
1819
"github.com/brevdev/brev-cli/pkg/terminal"
20+
"github.com/cenkalti/backoff/v4"
1921
)
2022

21-
// sshConnectionError marks an error as being due to a transient connection/transport failure
22-
type sshConnectionError struct{ err error }
23+
const (
24+
backoffInitialInterval = 1 * time.Second
25+
backoffMaxInterval = 10 * time.Second
26+
backoffMaxElapsedTime = 1 * time.Minute
2327

24-
func (e *sshConnectionError) Error() string { return e.err.Error() }
25-
func (e *sshConnectionError) Unwrap() error { return e.err }
26-
27-
// IsSSHConnectionError reports whether err indicates a transient connection/transport
28-
// failure that may be retried. Used by grantSSHAccess to decide whether to backoff-retry.
29-
func IsSSHConnectionError(err error) bool {
30-
var e *sshConnectionError
31-
return errors.As(err, &e)
32-
}
28+
backoffPrintRound = 500 * time.Millisecond
29+
)
3330

3431
// BrevKeyComment is the marker appended to every SSH key that Brev installs.
3532
// It allows RemoveBrevAuthorizedKeys to identify and remove exactly those keys.
@@ -56,28 +53,44 @@ func GrantSSHAccessToNode(
5653
}
5754

5855
client := nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL())
59-
_, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{
60-
ExternalNodeId: reg.ExternalNodeID,
61-
UserId: targetUser.ID,
62-
LinuxUser: osUser.Username,
63-
}))
64-
if err != nil {
65-
// Transport errors (connection reset, EOF) are transient — leave the key
66-
// installed so retries don't need to reinstall it, and signal the caller
67-
// with a distinct error type.
68-
var connectErr *connect.Error
69-
if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeInternal {
70-
return &sshConnectionError{err: fmt.Errorf("failed to grant SSH access (transient): %w", err)}
71-
}
72-
// Permanent error — roll back the key so we don't leave an unrecorded entry.
73-
if targetUser.PublicKey != "" {
74-
if rerr := RemoveAuthorizedKey(osUser, targetUser.PublicKey); rerr != nil {
75-
t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr)))
56+
57+
backoffCtx := backoff.WithContext(backoff.NewExponentialBackOff(
58+
backoff.WithInitialInterval(backoffInitialInterval),
59+
backoff.WithMaxInterval(backoffMaxInterval),
60+
backoff.WithMaxElapsedTime(backoffMaxElapsedTime),
61+
), ctx)
62+
63+
opToTry := func() error {
64+
_, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{
65+
ExternalNodeId: reg.ExternalNodeID,
66+
UserId: targetUser.ID,
67+
LinuxUser: osUser.Username,
68+
}))
69+
if err != nil {
70+
// Retryable error
71+
var connectErr *connect.Error
72+
if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeInternal {
73+
return fmt.Errorf("failed to grant SSH access (transient): %w", err)
74+
}
75+
76+
// Permanent error — roll back the key so we don't leave an unrecorded entry and abort the backoff retry
77+
if targetUser.PublicKey != "" {
78+
if rerr := RemoveAuthorizedKey(osUser, targetUser.PublicKey); rerr != nil {
79+
t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr)))
80+
}
7681
}
82+
return backoff.Permanent(fmt.Errorf("failed to grant SSH access: %w", err))
7783
}
84+
85+
return nil
86+
}
87+
onOpErr := func(err error, d time.Duration) {
88+
t.Vprintf(" SSH access not yet granted; retrying in: %s...\n", d.Round(backoffPrintRound))
89+
}
90+
err := backoff.RetryNotify(opToTry, backoffCtx, onOpErr)
91+
if err != nil {
7892
return fmt.Errorf("failed to grant SSH access: %w", err)
7993
}
80-
8194
return nil
8295
}
8396

0 commit comments

Comments
 (0)