Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

### Bug Fixes

- Properly cache tokens when using the `metadata-service` auth type ([#1225](https://github.com/databricks/databricks-sdk-go/pull/1225)).
Comment thread
parthban-db marked this conversation as resolved.

### Documentation

### Internal Changes
Expand Down
2 changes: 1 addition & 1 deletion config/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,5 @@ type noopAuth struct{}
func (noopAuth) Name() string { return "noop" }
func (noopAuth) Configure(context.Context, *Config) (credentials.CredentialsProvider, error) {
visitor := func(r *http.Request) error { return nil }
return credentials.NewCredentialsProvider(visitor), nil
return credentials.CredentialsProviderFn(visitor), nil
}
8 changes: 5 additions & 3 deletions config/auth_azure_github_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/databricks/databricks-sdk-go/config/credentials"
"github.com/databricks/databricks-sdk-go/config/experimental/auth"
"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
"github.com/databricks/databricks-sdk-go/httpclient"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -49,7 +50,8 @@ func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config)
httpClient: cfg.refreshClient,
}

return credentials.NewOAuthCredentialsProvider(refreshableVisitor(ts), ts.Token), nil
cts := auth.NewCachedTokenSource(ts)
return credentials.NewOAuthCredentialsProviderFromTokenSource(cts), nil
}

// azureOIDCTokenSource implements [oauth2.TokenSource] to obtain Azure auth
Expand All @@ -65,8 +67,8 @@ type azureOIDCTokenSource struct {

const azureOICDTimeout = 10 * time.Second

func (ts *azureOIDCTokenSource) Token() (*oauth2.Token, error) {
ctx, cancel := context.WithTimeout(context.Background(), azureOICDTimeout)
func (ts *azureOIDCTokenSource) Token(ctx context.Context) (*oauth2.Token, error) {
ctx, cancel := context.WithTimeout(ctx, azureOICDTimeout)
defer cancel()

resp := struct { // anonymous struct to parse the response
Expand Down
4 changes: 2 additions & 2 deletions config/auth_azure_github_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func TestAzureGithubOIDCCredentials(t *testing.T) {
},
},
},
wantErrPrefix: errPrefix("inner token: http 500"),
wantErrPrefix: errPrefix("error getting token: http 500"),
},
{
desc: "invalid auth token",
Expand Down Expand Up @@ -204,7 +204,7 @@ func TestAzureGithubOIDCCredentials(t *testing.T) {
},
},
},
wantErrPrefix: errPrefix("inner token: invalid token"),
wantErrPrefix: errPrefix("error getting token: invalid token"),
},
{
desc: "success",
Expand Down
1 change: 0 additions & 1 deletion config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ func (token msiToken) Token() (*oauth2.Token, error) {
}
epoch, err := token.ExpiresOn.Int64()
if err != nil {
// go 1.19 doesn't support multiple error unwraps
return nil, fmt.Errorf("%w: %s", errInvalidTokenExpiry, err)
}
return &oauth2.Token{
Expand Down
2 changes: 1 addition & 1 deletion config/auth_basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ func (c BasicCredentials) Configure(ctx context.Context, cfg *Config) (credentia
r.Header.Set("Authorization", fmt.Sprintf("Basic %s", b64))
return nil
}
return credentials.NewCredentialsProvider(visitor), nil
return credentials.CredentialsProviderFn(visitor), nil
}
2 changes: 1 addition & 1 deletion config/auth_gcp_google_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (c
if !cfg.IsAccountClient() {
logger.Infof(ctx, "Using Google Default Application Credentials for Workspace")
visitor := refreshableVisitor(inner)
return credentials.NewCredentialsProvider(visitor), nil
return credentials.CredentialsProviderFn(visitor), nil
}
// source for generateAccessToken
platform, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{
Expand Down
4 changes: 2 additions & 2 deletions config/auth_m2m.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import (
"github.com/databricks/databricks-sdk-go/logger"
)

type M2mCredentials struct {
}
type M2mCredentials struct{}

func (c M2mCredentials) Name() string {
return "oauth-m2m"
Expand All @@ -34,6 +33,7 @@ func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials
TokenURL: endpoints.TokenEndpoint,
Scopes: []string{"all-apis"},
}).TokenSource(ctx)

visitor := refreshableVisitor(ts)
return credentials.NewOAuthCredentialsProvider(visitor, ts.Token), nil
}
100 changes: 54 additions & 46 deletions config/auth_m2m_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,62 +11,70 @@ import (
)

func TestM2mHappyFlow(t *testing.T) {
assertHeaders(t, &Config{
Host: "a",
ClientID: "b",
ClientSecret: "c",
HTTPTransport: fixtures.MappingTransport{
"GET /oidc/.well-known/oauth-authorization-server": {
Response: u2m.OAuthAuthorizationServer{
AuthorizationEndpoint: "https://localhost:1234/dummy/auth",
TokenEndpoint: "https://localhost:1234/dummy/token",
assertHeaders(
t,
&Config{
Host: "a",
ClientID: "b",
ClientSecret: "c",
AuthType: "oauth-m2m",
HTTPTransport: fixtures.MappingTransport{
"GET /oidc/.well-known/oauth-authorization-server": {
Response: u2m.OAuthAuthorizationServer{
AuthorizationEndpoint: "https://localhost:1234/dummy/auth",
TokenEndpoint: "https://localhost:1234/dummy/token",
},
},
},
"POST /dummy/token": {
ExpectedHeaders: map[string]string{
"Authorization": "Basic Yjpj",
"Content-Type": "application/x-www-form-urlencoded",
},
ExpectedRequest: url.Values{
"grant_type": {"client_credentials"},
"scope": {"all-apis"},
},
Response: oauth2.Token{
TokenType: "Some",
AccessToken: "cde",
"POST /dummy/token": {
ExpectedHeaders: map[string]string{
"Authorization": "Basic Yjpj",
"Content-Type": "application/x-www-form-urlencoded",
},
ExpectedRequest: url.Values{
"grant_type": {"client_credentials"},
"scope": {"all-apis"},
},
Response: oauth2.Token{
TokenType: "Some",
AccessToken: "cde",
},
},
},
},
}, map[string]string{
"Authorization": "Some cde",
})
map[string]string{
"Authorization": "Some cde",
},
)
}

func TestM2mHappyFlowForAccount(t *testing.T) {
assertHeaders(t, &Config{
Host: "accounts.cloud.databricks.com",
AccountID: "a",
ClientID: "b",
ClientSecret: "c",
HTTPTransport: fixtures.MappingTransport{
"POST /oidc/accounts/a/v1/token": {
ExpectedHeaders: map[string]string{
"Authorization": "Basic Yjpj",
"Content-Type": "application/x-www-form-urlencoded",
},
ExpectedRequest: url.Values{
"grant_type": {"client_credentials"},
"scope": {"all-apis"},
},
Response: oauth2.Token{
TokenType: "Some",
AccessToken: "cde",
assertHeaders(t,
&Config{
Host: "accounts.cloud.databricks.com",
AccountID: "a",
ClientID: "b",
ClientSecret: "c",
HTTPTransport: fixtures.MappingTransport{
"POST /oidc/accounts/a/v1/token": {
ExpectedHeaders: map[string]string{
"Authorization": "Basic Yjpj",
"Content-Type": "application/x-www-form-urlencoded",
},
ExpectedRequest: url.Values{
"grant_type": {"client_credentials"},
"scope": {"all-apis"},
},
Response: oauth2.Token{
TokenType: "Some",
AccessToken: "cde",
},
},
},
},
}, map[string]string{
"Authorization": "Some cde",
})
map[string]string{
"Authorization": "Some cde",
},
)
}

func TestM2mNotSupported(t *testing.T) {
Expand Down
49 changes: 23 additions & 26 deletions config/auth_metadata_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/databricks/databricks-sdk-go/config/credentials"
"github.com/databricks/databricks-sdk-go/config/experimental/auth"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -55,60 +56,56 @@ func (c MetadataServiceCredentials) Configure(ctx context.Context, cfg *Config)
}
parsedMetadataServiceURL, err := url.Parse(cfg.MetadataServiceURL)
if err != nil {
// go 1.19 doesn't allow multiple error unwraping
return nil, fmt.Errorf("%w: %s", errMetadataServiceMalformed, err)
}
// only allow localhost URLs

// Only allow localhost URLs.
if parsedMetadataServiceURL.Hostname() != "localhost" && parsedMetadataServiceURL.Hostname() != "127.0.0.1" {
return nil, fmt.Errorf("%w: %s", errMetadataServiceNotLocalhost, cfg.MetadataServiceURL)
}

ms := metadataService{
metadataServiceURL: parsedMetadataServiceURL,
config: cfg,
}
response, err := ms.Get(ctx)
if err != nil {

// Sanity check that a token can be obtained.
//
// TODO: Move this outside of this function. If credentials providers have
// to be tested, this should be done in the main default loop, not here.
if _, err := ms.Token(ctx); err != nil {
return nil, err
}
if response == nil {
return nil, nil
}

return credentials.NewOAuthCredentialsProviderFromTokenSource(ms), nil
cts := auth.NewCachedTokenSource(ms)
return credentials.NewOAuthCredentialsProviderFromTokenSource(cts), nil
}

type metadataService struct {
metadataServiceURL *url.URL
config *Config
}

// performs a request to the metadata service and returns the token
func (s metadataService) Get(ctx context.Context) (*oauth2.Token, error) {
func (ms metadataService) Token(ctx context.Context) (*oauth2.Token, error) {
ctx, cancel := context.WithTimeout(ctx, metadataServiceTimeout)
defer cancel()
var inner msiToken
err := s.config.refreshClient.Do(ctx, http.MethodGet,
s.metadataServiceURL.String(),

var mt msiToken
err := ms.config.refreshClient.Do(ctx, http.MethodGet,
ms.metadataServiceURL.String(),
httpclient.WithRequestHeader(MetadataServiceVersionHeader, MetadataServiceVersion),
httpclient.WithRequestHeader(MetadataServiceHostHeader, s.config.Host),
httpclient.WithResponseUnmarshal(&inner),
httpclient.WithRequestHeader(MetadataServiceHostHeader, ms.config.Host),
httpclient.WithResponseUnmarshal(&mt),
)
if err != nil {
return nil, fmt.Errorf("token request: %w", err)
}
return inner.Token()
}

func (t metadataService) Token(ctx context.Context) (*oauth2.Token, error) {
token, err := t.Get(ctx)
token, err := mt.Token()
if err != nil {
return nil, err
return nil, fmt.Errorf("token parse: %w", err)
}
if token == nil {
return nil, fmt.Errorf("no token returned from metadata service")
}
logger.Debugf(ctx,
"Refreshed access token from local metadata service, which expires on %s",
token.Expiry.Format(time.RFC3339))

logger.Debugf(ctx, "Refreshed access token from local metadata service, which expires on %s", token.Expiry.Format(time.RFC3339))
return token, nil
}
2 changes: 1 addition & 1 deletion config/auth_pat.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ func (c PatCredentials) Configure(ctx context.Context, cfg *Config) (credentials
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", cfg.Token))
return nil
}
return credentials.NewCredentialsProvider(visitor), nil
return credentials.CredentialsProviderFn(visitor), nil
}
3 changes: 1 addition & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/databricks/databricks-sdk-go/common/environment"
"github.com/databricks/databricks-sdk-go/config/credentials"
"github.com/databricks/databricks-sdk-go/config/experimental/auth"
"github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
Expand Down Expand Up @@ -252,7 +251,7 @@ func (c *Config) GetTokenSource() auth.TokenSource {
return errorTokenSource(err)
}
if h, ok := c.credentialsProvider.(credentials.OAuthCredentialsProvider); ok {
return authconv.AuthTokenSource(h)
return h
} else {
return errorTokenSource(fmt.Errorf("OAuth Token not supported for current auth type %s", c.AuthType))
}
Expand Down
24 changes: 17 additions & 7 deletions config/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,33 @@ type CredentialsProvider interface {
SetHeaders(r *http.Request) error
}

type credentialsProvider func(r *http.Request) error
// CredentialsProviderFn is an adapter to allow the use of an ordinary function
// as a CredentialsProvider.
//
// Example:
//
// cp := CredentialsProviderFn(func(r *http.Request) error {
// return nil
// })
type CredentialsProviderFn func(r *http.Request) error

func (c credentialsProvider) SetHeaders(r *http.Request) error {
func (c CredentialsProviderFn) SetHeaders(r *http.Request) error {
return c(r)
}

// NewCredentialsProvider returns a new CredentialsProvider that uses the
// provided function to set headers on the request.
//
// DEPRECATED: Use CredentialsProviderFn instead.
func NewCredentialsProvider(f func(r *http.Request) error) CredentialsProvider {
return credentialsProvider(f)
return CredentialsProviderFn(f)
}

// OAuthCredentialsProvider is a specialized CredentialsProvider uses and provides an OAuth token.
type OAuthCredentialsProvider interface {
CredentialsProvider
// Token returns the OAuth token generated by the provider.
Token() (*oauth2.Token, error)
Token(ctx context.Context) (*oauth2.Token, error)
}

// NewOAuthCredentialsProviderFromTokenSource returns a new OAuthCredentialsProvider
Expand All @@ -58,8 +68,8 @@ func (cp tsOAuthCredentialsProvider) SetHeaders(r *http.Request) error {
return nil
}

func (cp tsOAuthCredentialsProvider) Token() (*oauth2.Token, error) {
return cp.ts.Token(context.Background())
func (cp tsOAuthCredentialsProvider) Token(ctx context.Context) (*oauth2.Token, error) {
return cp.ts.Token(ctx)
}

// DEPRECATED: Use NewOAuthCredentialsProviderFromTokenSource instead.
Expand All @@ -79,6 +89,6 @@ func (c *oauthCredentialsProvider) SetHeaders(r *http.Request) error {
return c.setHeaders(r)
}

func (c *oauthCredentialsProvider) Token() (*oauth2.Token, error) {
func (c *oauthCredentialsProvider) Token(_ context.Context) (*oauth2.Token, error) {
return c.token()
}
Loading
Loading