Skip to content

Commit 653637e

Browse files
committed
interactive vs prompt mode
1 parent e5ec5f1 commit 653637e

3 files changed

Lines changed: 246 additions & 62 deletions

File tree

pkg/cmd/register/register.go

Lines changed: 200 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ type RegisterStore interface {
2828
GetCurrentUser() (*entity.User, error)
2929
GetActiveOrganizationOrDefault() (*entity.Organization, error)
3030
GetOrganizationsByName(name string) ([]entity.Organization, error)
31+
ListOrganizations() ([]entity.Organization, error)
3132
GetAccessToken() (string, error)
3233
}
3334

@@ -51,6 +52,7 @@ type SetupRunner interface {
5152
type registerDeps struct {
5253
platform externalnode.PlatformChecker
5354
prompter terminal.Confirmer
55+
selector terminal.Selector
5456
netbird NetBirdManager
5557
setupRunner SetupRunner
5658
nodeClients externalnode.NodeClientFactory
@@ -59,9 +61,11 @@ type registerDeps struct {
5961
}
6062

6163
func defaultRegisterDeps() registerDeps {
64+
p := TerminalPrompter{}
6265
return registerDeps{
6366
platform: LinuxPlatform{},
64-
prompter: TerminalPrompter{},
67+
prompter: p,
68+
selector: p,
6569
netbird: Netbird{},
6670
setupRunner: ShellSetupRunner{},
6771
nodeClients: DefaultNodeClientFactory{},
@@ -73,13 +77,27 @@ func defaultRegisterDeps() registerDeps {
7377
var (
7478
registerLong = `Register your device with NVIDIA Brev
7579
76-
This command sets up network connectivity and registers this machine with Brev.`
80+
This command sets up network connectivity and registers this machine with Brev.
7781
78-
registerExample = ` brev register "My DGX Spark"`
82+
Two modes are supported:
83+
• Interactive (default): run 'brev register' (or 'brev register <name>') and follow prompts for org and options.
84+
• Non-interactive: use any of --name, --org, --enable-ssh, or --ssh-port (or --non-interactive). No prompts; --name and --org are required. Use for scripts/CI.`
85+
86+
registerExample = ` # Interactive (prompts for org, confirmations)
87+
brev register
88+
brev register "My DGX Spark"
89+
90+
# Non-interactive (any flag implies no prompts; --name and --org required)
91+
brev register --name my-node --org my-org
92+
brev register --name my-node --org my-org --enable-ssh --ssh-port 22`
7993
)
8094

8195
func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command {
8296
var orgFlag string
97+
var nonInteractive bool
98+
var nameFlag string
99+
var enableSSH bool
100+
var sshPort int
83101

84102
cmd := &cobra.Command{
85103
Annotations: map[string]string{"configuration": ""},
@@ -90,20 +108,129 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command {
90108
Example: registerExample,
91109
Args: cobra.MaximumNArgs(1),
92110
RunE: func(cmd *cobra.Command, args []string) error {
93-
var name string
94-
if len(args) > 0 {
111+
name := nameFlag
112+
if name == "" && len(args) > 0 {
95113
name = args[0]
96114
}
97-
return runRegister(cmd.Context(), t, store, name, orgFlag, defaultRegisterDeps())
115+
// Non-interactive if explicit flag or any register-specific flag is set (implies script/CI).
116+
flagDriven := nonInteractive ||
117+
nameFlag != "" ||
118+
orgFlag != "" ||
119+
enableSSH ||
120+
cmd.Flags().Changed("ssh-port")
121+
if flagDriven {
122+
return runRegisterFlagDriven(cmd.Context(), t, store, name, orgFlag, enableSSH, int32(sshPort), defaultRegisterDeps())
123+
}
124+
return runRegisterPromptDriven(cmd.Context(), t, store, name, orgFlag, defaultRegisterDeps())
98125
},
99126
}
100127

101-
cmd.Flags().StringVarP(&orgFlag, "org", "o", "", "organization name (overrides active org)")
128+
cmd.Flags().StringVarP(&orgFlag, "org", "o", "", "organization name (required when using non-interactive mode)")
129+
cmd.Flags().BoolVar(&nonInteractive, "non-interactive", false, "non-interactive mode (also implied by --name, --org, --enable-ssh, or --ssh-port)")
130+
cmd.Flags().StringVar(&nameFlag, "name", "", "device name (required when using non-interactive mode)")
131+
cmd.Flags().BoolVar(&enableSSH, "enable-ssh", false, "enable SSH access after registration (non-interactive mode)")
132+
cmd.Flags().IntVar(&sshPort, "ssh-port", 22, "SSH port when using --enable-ssh")
102133

103134
return cmd
104135
}
105136

106-
func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, orgName string, deps registerDeps) error { //nolint:funlen // registration flow
137+
func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, orgName string, deps registerDeps) error {
138+
return runRegisterPromptDriven(ctx, t, s, name, orgName, deps)
139+
}
140+
141+
// runRegisterSteps performs netbird install, hardware profile, AddNode, save registration, and runSetup.
142+
// It does not prompt or enable SSH. Used by both flag-driven and prompt-driven flows.
143+
func runRegisterSteps(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, org *entity.Organization, deps registerDeps) (*DeviceRegistration, error) {
144+
t.Vprint("")
145+
t.Vprint(t.Yellow("[Step 1/3] Setting up Brev tunnel..."))
146+
if err := deps.netbird.Install(); err != nil {
147+
return nil, fmt.Errorf("brev tunnel setup failed: %w", err)
148+
}
149+
t.Vprint(t.Green(" Brev tunnel ready."))
150+
151+
t.Vprint("")
152+
t.Vprint(t.Yellow("[Step 2/3] Collecting hardware profile..."))
153+
t.Vprint("")
154+
155+
hwProfile, err := deps.hardwareProfiler.Profile()
156+
if err != nil {
157+
return nil, fmt.Errorf("failed to collect hardware profile: %w", err)
158+
}
159+
160+
t.Vprint(" Hardware profile:")
161+
t.Vprint(FormatHardwareProfile(hwProfile))
162+
163+
t.Vprint("")
164+
t.Vprint(t.Yellow("[Step 3/3] Registering with Brev..."))
165+
166+
deviceID := uuid.New().String()
167+
client := deps.nodeClients.NewNodeClient(s, config.GlobalConfig.GetBrevPublicAPIURL())
168+
addResp, err := client.AddNode(ctx, connect.NewRequest(&nodev1.AddNodeRequest{
169+
OrganizationId: org.ID,
170+
Name: name,
171+
DeviceId: deviceID,
172+
NodeSpec: toProtoNodeSpec(hwProfile),
173+
}))
174+
if err != nil {
175+
return nil, fmt.Errorf("failed to register node: %w", err)
176+
}
177+
178+
node := addResp.Msg.GetExternalNode()
179+
reg := &DeviceRegistration{
180+
ExternalNodeID: node.GetExternalNodeId(),
181+
DisplayName: name,
182+
OrgID: org.ID,
183+
DeviceID: deviceID,
184+
RegisteredAt: time.Now().UTC().Format(time.RFC3339),
185+
HardwareProfile: *hwProfile,
186+
}
187+
if err := deps.registrationStore.Save(reg); err != nil {
188+
return nil, fmt.Errorf("node registered but failed to save locally: %w", err)
189+
}
190+
191+
t.Vprint(t.Green(" Registration complete."))
192+
runSetup(node, t, deps)
193+
return reg, nil
194+
}
195+
196+
// resolveOrgPromptDriven resolves organization for prompt-driven flow: by name if --org given, else always list and select with arrow keys.
197+
func resolveOrgPromptDriven(s RegisterStore, orgName string, deps registerDeps) (*entity.Organization, error) {
198+
if orgName != "" {
199+
orgs, err := s.GetOrganizationsByName(orgName)
200+
if err != nil {
201+
return nil, breverrors.WrapAndTrace(err)
202+
}
203+
if len(orgs) == 0 {
204+
return nil, fmt.Errorf("no organization found with name %q", orgName)
205+
}
206+
if len(orgs) > 1 {
207+
return nil, fmt.Errorf("multiple organizations found with name %q", orgName)
208+
}
209+
return &orgs[0], nil
210+
}
211+
212+
list, err := s.ListOrganizations()
213+
if err != nil {
214+
return nil, breverrors.WrapAndTrace(err)
215+
}
216+
if len(list) == 0 {
217+
return nil, fmt.Errorf("no organization found; please create or join an organization first")
218+
}
219+
220+
names := make([]string, len(list))
221+
for i := range list {
222+
names[i] = list[i].Name
223+
}
224+
chosen := deps.selector.Select("Select organization", names)
225+
for i := range list {
226+
if list[i].Name == chosen {
227+
return &list[i], nil
228+
}
229+
}
230+
return nil, fmt.Errorf("selected organization not found")
231+
}
232+
233+
func runRegisterPromptDriven(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, orgName string, deps registerDeps) error {
107234
if !deps.platform.IsCompatible() {
108235
return breverrors.New("brev register is only supported on Linux")
109236
}
@@ -120,15 +247,24 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam
120247
return checkExistingRegistration(ctx, t, s, name, deps)
121248
}
122249

250+
if name == "" {
251+
name = terminal.PromptGetInput(terminal.PromptContent{
252+
Label: "Device name",
253+
ErrorMsg: "name is required",
254+
AllowEmpty: false,
255+
})
256+
name = strings.TrimSpace(name)
257+
}
123258
if err := names.ValidateNodeName(name); err != nil {
124259
return breverrors.WrapAndTrace(err)
125260
}
126261

127-
brevUser, err := s.GetCurrentUser()
262+
org, err := resolveOrgPromptDriven(s, orgName, deps)
128263
if err != nil {
129-
return breverrors.WrapAndTrace(err)
264+
return err
130265
}
131-
org, err := getOrgToRegisterFor(s, orgName)
266+
267+
brevUser, err := s.GetCurrentUser()
132268
if err != nil {
133269
return breverrors.WrapAndTrace(err)
134270
}
@@ -155,60 +291,62 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam
155291
return nil
156292
}
157293

158-
t.Vprint("")
159-
t.Vprint(t.Yellow("[Step 1/3] Setting up Brev tunnel..."))
160-
if err := deps.netbird.Install(); err != nil {
161-
return fmt.Errorf("brev tunnel setup failed: %w", err)
294+
reg, err := runRegisterSteps(ctx, t, s, name, org, deps)
295+
if err != nil {
296+
return err
162297
}
163-
t.Vprint(t.Green(" Brev tunnel ready."))
164298

165-
t.Vprint("")
166-
t.Vprint(t.Yellow("[Step 2/3] Collecting hardware profile..."))
167-
t.Vprint("")
168-
169-
hwProfile, err := deps.hardwareProfiler.Profile()
170-
if err != nil {
171-
return fmt.Errorf("failed to collect hardware profile: %w", err)
299+
if deps.prompter.ConfirmYesNo("Would you like to enable SSH access to this device?") {
300+
if err := grantSSHAccessWithPort(ctx, t, deps, s, reg, brevUser, osUser, 0); err != nil {
301+
t.Vprintf(" Warning: SSH access not granted: %v\n", err)
302+
}
172303
}
173304

174-
t.Vprint(" Hardware profile:")
175-
t.Vprint(FormatHardwareProfile(hwProfile))
305+
return nil
306+
}
176307

177-
t.Vprint("")
178-
t.Vprint(t.Yellow("[Step 3/3] Registering with Brev..."))
308+
func runRegisterFlagDriven(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, orgName string, enableSSH bool, sshPort int32, deps registerDeps) error {
309+
if !deps.platform.IsCompatible() {
310+
return breverrors.New("brev register is only supported on Linux")
311+
}
179312

180-
deviceID := uuid.New().String()
181-
client := deps.nodeClients.NewNodeClient(s, config.GlobalConfig.GetBrevPublicAPIURL())
182-
addResp, err := client.AddNode(ctx, connect.NewRequest(&nodev1.AddNodeRequest{
183-
OrganizationId: org.ID,
184-
Name: name,
185-
DeviceId: deviceID,
186-
NodeSpec: toProtoNodeSpec(hwProfile),
187-
}))
188-
if err != nil {
189-
return fmt.Errorf("failed to register node: %w", err)
313+
if name == "" || orgName == "" {
314+
return fmt.Errorf("in non-interactive mode --name and --org are required")
190315
}
191316

192-
node := addResp.Msg.GetExternalNode()
193-
reg := &DeviceRegistration{
194-
ExternalNodeID: node.GetExternalNodeId(),
195-
DisplayName: name,
196-
OrgID: org.ID,
197-
DeviceID: deviceID,
198-
RegisteredAt: time.Now().UTC().Format(time.RFC3339),
199-
HardwareProfile: *hwProfile,
317+
alreadyRegistered, err := deps.registrationStore.Exists()
318+
if err != nil {
319+
return breverrors.WrapAndTrace(err)
200320
}
201-
if err := deps.registrationStore.Save(reg); err != nil {
202-
return fmt.Errorf("node registered but failed to save locally: %w", err)
321+
if alreadyRegistered {
322+
return checkExistingRegistration(ctx, t, s, name, deps)
203323
}
204324

205-
t.Vprint(t.Green(" Registration complete."))
325+
if err := names.ValidateNodeName(name); err != nil {
326+
return breverrors.WrapAndTrace(err)
327+
}
206328

207-
runSetup(node, t, deps)
329+
org, err := getOrgToRegisterFor(s, orgName)
330+
if err != nil {
331+
return err
332+
}
208333

209-
if deps.prompter.ConfirmYesNo("Would you like to enable SSH access to this device?") {
210-
if err := grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser); err != nil {
211-
t.Vprintf(" Warning: SSH access not granted: %v\n", err)
334+
reg, err := runRegisterSteps(ctx, t, s, name, org, deps)
335+
if err != nil {
336+
return err
337+
}
338+
339+
if enableSSH {
340+
brevUser, err := s.GetCurrentUser()
341+
if err != nil {
342+
return breverrors.WrapAndTrace(err)
343+
}
344+
osUser, err := user.Current()
345+
if err != nil {
346+
return fmt.Errorf("failed to determine current Linux user: %w", err)
347+
}
348+
if err := grantSSHAccessWithPort(ctx, t, deps, s, reg, brevUser, osUser, sshPort); err != nil {
349+
return err
212350
}
213351
}
214352

@@ -324,7 +462,8 @@ func runSetup(node *nodev1.ExternalNode, t *terminal.Terminal, deps registerDeps
324462
}
325463
}
326464

327-
func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) error {
465+
// grantSSHAccessWithPort enables SSH; if port is 0, prompts for port (prompt-driven). Otherwise uses the given port (flag-driven).
466+
func grantSSHAccessWithPort(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User, port int32) error {
328467
t.Vprint("")
329468
t.Vprint(t.Green("Enabling SSH access on this device"))
330469
t.Vprint("")
@@ -333,9 +472,14 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps
333472
t.Vprintf(" Linux user: %s\n", osUser.Username)
334473
t.Vprint("")
335474

336-
port, err := PromptSSHPort(t)
337-
if err != nil {
338-
return fmt.Errorf("SSH port: %w", err)
475+
var err error
476+
if port == 0 {
477+
port, err = PromptSSHPort(t)
478+
if err != nil {
479+
return fmt.Errorf("SSH port: %w", err)
480+
}
481+
} else {
482+
t.Vprintf(" SSH port: %d\n", port)
339483
}
340484

341485
if err := OpenSSHPort(ctx, t, deps.nodeClients, tokenProvider, reg, port); err != nil {

0 commit comments

Comments
 (0)