diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 7b7b0d33a..fdb312e75 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -6,6 +6,8 @@ ### Bug Fixes +- Properly cache tokens when using the `metadata-service` auth type ([#1225](https://github.com/databricks/databricks-sdk-go/pull/1225)). + ### Documentation ### Internal Changes diff --git a/config/api_client.go b/config/api_client.go index 0101dc2ac..04587e354 100644 --- a/config/api_client.go +++ b/config/api_client.go @@ -106,5 +106,5 @@ type noopAuth struct{} func (noopAuth) Name() string { return "noop" } func (noopAuth) Configure(context.Context, *Config) (credentials.CredentialsProvider, error) { visitor := func(r *http.Request) error { return nil } - return credentials.NewCredentialsProvider(visitor), nil + return credentials.CredentialsProviderFn(visitor), nil } diff --git a/config/auth_azure_github_oidc.go b/config/auth_azure_github_oidc.go index 8120f3b7b..eaa9628af 100644 --- a/config/auth_azure_github_oidc.go +++ b/config/auth_azure_github_oidc.go @@ -7,6 +7,7 @@ import ( "time" "github.com/databricks/databricks-sdk-go/config/credentials" + "github.com/databricks/databricks-sdk-go/config/experimental/auth" "github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc" "github.com/databricks/databricks-sdk-go/httpclient" "golang.org/x/oauth2" @@ -49,7 +50,8 @@ func (c AzureGithubOIDCCredentials) Configure(ctx context.Context, cfg *Config) httpClient: cfg.refreshClient, } - return credentials.NewOAuthCredentialsProvider(refreshableVisitor(ts), ts.Token), nil + cts := auth.NewCachedTokenSource(ts) + return credentials.NewOAuthCredentialsProviderFromTokenSource(cts), nil } // azureOIDCTokenSource implements [oauth2.TokenSource] to obtain Azure auth @@ -65,8 +67,8 @@ type azureOIDCTokenSource struct { const azureOICDTimeout = 10 * time.Second -func (ts *azureOIDCTokenSource) Token() (*oauth2.Token, error) { - ctx, cancel := context.WithTimeout(context.Background(), azureOICDTimeout) +func (ts *azureOIDCTokenSource) Token(ctx context.Context) (*oauth2.Token, error) { + ctx, cancel := context.WithTimeout(ctx, azureOICDTimeout) defer cancel() resp := struct { // anonymous struct to parse the response diff --git a/config/auth_azure_github_oidc_test.go b/config/auth_azure_github_oidc_test.go index 7f53e401d..d3e62dd51 100644 --- a/config/auth_azure_github_oidc_test.go +++ b/config/auth_azure_github_oidc_test.go @@ -168,7 +168,7 @@ func TestAzureGithubOIDCCredentials(t *testing.T) { }, }, }, - wantErrPrefix: errPrefix("inner token: http 500"), + wantErrPrefix: errPrefix("error getting token: http 500"), }, { desc: "invalid auth token", @@ -204,7 +204,7 @@ func TestAzureGithubOIDCCredentials(t *testing.T) { }, }, }, - wantErrPrefix: errPrefix("inner token: invalid token"), + wantErrPrefix: errPrefix("error getting token: invalid token"), }, { desc: "success", diff --git a/config/auth_azure_msi.go b/config/auth_azure_msi.go index 7a947d7d0..aaf0c53bd 100644 --- a/config/auth_azure_msi.go +++ b/config/auth_azure_msi.go @@ -105,7 +105,6 @@ func (token msiToken) Token() (*oauth2.Token, error) { } epoch, err := token.ExpiresOn.Int64() if err != nil { - // go 1.19 doesn't support multiple error unwraps return nil, fmt.Errorf("%w: %s", errInvalidTokenExpiry, err) } return &oauth2.Token{ diff --git a/config/auth_basic.go b/config/auth_basic.go index 8b25fa456..63e2e31b3 100644 --- a/config/auth_basic.go +++ b/config/auth_basic.go @@ -26,5 +26,5 @@ func (c BasicCredentials) Configure(ctx context.Context, cfg *Config) (credentia r.Header.Set("Authorization", fmt.Sprintf("Basic %s", b64)) return nil } - return credentials.NewCredentialsProvider(visitor), nil + return credentials.CredentialsProviderFn(visitor), nil } diff --git a/config/auth_gcp_google_id.go b/config/auth_gcp_google_id.go index 3d3a600b1..1c81b2451 100644 --- a/config/auth_gcp_google_id.go +++ b/config/auth_gcp_google_id.go @@ -31,7 +31,7 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (c if !cfg.IsAccountClient() { logger.Infof(ctx, "Using Google Default Application Credentials for Workspace") visitor := refreshableVisitor(inner) - return credentials.NewCredentialsProvider(visitor), nil + return credentials.CredentialsProviderFn(visitor), nil } // source for generateAccessToken platform, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{ diff --git a/config/auth_m2m.go b/config/auth_m2m.go index 24d129450..76958b2ff 100644 --- a/config/auth_m2m.go +++ b/config/auth_m2m.go @@ -11,8 +11,7 @@ import ( "github.com/databricks/databricks-sdk-go/logger" ) -type M2mCredentials struct { -} +type M2mCredentials struct{} func (c M2mCredentials) Name() string { return "oauth-m2m" @@ -34,6 +33,7 @@ func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials TokenURL: endpoints.TokenEndpoint, Scopes: []string{"all-apis"}, }).TokenSource(ctx) + visitor := refreshableVisitor(ts) return credentials.NewOAuthCredentialsProvider(visitor, ts.Token), nil } diff --git a/config/auth_m2m_test.go b/config/auth_m2m_test.go index 4284cbe62..e7b107e8d 100644 --- a/config/auth_m2m_test.go +++ b/config/auth_m2m_test.go @@ -11,62 +11,70 @@ import ( ) func TestM2mHappyFlow(t *testing.T) { - assertHeaders(t, &Config{ - Host: "a", - ClientID: "b", - ClientSecret: "c", - HTTPTransport: fixtures.MappingTransport{ - "GET /oidc/.well-known/oauth-authorization-server": { - Response: u2m.OAuthAuthorizationServer{ - AuthorizationEndpoint: "https://localhost:1234/dummy/auth", - TokenEndpoint: "https://localhost:1234/dummy/token", + assertHeaders( + t, + &Config{ + Host: "a", + ClientID: "b", + ClientSecret: "c", + AuthType: "oauth-m2m", + HTTPTransport: fixtures.MappingTransport{ + "GET /oidc/.well-known/oauth-authorization-server": { + Response: u2m.OAuthAuthorizationServer{ + AuthorizationEndpoint: "https://localhost:1234/dummy/auth", + TokenEndpoint: "https://localhost:1234/dummy/token", + }, }, - }, - "POST /dummy/token": { - ExpectedHeaders: map[string]string{ - "Authorization": "Basic Yjpj", - "Content-Type": "application/x-www-form-urlencoded", - }, - ExpectedRequest: url.Values{ - "grant_type": {"client_credentials"}, - "scope": {"all-apis"}, - }, - Response: oauth2.Token{ - TokenType: "Some", - AccessToken: "cde", + "POST /dummy/token": { + ExpectedHeaders: map[string]string{ + "Authorization": "Basic Yjpj", + "Content-Type": "application/x-www-form-urlencoded", + }, + ExpectedRequest: url.Values{ + "grant_type": {"client_credentials"}, + "scope": {"all-apis"}, + }, + Response: oauth2.Token{ + TokenType: "Some", + AccessToken: "cde", + }, }, }, }, - }, map[string]string{ - "Authorization": "Some cde", - }) + map[string]string{ + "Authorization": "Some cde", + }, + ) } func TestM2mHappyFlowForAccount(t *testing.T) { - assertHeaders(t, &Config{ - Host: "accounts.cloud.databricks.com", - AccountID: "a", - ClientID: "b", - ClientSecret: "c", - HTTPTransport: fixtures.MappingTransport{ - "POST /oidc/accounts/a/v1/token": { - ExpectedHeaders: map[string]string{ - "Authorization": "Basic Yjpj", - "Content-Type": "application/x-www-form-urlencoded", - }, - ExpectedRequest: url.Values{ - "grant_type": {"client_credentials"}, - "scope": {"all-apis"}, - }, - Response: oauth2.Token{ - TokenType: "Some", - AccessToken: "cde", + assertHeaders(t, + &Config{ + Host: "accounts.cloud.databricks.com", + AccountID: "a", + ClientID: "b", + ClientSecret: "c", + HTTPTransport: fixtures.MappingTransport{ + "POST /oidc/accounts/a/v1/token": { + ExpectedHeaders: map[string]string{ + "Authorization": "Basic Yjpj", + "Content-Type": "application/x-www-form-urlencoded", + }, + ExpectedRequest: url.Values{ + "grant_type": {"client_credentials"}, + "scope": {"all-apis"}, + }, + Response: oauth2.Token{ + TokenType: "Some", + AccessToken: "cde", + }, }, }, }, - }, map[string]string{ - "Authorization": "Some cde", - }) + map[string]string{ + "Authorization": "Some cde", + }, + ) } func TestM2mNotSupported(t *testing.T) { diff --git a/config/auth_metadata_service.go b/config/auth_metadata_service.go index 3718b4a93..cab658386 100644 --- a/config/auth_metadata_service.go +++ b/config/auth_metadata_service.go @@ -9,6 +9,7 @@ import ( "time" "github.com/databricks/databricks-sdk-go/config/credentials" + "github.com/databricks/databricks-sdk-go/config/experimental/auth" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2" @@ -55,26 +56,29 @@ func (c MetadataServiceCredentials) Configure(ctx context.Context, cfg *Config) } parsedMetadataServiceURL, err := url.Parse(cfg.MetadataServiceURL) if err != nil { - // go 1.19 doesn't allow multiple error unwraping return nil, fmt.Errorf("%w: %s", errMetadataServiceMalformed, err) } - // only allow localhost URLs + + // Only allow localhost URLs. if parsedMetadataServiceURL.Hostname() != "localhost" && parsedMetadataServiceURL.Hostname() != "127.0.0.1" { return nil, fmt.Errorf("%w: %s", errMetadataServiceNotLocalhost, cfg.MetadataServiceURL) } + ms := metadataService{ metadataServiceURL: parsedMetadataServiceURL, config: cfg, } - response, err := ms.Get(ctx) - if err != nil { + + // Sanity check that a token can be obtained. + // + // TODO: Move this outside of this function. If credentials providers have + // to be tested, this should be done in the main default loop, not here. + if _, err := ms.Token(ctx); err != nil { return nil, err } - if response == nil { - return nil, nil - } - return credentials.NewOAuthCredentialsProviderFromTokenSource(ms), nil + cts := auth.NewCachedTokenSource(ms) + return credentials.NewOAuthCredentialsProviderFromTokenSource(cts), nil } type metadataService struct { @@ -82,33 +86,26 @@ type metadataService struct { config *Config } -// performs a request to the metadata service and returns the token -func (s metadataService) Get(ctx context.Context) (*oauth2.Token, error) { +func (ms metadataService) Token(ctx context.Context) (*oauth2.Token, error) { ctx, cancel := context.WithTimeout(ctx, metadataServiceTimeout) defer cancel() - var inner msiToken - err := s.config.refreshClient.Do(ctx, http.MethodGet, - s.metadataServiceURL.String(), + + var mt msiToken + err := ms.config.refreshClient.Do(ctx, http.MethodGet, + ms.metadataServiceURL.String(), httpclient.WithRequestHeader(MetadataServiceVersionHeader, MetadataServiceVersion), - httpclient.WithRequestHeader(MetadataServiceHostHeader, s.config.Host), - httpclient.WithResponseUnmarshal(&inner), + httpclient.WithRequestHeader(MetadataServiceHostHeader, ms.config.Host), + httpclient.WithResponseUnmarshal(&mt), ) if err != nil { return nil, fmt.Errorf("token request: %w", err) } - return inner.Token() -} -func (t metadataService) Token(ctx context.Context) (*oauth2.Token, error) { - token, err := t.Get(ctx) + token, err := mt.Token() if err != nil { - return nil, err + return nil, fmt.Errorf("token parse: %w", err) } - if token == nil { - return nil, fmt.Errorf("no token returned from metadata service") - } - logger.Debugf(ctx, - "Refreshed access token from local metadata service, which expires on %s", - token.Expiry.Format(time.RFC3339)) + + logger.Debugf(ctx, "Refreshed access token from local metadata service, which expires on %s", token.Expiry.Format(time.RFC3339)) return token, nil } diff --git a/config/auth_pat.go b/config/auth_pat.go index dc51256c8..b2356a58d 100644 --- a/config/auth_pat.go +++ b/config/auth_pat.go @@ -23,5 +23,5 @@ func (c PatCredentials) Configure(ctx context.Context, cfg *Config) (credentials r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", cfg.Token)) return nil } - return credentials.NewCredentialsProvider(visitor), nil + return credentials.CredentialsProviderFn(visitor), nil } diff --git a/config/config.go b/config/config.go index a31508517..9bfef3354 100644 --- a/config/config.go +++ b/config/config.go @@ -15,7 +15,6 @@ import ( "github.com/databricks/databricks-sdk-go/common/environment" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/config/experimental/auth" - "github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv" "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" @@ -252,7 +251,7 @@ func (c *Config) GetTokenSource() auth.TokenSource { return errorTokenSource(err) } if h, ok := c.credentialsProvider.(credentials.OAuthCredentialsProvider); ok { - return authconv.AuthTokenSource(h) + return h } else { return errorTokenSource(fmt.Errorf("OAuth Token not supported for current auth type %s", c.AuthType)) } diff --git a/config/credentials/credentials.go b/config/credentials/credentials.go index 580040fe2..997ea9b53 100644 --- a/config/credentials/credentials.go +++ b/config/credentials/credentials.go @@ -16,23 +16,33 @@ type CredentialsProvider interface { SetHeaders(r *http.Request) error } -type credentialsProvider func(r *http.Request) error +// CredentialsProviderFn is an adapter to allow the use of an ordinary function +// as a CredentialsProvider. +// +// Example: +// +// cp := CredentialsProviderFn(func(r *http.Request) error { +// return nil +// }) +type CredentialsProviderFn func(r *http.Request) error -func (c credentialsProvider) SetHeaders(r *http.Request) error { +func (c CredentialsProviderFn) SetHeaders(r *http.Request) error { return c(r) } // NewCredentialsProvider returns a new CredentialsProvider that uses the // provided function to set headers on the request. +// +// DEPRECATED: Use CredentialsProviderFn instead. func NewCredentialsProvider(f func(r *http.Request) error) CredentialsProvider { - return credentialsProvider(f) + return CredentialsProviderFn(f) } // OAuthCredentialsProvider is a specialized CredentialsProvider uses and provides an OAuth token. type OAuthCredentialsProvider interface { CredentialsProvider // Token returns the OAuth token generated by the provider. - Token() (*oauth2.Token, error) + Token(ctx context.Context) (*oauth2.Token, error) } // NewOAuthCredentialsProviderFromTokenSource returns a new OAuthCredentialsProvider @@ -58,8 +68,8 @@ func (cp tsOAuthCredentialsProvider) SetHeaders(r *http.Request) error { return nil } -func (cp tsOAuthCredentialsProvider) Token() (*oauth2.Token, error) { - return cp.ts.Token(context.Background()) +func (cp tsOAuthCredentialsProvider) Token(ctx context.Context) (*oauth2.Token, error) { + return cp.ts.Token(ctx) } // DEPRECATED: Use NewOAuthCredentialsProviderFromTokenSource instead. @@ -79,6 +89,6 @@ func (c *oauthCredentialsProvider) SetHeaders(r *http.Request) error { return c.setHeaders(r) } -func (c *oauthCredentialsProvider) Token() (*oauth2.Token, error) { +func (c *oauthCredentialsProvider) Token(_ context.Context) (*oauth2.Token, error) { return c.token() } diff --git a/config/token_source_strategy.go b/config/token_source_strategy.go index 5181a3260..ac5fce1b2 100644 --- a/config/token_source_strategy.go +++ b/config/token_source_strategy.go @@ -27,7 +27,7 @@ func (tss *tokenSourceStrategy) Configure(ctx context.Context, cfg *Config) (cre // // TODO: Move this outside of this function. If credentials providers have // to be tested, this should be done in the main default loop, not here. - if _, err := cp.Token(); err != nil { + if _, err := cp.Token(ctx); err != nil { return nil, err } diff --git a/examples/custom-auth/main.go b/examples/custom-auth/main.go index 76b48cab9..f45a0536a 100644 --- a/examples/custom-auth/main.go +++ b/examples/custom-auth/main.go @@ -28,7 +28,7 @@ func (c *CustomCredentials) Configure( r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) return nil } - return credentials.NewCredentialsProvider(visitor), nil + return credentials.CredentialsProviderFn(visitor), nil } func main() {