Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ type Config struct {
// When using Workload Identity Federation, the audience to specify when fetching an ID token from the ID token supplier.
TokenAudience string `name:"audience" env:"DATABRICKS_TOKEN_AUDIENCE" auth:"-"`

// When using Databricks OAuth, the scopes to request for the token.
// If not set, the default scopes will be used ("offline_access" + "all-apis").
OAuthScopes []string `name:"scopes" env:"DATABRICKS_OAUTH_SCOPES" auth:"-"`

Loaders []Loader

// marker for configuration resolving
Expand Down Expand Up @@ -481,7 +485,7 @@ func (c *Config) getOidcEndpoints(ctx context.Context) (*u2m.OAuthAuthorizationS
func (c *Config) getOAuthArgument() (u2m.OAuthArgument, error) {
host := c.CanonicalHostName()
if c.IsAccountClient() {
return u2m.NewBasicAccountOAuthArgument(host, c.AccountID)
return u2m.NewBasicAccountOAuthArgument(host, c.AccountID, c.OAuthScopes...)
}
return u2m.NewBasicWorkspaceOAuthArgument(host)
return u2m.NewBasicWorkspaceOAuthArgument(host, c.OAuthScopes...)
}
9 changes: 6 additions & 3 deletions credentials/u2m/account_oauth_argument.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@ type AccountOAuthArgument interface {
// BasicAccountOAuthArgument is a basic implementation of the AccountOAuthArgument
// interface that links each account with exactly one OAuth token.
type BasicAccountOAuthArgument struct {
oauthScopes

accountHost string
accountID string
}

var _ AccountOAuthArgument = BasicAccountOAuthArgument{}

// NewBasicAccountOAuthArgument creates a new BasicAccountOAuthArgument.
func NewBasicAccountOAuthArgument(accountsHost, accountID string) (BasicAccountOAuthArgument, error) {
func NewBasicAccountOAuthArgument(accountsHost, accountID string, scopes ...string) (BasicAccountOAuthArgument, error) {
if err := validateHost(accountsHost); err != nil {
return BasicAccountOAuthArgument{}, err
}
return BasicAccountOAuthArgument{accountHost: accountsHost, accountID: accountID}, nil
return BasicAccountOAuthArgument{oauthScopes: newOAuthScopes(scopes...), accountHost: accountsHost, accountID: accountID}, nil
}

// GetAccountHost returns the host of the account to authenticate to.
Expand All @@ -46,5 +48,6 @@ func (a BasicAccountOAuthArgument) GetAccountId() string {
// GetCacheKey returns a unique key for caching the OAuth token for the account.
// The key is in the format "<accountHost>/oidc/accounts/<accountID>".
func (a BasicAccountOAuthArgument) GetCacheKey() string {
return fmt.Sprintf("%s/oidc/accounts/%s", a.accountHost, a.accountID)
base := fmt.Sprintf("%s/oidc/accounts/%s", a.accountHost, a.accountID)
return computeScopedCacheKey(base, a.oauthScopes)
}
3 changes: 3 additions & 0 deletions credentials/u2m/oauth_argument.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ type OAuthArgument interface {
// GetCacheKey returns a unique key for the OAuthArgument. This key is used
// to store and retrieve the token from the token cache.
GetCacheKey() string

// GetScopes returns the OAuth scopes to request for this argument.
GetScopes() []string
}
5 changes: 1 addition & 4 deletions credentials/u2m/persistent_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,7 @@ func (a *PersistentAuth) validateArg() error {

// oauth2Config returns the OAuth2 configuration for the given OAuthArgument.
func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) {
scopes := []string{
"offline_access", // ensures OAuth token includes refresh token
"all-apis", // ensures OAuth token has access to all control-plane APIs
}
scopes := a.oAuthArgument.GetScopes()
var endpoints *OAuthAuthorizationServer
var err error
switch argg := a.oAuthArgument.(type) {
Expand Down
86 changes: 86 additions & 0 deletions credentials/u2m/scopes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package u2m

import (
"slices"

Check failure on line 4 in credentials/u2m/scopes.go

View workflow job for this annotation

GitHub Actions / tests (1.20)

package slices is not in GOROOT (/opt/hostedtoolcache/go/1.20.14/x64/src/slices)

Check failure on line 4 in credentials/u2m/scopes.go

View workflow job for this annotation

GitHub Actions / tests (1.20)

package slices is not in GOROOT (/opt/hostedtoolcache/go/1.20.14/x64/src/slices)

Check failure on line 4 in credentials/u2m/scopes.go

View workflow job for this annotation

GitHub Actions / tests (1.19)

cannot find package "." in:

Check failure on line 4 in credentials/u2m/scopes.go

View workflow job for this annotation

GitHub Actions / tests (1.19)

cannot find package "." in:
"sort"
"strings"
)

// OAuthScopes encapsulates OAuth scopes configuration and normalization logic.
//
// It guarantees that the resulting scope list:
// - always includes "offline_access" (to receive refresh tokens)
// - is lower-cased, de-duplicated, and stable-sorted
//
// The default value equals the SDK's historical behavior: offline_access + all-apis.
type oauthScopes struct {
values []string
}

var defaultScopeValues = []string{
"offline_access", // Ensures OAuth token includes refresh token.
"all-apis", // Ensures OAuth token has access to all control-plane APIs.
}

// newOAuthScopes constructs oauthScopes from the provided scope values.
// If no scopes are provided, it returns the default scopes.
func newOAuthScopes(scopes ...string) oauthScopes {
if len(scopes) == 0 {
scopes = defaultScopeValues
}
vals := make([]string, len(scopes))
copy(vals, scopes)
return oauthScopes{values: vals}
}

// GetScopes returns the normalized, de-duplicated, stable-sorted list of scopes,
// guaranteeing that "offline_access" is included.
// Exported because it implements the OAuthArgument interface.
func (s oauthScopes) GetScopes() []string {
return normalizeScopes(s.values)
}

// isDefault reports whether the scopes are equivalent to the default scopes
// (order-insensitive, case-insensitive, duplicates ignored).
func (s oauthScopes) isDefault() bool {
left := normalizeScopes(s.values)
right := normalizeScopes(defaultScopeValues)
return slices.Equal(left, right)
}

// ComputeScopedCacheKey produces a backward-compatible cache key:
// - For default scopes, it returns the baseKey unchanged
// - For non-default scopes, it appends a stable hash suffix of the scopes
// to avoid very long keys while ensuring uniqueness
//
// Result format for non-default scopes: "<baseKey>#scopes=<scope1,scope2,...>"
// where scopes are normalized and comma-separated for readability.
func computeScopedCacheKey(baseKey string, scopes oauthScopes) string {
if scopes.isDefault() {
return baseKey
}
joined := strings.Join(normalizeScopes(scopes.values), ",")
return baseKey + "#scopes=" + joined
}

// normalizeScopes lowercases, ensures offline_access presence, removes duplicates,
// and returns a stable-sorted slice.
func normalizeScopes(in []string) []string {
set := map[string]struct{}{}
for _, s := range in {
v := strings.TrimSpace(strings.ToLower(s))
if v == "" {
continue
}
set[v] = struct{}{}
}
// Always include offline_access to guarantee refresh tokens
set["offline_access"] = struct{}{}

out := make([]string, 0, len(set))
for v := range set {
out = append(out, v)
}
sort.Strings(out)
return out
}
44 changes: 44 additions & 0 deletions credentials/u2m/scopes_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package u2m

import (
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestDefaultOAuthScopes_NormalizationAndDefault(t *testing.T) {
s := newOAuthScopes()
got := s.GetScopes()
want := []string{"all-apis", "offline_access"}
require.Equal(t, want, got, "GetScopes() should return normalized default scopes")
require.True(t, s.isDefault(), "IsDefault() should be true for default scopes")
}

func TestNewOAuthScopes_Normalization(t *testing.T) {
// duplicates, mixed case, empty, and whitespace
s := newOAuthScopes("All-APIs", "offline_access", "", " x:y ", "x:y")
got := s.GetScopes()
// normalized, sorted
want := []string{"all-apis", "offline_access", "x:y"}
require.Equal(t, want, got, "GetScopes() should return normalized, sorted scopes")
require.False(t, s.isDefault(), "IsDefault() should be false for non-default scopes")
}

func TestComputeScopedCacheKey_DefaultKeepsBase(t *testing.T) {
base := "https://abc"
key := computeScopedCacheKey(base, newOAuthScopes())
require.Equal(t, base, key)
}

func TestComputeScopedCacheKey_NonDefaultIncludesReadableScopes(t *testing.T) {
base := "https://abc"
key := computeScopedCacheKey(base, newOAuthScopes("foo", "offline_access"))
require.True(t, strings.HasPrefix(key, base+"#scopes="))
require.Contains(t, key, "foo")
// ensure normalization kept offline_access
require.Contains(t, key, "offline_access")
// stability check: same scopes order-insensitive yields same key
key2 := computeScopedCacheKey(base, newOAuthScopes("OFFLINE_ACCESS", "foo"))
require.Equal(t, key, key2)
}
9 changes: 6 additions & 3 deletions credentials/u2m/workspace_oauth_argument.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ type WorkspaceOAuthArgument interface {
// BasicWorkspaceOAuthArgument is a basic implementation of the WorkspaceOAuthArgument
// interface that links each host with exactly one OAuth token.
type BasicWorkspaceOAuthArgument struct {
oauthScopes

// host is the host of the workspace to authenticate to. This must start
// with "https://" and must not have a trailing slash.
host string
Expand All @@ -38,11 +40,11 @@ func validateHost(host string) error {
}

// NewBasicWorkspaceOAuthArgument creates a new BasicWorkspaceOAuthArgument.
func NewBasicWorkspaceOAuthArgument(host string) (BasicWorkspaceOAuthArgument, error) {
func NewBasicWorkspaceOAuthArgument(host string, scopes ...string) (BasicWorkspaceOAuthArgument, error) {
if err := validateHost(host); err != nil {
return BasicWorkspaceOAuthArgument{}, err
}
return BasicWorkspaceOAuthArgument{host: host}, nil
return BasicWorkspaceOAuthArgument{oauthScopes: newOAuthScopes(scopes...), host: host}, nil
}

// GetWorkspaceHost returns the host of the workspace to authenticate to.
Expand All @@ -57,7 +59,8 @@ func (a BasicWorkspaceOAuthArgument) GetCacheKey() string {
if !strings.HasPrefix(a.host, "http") {
a.host = fmt.Sprintf("https://%s", a.host)
}
return a.host
base := a.host
return computeScopedCacheKey(base, a.oauthScopes)
}

var _ WorkspaceOAuthArgument = BasicWorkspaceOAuthArgument{}
Loading