@@ -15,7 +15,6 @@ import (
1515 cache "github.com/databricks/databricks-sdk-go/credentials/u2m/cache"
1616 "github.com/databricks/databricks-sdk-go/httpclient"
1717 "github.com/databricks/databricks-sdk-go/logger"
18- "github.com/databricks/databricks-sdk-go/retries"
1918 "github.com/pkg/browser"
2019 "golang.org/x/oauth2"
2120 "golang.org/x/oauth2/authhandler"
@@ -25,15 +24,25 @@ const (
2524 // appClientId is the default client ID used by the SDK for U2M OAuth.
2625 appClientID = "databricks-cli"
2726
28- // appRedirectAddr is the default address for the OAuth2 callback server.
29- // Using ":0" tells the system to pick a random available port .
30- appRedirectAddr = "localhost:0"
27+ // defaultPort is the default port for the OAuth2 callback server. If the
28+ // port is already in use, the next port is tried (8021, 8022, etc.) .
29+ defaultPort = 8020
3130
32- // listenerTimeout is the maximum amount of time to acquire listener on
33- // appRedirectAddr.
31+ // maxPortFallback is the maximum port to try when using the fallback
32+ // mechanism.
33+ maxPortFallback = 8040
34+
35+ // listenerTimeout is the maximum duration spent trying to acquire a
36+ // listener (including port selection).
3437 listenerTimeout = 45 * time .Second
3538)
3639
40+ var (
41+ // Internal errors used for testing.
42+ errListenerTimeout = errors .New ("failed to listen on any port: timeout" )
43+ errNoPortAvailable = errors .New ("no port available to listen on" )
44+ )
45+
3746// PersistentAuth is an OAuth manager that handles the U2M OAuth flow. Tokens
3847// are stored in and looked up from the provided cache. Tokens include the
3948// refresh token. On load, if the access token is expired, it is refreshed
@@ -44,25 +53,41 @@ const (
4453type PersistentAuth struct {
4554 // cache is the token cache to store and lookup tokens.
4655 cache cache.TokenCache
56+
4757 // client is the HTTP client to use for OAuth2 requests.
4858 client * http.Client
59+
4960 // endpointSupplier is the HTTP endpointSupplier to use for OAuth2 requests.
5061 endpointSupplier OAuthEndpointSupplier
62+
5163 // oAuthArgument defines the workspace or account to authenticate to and the
5264 // cache key for the token.
5365 oAuthArgument OAuthArgument
66+
5467 // browser is the function to open a URL in the default browser.
5568 browser func (url string ) error
69+
5670 // ln is the listener for the OAuth2 callback server.
5771 ln net.Listener
72+
5873 // ctx is the context to use for underlying operations. This is needed in
5974 // order to implement the oauth2.TokenSource interface.
6075 ctx context.Context
76+
6177 // redirectAddr is the redirect address for OAuth2 callbacks. The value is
6278 // set to localhost:PORT by startListener which will dynamically assign a
6379 // random port. If a value is already provided, it will be used instead
6480 // (e.g. for testing).
6581 redirectAddr string
82+
83+ // Optional port to use for the OAuth2 callback server. If set to 0, the
84+ // default port with fallback is used. This means that setting a port will
85+ // disable the fallback mechanism.
86+ port int
87+
88+ // netListen is an optional function to listen on a TCP address. If not set,
89+ // it will use net.Listen by default. This is useful for testing.
90+ netListen func (network , address string ) (net.Listener , error )
6691}
6792
6893type PersistentAuthOption func (* PersistentAuth )
@@ -103,6 +128,13 @@ func WithBrowser(b func(url string) error) PersistentAuthOption {
103128 }
104129}
105130
131+ // WithPort sets the port for the PersistentAuth.
132+ func WithPort (port int ) PersistentAuthOption {
133+ return func (a * PersistentAuth ) {
134+ a .port = port
135+ }
136+ }
137+
106138// NewPersistentAuth creates a new PersistentAuth with the provided options.
107139func NewPersistentAuth (ctx context.Context , opts ... PersistentAuthOption ) (* PersistentAuth , error ) {
108140 p := & PersistentAuth {}
@@ -273,35 +305,51 @@ func (a *PersistentAuth) Challenge() error {
273305
274306// startListener starts a listener on appRedirectAddr, retrying if the address
275307// is already in use.
276- func (a * PersistentAuth ) startListener (ctx context.Context ) error {
277- // Use the value of redirectURL if it is already set. This is only expected
278- // in tests to set a fixed redirect URL.
279- addr := a .redirectAddr
280- if addr == "" {
281- addr = appRedirectAddr
282- }
283-
284- listener , err := retries .Poll (ctx , listenerTimeout ,
285- func () (* net.Listener , * retries.Err ) {
286- var lc net.ListenConfig
287- l , err := lc .Listen (ctx , "tcp" , addr )
288- if err != nil {
289- logger .Debugf (ctx , "failed to listen on %s: %v, retrying" , addr , err )
290- return nil , retries .Continue (err )
291- }
292- return & l , nil
293- })
294- if err != nil {
295- return fmt .Errorf ("listener: %w" , err )
308+ func (pa * PersistentAuth ) startListener (ctx context.Context ) error {
309+ if pa .port != 0 { // if port is set, use it
310+ return pa .startListenerWithPort (pa .port )
296311 }
297- a .ln = * listener
312+ return pa .startListenerWithFallback (ctx )
313+ }
298314
299- // Get the actual address that was assigned (including the port).
300- a .redirectAddr = a .ln .Addr ().String ()
301- logger .Debugf (ctx , "OAuth callback server listening on %s" , a .redirectAddr )
315+ // startListenerWithFallback starts a listener that will try to find a free
316+ // port to listen on starting from the default port and incrementing by 1 until
317+ // a free port is found.
318+ func (pa * PersistentAuth ) startListenerWithFallback (ctx context.Context ) error {
319+ startTime := time .Now ()
320+ for port := defaultPort ; port <= maxPortFallback ; port ++ {
321+ if time .Since (startTime ) > listenerTimeout {
322+ return errListenerTimeout
323+ }
324+ if err := pa .startListenerWithPort (port ); err != nil {
325+ logger .Debugf (ctx , "failed to listen on %d: %v, retrying" , port , err )
326+ continue
327+ }
328+ logger .Debugf (ctx , "OAuth callback server listening on %s" , pa .redirectAddr )
329+ return nil
330+ }
331+ return errNoPortAvailable
332+ }
333+
334+ func (pa * PersistentAuth ) startListenerWithPort (port int ) error {
335+ addr := fmt .Sprintf ("localhost:%d" , port )
336+ listener , err := pa .listen ("tcp" , addr )
337+ if err != nil {
338+ return fmt .Errorf ("failed to listen on %s: %w" , addr , err )
339+ }
340+ pa .ln = listener
341+ pa .redirectAddr = addr
302342 return nil
303343}
304344
345+ func (pa * PersistentAuth ) listen (network , addr string ) (net.Listener , error ) {
346+ if pa .netListen != nil {
347+ return pa .netListen (network , addr )
348+ } else {
349+ return net .Listen (network , addr )
350+ }
351+ }
352+
305353func (a * PersistentAuth ) Close () error {
306354 if a .ln == nil {
307355 return nil
0 commit comments