diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 53f2f500f..cc9e99159 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -6,6 +6,8 @@ ### New Features and Improvements +* Add `u2m.WithDiscoveryAccountTarget()` option that sets `target=ACCOUNT` on the login.databricks.com authorize URL, so the discovery flow lands the user on the account selector instead of the workspace selector. + ### Bug Fixes ### Documentation diff --git a/credentials/u2m/discovery_token_source.go b/credentials/u2m/discovery_token_source.go index 8106cee32..578374f3d 100644 --- a/credentials/u2m/discovery_token_source.go +++ b/credentials/u2m/discovery_token_source.go @@ -48,17 +48,26 @@ func DeriveTokenEndpoint(issuer string) string { return strings.TrimRight(issuer, "/") + "/v1/token" } +// discoveryTargetAccount is the value of the `target` query parameter that +// tells login.databricks.com to land the user on the account selector instead +// of the workspace selector. Used when the caller has signalled (e.g. via +// WithDiscoveryAccountTarget) that they only want account-level access. +const discoveryTargetAccount = "ACCOUNT" + // BuildDiscoveryAuthorizeURL builds the login.databricks.com URL that initiates // the discovery OAuth flow. The OIDC authorize path with all OAuth query params // is URL-encoded as the destination_url parameter. func BuildDiscoveryAuthorizeURL(redirectAddr, state string, pkce PKCEParams, scopes []string) string { - return buildDiscoveryAuthorizeURL(defaultLoginDatabricksHost, redirectAddr, state, pkce, scopes) + return buildDiscoveryAuthorizeURL(defaultLoginDatabricksHost, redirectAddr, state, pkce, scopes, "") } // buildDiscoveryAuthorizeURL builds the discovery authorize URL against the // given host. Trailing slashes on host are trimmed so the result is -// well-formed regardless of how an override is written. -func buildDiscoveryAuthorizeURL(host, redirectAddr, state string, pkce PKCEParams, scopes []string) string { +// well-formed regardless of how an override is written. When target is +// non-empty it is set as the top-level `target` query parameter, which +// login.databricks.com uses to route the user to a specific selector page +// (e.g. "ACCOUNT" for the account selector). +func buildDiscoveryAuthorizeURL(host, redirectAddr, state string, pkce PKCEParams, scopes []string, target string) string { // Build the nested OIDC authorize path with query parameters. authParams := url.Values{} authParams.Set("client_id", appClientID) @@ -73,6 +82,9 @@ func buildDiscoveryAuthorizeURL(host, redirectAddr, state string, pkce PKCEParam // Wrap the authorize path as the destination_url query parameter on the // discovery host. topParams := url.Values{} + if target != "" { + topParams.Set("target", target) + } topParams.Set("destination_url", destinationURL) return strings.TrimRight(host, "/") + "/?" + topParams.Encode() } @@ -93,6 +105,10 @@ type discoveryTokenSource struct { pa *PersistentAuth // host overrides defaultLoginDatabricksHost when non-empty. host string + // target is the value of the top-level `target` query parameter on the + // authorize URL. When non-empty (e.g. "ACCOUNT"), login.databricks.com + // routes the user directly to the corresponding selector. + target string } // challenge initiates the discovery OAuth flow through login.databricks.com. @@ -122,7 +138,7 @@ func (d *discoveryTokenSource) challenge() error { if host == "" { host = defaultLoginDatabricksHost } - authorizeURL := buildDiscoveryAuthorizeURL(host, d.pa.redirectAddr, state, pkce, scopes) + authorizeURL := buildDiscoveryAuthorizeURL(host, d.pa.redirectAddr, state, pkce, scopes, d.target) // Use cb.Handler to open the browser and wait for the callback. code, returnedState, err := cb.Handler(authorizeURL) diff --git a/credentials/u2m/discovery_token_source_test.go b/credentials/u2m/discovery_token_source_test.go index f081dd299..d3eb81d47 100644 --- a/credentials/u2m/discovery_token_source_test.go +++ b/credentials/u2m/discovery_token_source_test.go @@ -187,7 +187,7 @@ func TestBuildDiscoveryAuthorizeURL_HostOverride(t *testing.T) { } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got := buildDiscoveryAuthorizeURL(tc.host, "localhost:8020", "s", pkce, scopes) + got := buildDiscoveryAuthorizeURL(tc.host, "localhost:8020", "s", pkce, scopes, "") u, err := url.Parse(got) if err != nil { t.Fatalf("parsing URL: %v", err) @@ -199,6 +199,50 @@ func TestBuildDiscoveryAuthorizeURL_HostOverride(t *testing.T) { } } +func TestBuildDiscoveryAuthorizeURL_Target(t *testing.T) { + pkce := PKCEParams{ + Challenge: "c", + ChallengeMethod: "S256", + Verifier: "v", + } + scopes := []string{"offline_access", "all-apis"} + tests := []struct { + name string + target string + wantTarget string + }{ + {name: "no target", target: "", wantTarget: ""}, + {name: "account target", target: "ACCOUNT", wantTarget: "ACCOUNT"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := buildDiscoveryAuthorizeURL(defaultLoginDatabricksHost, "localhost:8020", "s", pkce, scopes, tc.target) + u, err := url.Parse(got) + if err != nil { + t.Fatalf("parsing URL: %v", err) + } + if g := u.Query().Get("target"); g != tc.wantTarget { + t.Errorf("target = %q, want %q", g, tc.wantTarget) + } + // destination_url must still be present in every variant. + if u.Query().Get("destination_url") == "" { + t.Error("destination_url should be set regardless of target") + } + }) + } +} + +func TestWithDiscoveryAccountTarget(t *testing.T) { + var a PersistentAuth + if a.discoveryAccountTarget { + t.Fatal("discoveryAccountTarget should default to false") + } + WithDiscoveryAccountTarget()(&a) + if !a.discoveryAccountTarget { + t.Error("WithDiscoveryAccountTarget did not set discoveryAccountTarget") + } +} + func TestWithDiscoveryHost_NormalizesScheme(t *testing.T) { tests := []struct { name string diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index fb8c81d0b..ba5d96d1b 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -111,6 +111,12 @@ type PersistentAuth struct { // discoveryHost overrides the default login.databricks.com host used by // the discovery flow. Empty means the production host. discoveryHost string + + // discoveryAccountTarget, when true, instructs the discovery flow to set + // the top-level `target=ACCOUNT` query parameter on the authorize URL so + // login.databricks.com lands the user on the account selector instead of + // the workspace selector. Use for account-only logins. + discoveryAccountTarget bool } type PersistentAuthOption func(*PersistentAuth) @@ -200,6 +206,18 @@ func WithDiscoveryHost(host string) PersistentAuthOption { } } +// WithDiscoveryAccountTarget sets the top-level `target=ACCOUNT` query +// parameter on the discovery authorize URL so login.databricks.com lands the +// user on the account selector instead of the workspace selector. Use for +// account-only logins where workspace selection would be a wasted step. +// +// Has no effect unless WithDiscoveryLogin is also set. +func WithDiscoveryAccountTarget() PersistentAuthOption { + return func(a *PersistentAuth) { + a.discoveryAccountTarget = true + } +} + // NewPersistentAuth creates a new PersistentAuth with the provided options. func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) { p := &PersistentAuth{} @@ -424,6 +442,9 @@ func (a *PersistentAuth) discoveryChallenge() error { } defer a.Close() ds := &discoveryTokenSource{pa: a, host: a.discoveryHost} + if a.discoveryAccountTarget { + ds.target = discoveryTargetAccount + } return ds.challenge() }