Skip to content

Commit 4431777

Browse files
Update u2m port selection mechanism to start from 8020. (#1269)
## What changes are proposed in this pull request? This PR replaces the port selection mechanism in `U2M` to start from `8020`, and fallback incrementally (`8021`, `8022`, and so on... until 8040) if a port is not free. The PR also provides an environment/config variable to specify a specific port for the OAuth callback. If set, that port disables the fallback mechanism. This change fixes a bug in the Databricks CLI which is only allowlisted for port `8020` at the moment. ## How is this tested? Unit tests
1 parent b6d5a08 commit 4431777

5 files changed

Lines changed: 186 additions & 61 deletions

File tree

NEXT_CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66

77
### Bug Fixes
88

9+
- Update the `U2M` port selection mechanism to try port `8020` first and fall
10+
back incrementally (to port `8021`, `8022`, and so on...) if a port is not
11+
available. This fixes an issue with the Databricks CLI which is only
12+
allowlisted on port `8020`.
13+
914
### Documentation
1015

1116
### Internal Changes

config/auth_u2m.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials
4444
}
4545

4646
if u.ts == nil {
47-
auth, err := u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(arg))
47+
auth, err := u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(arg), u2m.WithPort(cfg.OAuthCallbackPort))
4848
if err != nil {
4949
logger.Debugf(ctx, "failed to create persistent auth: %v, continuing", err)
5050
return nil, nil

config/config.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ type Config struct {
7070
// in ~/.databrickscfg.
7171
ConfigFile string `name:"config_file" env:"DATABRICKS_CONFIG_FILE"`
7272

73+
// OAuthPort is the port to use for the OAuth2 callback server. If not set,
74+
// the default port with fallback is used. This means that setting a port
75+
// will disable the fallback mechanism.
76+
OAuthCallbackPort int `name:"oauth_callback_port" env:"DATABRICKS_OAUTH_CALLBACK_PORT" auth:"-"`
77+
7378
GoogleServiceAccount string `name:"google_service_account" env:"DATABRICKS_GOOGLE_SERVICE_ACCOUNT" auth:"google" auth_types:"google-id"`
7479
GoogleCredentials string `name:"google_credentials" env:"GOOGLE_CREDENTIALS" auth:"google,sensitive" auth_types:"google-credentials"`
7580

credentials/u2m/persistent_auth.go

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
cache "github.com/databricks/databricks-sdk-go/credentials/u2m/cache"
1616
"github.com/databricks/databricks-sdk-go/httpclient"
1717
"github.com/databricks/databricks-sdk-go/logger"
18-
"github.com/databricks/databricks-sdk-go/retries"
1918
"github.com/pkg/browser"
2019
"golang.org/x/oauth2"
2120
"golang.org/x/oauth2/authhandler"
@@ -25,15 +24,25 @@ const (
2524
// appClientId is the default client ID used by the SDK for U2M OAuth.
2625
appClientID = "databricks-cli"
2726

28-
// appRedirectAddr is the default address for the OAuth2 callback server.
29-
// Using ":0" tells the system to pick a random available port.
30-
appRedirectAddr = "localhost:0"
27+
// defaultPort is the default port for the OAuth2 callback server. If the
28+
// port is already in use, the next port is tried (8021, 8022, etc.).
29+
defaultPort = 8020
3130

32-
// listenerTimeout is the maximum amount of time to acquire listener on
33-
// appRedirectAddr.
31+
// maxPortFallback is the maximum port to try when using the fallback
32+
// mechanism.
33+
maxPortFallback = 8040
34+
35+
// listenerTimeout is the maximum duration spent trying to acquire a
36+
// listener (including port selection).
3437
listenerTimeout = 45 * time.Second
3538
)
3639

40+
var (
41+
// Internal errors used for testing.
42+
errListenerTimeout = errors.New("failed to listen on any port: timeout")
43+
errNoPortAvailable = errors.New("no port available to listen on")
44+
)
45+
3746
// PersistentAuth is an OAuth manager that handles the U2M OAuth flow. Tokens
3847
// are stored in and looked up from the provided cache. Tokens include the
3948
// refresh token. On load, if the access token is expired, it is refreshed
@@ -44,25 +53,41 @@ const (
4453
type PersistentAuth struct {
4554
// cache is the token cache to store and lookup tokens.
4655
cache cache.TokenCache
56+
4757
// client is the HTTP client to use for OAuth2 requests.
4858
client *http.Client
59+
4960
// endpointSupplier is the HTTP endpointSupplier to use for OAuth2 requests.
5061
endpointSupplier OAuthEndpointSupplier
62+
5163
// oAuthArgument defines the workspace or account to authenticate to and the
5264
// cache key for the token.
5365
oAuthArgument OAuthArgument
66+
5467
// browser is the function to open a URL in the default browser.
5568
browser func(url string) error
69+
5670
// ln is the listener for the OAuth2 callback server.
5771
ln net.Listener
72+
5873
// ctx is the context to use for underlying operations. This is needed in
5974
// order to implement the oauth2.TokenSource interface.
6075
ctx context.Context
76+
6177
// redirectAddr is the redirect address for OAuth2 callbacks. The value is
6278
// set to localhost:PORT by startListener which will dynamically assign a
6379
// random port. If a value is already provided, it will be used instead
6480
// (e.g. for testing).
6581
redirectAddr string
82+
83+
// Optional port to use for the OAuth2 callback server. If set to 0, the
84+
// default port with fallback is used. This means that setting a port will
85+
// disable the fallback mechanism.
86+
port int
87+
88+
// netListen is an optional function to listen on a TCP address. If not set,
89+
// it will use net.Listen by default. This is useful for testing.
90+
netListen func(network, address string) (net.Listener, error)
6691
}
6792

6893
type PersistentAuthOption func(*PersistentAuth)
@@ -103,6 +128,13 @@ func WithBrowser(b func(url string) error) PersistentAuthOption {
103128
}
104129
}
105130

131+
// WithPort sets the port for the PersistentAuth.
132+
func WithPort(port int) PersistentAuthOption {
133+
return func(a *PersistentAuth) {
134+
a.port = port
135+
}
136+
}
137+
106138
// NewPersistentAuth creates a new PersistentAuth with the provided options.
107139
func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) {
108140
p := &PersistentAuth{}
@@ -273,35 +305,51 @@ func (a *PersistentAuth) Challenge() error {
273305

274306
// startListener starts a listener on appRedirectAddr, retrying if the address
275307
// is already in use.
276-
func (a *PersistentAuth) startListener(ctx context.Context) error {
277-
// Use the value of redirectURL if it is already set. This is only expected
278-
// in tests to set a fixed redirect URL.
279-
addr := a.redirectAddr
280-
if addr == "" {
281-
addr = appRedirectAddr
282-
}
283-
284-
listener, err := retries.Poll(ctx, listenerTimeout,
285-
func() (*net.Listener, *retries.Err) {
286-
var lc net.ListenConfig
287-
l, err := lc.Listen(ctx, "tcp", addr)
288-
if err != nil {
289-
logger.Debugf(ctx, "failed to listen on %s: %v, retrying", addr, err)
290-
return nil, retries.Continue(err)
291-
}
292-
return &l, nil
293-
})
294-
if err != nil {
295-
return fmt.Errorf("listener: %w", err)
308+
func (pa *PersistentAuth) startListener(ctx context.Context) error {
309+
if pa.port != 0 { // if port is set, use it
310+
return pa.startListenerWithPort(pa.port)
296311
}
297-
a.ln = *listener
312+
return pa.startListenerWithFallback(ctx)
313+
}
298314

299-
// Get the actual address that was assigned (including the port).
300-
a.redirectAddr = a.ln.Addr().String()
301-
logger.Debugf(ctx, "OAuth callback server listening on %s", a.redirectAddr)
315+
// startListenerWithFallback starts a listener that will try to find a free
316+
// port to listen on starting from the default port and incrementing by 1 until
317+
// a free port is found.
318+
func (pa *PersistentAuth) startListenerWithFallback(ctx context.Context) error {
319+
startTime := time.Now()
320+
for port := defaultPort; port <= maxPortFallback; port++ {
321+
if time.Since(startTime) > listenerTimeout {
322+
return errListenerTimeout
323+
}
324+
if err := pa.startListenerWithPort(port); err != nil {
325+
logger.Debugf(ctx, "failed to listen on %d: %v, retrying", port, err)
326+
continue
327+
}
328+
logger.Debugf(ctx, "OAuth callback server listening on %s", pa.redirectAddr)
329+
return nil
330+
}
331+
return errNoPortAvailable
332+
}
333+
334+
func (pa *PersistentAuth) startListenerWithPort(port int) error {
335+
addr := fmt.Sprintf("localhost:%d", port)
336+
listener, err := pa.listen("tcp", addr)
337+
if err != nil {
338+
return fmt.Errorf("failed to listen on %s: %w", addr, err)
339+
}
340+
pa.ln = listener
341+
pa.redirectAddr = addr
302342
return nil
303343
}
304344

345+
func (pa *PersistentAuth) listen(network, addr string) (net.Listener, error) {
346+
if pa.netListen != nil {
347+
return pa.netListen(network, addr)
348+
} else {
349+
return net.Listen(network, addr)
350+
}
351+
}
352+
305353
func (a *PersistentAuth) Close() error {
306354
if a.ln == nil {
307355
return nil

credentials/u2m/persistent_auth_test.go

Lines changed: 97 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"net"
78
"net/http"
89
"net/url"
9-
"regexp"
1010
"strings"
1111
"testing"
1212
"time"
@@ -298,9 +298,6 @@ func TestChallenge(t *testing.T) {
298298
}
299299
defer p.Close()
300300

301-
// Set a fixed redirect URL for the test.
302-
p.redirectAddr = "localhost:1337"
303-
304301
errc := make(chan error)
305302
go func() {
306303
err := p.Challenge()
@@ -309,7 +306,7 @@ func TestChallenge(t *testing.T) {
309306
}()
310307

311308
state := <-browserOpened
312-
resp, err := http.Get(fmt.Sprintf("http://localhost:1337?code=__THIS__&state=%s", state))
309+
resp, err := http.Get(fmt.Sprintf("http://localhost:8020?code=__THIS__&state=%s", state))
313310
if err != nil {
314311
t.Fatalf("http.Get(): want no error, got %v", err)
315312
}
@@ -351,8 +348,6 @@ func TestChallenge_ReturnsErrorOnFailure(t *testing.T) {
351348
}
352349
defer p.Close()
353350

354-
p.redirectAddr = "localhost:1337" // set a fixed redirect URL for the test
355-
356351
errc := make(chan error)
357352
go func() {
358353
err := p.Challenge()
@@ -361,7 +356,7 @@ func TestChallenge_ReturnsErrorOnFailure(t *testing.T) {
361356
}()
362357

363358
<-browserOpened
364-
resp, err := http.Get("http://localhost:1337?error=access_denied&error_description=Policy%20evaluation%20failed%20for%20this%20request")
359+
resp, err := http.Get("http://localhost:8020?error=access_denied&error_description=Policy%20evaluation%20failed%20for%20this%20request")
365360
if err != nil {
366361
t.Fatalf("http.Get(): want no error, got %v", err)
367362
}
@@ -380,36 +375,108 @@ func TestChallenge_ReturnsErrorOnFailure(t *testing.T) {
380375
}
381376
}
382377

383-
// Verifies that startListener assigns a random port to the redirectAddr.
384-
func TestPersistentAuth_startListener_useDifferentPorts(t *testing.T) {
385-
ctx := context.Background()
386-
arg, err := NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz")
387-
if err != nil {
388-
t.Fatalf("NewBasicAccountOAuthArgument(): want no error, got %v", err)
378+
func TestPersistentAuth_startListener_startFrom8020(t *testing.T) {
379+
pa := &PersistentAuth{}
380+
pa.netListen = func(_, address string) (net.Listener, error) {
381+
return nil, nil
389382
}
390383

391-
p1, err := NewPersistentAuth(ctx, WithOAuthArgument(arg))
392-
if err != nil {
393-
t.Fatalf("NewPersistentAuth(): want no error, got %v", err)
384+
gotErr := pa.startListener(context.Background())
385+
386+
if gotErr != nil {
387+
t.Fatalf("pa.startListener(): want no error, got %v", gotErr)
394388
}
395-
defer p1.Close()
389+
if pa.redirectAddr != "localhost:8020" {
390+
t.Errorf("pa.redirectAddr should be localhost:8020, got %s", pa.redirectAddr)
391+
}
392+
}
396393

397-
p2, err := NewPersistentAuth(ctx, WithOAuthArgument(arg))
398-
if err != nil {
399-
t.Fatalf("NewPersistentAuth(): want no error, got %v", err)
394+
func TestPersistentAuth_startListener_incrementalFallBack(t *testing.T) {
395+
pa := &PersistentAuth{}
396+
pa.netListen = func(_, address string) (net.Listener, error) {
397+
if address == "localhost:8020" {
398+
return nil, fmt.Errorf("address already in use")
399+
}
400+
if address == "localhost:8021" {
401+
return nil, fmt.Errorf("address already in use")
402+
}
403+
return nil, nil
404+
}
405+
406+
gotErr := pa.startListener(context.Background())
407+
408+
if gotErr != nil {
409+
t.Fatalf("pa.startListener(): want no error, got %v", gotErr)
410+
}
411+
if pa.redirectAddr != "localhost:8022" {
412+
t.Errorf("pa.redirectAddr should be localhost:8022, got %s", pa.redirectAddr)
413+
}
414+
}
415+
416+
func TestPersistentAuth_startListener_noAvailablePort(t *testing.T) {
417+
pa := &PersistentAuth{}
418+
pa.netListen = func(_, address string) (net.Listener, error) {
419+
return nil, fmt.Errorf("address already in use")
400420
}
401-
defer p2.Close()
402421

403-
p1.startListener(ctx)
404-
p2.startListener(ctx)
422+
gotErr := pa.startListener(context.Background())
423+
424+
if !errors.Is(gotErr, errNoPortAvailable) {
425+
t.Fatalf("pa.startListener(): want error %v, got %v", errNoPortAvailable, gotErr)
426+
}
427+
}
405428

406-
if !regexp.MustCompile(`^127\.0\.0\.1:\d+$`).MatchString(p1.redirectAddr) {
407-
t.Errorf("p1.redirectAddr should be random localhost port, got %s", p1.redirectAddr)
429+
func TestPersistentAuth_startListener_maxPortFallbackIncluded(t *testing.T) {
430+
maxAddress := fmt.Sprintf("localhost:%d", maxPortFallback)
431+
pa := &PersistentAuth{}
432+
pa.netListen = func(_, address string) (net.Listener, error) {
433+
if address == maxAddress {
434+
return nil, nil
435+
}
436+
return nil, fmt.Errorf("address already in use")
408437
}
409-
if !regexp.MustCompile(`^127\.0\.0\.1:\d+$`).MatchString(p2.redirectAddr) {
410-
t.Errorf("p2.redirectAddr should be random localhost port, got %s", p2.redirectAddr)
438+
439+
gotErr := pa.startListener(context.Background())
440+
441+
if gotErr != nil {
442+
t.Fatalf("pa.startListener(): want no error, got %v", gotErr)
443+
}
444+
if pa.redirectAddr != maxAddress {
445+
t.Errorf("pa.redirectAddr should be %s, got %s", maxAddress, pa.redirectAddr)
446+
}
447+
}
448+
449+
func TestPersistentAuth_startListener_explicitPort(t *testing.T) {
450+
explicitPort := 1337
451+
pa := &PersistentAuth{port: explicitPort}
452+
pa.netListen = func(_, address string) (net.Listener, error) {
453+
return nil, nil
411454
}
412-
if p1.redirectAddr == p2.redirectAddr {
413-
t.Errorf("p1.redirectURL and p2.redirectURL should be different, got %s", p1.redirectAddr)
455+
456+
gotErr := pa.startListener(context.Background())
457+
458+
if gotErr != nil {
459+
t.Fatalf("pa.startListener(): want no error, got %v", gotErr)
460+
}
461+
if pa.redirectAddr != "localhost:1337" {
462+
t.Errorf("pa.redirectAddr should be localhost:1337, got %s", pa.redirectAddr)
463+
}
464+
}
465+
466+
func TestPersistentAuth_startListener_explicitPortNoFallBack(t *testing.T) {
467+
testError := errors.New("test error")
468+
explicitPort := 1337
469+
pa := &PersistentAuth{port: explicitPort}
470+
pa.netListen = func(_, address string) (net.Listener, error) {
471+
if address == "localhost:1337" {
472+
return nil, testError
473+
}
474+
return nil, nil
475+
}
476+
477+
gotErr := pa.startListener(context.Background())
478+
479+
if !errors.Is(gotErr, testError) {
480+
t.Fatalf("pa.startListener(): want error %v, got %v", testError, gotErr)
414481
}
415482
}

0 commit comments

Comments
 (0)