Skip to content

Commit 8f06d5f

Browse files
committed
fix(BRE2-804): various polish: name sanitization
1 parent 78b382a commit 8f06d5f

5 files changed

Lines changed: 162 additions & 16 deletions

File tree

pkg/cmd/gpucreate/gpucreate.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/brevdev/brev-cli/pkg/entity"
1919
breverrors "github.com/brevdev/brev-cli/pkg/errors"
2020
"github.com/brevdev/brev-cli/pkg/featureflag"
21+
"github.com/brevdev/brev-cli/pkg/names"
2122
"github.com/brevdev/brev-cli/pkg/store"
2223
"github.com/brevdev/brev-cli/pkg/terminal"
2324
"github.com/spf13/cobra"
@@ -194,8 +195,8 @@ func NewCmdGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore) *cobra
194195
}
195196
}
196197

197-
if name == "" {
198-
return breverrors.NewValidationError("name is required (as argument or --name flag)")
198+
if err := names.ValidateNodeName(name); err != nil {
199+
return err
199200
}
200201

201202
if count < 1 {

pkg/cmd/register/register.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/brevdev/brev-cli/pkg/entity"
1919
breverrors "github.com/brevdev/brev-cli/pkg/errors"
2020
"github.com/brevdev/brev-cli/pkg/externalnode"
21+
"github.com/brevdev/brev-cli/pkg/names"
2122
"github.com/brevdev/brev-cli/pkg/terminal"
2223

2324
"github.com/spf13/cobra"
@@ -120,8 +121,8 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam
120121
return checkExistingRegistration(ctx, t, s, name, deps)
121122
}
122123

123-
if name == "" {
124-
return fmt.Errorf("please provide a name for this device\n\nUsage: brev register <name>\nExample: brev register \"my-DGX-Spark\"")
124+
if err := names.ValidateNodeName(name); err != nil {
125+
return err
125126
}
126127

127128
brevUser, err := s.GetCurrentUser()

pkg/cmd/register/register_test.go

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ func Test_runRegister_HappyPath(t *testing.T) {
144144
if req.GetOrganizationId() != "org_123" {
145145
t.Errorf("unexpected org: %s", req.GetOrganizationId())
146146
}
147-
if req.GetName() != "My Spark" {
147+
if req.GetName() != "my-spark" {
148148
t.Errorf("unexpected name: %s", req.GetName())
149149
}
150150
return &nodev1.AddNodeResponse{
@@ -172,7 +172,7 @@ func Test_runRegister_HappyPath(t *testing.T) {
172172
defer ClearTestSSHPort()
173173

174174
term := terminal.New()
175-
err := runRegister(context.Background(), term, store, "My Spark", deps)
175+
err := runRegister(context.Background(), term, store, "my-spark", deps)
176176
if err != nil {
177177
t.Fatalf("runRegister failed: %v", err)
178178
}
@@ -193,8 +193,8 @@ func Test_runRegister_HappyPath(t *testing.T) {
193193
if reg.ExternalNodeID != "unode_abc" {
194194
t.Errorf("expected ExternalNodeID unode_abc, got %s", reg.ExternalNodeID)
195195
}
196-
if reg.DisplayName != "My Spark" {
197-
t.Errorf("expected display name 'My Spark', got %s", reg.DisplayName)
196+
if reg.DisplayName != "my-spark" {
197+
t.Errorf("expected display name 'my-spark', got %s", reg.DisplayName)
198198
}
199199
if reg.OrgID != "org_123" {
200200
t.Errorf("expected org org_123, got %s", reg.OrgID)
@@ -223,7 +223,7 @@ func Test_runRegister_UserCancels(t *testing.T) {
223223
deps.prompter = mockConfirmer{confirm: false}
224224

225225
term := terminal.New()
226-
err := runRegister(context.Background(), term, store, "My Spark", deps)
226+
err := runRegister(context.Background(), term, store, "my-spark", deps)
227227
if err != nil {
228228
t.Fatalf("expected nil error on cancel, got: %v", err)
229229
}
@@ -339,7 +339,7 @@ func Test_runRegister_NoOrganization(t *testing.T) {
339339
defer server.Close()
340340

341341
term := terminal.New()
342-
err := runRegister(context.Background(), term, store, "My Spark", deps)
342+
err := runRegister(context.Background(), term, store, "my-spark", deps)
343343
if err == nil {
344344
t.Fatal("expected error when no org exists")
345345
}
@@ -365,7 +365,7 @@ func Test_runRegister_AddNodeFails(t *testing.T) {
365365
defer server.Close()
366366

367367
term := terminal.New()
368-
err := runRegister(context.Background(), term, store, "My Spark", deps)
368+
err := runRegister(context.Background(), term, store, "my-spark", deps)
369369
if err == nil {
370370
t.Fatal("expected error when AddNode fails")
371371
}
@@ -415,7 +415,7 @@ func Test_runRegister_NoSetupCommand(t *testing.T) {
415415
defer ClearTestSSHPort()
416416

417417
term := terminal.New()
418-
err := runRegister(context.Background(), term, store, "My Spark", deps)
418+
err := runRegister(context.Background(), term, store, "my-spark", deps)
419419
if err != nil {
420420
t.Fatalf("runRegister failed: %v", err)
421421
}
@@ -548,7 +548,7 @@ func Test_runRegister_GrantSSH_retries_on_connection_error_then_succeeds(t *test
548548
defer ClearTestSSHPort()
549549

550550
term := terminal.New()
551-
err := runRegister(context.Background(), term, store, "My Spark", deps)
551+
err := runRegister(context.Background(), term, store, "my-spark", deps)
552552
if err != nil {
553553
t.Fatalf("runRegister failed: %v", err)
554554
}
@@ -597,7 +597,7 @@ func Test_runRegister_GrantSSH_no_retry_on_permanent_error(t *testing.T) {
597597
defer ClearTestSSHPort()
598598

599599
term := terminal.New()
600-
err := runRegister(context.Background(), term, store, "My Spark", deps)
600+
err := runRegister(context.Background(), term, store, "my-spark", deps)
601601
if err != nil {
602602
t.Fatalf("runRegister should not fail the overall flow when SSH grant fails: %v", err)
603603
}
@@ -607,6 +607,73 @@ func Test_runRegister_GrantSSH_no_retry_on_permanent_error(t *testing.T) {
607607
}
608608
}
609609

610+
func Test_runRegister_NameValidation(t *testing.T) {
611+
tests := []struct {
612+
name string
613+
input string
614+
wantErr bool
615+
errSubstr string
616+
}{
617+
{"Valid", "my-dgx-spark", false, ""},
618+
{"WithDots", "node.local.1", false, ""},
619+
{"WithUnderscore", "my_node", false, ""},
620+
{"Spaces", "My Spark", true, "letters, digits"},
621+
{"ShellInjection", "$(whoami)", true, "letters, digits"},
622+
{"PathTraversal", "../etc/passwd", true, "letters, digits"},
623+
{"Backticks", "`rm -rf`", true, "letters, digits"},
624+
{"Semicolon", "a;rm -rf /", true, "letters, digits"},
625+
{"LeadingHyphen", "-node", true, "start with"},
626+
{"LeadingDot", ".hidden", true, "start with"},
627+
{"TooLong", strings.Repeat("a", 64), true, "63 characters"},
628+
{"Empty", "", true, "name is required"},
629+
}
630+
631+
for _, tt := range tests {
632+
t.Run(tt.name, func(t *testing.T) {
633+
regStore := &mockRegistrationStore{}
634+
store := &mockRegisterStore{
635+
user: &entity.User{ID: "user_1"},
636+
org: &entity.Organization{ID: "org_123", Name: "TestOrg"},
637+
token: "tok",
638+
}
639+
640+
svc := &fakeNodeService{
641+
addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) {
642+
return &nodev1.AddNodeResponse{
643+
ExternalNode: &nodev1.ExternalNode{
644+
ExternalNodeId: "unode_abc",
645+
OrganizationId: "org_123",
646+
Name: req.GetName(),
647+
DeviceId: req.GetDeviceId(),
648+
},
649+
}, nil
650+
},
651+
}
652+
653+
deps, server := testRegisterDeps(t, svc, regStore)
654+
defer server.Close()
655+
656+
SetTestSSHPort(22)
657+
defer ClearTestSSHPort()
658+
659+
term := terminal.New()
660+
err := runRegister(context.Background(), term, store, tt.input, deps)
661+
if tt.wantErr {
662+
if err == nil {
663+
t.Fatal("expected error, got nil")
664+
}
665+
if !strings.Contains(err.Error(), tt.errSubstr) {
666+
t.Errorf("expected error containing %q, got: %v", tt.errSubstr, err)
667+
}
668+
} else {
669+
if err != nil {
670+
t.Errorf("unexpected error: %v", err)
671+
}
672+
}
673+
})
674+
}
675+
}
676+
610677
func Test_runRegister_NoNameNotRegistered(t *testing.T) {
611678
regStore := &mockRegistrationStore{}
612679

@@ -625,8 +692,8 @@ func Test_runRegister_NoNameNotRegistered(t *testing.T) {
625692
if err == nil {
626693
t.Fatal("expected error when no name provided and not registered")
627694
}
628-
if !strings.Contains(err.Error(), "please provide a name") {
629-
t.Errorf("expected 'please provide a name' error, got: %v", err)
695+
if !strings.Contains(err.Error(), "name is required") {
696+
t.Errorf("expected 'name is required' error, got: %v", err)
630697
}
631698
}
632699

pkg/names/validate.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package names
2+
3+
import (
4+
"fmt"
5+
"regexp"
6+
7+
breverrors "github.com/brevdev/brev-cli/pkg/errors"
8+
)
9+
10+
var validNameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]*$`)
11+
12+
const maxNameLen = 63
13+
14+
func ValidateNodeName(name string) error {
15+
if name == "" {
16+
return breverrors.NewValidationError("name is required")
17+
}
18+
if len(name) > maxNameLen {
19+
return breverrors.NewValidationError(
20+
fmt.Sprintf("name must be %d characters or fewer (got %d)", maxNameLen, len(name)))
21+
}
22+
if !validNameRe.MatchString(name) {
23+
return breverrors.NewValidationError(
24+
"name must start with a letter or digit and contain only letters, digits, hyphens, underscores, and dots")
25+
}
26+
return nil
27+
}

pkg/names/validate_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package names
2+
3+
import (
4+
"strings"
5+
"testing"
6+
)
7+
8+
func TestValidateNodeName(t *testing.T) {
9+
tests := []struct {
10+
name string
11+
input string
12+
wantErr bool
13+
errSubstr string
14+
}{
15+
{"Valid", "my-dgx-spark", false, ""},
16+
{"WithDots", "node.local.1", false, ""},
17+
{"WithUnderscore", "my_node", false, ""},
18+
{"SingleChar", "a", false, ""},
19+
{"MaxLength", strings.Repeat("a", 63), false, ""},
20+
{"Spaces", "My Spark", true, "letters, digits"},
21+
{"ShellInjection", "$(whoami)", true, "letters, digits"},
22+
{"PathTraversal", "../etc/passwd", true, "letters, digits"},
23+
{"Backticks", "`rm -rf`", true, "letters, digits"},
24+
{"Semicolon", "a;rm -rf /", true, "letters, digits"},
25+
{"Pipe", "a|cat", true, "letters, digits"},
26+
{"Ampersand", "a&bg", true, "letters, digits"},
27+
{"LeadingHyphen", "-node", true, "start with"},
28+
{"LeadingDot", ".hidden", true, "start with"},
29+
{"TooLong", strings.Repeat("a", 64), true, "63 characters"},
30+
{"Empty", "", true, "name is required"},
31+
}
32+
33+
for _, tt := range tests {
34+
t.Run(tt.name, func(t *testing.T) {
35+
err := ValidateNodeName(tt.input)
36+
if tt.wantErr {
37+
if err == nil {
38+
t.Fatal("expected error, got nil")
39+
}
40+
if !strings.Contains(err.Error(), tt.errSubstr) {
41+
t.Errorf("expected error containing %q, got: %v", tt.errSubstr, err)
42+
}
43+
} else {
44+
if err != nil {
45+
t.Errorf("unexpected error: %v", err)
46+
}
47+
}
48+
})
49+
}
50+
}

0 commit comments

Comments
 (0)