Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

### New Features and Improvements

* Add `u2m.WithDiscoveryAccountTarget()` option that sets `target=ACCOUNT` on the login.databricks.com authorize URL, so the discovery flow lands the user on the account selector instead of the workspace selector.

### Bug Fixes

### Documentation
Expand Down
24 changes: 20 additions & 4 deletions credentials/u2m/discovery_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,26 @@ func DeriveTokenEndpoint(issuer string) string {
return strings.TrimRight(issuer, "/") + "/v1/token"
}

// discoveryTargetAccount is the value of the `target` query parameter that
// tells login.databricks.com to land the user on the account selector instead
// of the workspace selector. Used when the caller has signalled (e.g. via
// WithDiscoveryAccountTarget) that they only want account-level access.
const discoveryTargetAccount = "ACCOUNT"

// BuildDiscoveryAuthorizeURL builds the login.databricks.com URL that initiates
// the discovery OAuth flow. The OIDC authorize path with all OAuth query params
// is URL-encoded as the destination_url parameter.
func BuildDiscoveryAuthorizeURL(redirectAddr, state string, pkce PKCEParams, scopes []string) string {
return buildDiscoveryAuthorizeURL(defaultLoginDatabricksHost, redirectAddr, state, pkce, scopes)
return buildDiscoveryAuthorizeURL(defaultLoginDatabricksHost, redirectAddr, state, pkce, scopes, "")
}

// buildDiscoveryAuthorizeURL builds the discovery authorize URL against the
// given host. Trailing slashes on host are trimmed so the result is
// well-formed regardless of how an override is written.
func buildDiscoveryAuthorizeURL(host, redirectAddr, state string, pkce PKCEParams, scopes []string) string {
// well-formed regardless of how an override is written. When target is
// non-empty it is set as the top-level `target` query parameter, which
// login.databricks.com uses to route the user to a specific selector page
// (e.g. "ACCOUNT" for the account selector).
func buildDiscoveryAuthorizeURL(host, redirectAddr, state string, pkce PKCEParams, scopes []string, target string) string {
// Build the nested OIDC authorize path with query parameters.
authParams := url.Values{}
authParams.Set("client_id", appClientID)
Expand All @@ -73,6 +82,9 @@ func buildDiscoveryAuthorizeURL(host, redirectAddr, state string, pkce PKCEParam
// Wrap the authorize path as the destination_url query parameter on the
// discovery host.
topParams := url.Values{}
if target != "" {
topParams.Set("target", target)
}
topParams.Set("destination_url", destinationURL)
return strings.TrimRight(host, "/") + "/?" + topParams.Encode()
}
Expand All @@ -93,6 +105,10 @@ type discoveryTokenSource struct {
pa *PersistentAuth
// host overrides defaultLoginDatabricksHost when non-empty.
host string
// target is the value of the top-level `target` query parameter on the
// authorize URL. When non-empty (e.g. "ACCOUNT"), login.databricks.com
// routes the user directly to the corresponding selector.
target string
}

// challenge initiates the discovery OAuth flow through login.databricks.com.
Expand Down Expand Up @@ -122,7 +138,7 @@ func (d *discoveryTokenSource) challenge() error {
if host == "" {
host = defaultLoginDatabricksHost
}
authorizeURL := buildDiscoveryAuthorizeURL(host, d.pa.redirectAddr, state, pkce, scopes)
authorizeURL := buildDiscoveryAuthorizeURL(host, d.pa.redirectAddr, state, pkce, scopes, d.target)

// Use cb.Handler to open the browser and wait for the callback.
code, returnedState, err := cb.Handler(authorizeURL)
Expand Down
46 changes: 45 additions & 1 deletion credentials/u2m/discovery_token_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func TestBuildDiscoveryAuthorizeURL_HostOverride(t *testing.T) {
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := buildDiscoveryAuthorizeURL(tc.host, "localhost:8020", "s", pkce, scopes)
got := buildDiscoveryAuthorizeURL(tc.host, "localhost:8020", "s", pkce, scopes, "")
u, err := url.Parse(got)
if err != nil {
t.Fatalf("parsing URL: %v", err)
Expand All @@ -199,6 +199,50 @@ func TestBuildDiscoveryAuthorizeURL_HostOverride(t *testing.T) {
}
}

func TestBuildDiscoveryAuthorizeURL_Target(t *testing.T) {
pkce := PKCEParams{
Challenge: "c",
ChallengeMethod: "S256",
Verifier: "v",
}
scopes := []string{"offline_access", "all-apis"}
tests := []struct {
name string
target string
wantTarget string
}{
{name: "no target", target: "", wantTarget: ""},
{name: "account target", target: "ACCOUNT", wantTarget: "ACCOUNT"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := buildDiscoveryAuthorizeURL(defaultLoginDatabricksHost, "localhost:8020", "s", pkce, scopes, tc.target)
u, err := url.Parse(got)
if err != nil {
t.Fatalf("parsing URL: %v", err)
}
if g := u.Query().Get("target"); g != tc.wantTarget {
t.Errorf("target = %q, want %q", g, tc.wantTarget)
}
// destination_url must still be present in every variant.
if u.Query().Get("destination_url") == "" {
t.Error("destination_url should be set regardless of target")
}
})
}
}

func TestWithDiscoveryAccountTarget(t *testing.T) {
var a PersistentAuth
if a.discoveryAccountTarget {
t.Fatal("discoveryAccountTarget should default to false")
}
WithDiscoveryAccountTarget()(&a)
if !a.discoveryAccountTarget {
t.Error("WithDiscoveryAccountTarget did not set discoveryAccountTarget")
}
}

func TestWithDiscoveryHost_NormalizesScheme(t *testing.T) {
tests := []struct {
name string
Expand Down
21 changes: 21 additions & 0 deletions credentials/u2m/persistent_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ type PersistentAuth struct {
// discoveryHost overrides the default login.databricks.com host used by
// the discovery flow. Empty means the production host.
discoveryHost string

// discoveryAccountTarget, when true, instructs the discovery flow to set
// the top-level `target=ACCOUNT` query parameter on the authorize URL so
// login.databricks.com lands the user on the account selector instead of
// the workspace selector. Use for account-only logins.
discoveryAccountTarget bool
}

type PersistentAuthOption func(*PersistentAuth)
Expand Down Expand Up @@ -200,6 +206,18 @@ func WithDiscoveryHost(host string) PersistentAuthOption {
}
}

// WithDiscoveryAccountTarget sets the top-level `target=ACCOUNT` query
// parameter on the discovery authorize URL so login.databricks.com lands the
// user on the account selector instead of the workspace selector. Use for
// account-only logins where workspace selection would be a wasted step.
//
// Has no effect unless WithDiscoveryLogin is also set.
func WithDiscoveryAccountTarget() PersistentAuthOption {
return func(a *PersistentAuth) {
a.discoveryAccountTarget = true
}
}

// NewPersistentAuth creates a new PersistentAuth with the provided options.
func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) {
p := &PersistentAuth{}
Expand Down Expand Up @@ -424,6 +442,9 @@ func (a *PersistentAuth) discoveryChallenge() error {
}
defer a.Close()
ds := &discoveryTokenSource{pa: a, host: a.discoveryHost}
if a.discoveryAccountTarget {
ds.target = discoveryTargetAccount
}
return ds.challenge()
}

Expand Down
Loading