Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

### CLI

* Wait for the SSH server health check to pass for the full startup timeout (10 minutes, or 45 minutes with `--accelerator`) instead of a fixed 60 seconds, so `databricks ssh connect` no longer fails with a driver-proxy 503 while a custom `--base-environment` finishes installing ([#5807](https://github.com/databricks/cli/pull/5807)).

### Bundles

### Dependency updates
Expand Down
38 changes: 25 additions & 13 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1063,35 +1063,47 @@
sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime())
defer sp.Close()
sp.Update("Waiting for the SSH server to start...")
maxRetries := 30
for retries := range maxRetries {
if ctx.Err() != nil {
return "", 0, "", ctx.Err()
}
// After the task reaches RUNNING the driver proxy still answers /metadata with 503
// until the container's HTTP endpoint is reachable; with a custom base environment
// that install happens post-RUNNING, so warmup can outlast a short fixed window. Poll
// for the same budget we allowed for the task to start (accelerator-aware), backing
// off between attempts rather than hammering the proxy every 2s during a long wait.
_, pollErr := retries.Poll(ctx, opts.TaskStartupTimeout, func() (*struct{}, *retries.Err) {
serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
if err == nil {
cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...")
break
return &struct{}{}, nil
}
// The metadata never appears if the bootstrap job dies after reaching RUNNING.
// Surface the job's actual error instead of waiting out the full timeout with a
// generic "metadata.json doesn't exist" message.
if failure, terminated := runFailureIfTerminated(ctx, client, runID); terminated {
return "", 0, "", fmt.Errorf("ssh server bootstrap job failed:\n%s", failure)
}
if retries < maxRetries-1 {
time.Sleep(2 * time.Second)
} else {
return "", 0, "", fmt.Errorf("failed to start the ssh server: %w\n%s", err, describeRunFailure(ctx, client, runID))
return nil, retries.Halt(fmt.Errorf("ssh server bootstrap job failed:\n%s", failure))
}
return nil, retries.Continue(fmt.Errorf("waiting for ssh server health check: %w", err))
})
if pollErr != nil {
return "", 0, "", formatServerStartError(ctx, client, runID, pollErr)
}
cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...")
} else if err != nil {
return "", 0, "", err
}

return userName, serverPort, effectiveClusterID, nil
}

// formatServerStartError turns a failed health-check poll into the error returned to the user.
// A halted poll (the bootstrap job terminated) already carries the job's failure details, so it
// is returned as-is. A timeout carries only the last generic health-check error, so the run's
// diagnostics are appended. Splitting on the timeout case avoids printing the failure trace twice.
func formatServerStartError(ctx context.Context, client *databricks.WorkspaceClient, runID int64, pollErr error) error {
var timedOut *retries.ErrTimedOut
if errors.As(pollErr, &timedOut) {

Check failure on line 1101 in experimental/ssh/internal/client/client.go

View workflow job for this annotation

GitHub Actions / lint

use of `errors.As` forbidden because "Use errors.AsType[T](err) for type-safe error unwrapping (Go 1.26+)." (forbidigo)
return fmt.Errorf("failed to start the ssh server: %w\n%s", pollErr, describeRunFailure(ctx, client, runID))
}
return pollErr
}

func logSshTunnelEvent(ctx context.Context, opts ClientOptions, isSuccess, isReconnect bool, serverStartTimeMs int64) {
computeType := protos.SshTunnelComputeTypeDedicated
if opts.IsServerlessMode() {
Expand Down
32 changes: 32 additions & 0 deletions experimental/ssh/internal/client/client_internal_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package client

import (
"fmt"
"strings"
"testing"
"time"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/retries"
"github.com/databricks/databricks-sdk-go/service/environments"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -193,6 +195,36 @@ func TestRunFailureIfTerminated(t *testing.T) {
})
}

func TestFormatServerStartErrorTimeoutAppendsRunFailure(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
m := mocks.NewMockWorkspaceClient(t)
api := m.GetMockJobsAPI()
api.EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 1}).Return(
terminatedRun(1, 99, "Could not reach driver of cluster 0605-x.", "https://example.test/run/1"), nil)
api.EXPECT().GetRunOutput(mock.Anything, jobs.GetRunOutputRequest{RunId: 99}).Return(
&jobs.RunOutput{}, nil)

pollErr := &retries.ErrTimedOut{}
err := formatServerStartError(ctx, m.WorkspaceClient, 1, pollErr)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to start the ssh server")
// The run diagnostics are fetched and appended for the timeout case.
assert.Contains(t, err.Error(), "Could not reach driver of cluster 0605-x.")
}

func TestFormatServerStartErrorHaltReturnedAsIs(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
// A halted poll already carries the job failure details, so formatServerStartError must not
// fetch the run again (no GetRun expectation) nor re-append the trace.
m := mocks.NewMockWorkspaceClient(t)
pollErr := fmt.Errorf("ssh server bootstrap job failed:\n%s", "Could not reach driver of cluster 0605-x.")
err := formatServerStartError(ctx, m.WorkspaceClient, 1, pollErr)
require.Error(t, err)
assert.Equal(t, pollErr.Error(), err.Error())
assert.Equal(t, 1, strings.Count(err.Error(), "Could not reach driver of cluster 0605-x."))
assert.NotContains(t, err.Error(), "failed to start the ssh server")
}

func TestWaitForJobToStartSurfacesFailure(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
m := mocks.NewMockWorkspaceClient(t)
Expand Down
Loading