Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### New Features and Improvements

* Use a free port in `u2m` authentication flows rather than 8020.

### Bug Fixes

### Documentation
Expand Down
35 changes: 27 additions & 8 deletions credentials/u2m/persistent_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ const (
appClientID = "databricks-cli"

// appRedirectAddr is the default address for the OAuth2 callback server.
appRedirectAddr = "localhost:8020"
// Using ":0" tells the system to pick a random available port.
appRedirectAddr = "localhost:0"

// listenerTimeout is the maximum amount of time to acquire listener on
// appRedirectAddr.
Expand Down Expand Up @@ -57,6 +58,11 @@ type PersistentAuth struct {
// ctx is the context to use for underlying operations. This is needed in
// order to implement the oauth2.TokenSource interface.
ctx context.Context
// redirectAddr is the redirect address for OAuth2 callbacks. The value is
// set to localhost:PORT by startListener which will dynamically assign a
// random port. If a value is already provided, it will be used instead
// (e.g. for testing).
redirectAddr string
}

type PersistentAuthOption func(*PersistentAuth)
Expand Down Expand Up @@ -128,9 +134,6 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers
return nil, fmt.Errorf("cache: %w", err)
}
}
if p.oAuthArgument == nil {
return nil, errors.New("missing OAuthArgument")
}
if err := p.validateArg(); err != nil {
return nil, err
}
Expand Down Expand Up @@ -162,7 +165,9 @@ func (a *PersistentAuth) Token() (t *oauth2.Token, err error) {
return nil, fmt.Errorf("token refresh: %w", err)
}
}
// do not print refresh token to end-user

// Do not include the refresh token for security reasons. Refresh tokens are
// long-lived credentials that we do not want to expose unnecessarily.
t.RefreshToken = ""
return t, nil
}
Expand Down Expand Up @@ -269,12 +274,19 @@ func (a *PersistentAuth) Challenge() error {
// startListener starts a listener on appRedirectAddr, retrying if the address
// is already in use.
func (a *PersistentAuth) startListener(ctx context.Context) error {
// Use the value of redirectURL if it is already set. This is only expected
// in tests to set a fixed redirect URL.
addr := a.redirectAddr
if addr == "" {
addr = appRedirectAddr
}

listener, err := retries.Poll(ctx, listenerTimeout,
func() (*net.Listener, *retries.Err) {
var lc net.ListenConfig
l, err := lc.Listen(ctx, "tcp", appRedirectAddr)
l, err := lc.Listen(ctx, "tcp", addr)
if err != nil {
logger.Debugf(ctx, "failed to listen on %s: %v, retrying", appRedirectAddr, err)
logger.Debugf(ctx, "failed to listen on %s: %v, retrying", addr, err)
return nil, retries.Continue(err)
}
return &l, nil
Expand All @@ -283,6 +295,10 @@ func (a *PersistentAuth) startListener(ctx context.Context) error {
return fmt.Errorf("listener: %w", err)
}
a.ln = *listener

// Get the actual address that was assigned (including the port).
a.redirectAddr = a.ln.Addr().String()
logger.Debugf(ctx, "OAuth callback server listening on %s", a.redirectAddr)
return nil
}

Expand All @@ -296,6 +312,9 @@ func (a *PersistentAuth) Close() error {
// validateArg ensures that the OAuthArgument is either a WorkspaceOAuthArgument
// or an AccountOAuthArgument.
func (a *PersistentAuth) validateArg() error {
if a.oAuthArgument == nil {
return errors.New("missing OAuthArgument")
}
_, isWorkspaceArg := a.oAuthArgument.(WorkspaceOAuthArgument)
_, isAccountArg := a.oAuthArgument.(AccountOAuthArgument)
if !isWorkspaceArg && !isAccountArg {
Expand Down Expand Up @@ -331,7 +350,7 @@ func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) {
TokenURL: endpoints.TokenEndpoint,
AuthStyle: oauth2.AuthStyleInParams,
},
RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr),
RedirectURL: fmt.Sprintf("http://%s", a.redirectAddr),
Scopes: scopes,
}, nil
}
Expand Down
44 changes: 42 additions & 2 deletions credentials/u2m/persistent_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -297,6 +298,9 @@ func TestChallenge(t *testing.T) {
}
defer p.Close()

// Set a fixed redirect URL for the test.
p.redirectAddr = "localhost:1337"

errc := make(chan error)
go func() {
err := p.Challenge()
Expand All @@ -305,7 +309,7 @@ func TestChallenge(t *testing.T) {
}()

state := <-browserOpened
resp, err := http.Get(fmt.Sprintf("http://localhost:8020?code=__THIS__&state=%s", state))
resp, err := http.Get(fmt.Sprintf("http://localhost:1337?code=__THIS__&state=%s", state))
if err != nil {
t.Fatalf("http.Get(): want no error, got %v", err)
}
Expand Down Expand Up @@ -347,6 +351,8 @@ func TestChallenge_ReturnsErrorOnFailure(t *testing.T) {
}
defer p.Close()

p.redirectAddr = "localhost:1337" // set a fixed redirect URL for the test

errc := make(chan error)
go func() {
err := p.Challenge()
Expand All @@ -355,7 +361,7 @@ func TestChallenge_ReturnsErrorOnFailure(t *testing.T) {
}()

<-browserOpened
resp, err := http.Get("http://localhost:8020?error=access_denied&error_description=Policy%20evaluation%20failed%20for%20this%20request")
resp, err := http.Get("http://localhost:1337?error=access_denied&error_description=Policy%20evaluation%20failed%20for%20this%20request")
if err != nil {
t.Fatalf("http.Get(): want no error, got %v", err)
}
Expand All @@ -373,3 +379,37 @@ func TestChallenge_ReturnsErrorOnFailure(t *testing.T) {
t.Fatalf("p.Challenge(): want error containing 'authorize: access_denied: Policy evaluation failed for this request', got %v", err)
}
}

// Verifies that startListener assigns a random port to the redirectAddr.
func TestPersistentAuth_startListener_useDifferentPorts(t *testing.T) {
ctx := context.Background()
arg, err := NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz")
if err != nil {
t.Fatalf("NewBasicAccountOAuthArgument(): want no error, got %v", err)
}

p1, err := NewPersistentAuth(ctx, WithOAuthArgument(arg))
if err != nil {
t.Fatalf("NewPersistentAuth(): want no error, got %v", err)
}
defer p1.Close()

p2, err := NewPersistentAuth(ctx, WithOAuthArgument(arg))
if err != nil {
t.Fatalf("NewPersistentAuth(): want no error, got %v", err)
}
defer p2.Close()

p1.startListener(ctx)
p2.startListener(ctx)

if !regexp.MustCompile(`^127\.0\.0\.1:\d+$`).MatchString(p1.redirectAddr) {
t.Errorf("p1.redirectAddr should be random localhost port, got %s", p1.redirectAddr)
}
if !regexp.MustCompile(`^127\.0\.0\.1:\d+$`).MatchString(p2.redirectAddr) {
t.Errorf("p2.redirectAddr should be random localhost port, got %s", p2.redirectAddr)
}
if p1.redirectAddr == p2.redirectAddr {
t.Errorf("p1.redirectURL and p2.redirectURL should be different, got %s", p1.redirectAddr)
}
}
Loading