Skip to content

Commit dc81fe9

Browse files
authored
Make ssh setup work with interactive cluster selection (#5207)
## Changes Move the creation of `proxyCommand` to _after_ interactive cluster selection ## Why Before `proxyCommand` was created before interactive cluster selection, meaning we would output a broken proxy command in the generate SSH config. Moving creation of `proxyCommand` to after interactive cluster selection means the selected cluster is properly populated in the generated SSH config ## Tests Added test
1 parent 596207d commit dc81fe9

3 files changed

Lines changed: 73 additions & 54 deletions

File tree

experimental/ssh/cmd/setup.go

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
package ssh
22

33
import (
4-
"fmt"
54
"time"
65

76
"github.com/databricks/cli/cmd/root"
8-
"github.com/databricks/cli/experimental/ssh/internal/client"
97
"github.com/databricks/cli/experimental/ssh/internal/setup"
108
"github.com/databricks/cli/libs/cmdctx"
119
"github.com/spf13/cobra"
@@ -57,17 +55,6 @@ an SSH host configuration to your SSH config file.
5755
Profile: wsClient.Config.Profile,
5856
AutoApprove: autoApprove,
5957
}
60-
clientOpts := client.ClientOptions{
61-
ClusterID: setupOpts.ClusterID,
62-
AutoStartCluster: setupOpts.AutoStartCluster,
63-
ShutdownDelay: setupOpts.ShutdownDelay,
64-
Profile: setupOpts.Profile,
65-
}
66-
proxyCommand, err := clientOpts.ToProxyCommand()
67-
if err != nil {
68-
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
69-
}
70-
setupOpts.ProxyCommand = proxyCommand
7158
return setup.Setup(ctx, wsClient, setupOpts)
7259
}
7360

experimental/ssh/internal/setup/setup.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"time"
88

9+
sshclient "github.com/databricks/cli/experimental/ssh/internal/client"
910
"github.com/databricks/cli/experimental/ssh/internal/keys"
1011
"github.com/databricks/cli/experimental/ssh/internal/sshconfig"
1112
"github.com/databricks/cli/libs/cmdio"
@@ -28,8 +29,6 @@ type SetupOptions struct {
2829
SSHKeysDir string
2930
// Optional auth profile name. If present, will be added as --profile flag to the ProxyCommand
3031
Profile string
31-
// Proxy command to use for the SSH connection
32-
ProxyCommand string
3332
// Skip confirmation prompts (e.g. recreate existing host config without asking)
3433
AutoApprove bool
3534
}
@@ -45,17 +44,20 @@ func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClie
4544
return nil
4645
}
4746

48-
func generateHostConfig(ctx context.Context, opts SetupOptions) (string, error) {
47+
func generateHostConfig(ctx context.Context, opts SetupOptions, proxyCommand string) (string, error) {
4948
identityFilePath, err := keys.GetLocalSSHKeyPath(ctx, opts.ClusterID, opts.SSHKeysDir)
5049
if err != nil {
5150
return "", fmt.Errorf("failed to get local keys folder: %w", err)
5251
}
5352

54-
hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, opts.ProxyCommand)
53+
hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, proxyCommand)
5554
return hostConfig, nil
5655
}
5756

58-
func clusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) {
57+
// clusterSelectionPrompt is a package-level var so tests can replace it with a mock.
58+
var clusterSelectionPrompt = defaultClusterSelectionPrompt
59+
60+
func defaultClusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) {
5961
sp := cmdio.NewSpinner(ctx)
6062
sp.Update("Loading clusters.")
6163
clusters, err := client.Clusters.ClusterDetailsClusterNameToClusterIdMap(ctx, compute.ListClustersRequest{
@@ -92,6 +94,20 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp
9294
return err
9395
}
9496

97+
// Build the ProxyCommand after the cluster ID is resolved. When the user
98+
// omits --cluster, the ID is only known after the interactive picker above,
99+
// so building it earlier would serialize an empty --cluster= flag.
100+
clientOpts := sshclient.ClientOptions{
101+
ClusterID: opts.ClusterID,
102+
AutoStartCluster: opts.AutoStartCluster,
103+
ShutdownDelay: opts.ShutdownDelay,
104+
Profile: opts.Profile,
105+
}
106+
proxyCommand, err := clientOpts.ToProxyCommand()
107+
if err != nil {
108+
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
109+
}
110+
95111
configPath, err := sshconfig.GetMainConfigPathOrDefault(ctx, opts.SSHConfigPath)
96112
if err != nil {
97113
return err
@@ -102,7 +118,7 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp
102118
return err
103119
}
104120

105-
hostConfig, err := generateHostConfig(ctx, opts)
121+
hostConfig, err := generateHostConfig(ctx, opts, proxyCommand)
106122
if err != nil {
107123
return err
108124
}

experimental/ssh/internal/setup/setup_test.go

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package setup
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"os"
@@ -10,6 +11,7 @@ import (
1011

1112
"github.com/databricks/cli/experimental/ssh/internal/client"
1213
"github.com/databricks/cli/libs/cmdio"
14+
"github.com/databricks/databricks-sdk-go"
1315
"github.com/databricks/databricks-sdk-go/experimental/mocks"
1416
"github.com/databricks/databricks-sdk-go/service/compute"
1517
"github.com/stretchr/testify/assert"
@@ -134,10 +136,9 @@ func TestGenerateHostConfig_Valid(t *testing.T) {
134136
SSHKeysDir: tmpDir,
135137
ShutdownDelay: 30 * time.Second,
136138
Profile: "test-profile",
137-
ProxyCommand: proxyCommand,
138139
}
139140

140-
result, err := generateHostConfig(t.Context(), opts)
141+
result, err := generateHostConfig(t.Context(), opts, proxyCommand)
141142
assert.NoError(t, err)
142143

143144
assert.Contains(t, result, "Host test-host")
@@ -169,10 +170,9 @@ func TestGenerateHostConfig_WithoutProfile(t *testing.T) {
169170
SSHKeysDir: tmpDir,
170171
ShutdownDelay: 30 * time.Second,
171172
Profile: "",
172-
ProxyCommand: proxyCommand,
173173
}
174174

175-
result, err := generateHostConfig(t.Context(), opts)
175+
result, err := generateHostConfig(t.Context(), opts, proxyCommand)
176176
assert.NoError(t, err)
177177

178178
assert.NotContains(t, result, "--profile=")
@@ -193,7 +193,7 @@ func TestGenerateHostConfig_PathEscaping(t *testing.T) {
193193
ShutdownDelay: 30 * time.Second,
194194
}
195195

196-
result, err := generateHostConfig(t.Context(), opts)
196+
result, err := generateHostConfig(t.Context(), opts, "")
197197
assert.NoError(t, err)
198198

199199
// Check that quotes are properly escaped
@@ -225,17 +225,7 @@ func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) {
225225
Profile: "test-profile",
226226
}
227227

228-
clientOpts := client.ClientOptions{
229-
ClusterID: opts.ClusterID,
230-
AutoStartCluster: opts.AutoStartCluster,
231-
ShutdownDelay: opts.ShutdownDelay,
232-
Profile: opts.Profile,
233-
}
234-
proxyCommand, err := clientOpts.ToProxyCommand()
235-
require.NoError(t, err)
236-
opts.ProxyCommand = proxyCommand
237-
238-
err = Setup(ctx, m.WorkspaceClient, opts)
228+
err := Setup(ctx, m.WorkspaceClient, opts)
239229
assert.NoError(t, err)
240230

241231
// Check that main config has Include directive
@@ -285,15 +275,7 @@ func TestSetup_AutoApproveRecreatesExistingHost(t *testing.T) {
285275
AutoApprove: true,
286276
}
287277

288-
clientOpts := client.ClientOptions{
289-
ClusterID: opts.ClusterID,
290-
ShutdownDelay: opts.ShutdownDelay,
291-
}
292-
proxyCommand, err := clientOpts.ToProxyCommand()
293-
require.NoError(t, err)
294-
opts.ProxyCommand = proxyCommand
295-
296-
err = Setup(ctx, m.WorkspaceClient, opts)
278+
err := Setup(ctx, m.WorkspaceClient, opts)
297279
assert.NoError(t, err)
298280

299281
// Host config should be recreated (no longer contains the stale User).
@@ -304,6 +286,50 @@ func TestSetup_AutoApproveRecreatesExistingHost(t *testing.T) {
304286
assert.Contains(t, s, "--cluster=cluster-123")
305287
}
306288

289+
func TestSetup_PromptsForClusterWhenNotProvided(t *testing.T) {
290+
ctx := cmdio.MockDiscard(t.Context())
291+
tmpDir := t.TempDir()
292+
t.Setenv("HOME", tmpDir)
293+
t.Setenv("USERPROFILE", tmpDir)
294+
295+
configPath := filepath.Join(tmpDir, "ssh_config")
296+
297+
// Replace the cluster picker with a stub returning a fixed ID. This lets the
298+
// test exercise the empty-ClusterID path of Setup without driving promptui.
299+
origPrompt := clusterSelectionPrompt
300+
t.Cleanup(func() { clusterSelectionPrompt = origPrompt })
301+
promptCalled := false
302+
clusterSelectionPrompt = func(_ context.Context, _ *databricks.WorkspaceClient) (string, error) {
303+
promptCalled = true
304+
return "picked-cluster", nil
305+
}
306+
307+
m := mocks.NewMockWorkspaceClient(t)
308+
clustersAPI := m.GetMockClustersAPI()
309+
clustersAPI.EXPECT().Get(ctx, compute.GetClusterRequest{ClusterId: "picked-cluster"}).Return(&compute.ClusterDetails{
310+
DataSecurityMode: compute.DataSecurityModeSingleUser,
311+
}, nil)
312+
313+
opts := SetupOptions{
314+
HostName: "test-host",
315+
SSHConfigPath: configPath,
316+
SSHKeysDir: tmpDir,
317+
ShutdownDelay: 30 * time.Second,
318+
}
319+
320+
err := Setup(ctx, m.WorkspaceClient, opts)
321+
require.NoError(t, err)
322+
assert.True(t, promptCalled, "cluster picker should run when ClusterID is empty")
323+
324+
// The picked ID must be serialized into the ProxyCommand's --cluster= flag.
325+
hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "test-host")
326+
hostContent, err := os.ReadFile(hostConfigPath)
327+
require.NoError(t, err)
328+
hostConfigStr := string(hostContent)
329+
assert.Contains(t, hostConfigStr, "--cluster=picked-cluster")
330+
assert.NotContains(t, hostConfigStr, "--cluster= ")
331+
}
332+
307333
func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) {
308334
ctx := cmdio.MockDiscard(t.Context())
309335
tmpDir := t.TempDir()
@@ -332,16 +358,6 @@ func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) {
332358
ShutdownDelay: 60 * time.Second,
333359
}
334360

335-
clientOpts := client.ClientOptions{
336-
ClusterID: opts.ClusterID,
337-
AutoStartCluster: opts.AutoStartCluster,
338-
ShutdownDelay: opts.ShutdownDelay,
339-
Profile: opts.Profile,
340-
}
341-
proxyCommand, err := clientOpts.ToProxyCommand()
342-
require.NoError(t, err)
343-
opts.ProxyCommand = proxyCommand
344-
345361
err = Setup(ctx, m.WorkspaceClient, opts)
346362
assert.NoError(t, err)
347363

0 commit comments

Comments
 (0)