diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 247a88f94..4ec35b34b 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,8 @@ ### New Features and Improvements +* Use a free port in `u2m` authentication flows rather than 8020. + ### Bug Fixes ### Documentation diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index 7e28a25f5..ce2a904d6 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -26,7 +26,8 @@ const ( appClientID = "databricks-cli" // appRedirectAddr is the default address for the OAuth2 callback server. - appRedirectAddr = "localhost:8020" + // Using ":0" tells the system to pick a random available port. + appRedirectAddr = "localhost:0" // listenerTimeout is the maximum amount of time to acquire listener on // appRedirectAddr. @@ -57,6 +58,11 @@ type PersistentAuth struct { // ctx is the context to use for underlying operations. This is needed in // order to implement the oauth2.TokenSource interface. ctx context.Context + // redirectAddr is the redirect address for OAuth2 callbacks. The value is + // set to localhost:PORT by startListener which will dynamically assign a + // random port. If a value is already provided, it will be used instead + // (e.g. for testing). + redirectAddr string } type PersistentAuthOption func(*PersistentAuth) @@ -128,9 +134,6 @@ func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*Pers return nil, fmt.Errorf("cache: %w", err) } } - if p.oAuthArgument == nil { - return nil, errors.New("missing OAuthArgument") - } if err := p.validateArg(); err != nil { return nil, err } @@ -162,7 +165,9 @@ func (a *PersistentAuth) Token() (t *oauth2.Token, err error) { return nil, fmt.Errorf("token refresh: %w", err) } } - // do not print refresh token to end-user + + // Do not include the refresh token for security reasons. Refresh tokens are + // long-lived credentials that we do not want to expose unnecessarily. t.RefreshToken = "" return t, nil } @@ -269,12 +274,19 @@ func (a *PersistentAuth) Challenge() error { // startListener starts a listener on appRedirectAddr, retrying if the address // is already in use. func (a *PersistentAuth) startListener(ctx context.Context) error { + // Use the value of redirectURL if it is already set. This is only expected + // in tests to set a fixed redirect URL. + addr := a.redirectAddr + if addr == "" { + addr = appRedirectAddr + } + listener, err := retries.Poll(ctx, listenerTimeout, func() (*net.Listener, *retries.Err) { var lc net.ListenConfig - l, err := lc.Listen(ctx, "tcp", appRedirectAddr) + l, err := lc.Listen(ctx, "tcp", addr) if err != nil { - logger.Debugf(ctx, "failed to listen on %s: %v, retrying", appRedirectAddr, err) + logger.Debugf(ctx, "failed to listen on %s: %v, retrying", addr, err) return nil, retries.Continue(err) } return &l, nil @@ -283,6 +295,10 @@ func (a *PersistentAuth) startListener(ctx context.Context) error { return fmt.Errorf("listener: %w", err) } a.ln = *listener + + // Get the actual address that was assigned (including the port). + a.redirectAddr = a.ln.Addr().String() + logger.Debugf(ctx, "OAuth callback server listening on %s", a.redirectAddr) return nil } @@ -296,6 +312,9 @@ func (a *PersistentAuth) Close() error { // validateArg ensures that the OAuthArgument is either a WorkspaceOAuthArgument // or an AccountOAuthArgument. func (a *PersistentAuth) validateArg() error { + if a.oAuthArgument == nil { + return errors.New("missing OAuthArgument") + } _, isWorkspaceArg := a.oAuthArgument.(WorkspaceOAuthArgument) _, isAccountArg := a.oAuthArgument.(AccountOAuthArgument) if !isWorkspaceArg && !isAccountArg { @@ -331,7 +350,7 @@ func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) { TokenURL: endpoints.TokenEndpoint, AuthStyle: oauth2.AuthStyleInParams, }, - RedirectURL: fmt.Sprintf("http://%s", appRedirectAddr), + RedirectURL: fmt.Sprintf("http://%s", a.redirectAddr), Scopes: scopes, }, nil } diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go index 402ecdecc..27048ead7 100644 --- a/credentials/u2m/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/url" + "regexp" "strings" "testing" "time" @@ -297,6 +298,9 @@ func TestChallenge(t *testing.T) { } defer p.Close() + // Set a fixed redirect URL for the test. + p.redirectAddr = "localhost:1337" + errc := make(chan error) go func() { err := p.Challenge() @@ -305,7 +309,7 @@ func TestChallenge(t *testing.T) { }() state := <-browserOpened - resp, err := http.Get(fmt.Sprintf("http://localhost:8020?code=__THIS__&state=%s", state)) + resp, err := http.Get(fmt.Sprintf("http://localhost:1337?code=__THIS__&state=%s", state)) if err != nil { t.Fatalf("http.Get(): want no error, got %v", err) } @@ -347,6 +351,8 @@ func TestChallenge_ReturnsErrorOnFailure(t *testing.T) { } defer p.Close() + p.redirectAddr = "localhost:1337" // set a fixed redirect URL for the test + errc := make(chan error) go func() { err := p.Challenge() @@ -355,7 +361,7 @@ func TestChallenge_ReturnsErrorOnFailure(t *testing.T) { }() <-browserOpened - resp, err := http.Get("http://localhost:8020?error=access_denied&error_description=Policy%20evaluation%20failed%20for%20this%20request") + resp, err := http.Get("http://localhost:1337?error=access_denied&error_description=Policy%20evaluation%20failed%20for%20this%20request") if err != nil { t.Fatalf("http.Get(): want no error, got %v", err) } @@ -373,3 +379,37 @@ func TestChallenge_ReturnsErrorOnFailure(t *testing.T) { t.Fatalf("p.Challenge(): want error containing 'authorize: access_denied: Policy evaluation failed for this request', got %v", err) } } + +// Verifies that startListener assigns a random port to the redirectAddr. +func TestPersistentAuth_startListener_useDifferentPorts(t *testing.T) { + ctx := context.Background() + arg, err := NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + if err != nil { + t.Fatalf("NewBasicAccountOAuthArgument(): want no error, got %v", err) + } + + p1, err := NewPersistentAuth(ctx, WithOAuthArgument(arg)) + if err != nil { + t.Fatalf("NewPersistentAuth(): want no error, got %v", err) + } + defer p1.Close() + + p2, err := NewPersistentAuth(ctx, WithOAuthArgument(arg)) + if err != nil { + t.Fatalf("NewPersistentAuth(): want no error, got %v", err) + } + defer p2.Close() + + p1.startListener(ctx) + p2.startListener(ctx) + + if !regexp.MustCompile(`^127\.0\.0\.1:\d+$`).MatchString(p1.redirectAddr) { + t.Errorf("p1.redirectAddr should be random localhost port, got %s", p1.redirectAddr) + } + if !regexp.MustCompile(`^127\.0\.0\.1:\d+$`).MatchString(p2.redirectAddr) { + t.Errorf("p2.redirectAddr should be random localhost port, got %s", p2.redirectAddr) + } + if p1.redirectAddr == p2.redirectAddr { + t.Errorf("p1.redirectURL and p2.redirectURL should be different, got %s", p1.redirectAddr) + } +}