@@ -28,6 +28,7 @@ type RegisterStore interface {
2828 GetCurrentUser () (* entity.User , error )
2929 GetActiveOrganizationOrDefault () (* entity.Organization , error )
3030 GetOrganizationsByName (name string ) ([]entity.Organization , error )
31+ ListOrganizations () ([]entity.Organization , error )
3132 GetAccessToken () (string , error )
3233}
3334
@@ -51,6 +52,7 @@ type SetupRunner interface {
5152type registerDeps struct {
5253 platform externalnode.PlatformChecker
5354 prompter terminal.Confirmer
55+ selector terminal.Selector
5456 netbird NetBirdManager
5557 setupRunner SetupRunner
5658 nodeClients externalnode.NodeClientFactory
@@ -59,9 +61,11 @@ type registerDeps struct {
5961}
6062
6163func defaultRegisterDeps () registerDeps {
64+ p := TerminalPrompter {}
6265 return registerDeps {
6366 platform : LinuxPlatform {},
64- prompter : TerminalPrompter {},
67+ prompter : p ,
68+ selector : p ,
6569 netbird : Netbird {},
6670 setupRunner : ShellSetupRunner {},
6771 nodeClients : DefaultNodeClientFactory {},
@@ -73,13 +77,27 @@ func defaultRegisterDeps() registerDeps {
7377var (
7478 registerLong = `Register your device with NVIDIA Brev
7579
76- This command sets up network connectivity and registers this machine with Brev.`
80+ This command sets up network connectivity and registers this machine with Brev.
7781
78- registerExample = ` brev register "My DGX Spark"`
82+ Two modes are supported:
83+ • Interactive (default): run 'brev register' (or 'brev register <name>') and follow prompts for org and options.
84+ • Non-interactive: use any of --name, --org, --enable-ssh, or --ssh-port (or --non-interactive). No prompts; --name and --org are required. Use for scripts/CI.`
85+
86+ registerExample = ` # Interactive (prompts for org, confirmations)
87+ brev register
88+ brev register "My DGX Spark"
89+
90+ # Non-interactive (any flag implies no prompts; --name and --org required)
91+ brev register --name my-node --org my-org
92+ brev register --name my-node --org my-org --enable-ssh --ssh-port 22`
7993)
8094
8195func NewCmdRegister (t * terminal.Terminal , store RegisterStore ) * cobra.Command {
8296 var orgFlag string
97+ var nonInteractive bool
98+ var nameFlag string
99+ var enableSSH bool
100+ var sshPort int
83101
84102 cmd := & cobra.Command {
85103 Annotations : map [string ]string {"configuration" : "" },
@@ -90,20 +108,129 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command {
90108 Example : registerExample ,
91109 Args : cobra .MaximumNArgs (1 ),
92110 RunE : func (cmd * cobra.Command , args []string ) error {
93- var name string
94- if len (args ) > 0 {
111+ name := nameFlag
112+ if name == "" && len (args ) > 0 {
95113 name = args [0 ]
96114 }
97- return runRegister (cmd .Context (), t , store , name , orgFlag , defaultRegisterDeps ())
115+ // Non-interactive if explicit flag or any register-specific flag is set (implies script/CI).
116+ flagDriven := nonInteractive ||
117+ nameFlag != "" ||
118+ orgFlag != "" ||
119+ enableSSH ||
120+ cmd .Flags ().Changed ("ssh-port" )
121+ if flagDriven {
122+ return runRegisterFlagDriven (cmd .Context (), t , store , name , orgFlag , enableSSH , int32 (sshPort ), defaultRegisterDeps ())
123+ }
124+ return runRegisterPromptDriven (cmd .Context (), t , store , name , orgFlag , defaultRegisterDeps ())
98125 },
99126 }
100127
101- cmd .Flags ().StringVarP (& orgFlag , "org" , "o" , "" , "organization name (overrides active org)" )
128+ cmd .Flags ().StringVarP (& orgFlag , "org" , "o" , "" , "organization name (required when using non-interactive mode)" )
129+ cmd .Flags ().BoolVar (& nonInteractive , "non-interactive" , false , "non-interactive mode (also implied by --name, --org, --enable-ssh, or --ssh-port)" )
130+ cmd .Flags ().StringVar (& nameFlag , "name" , "" , "device name (required when using non-interactive mode)" )
131+ cmd .Flags ().BoolVar (& enableSSH , "enable-ssh" , false , "enable SSH access after registration (non-interactive mode)" )
132+ cmd .Flags ().IntVar (& sshPort , "ssh-port" , 22 , "SSH port when using --enable-ssh" )
102133
103134 return cmd
104135}
105136
106- func runRegister (ctx context.Context , t * terminal.Terminal , s RegisterStore , name string , orgName string , deps registerDeps ) error { //nolint:funlen // registration flow
137+ func runRegister (ctx context.Context , t * terminal.Terminal , s RegisterStore , name string , orgName string , deps registerDeps ) error {
138+ return runRegisterPromptDriven (ctx , t , s , name , orgName , deps )
139+ }
140+
141+ // runRegisterSteps performs netbird install, hardware profile, AddNode, save registration, and runSetup.
142+ // It does not prompt or enable SSH. Used by both flag-driven and prompt-driven flows.
143+ func runRegisterSteps (ctx context.Context , t * terminal.Terminal , s RegisterStore , name string , org * entity.Organization , deps registerDeps ) (* DeviceRegistration , error ) {
144+ t .Vprint ("" )
145+ t .Vprint (t .Yellow ("[Step 1/3] Setting up Brev tunnel..." ))
146+ if err := deps .netbird .Install (); err != nil {
147+ return nil , fmt .Errorf ("brev tunnel setup failed: %w" , err )
148+ }
149+ t .Vprint (t .Green (" Brev tunnel ready." ))
150+
151+ t .Vprint ("" )
152+ t .Vprint (t .Yellow ("[Step 2/3] Collecting hardware profile..." ))
153+ t .Vprint ("" )
154+
155+ hwProfile , err := deps .hardwareProfiler .Profile ()
156+ if err != nil {
157+ return nil , fmt .Errorf ("failed to collect hardware profile: %w" , err )
158+ }
159+
160+ t .Vprint (" Hardware profile:" )
161+ t .Vprint (FormatHardwareProfile (hwProfile ))
162+
163+ t .Vprint ("" )
164+ t .Vprint (t .Yellow ("[Step 3/3] Registering with Brev..." ))
165+
166+ deviceID := uuid .New ().String ()
167+ client := deps .nodeClients .NewNodeClient (s , config .GlobalConfig .GetBrevPublicAPIURL ())
168+ addResp , err := client .AddNode (ctx , connect .NewRequest (& nodev1.AddNodeRequest {
169+ OrganizationId : org .ID ,
170+ Name : name ,
171+ DeviceId : deviceID ,
172+ NodeSpec : toProtoNodeSpec (hwProfile ),
173+ }))
174+ if err != nil {
175+ return nil , fmt .Errorf ("failed to register node: %w" , err )
176+ }
177+
178+ node := addResp .Msg .GetExternalNode ()
179+ reg := & DeviceRegistration {
180+ ExternalNodeID : node .GetExternalNodeId (),
181+ DisplayName : name ,
182+ OrgID : org .ID ,
183+ DeviceID : deviceID ,
184+ RegisteredAt : time .Now ().UTC ().Format (time .RFC3339 ),
185+ HardwareProfile : * hwProfile ,
186+ }
187+ if err := deps .registrationStore .Save (reg ); err != nil {
188+ return nil , fmt .Errorf ("node registered but failed to save locally: %w" , err )
189+ }
190+
191+ t .Vprint (t .Green (" Registration complete." ))
192+ runSetup (node , t , deps )
193+ return reg , nil
194+ }
195+
196+ // resolveOrgPromptDriven resolves organization for prompt-driven flow: by name if --org given, else always list and select with arrow keys.
197+ func resolveOrgPromptDriven (s RegisterStore , orgName string , deps registerDeps ) (* entity.Organization , error ) {
198+ if orgName != "" {
199+ orgs , err := s .GetOrganizationsByName (orgName )
200+ if err != nil {
201+ return nil , breverrors .WrapAndTrace (err )
202+ }
203+ if len (orgs ) == 0 {
204+ return nil , fmt .Errorf ("no organization found with name %q" , orgName )
205+ }
206+ if len (orgs ) > 1 {
207+ return nil , fmt .Errorf ("multiple organizations found with name %q" , orgName )
208+ }
209+ return & orgs [0 ], nil
210+ }
211+
212+ list , err := s .ListOrganizations ()
213+ if err != nil {
214+ return nil , breverrors .WrapAndTrace (err )
215+ }
216+ if len (list ) == 0 {
217+ return nil , fmt .Errorf ("no organization found; please create or join an organization first" )
218+ }
219+
220+ names := make ([]string , len (list ))
221+ for i := range list {
222+ names [i ] = list [i ].Name
223+ }
224+ chosen := deps .selector .Select ("Select organization" , names )
225+ for i := range list {
226+ if list [i ].Name == chosen {
227+ return & list [i ], nil
228+ }
229+ }
230+ return nil , fmt .Errorf ("selected organization not found" )
231+ }
232+
233+ func runRegisterPromptDriven (ctx context.Context , t * terminal.Terminal , s RegisterStore , name string , orgName string , deps registerDeps ) error {
107234 if ! deps .platform .IsCompatible () {
108235 return breverrors .New ("brev register is only supported on Linux" )
109236 }
@@ -120,15 +247,24 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam
120247 return checkExistingRegistration (ctx , t , s , name , deps )
121248 }
122249
250+ if name == "" {
251+ name = terminal .PromptGetInput (terminal.PromptContent {
252+ Label : "Device name" ,
253+ ErrorMsg : "name is required" ,
254+ AllowEmpty : false ,
255+ })
256+ name = strings .TrimSpace (name )
257+ }
123258 if err := names .ValidateNodeName (name ); err != nil {
124259 return breverrors .WrapAndTrace (err )
125260 }
126261
127- brevUser , err := s . GetCurrentUser ( )
262+ org , err := resolveOrgPromptDriven ( s , orgName , deps )
128263 if err != nil {
129- return breverrors . WrapAndTrace ( err )
264+ return err
130265 }
131- org , err := getOrgToRegisterFor (s , orgName )
266+
267+ brevUser , err := s .GetCurrentUser ()
132268 if err != nil {
133269 return breverrors .WrapAndTrace (err )
134270 }
@@ -155,60 +291,62 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam
155291 return nil
156292 }
157293
158- t .Vprint ("" )
159- t .Vprint (t .Yellow ("[Step 1/3] Setting up Brev tunnel..." ))
160- if err := deps .netbird .Install (); err != nil {
161- return fmt .Errorf ("brev tunnel setup failed: %w" , err )
294+ reg , err := runRegisterSteps (ctx , t , s , name , org , deps )
295+ if err != nil {
296+ return err
162297 }
163- t .Vprint (t .Green (" Brev tunnel ready." ))
164298
165- t .Vprint ("" )
166- t .Vprint (t .Yellow ("[Step 2/3] Collecting hardware profile..." ))
167- t .Vprint ("" )
168-
169- hwProfile , err := deps .hardwareProfiler .Profile ()
170- if err != nil {
171- return fmt .Errorf ("failed to collect hardware profile: %w" , err )
299+ if deps .prompter .ConfirmYesNo ("Would you like to enable SSH access to this device?" ) {
300+ if err := grantSSHAccessWithPort (ctx , t , deps , s , reg , brevUser , osUser , 0 ); err != nil {
301+ t .Vprintf (" Warning: SSH access not granted: %v\n " , err )
302+ }
172303 }
173304
174- t . Vprint ( " Hardware profile:" )
175- t . Vprint ( FormatHardwareProfile ( hwProfile ))
305+ return nil
306+ }
176307
177- t .Vprint ("" )
178- t .Vprint (t .Yellow ("[Step 3/3] Registering with Brev..." ))
308+ func runRegisterFlagDriven (ctx context.Context , t * terminal.Terminal , s RegisterStore , name string , orgName string , enableSSH bool , sshPort int32 , deps registerDeps ) error {
309+ if ! deps .platform .IsCompatible () {
310+ return breverrors .New ("brev register is only supported on Linux" )
311+ }
179312
180- deviceID := uuid .New ().String ()
181- client := deps .nodeClients .NewNodeClient (s , config .GlobalConfig .GetBrevPublicAPIURL ())
182- addResp , err := client .AddNode (ctx , connect .NewRequest (& nodev1.AddNodeRequest {
183- OrganizationId : org .ID ,
184- Name : name ,
185- DeviceId : deviceID ,
186- NodeSpec : toProtoNodeSpec (hwProfile ),
187- }))
188- if err != nil {
189- return fmt .Errorf ("failed to register node: %w" , err )
313+ if name == "" || orgName == "" {
314+ return fmt .Errorf ("in non-interactive mode --name and --org are required" )
190315 }
191316
192- node := addResp .Msg .GetExternalNode ()
193- reg := & DeviceRegistration {
194- ExternalNodeID : node .GetExternalNodeId (),
195- DisplayName : name ,
196- OrgID : org .ID ,
197- DeviceID : deviceID ,
198- RegisteredAt : time .Now ().UTC ().Format (time .RFC3339 ),
199- HardwareProfile : * hwProfile ,
317+ alreadyRegistered , err := deps .registrationStore .Exists ()
318+ if err != nil {
319+ return breverrors .WrapAndTrace (err )
200320 }
201- if err := deps . registrationStore . Save ( reg ); err != nil {
202- return fmt . Errorf ( "node registered but failed to save locally: %w" , err )
321+ if alreadyRegistered {
322+ return checkExistingRegistration ( ctx , t , s , name , deps )
203323 }
204324
205- t .Vprint (t .Green (" Registration complete." ))
325+ if err := names .ValidateNodeName (name ); err != nil {
326+ return breverrors .WrapAndTrace (err )
327+ }
206328
207- runSetup (node , t , deps )
329+ org , err := getOrgToRegisterFor (s , orgName )
330+ if err != nil {
331+ return err
332+ }
208333
209- if deps .prompter .ConfirmYesNo ("Would you like to enable SSH access to this device?" ) {
210- if err := grantSSHAccess (ctx , t , deps , s , reg , brevUser , osUser ); err != nil {
211- t .Vprintf (" Warning: SSH access not granted: %v\n " , err )
334+ reg , err := runRegisterSteps (ctx , t , s , name , org , deps )
335+ if err != nil {
336+ return err
337+ }
338+
339+ if enableSSH {
340+ brevUser , err := s .GetCurrentUser ()
341+ if err != nil {
342+ return breverrors .WrapAndTrace (err )
343+ }
344+ osUser , err := user .Current ()
345+ if err != nil {
346+ return fmt .Errorf ("failed to determine current Linux user: %w" , err )
347+ }
348+ if err := grantSSHAccessWithPort (ctx , t , deps , s , reg , brevUser , osUser , sshPort ); err != nil {
349+ return err
212350 }
213351 }
214352
@@ -324,7 +462,8 @@ func runSetup(node *nodev1.ExternalNode, t *terminal.Terminal, deps registerDeps
324462 }
325463}
326464
327- func grantSSHAccess (ctx context.Context , t * terminal.Terminal , deps registerDeps , tokenProvider externalnode.TokenProvider , reg * DeviceRegistration , brevUser * entity.User , osUser * user.User ) error {
465+ // grantSSHAccessWithPort enables SSH; if port is 0, prompts for port (prompt-driven). Otherwise uses the given port (flag-driven).
466+ func grantSSHAccessWithPort (ctx context.Context , t * terminal.Terminal , deps registerDeps , tokenProvider externalnode.TokenProvider , reg * DeviceRegistration , brevUser * entity.User , osUser * user.User , port int32 ) error {
328467 t .Vprint ("" )
329468 t .Vprint (t .Green ("Enabling SSH access on this device" ))
330469 t .Vprint ("" )
@@ -333,9 +472,14 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps
333472 t .Vprintf (" Linux user: %s\n " , osUser .Username )
334473 t .Vprint ("" )
335474
336- port , err := PromptSSHPort (t )
337- if err != nil {
338- return fmt .Errorf ("SSH port: %w" , err )
475+ var err error
476+ if port == 0 {
477+ port , err = PromptSSHPort (t )
478+ if err != nil {
479+ return fmt .Errorf ("SSH port: %w" , err )
480+ }
481+ } else {
482+ t .Vprintf (" SSH port: %d\n " , port )
339483 }
340484
341485 if err := OpenSSHPort (ctx , t , deps .nodeClients , tokenProvider , reg , port ); err != nil {
0 commit comments