diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 9772892a0..7adac500b 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,6 +3,8 @@ ## Release v0.90.0 ### New Features and Improvements +* Add support for unified hosts, i.e. hosts that support both workspace-level and account-level operations +* Deprecate Config.IsAccountClient, which will not work for unified hosts, and replace it with Config.HostType and Config.ConfigType methods. ### Bug Fixes @@ -11,3 +13,4 @@ ### Internal Changes ### API Changes + diff --git a/account_functions.go b/account_functions.go index 50488cd63..603ce7e65 100644 --- a/account_functions.go +++ b/account_functions.go @@ -1,6 +1,10 @@ package databricks -import "github.com/databricks/databricks-sdk-go/service/provisioning" +import ( + "fmt" + + "github.com/databricks/databricks-sdk-go/service/provisioning" +) // GetWorkspaceClient returns a WorkspaceClient for the given workspace. The // workspace can be fetched by calling w.Workspaces.Get() or w.Workspaces.List(). @@ -32,6 +36,7 @@ func (c *AccountClient) GetWorkspaceClient(ws provisioning.Workspace) (*Workspac return nil, err } cfg.AzureResourceID = ws.AzureResourceId() + cfg.WorkspaceId = fmt.Sprintf("%d", ws.WorkspaceId) w, err := NewWorkspaceClient((*Config)(cfg)) if err != nil { return nil, err diff --git a/config/api_client.go b/config/api_client.go index 2a4163161..9b558bf34 100644 --- a/config/api_client.go +++ b/config/api_client.go @@ -27,6 +27,54 @@ func HTTPClientConfigFromConfig(cfg *Config) (httpclient.ClientConfig, error) { return httpclient.ClientConfig{}, err } + visitors := []httpclient.RequestVisitor{ + func(r *http.Request) error { + if r.URL == nil { + return fmt.Errorf("no URL found in request") + } + url, err := url.Parse(cfg.Host) + if err != nil { + return err + } + r.URL.Host = url.Host + r.URL.Scheme = url.Scheme + return nil + }, + authInUserAgentVisitor(cfg), + func(r *http.Request) error { + // Detect if we are running in a CI/CD environment + provider := useragent.CiCdProvider() + if provider == "" { + return nil + } + // Add the detected CI/CD provider to the user agent + ctx := useragent.InContext(r.Context(), useragent.CicdKey, provider) + *r = *r.WithContext(ctx) // replace request + return nil + }, + func(r *http.Request) error { + // Detect if the SDK is being run in a Databricks Runtime. + v := useragent.Runtime() + if v == "" { + return nil + } + // Add the detected Databricks Runtime version to the user agent + ctx := useragent.InContext(r.Context(), useragent.RuntimeKey, v) + *r = *r.WithContext(ctx) // replace request + return nil + }, + } + + // Unified hosts use X-Databricks-Org-Id header to determine which workspace to route the request to. + // The header must not be set for account-level API requests, otherwise the request will fail. + // This visitor relies on the assumption that cfg.WorkspaceId is only set for workspace client configs. + if cfg.HostType() == UnifiedHost && cfg.WorkspaceId != "" { + visitors = append(visitors, func(r *http.Request) error { + r.Header.Set("X-Databricks-Org-Id", cfg.WorkspaceId) + return nil + }) + } + return httpclient.ClientConfig{ AccountID: cfg.AccountID, Host: cfg.Host, @@ -38,43 +86,7 @@ func HTTPClientConfigFromConfig(cfg *Config) (httpclient.ClientConfig, error) { InsecureSkipVerify: cfg.InsecureSkipVerify, Transport: cfg.HTTPTransport, AuthVisitor: cfg.Authenticate, - Visitors: []httpclient.RequestVisitor{ - func(r *http.Request) error { - if r.URL == nil { - return fmt.Errorf("no URL found in request") - } - url, err := url.Parse(cfg.Host) - if err != nil { - return err - } - r.URL.Host = url.Host - r.URL.Scheme = url.Scheme - return nil - }, - authInUserAgentVisitor(cfg), - func(r *http.Request) error { - // Detect if we are running in a CI/CD environment - provider := useragent.CiCdProvider() - if provider == "" { - return nil - } - // Add the detected CI/CD provider to the user agent - ctx := useragent.InContext(r.Context(), useragent.CicdKey, provider) - *r = *r.WithContext(ctx) // replace request - return nil - }, - func(r *http.Request) error { - // Detect if the SDK is being run in a Databricks Runtime. - v := useragent.Runtime() - if v == "" { - return nil - } - // Add the detected Databricks Runtime version to the user agent - ctx := useragent.InContext(r.Context(), useragent.RuntimeKey, v) - *r = *r.WithContext(ctx) // replace request - return nil - }, - }, + Visitors: visitors, TransientErrors: []string{ // This is temporary workaround for SCIM API returning 500. // TODO: Remove when it's fixed. diff --git a/config/auth_azure_msi.go b/config/auth_azure_msi.go index aaf0c53bd..b4e615dfc 100644 --- a/config/auth_azure_msi.go +++ b/config/auth_azure_msi.go @@ -32,7 +32,7 @@ func (c AzureMsiCredentials) Name() string { } func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) { - if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && !cfg.IsAccountClient()) { + if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && cfg.ConfigType() == WorkspaceConfig) { return nil, nil } env := cfg.Environment() diff --git a/config/auth_default.go b/config/auth_default.go index 2983ba6b5..1baa8ba6c 100644 --- a/config/auth_default.go +++ b/config/auth_default.go @@ -131,7 +131,7 @@ func oidcStrategy(cfg *Config, name string, ts oidc.IDTokenSource) CredentialsSt Audience: cfg.TokenAudience, IDTokenSource: ts, } - if cfg.IsAccountClient() { + if cfg.HostType() != WorkspaceHost { oidcConfig.AccountID = cfg.AccountID } tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig) diff --git a/config/auth_gcp_google_id.go b/config/auth_gcp_google_id.go index 1c81b2451..f05413d33 100644 --- a/config/auth_gcp_google_id.go +++ b/config/auth_gcp_google_id.go @@ -28,7 +28,7 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (c if err != nil { return nil, err } - if !cfg.IsAccountClient() { + if cfg.ConfigType() == WorkspaceConfig { logger.Infof(ctx, "Using Google Default Application Credentials for Workspace") visitor := refreshableVisitor(inner) return credentials.CredentialsProviderFn(visitor), nil diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index e9e998ade..2f634f770 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -10,6 +10,7 @@ import ( "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" + "github.com/databricks/databricks-sdk-go/internal/env" "github.com/stretchr/testify/require" "golang.org/x/oauth2" ) @@ -24,6 +25,7 @@ func (m mockU2mTokenSource) Token() (*oauth2.Token, error) { } func TestU2MCredentials(t *testing.T) { + env.CleanupEnvironment(t) tests := []struct { name string cfg *Config diff --git a/config/config.go b/config/config.go index 591de1f4e..6ecc90ca9 100644 --- a/config/config.go +++ b/config/config.go @@ -39,6 +39,30 @@ type Loader interface { Configure(*Config) error } +// HostType represents the type of API the configured host supports. +type HostType string + +const ( + // WorkspaceHost supports only workspace-level APIs. + WorkspaceHost HostType = "WORKSPACE_HOST" + // AccountHost supports only account-level APIs. + AccountHost HostType = "ACCOUNT_HOST" + // UnifiedHost supports both workspace-level and account-level APIs. + UnifiedHost HostType = "UNIFIED_HOST" +) + +// ConfigType represents the type of API this config is valid for. +type ConfigType string + +const ( + // WorkspaceConfig is valid for workspace-level API requests. + WorkspaceConfig ConfigType = "WORKSPACE_CONFIG" + // AccountConfig is valid for account-level API requests. + AccountConfig ConfigType = "ACCOUNT_CONFIG" + // InvalidConfig is returned when the config is not valid for either workspace-level or account-level APIs. + InvalidConfig ConfigType = "INVALID_CONFIG" +) + // Config represents configuration for Databricks Connectivity type Config struct { // Credentials holds an instance of Credentials Strategy to authenticate with Databricks REST APIs. @@ -58,6 +82,9 @@ type Config struct { // Databricks Account ID for Accounts API. This field is used in dependencies. AccountID string `name:"account_id" env:"DATABRICKS_ACCOUNT_ID"` + // Databricks Workspace ID for Workspace clients when working with unified hosts + WorkspaceId string `name:"workspace_id" env:"DATABRICKS_WORKSPACE_ID"` + Token string `name:"token" env:"DATABRICKS_TOKEN" auth:"pat,sensitive"` Username string `name:"username" env:"DATABRICKS_USERNAME" auth:"basic"` Password string `name:"password" env:"DATABRICKS_PASSWORD" auth:"basic,sensitive"` @@ -183,6 +210,9 @@ type Config struct { // Keep track of the source of each attribute attrSource map[string]Source + + // Marker for unified hosts. Will be redundant once we can recognize unified hosts by their hostname. + Experimental_IsUnifiedHost bool `name:"experimental_is_unified_host" env:"DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST" auth:"-"` } // NewWithWorkspaceHost returns a new instance of the Config with the host set to @@ -195,7 +225,7 @@ func (c *Config) NewWithWorkspaceHost(host string) (*Config, error) { return nil, err } - var fieldsToSkip = map[string]struct{}{ + fieldsToSkip := map[string]struct{}{ "Host": {}, "AzureResourceID": {}, "AccountID": {}, @@ -289,8 +319,15 @@ func (c *Config) IsAws() bool { return c.Host != "" && !c.IsAzure() && !c.IsGcp() } -// IsAccountClient returns true if client is configured for Accounts API +// IsAccountClient returns true if client is configured for Accounts API. +// Panics if the config has the unified host flag set. +// +// Deprecated: Use HostType() if possible, or ConfigType() if necessary. func (c *Config) IsAccountClient() bool { + if c.Experimental_IsUnifiedHost { + panic("IsAccountClient cannot be used with unified hosts; use HostType() instead") + } + if c.AccountID != "" && c.isTesting { return true } @@ -307,6 +344,55 @@ func (c *Config) IsAccountClient() bool { return false } +// HostType returns the type of host that the client is configured for. +func (c *Config) HostType() HostType { + if c.Experimental_IsUnifiedHost { + return UnifiedHost + } + + // TODO: Refactor tests so that this is not needed. + if c.AccountID != "" && c.isTesting { + return AccountHost + } + + accountsPrefixes := []string{ + "https://accounts.", + "https://accounts-dod.", + } + for _, prefix := range accountsPrefixes { + if strings.HasPrefix(c.Host, prefix) { + return AccountHost + } + } + + return WorkspaceHost +} + +// ConfigType returns the type of config that the client is configured for. +// Returns InvalidConfig if the config is invalid. +// Use of this function should be avoided where possible, because we plan +// to remove WorkspaceClient and AccountClient in favor of a single unified +// client in the future. +func (c *Config) ConfigType() ConfigType { + switch c.HostType() { + case AccountHost: + return AccountConfig + case WorkspaceHost: + return WorkspaceConfig + case UnifiedHost: + if c.AccountID == "" { + // All unified host configs must have an account ID + return InvalidConfig + } + if c.WorkspaceId != "" { + return WorkspaceConfig + } + return AccountConfig + default: + return InvalidConfig + } +} + func (c *Config) EnsureResolved() error { if c.resolved { return nil @@ -327,7 +413,6 @@ func (c *Config) EnsureResolved() error { logger.Tracef(ctx, "Loading config via %s", loader.Name()) err := loader.Configure(c) if err != nil { - return c.wrapDebug(fmt.Errorf("resolve: %w", err)) } } @@ -475,16 +560,32 @@ func (c *Config) getOidcEndpoints(ctx context.Context) (*u2m.OAuthAuthorizationS Client: c.refreshClient, } host := c.CanonicalHostName() - if c.IsAccountClient() { + switch c.HostType() { + case AccountHost: return oauthClient.GetAccountOAuthEndpoints(ctx, host, c.AccountID) + case UnifiedHost: + return oauthClient.GetUnifiedOAuthEndpoints(ctx, host, c.AccountID) + case WorkspaceHost: + return oauthClient.GetWorkspaceOAuthEndpoints(ctx, host) + default: + return nil, fmt.Errorf("unknown host type: %v", c.HostType()) } - return oauthClient.GetWorkspaceOAuthEndpoints(ctx, host) } func (c *Config) getOAuthArgument() (u2m.OAuthArgument, error) { + err := c.EnsureResolved() + if err != nil { + return nil, err + } host := c.CanonicalHostName() - if c.IsAccountClient() { + switch c.HostType() { + case AccountHost: return u2m.NewBasicAccountOAuthArgument(host, c.AccountID) + case UnifiedHost: + return u2m.NewBasicUnifiedOAuthArgument(host, c.AccountID) + case WorkspaceHost: + return u2m.NewBasicWorkspaceOAuthArgument(host) + default: + return nil, fmt.Errorf("unknown host type: %v", c.HostType()) } - return u2m.NewBasicWorkspaceOAuthArgument(host) } diff --git a/config/config_test.go b/config/config_test.go index 7b117bdbd..3f6b93392 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -11,28 +11,46 @@ import ( "github.com/stretchr/testify/require" ) -func TestIsAccountClient_AwsAccount(t *testing.T) { +func TestHostType_AwsAccount(t *testing.T) { c := &Config{ Host: "https://accounts.cloud.databricks.com", AccountID: "123e4567-e89b-12d3-a456-426614174000", } - assert.True(t, c.IsAccountClient()) + assert.Equal(t, AccountHost, c.HostType()) } -func TestIsAccountClient_AwsDodAccount(t *testing.T) { +func TestHostType_AwsDodAccount(t *testing.T) { c := &Config{ Host: "https://accounts-dod.cloud.databricks.us", AccountID: "123e4567-e89b-12d3-a456-426614174000", } - assert.True(t, c.IsAccountClient()) + assert.Equal(t, AccountHost, c.HostType()) } -func TestIsAccountClient_AwsWorkspace(t *testing.T) { +func TestHostType_AwsWorkspace(t *testing.T) { c := &Config{ Host: "https://my-workspace.cloud.databricks.us", AccountID: "123e4567-e89b-12d3-a456-426614174000", } - assert.False(t, c.IsAccountClient()) + assert.Equal(t, WorkspaceHost, c.HostType()) +} + +func TestHostType_Unified(t *testing.T) { + c := &Config{ + Host: "https://unified.cloud.databricks.com", + AccountID: "123e4567-e89b-12d3-a456-426614174000", + Experimental_IsUnifiedHost: true, + } + assert.Equal(t, UnifiedHost, c.HostType()) +} + +func TestIsAccountClient_PanicsOnUnifiedHost(t *testing.T) { + c := &Config{ + Host: "https://unified.cloud.databricks.com", + AccountID: "test-account", + Experimental_IsUnifiedHost: true, + } + assert.Panics(t, func() { c.IsAccountClient() }) } func TestNewWithWorkspaceHost(t *testing.T) { @@ -141,6 +159,49 @@ func TestConfig_getOidcEndpoints_workspace(t *testing.T) { } } +func TestConfig_getOidcEndpoints_unified(t *testing.T) { + tests := []struct { + name string + host string + accountID string + }{ + { + name: "without trailing slash", + host: "https://unified.cloud.databricks.com", + accountID: "abc", + }, + { + name: "with trailing slash", + host: "https://unified.cloud.databricks.com/", + accountID: "abc", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Config{ + Host: tt.host, + AccountID: tt.accountID, + Experimental_IsUnifiedHost: true, + HTTPTransport: fixtures.SliceTransport{ + { + Method: "GET", + Resource: "/oidc/accounts/abc/.well-known/oauth-authorization-server", + Status: 200, + Response: `{"authorization_endpoint": "https://unified.cloud.databricks.com/oidc/accounts/abc/v1/authorize", "token_endpoint": "https://unified.cloud.databricks.com/oidc/accounts/abc/v1/token"}`, + }, + }, + } + got, err := c.getOidcEndpoints(context.Background()) + assert.NoError(t, err) + assert.Equal(t, &u2m.OAuthAuthorizationServer{ + AuthorizationEndpoint: "https://unified.cloud.databricks.com/oidc/accounts/abc/v1/authorize", + TokenEndpoint: "https://unified.cloud.databricks.com/oidc/accounts/abc/v1/token", + }, got) + }) + } +} + func TestConfig_getOAuthArgument_account(t *testing.T) { tests := []struct { name string @@ -203,3 +264,38 @@ func TestConfig_getOAuthArgument_workspace(t *testing.T) { }) } } + +func TestConfig_getOAuthArgument_Unified(t *testing.T) { + tests := []struct { + name string + host string + accountID string + }{ + { + name: "without trailing slash", + host: "https://unified.cloud.databricks.com", + accountID: "account-123", + }, + { + name: "with trailing slash", + host: "https://unified.cloud.databricks.com/", + accountID: "account-123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Config{ + Host: tt.host, + AccountID: tt.accountID, + Experimental_IsUnifiedHost: true, + } + rawGot, err := c.getOAuthArgument() + assert.NoError(t, err) + got, ok := rawGot.(u2m.UnifiedOAuthArgument) + assert.True(t, ok, "Expected UnifiedOAuthArgument") + assert.Equal(t, "https://unified.cloud.databricks.com", got.GetHost()) + assert.Equal(t, "account-123", got.GetAccountId()) + }) + } +} diff --git a/config/oauth_visitors.go b/config/oauth_visitors.go index 69fadc03f..c7e2b9562 100644 --- a/config/oauth_visitors.go +++ b/config/oauth_visitors.go @@ -12,7 +12,7 @@ import ( ) // serviceToServiceVisitor returns a visitor that sets the Authorization header -// to the token from the auth token sourcevand the provided secondary header to +// to the token from the auth token source and the provided secondary header to // the token from the secondary token source. func serviceToServiceVisitor(primary, secondary oauth2.TokenSource, secondaryHeader string) func(r *http.Request) error { refreshableAuth := auth.NewCachedTokenSource(authconv.AuthTokenSource(primary)) diff --git a/credentials/u2m/endpoint_supplier.go b/credentials/u2m/endpoint_supplier.go index 362d22129..c0fa976ea 100644 --- a/credentials/u2m/endpoint_supplier.go +++ b/credentials/u2m/endpoint_supplier.go @@ -19,6 +19,9 @@ type OAuthEndpointSupplier interface { // GetAccountOAuthEndpoints returns the OAuth2 endpoints for the account. GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) + + // GetUnifiedOAuthEndpoints returns the OAuth2 endpoints for the unified host. + GetUnifiedOAuthEndpoints(ctx context.Context, host string, accountId string) (*OAuthAuthorizationServer, error) } // BasicOAuthEndpointSupplier is an implementation of the OAuthEndpointSupplier interface. @@ -27,13 +30,10 @@ type BasicOAuthEndpointSupplier struct { Client *httpclient.ApiClient } -// GetWorkspaceOAuthEndpoints returns the OAuth endpoints for the given workspace. -// It queries the OIDC discovery endpoint to get the OAuth endpoints using the -// provided ApiClient. -func (c *BasicOAuthEndpointSupplier) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) { - oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", workspaceHost) +// getOAuthEndpointsByDiscoveryUrl queries the OIDC discovery endpoint to get the OAuth endpoints. +func (c *BasicOAuthEndpointSupplier) getOAuthEndpointsByDiscoveryUrl(ctx context.Context, discoveryUrl string) (*OAuthAuthorizationServer, error) { var oauthEndpoints OAuthAuthorizationServer - if err := c.Client.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints)); err != nil { + if err := c.Client.Do(ctx, "GET", discoveryUrl, httpclient.WithResponseUnmarshal(&oauthEndpoints)); err != nil { if errors.Is(err, apierr.ErrNotFound) { return nil, ErrOAuthNotSupported } @@ -42,6 +42,12 @@ func (c *BasicOAuthEndpointSupplier) GetWorkspaceOAuthEndpoints(ctx context.Cont return &oauthEndpoints, nil } +// GetWorkspaceOAuthEndpoints returns the OAuth endpoints for the given workspace. +func (c *BasicOAuthEndpointSupplier) GetWorkspaceOAuthEndpoints(ctx context.Context, workspaceHost string) (*OAuthAuthorizationServer, error) { + oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", workspaceHost) + return c.getOAuthEndpointsByDiscoveryUrl(ctx, oidc) +} + // GetAccountOAuthEndpoints returns the OAuth2 endpoints for the account. The // account-level OAuth endpoints are fixed based on the account ID and host. func (c *BasicOAuthEndpointSupplier) GetAccountOAuthEndpoints(ctx context.Context, accountHost string, accountId string) (*OAuthAuthorizationServer, error) { @@ -51,6 +57,12 @@ func (c *BasicOAuthEndpointSupplier) GetAccountOAuthEndpoints(ctx context.Contex }, nil } +// GetUnifiedOAuthEndpoints returns the OAuth2 endpoints for the unified host +func (c *BasicOAuthEndpointSupplier) GetUnifiedOAuthEndpoints(ctx context.Context, host string, accountId string) (*OAuthAuthorizationServer, error) { + oidc := fmt.Sprintf("%s/oidc/accounts/%s/.well-known/oauth-authorization-server", host, accountId) + return c.getOAuthEndpointsByDiscoveryUrl(ctx, oidc) +} + // OAuthAuthorizationServer contains the OAuth endpoints for a Databricks account // or workspace. type OAuthAuthorizationServer struct { diff --git a/credentials/u2m/endpoint_supplier_test.go b/credentials/u2m/endpoint_supplier_test.go index 72106e91b..dfac1de08 100644 --- a/credentials/u2m/endpoint_supplier_test.go +++ b/credentials/u2m/endpoint_supplier_test.go @@ -35,3 +35,23 @@ func TestGetWorkspaceOAuthEndpoints(t *testing.T) { assert.Equal(t, "a", endpoints.AuthorizationEndpoint) assert.Equal(t, "b", endpoints.TokenEndpoint) } + +func TestGetUnifiedOAuthEndpoints(t *testing.T) { + p := httpclient.NewApiClient(httpclient.ClientConfig{ + Transport: fixtures.MappingTransport{ + "GET /oidc/accounts/xyz/.well-known/oauth-authorization-server": { + Status: 200, + Response: map[string]string{ + "authorization_endpoint": "https://abc/oidc/accounts/xyz/v1/authorize", + "token_endpoint": "https://abc/oidc/accounts/xyz/v1/token", + }, + }, + }, + }) + c := &BasicOAuthEndpointSupplier{Client: p} + endpoints, err := c.GetUnifiedOAuthEndpoints(context.Background(), "https://abc", "xyz") + + assert.NoError(t, err) + assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", endpoints.AuthorizationEndpoint) + assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", endpoints.TokenEndpoint) +} diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index 768b2cad5..d1d553b14 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -358,15 +358,16 @@ func (a *PersistentAuth) Close() error { } // validateArg ensures that the OAuthArgument is either a WorkspaceOAuthArgument -// or an AccountOAuthArgument. +// or an AccountOAuthArgument or a UnifiedOAuthArgument. func (a *PersistentAuth) validateArg() error { if a.oAuthArgument == nil { return errors.New("missing OAuthArgument") } _, isWorkspaceArg := a.oAuthArgument.(WorkspaceOAuthArgument) _, isAccountArg := a.oAuthArgument.(AccountOAuthArgument) - if !isWorkspaceArg && !isAccountArg { - return fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument or AccountOAuthArgument interface", a.oAuthArgument) + _, isUnifiedArg := a.oAuthArgument.(UnifiedOAuthArgument) + if !isWorkspaceArg && !isAccountArg && !isUnifiedArg { + return fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument, AccountOAuthArgument or UnifiedOAuthArgument interface", a.oAuthArgument) } return nil } @@ -385,8 +386,10 @@ func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) { case AccountOAuthArgument: endpoints, err = a.endpointSupplier.GetAccountOAuthEndpoints( a.ctx, argg.GetAccountHost(), argg.GetAccountId()) + case UnifiedOAuthArgument: + endpoints, err = a.endpointSupplier.GetUnifiedOAuthEndpoints(a.ctx, argg.GetHost(), argg.GetAccountId()) default: - return nil, fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument or AccountOAuthArgument interface", a.oAuthArgument) + return nil, fmt.Errorf("unsupported OAuthArgument type: %T, must implement either WorkspaceOAuthArgument, AccountOAuthArgument or UnifiedOAuthArgument interface", a.oAuthArgument) } if err != nil { return nil, fmt.Errorf("fetching OAuth endpoints: %w", err) diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go index 25f2ed597..a09cf8d03 100644 --- a/credentials/u2m/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -79,6 +79,13 @@ func (m MockOAuthEndpointSupplier) GetWorkspaceOAuthEndpoints(ctx context.Contex }, nil } +func (m MockOAuthEndpointSupplier) GetUnifiedOAuthEndpoints(ctx context.Context, host string, accountId string) (*OAuthAuthorizationServer, error) { + return &OAuthAuthorizationServer{ + AuthorizationEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/authorize", host, accountId), + TokenEndpoint: fmt.Sprintf("%s/oidc/accounts/%s/v1/token", host, accountId), + }, nil +} + func TestToken_RefreshesExpiredAccessToken(t *testing.T) { ctx := context.Background() expectedKey := "https://accounts.cloud.databricks.com/oidc/accounts/xyz" diff --git a/credentials/u2m/unified_oauth_argument.go b/credentials/u2m/unified_oauth_argument.go new file mode 100644 index 000000000..7c9c032cd --- /dev/null +++ b/credentials/u2m/unified_oauth_argument.go @@ -0,0 +1,50 @@ +package u2m + +import ( + "fmt" +) + +// UnifiedOAuthArgument is an interface that provides the necessary information +// to authenticate using OAuth to a host that supports both account and workspace APIs. +type UnifiedOAuthArgument interface { + OAuthArgument + + // GetHost returns the host to authenticate to. + GetHost() string + + // GetAccountId returns the account ID of the account to authenticate to. + GetAccountId() string +} + +// BasicUnifiedOAuthArgument is a basic implementation of the UnifiedOAuthArgument +// interface that links each account with exactly one OAuth token. +type BasicUnifiedOAuthArgument struct { + host string + accountID string +} + +var _ UnifiedOAuthArgument = BasicUnifiedOAuthArgument{} + +// NewBasicUnifiedOAuthArgument creates a new BasicUnifiedOAuthArgument. +func NewBasicUnifiedOAuthArgument(accountsHost, accountID string) (BasicUnifiedOAuthArgument, error) { + if err := validateHost(accountsHost); err != nil { + return BasicUnifiedOAuthArgument{}, err + } + return BasicUnifiedOAuthArgument{host: accountsHost, accountID: accountID}, nil +} + +// GetAccountHost returns the host of the account to authenticate to. +func (a BasicUnifiedOAuthArgument) GetHost() string { + return a.host +} + +// GetAccountId returns the account ID of the account to authenticate to. +func (a BasicUnifiedOAuthArgument) GetAccountId() string { + return a.accountID +} + +// GetCacheKey returns a unique key for caching the OAuth token for the account. +// The key is in the format "/oidc/accounts/". +func (a BasicUnifiedOAuthArgument) GetCacheKey() string { + return fmt.Sprintf("%s/oidc/accounts/%s", a.host, a.accountID) +} diff --git a/credentials/u2m/unified_oauth_argument_test.go b/credentials/u2m/unified_oauth_argument_test.go new file mode 100644 index 000000000..b95ba0c09 --- /dev/null +++ b/credentials/u2m/unified_oauth_argument_test.go @@ -0,0 +1,91 @@ +package u2m + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewBasicUnifiedOAuthArgument(t *testing.T) { + arg, err := NewBasicUnifiedOAuthArgument("https://unified.databricks.com", "account-123") + assert.NoError(t, err) + assert.Equal(t, "https://unified.databricks.com", arg.GetHost()) + assert.Equal(t, "account-123", arg.GetAccountId()) + assert.Equal(t, "https://unified.databricks.com/oidc/accounts/account-123", arg.GetCacheKey()) +} + +func TestNewBasicUnifiedOAuthArgument_ValidatesHost(t *testing.T) { + tests := []struct { + name string + host string + accountID string + wantErr string + }{ + { + name: "invalid http protocol", + host: "http://insecure.com", + accountID: "account-123", + wantErr: "host must start with 'https://': http://insecure.com", + }, + { + name: "trailing slash", + host: "https://unified.databricks.com/", + accountID: "account-123", + wantErr: "host must not have a trailing slash: https://unified.databricks.com/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewBasicUnifiedOAuthArgument(tt.host, tt.accountID) + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} + +func TestBasicUnifiedOAuthArgument_GetHost(t *testing.T) { + arg, _ := NewBasicUnifiedOAuthArgument("https://unified.databricks.com", "account-123") + assert.Equal(t, "https://unified.databricks.com", arg.GetHost()) +} + +func TestBasicUnifiedOAuthArgument_GetAccountId(t *testing.T) { + arg, _ := NewBasicUnifiedOAuthArgument("https://unified.databricks.com", "account-123") + assert.Equal(t, "account-123", arg.GetAccountId()) +} + +func TestBasicUnifiedOAuthArgument_GetCacheKey(t *testing.T) { + tests := []struct { + name string + host string + accountID string + wantKey string + }{ + { + name: "standard case", + host: "https://unified.databricks.com", + accountID: "account-123", + wantKey: "https://unified.databricks.com/oidc/accounts/account-123", + }, + { + name: "different account", + host: "https://unified.databricks.com", + accountID: "account-456", + wantKey: "https://unified.databricks.com/oidc/accounts/account-456", + }, + { + name: "different host", + host: "https://other-unified.databricks.com", + accountID: "account-123", + wantKey: "https://other-unified.databricks.com/oidc/accounts/account-123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + arg, err := NewBasicUnifiedOAuthArgument(tt.host, tt.accountID) + assert.NoError(t, err) + assert.Equal(t, tt.wantKey, arg.GetCacheKey()) + }) + } +} diff --git a/internal/init_test.go b/internal/init_test.go index 7a427b588..c4fd6a29b 100644 --- a/internal/init_test.go +++ b/internal/init_test.go @@ -66,7 +66,7 @@ func accountTest(t *testing.T) (context.Context, *databricks.AccountClient) { if err != nil { skipf(t)("error: %s", err) } - if !cfg.IsAccountClient() { + if cfg.HostType() == config.WorkspaceHost { skipf(t)("Not in account env: %s/%s", cfg.AccountID, cfg.Host) } t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV")) @@ -86,7 +86,7 @@ func ucacctTest(t *testing.T) (context.Context, *databricks.AccountClient) { if err != nil { skipf(t)("error: %s", err) } - if !cfg.IsAccountClient() { + if cfg.HostType() == config.WorkspaceHost { skipf(t)("Not in account env: %s/%s", cfg.AccountID, cfg.Host) } t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))