Skip to content
Merged
3 changes: 3 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## Release v0.90.0

### New Features and Improvements
* Add support for unified hosts, i.e. hosts that support both workspace-level and account-level operations
* Deprecate Config.IsAccountClient, which will not work for unified hosts, and replace it with Config.HostType and Config.ConfigType methods.

### Bug Fixes

Expand All @@ -11,3 +13,4 @@
### Internal Changes

### API Changes

7 changes: 6 additions & 1 deletion account_functions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package databricks

import "github.com/databricks/databricks-sdk-go/service/provisioning"
import (
"fmt"

"github.com/databricks/databricks-sdk-go/service/provisioning"
)

// GetWorkspaceClient returns a WorkspaceClient for the given workspace. The
// workspace can be fetched by calling w.Workspaces.Get() or w.Workspaces.List().
Expand Down Expand Up @@ -32,6 +36,7 @@ func (c *AccountClient) GetWorkspaceClient(ws provisioning.Workspace) (*Workspac
return nil, err
}
cfg.AzureResourceID = ws.AzureResourceId()
cfg.WorkspaceId = fmt.Sprintf("%d", ws.WorkspaceId)
w, err := NewWorkspaceClient((*Config)(cfg))
if err != nil {
return nil, err
Expand Down
86 changes: 49 additions & 37 deletions config/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,54 @@ func HTTPClientConfigFromConfig(cfg *Config) (httpclient.ClientConfig, error) {
return httpclient.ClientConfig{}, err
}

visitors := []httpclient.RequestVisitor{
func(r *http.Request) error {
if r.URL == nil {
return fmt.Errorf("no URL found in request")
}
url, err := url.Parse(cfg.Host)
if err != nil {
return err
}
r.URL.Host = url.Host
r.URL.Scheme = url.Scheme
return nil
},
authInUserAgentVisitor(cfg),
func(r *http.Request) error {
// Detect if we are running in a CI/CD environment
provider := useragent.CiCdProvider()
if provider == "" {
return nil
}
// Add the detected CI/CD provider to the user agent
ctx := useragent.InContext(r.Context(), useragent.CicdKey, provider)
*r = *r.WithContext(ctx) // replace request
return nil
},
func(r *http.Request) error {
// Detect if the SDK is being run in a Databricks Runtime.
v := useragent.Runtime()
if v == "" {
return nil
}
// Add the detected Databricks Runtime version to the user agent
ctx := useragent.InContext(r.Context(), useragent.RuntimeKey, v)
*r = *r.WithContext(ctx) // replace request
return nil
},
}

// Unified hosts use X-Databricks-Org-Id header to determine which workspace to route the request to.
// The header must not be set for account-level API requests, otherwise the request will fail.
// This visitor relies on the assumption that cfg.WorkspaceId is only set for workspace client configs.
if cfg.HostType() == UnifiedHost && cfg.WorkspaceId != "" {
visitors = append(visitors, func(r *http.Request) error {
r.Header.Set("X-Databricks-Org-Id", cfg.WorkspaceId)
return nil
})
}

return httpclient.ClientConfig{
AccountID: cfg.AccountID,
Host: cfg.Host,
Expand All @@ -38,43 +86,7 @@ func HTTPClientConfigFromConfig(cfg *Config) (httpclient.ClientConfig, error) {
InsecureSkipVerify: cfg.InsecureSkipVerify,
Transport: cfg.HTTPTransport,
AuthVisitor: cfg.Authenticate,
Visitors: []httpclient.RequestVisitor{
func(r *http.Request) error {
if r.URL == nil {
return fmt.Errorf("no URL found in request")
}
url, err := url.Parse(cfg.Host)
if err != nil {
return err
}
r.URL.Host = url.Host
r.URL.Scheme = url.Scheme
return nil
},
authInUserAgentVisitor(cfg),
func(r *http.Request) error {
// Detect if we are running in a CI/CD environment
provider := useragent.CiCdProvider()
if provider == "" {
return nil
}
// Add the detected CI/CD provider to the user agent
ctx := useragent.InContext(r.Context(), useragent.CicdKey, provider)
*r = *r.WithContext(ctx) // replace request
return nil
},
func(r *http.Request) error {
// Detect if the SDK is being run in a Databricks Runtime.
v := useragent.Runtime()
if v == "" {
return nil
}
// Add the detected Databricks Runtime version to the user agent
ctx := useragent.InContext(r.Context(), useragent.RuntimeKey, v)
*r = *r.WithContext(ctx) // replace request
return nil
},
},
Visitors: visitors,
TransientErrors: []string{
// This is temporary workaround for SCIM API returning 500.
// TODO: Remove when it's fixed.
Expand Down
2 changes: 1 addition & 1 deletion config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (c AzureMsiCredentials) Name() string {
}

func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && !cfg.IsAccountClient()) {
if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && cfg.ConfigType() == WorkspaceConfig) {
return nil, nil
}
env := cfg.Environment()
Expand Down
2 changes: 1 addition & 1 deletion config/auth_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func oidcStrategy(cfg *Config, name string, ts oidc.IDTokenSource) CredentialsSt
Audience: cfg.TokenAudience,
IDTokenSource: ts,
}
if cfg.IsAccountClient() {
if cfg.HostType() != WorkspaceHost {
oidcConfig.AccountID = cfg.AccountID
}
tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig)
Expand Down
2 changes: 1 addition & 1 deletion config/auth_gcp_google_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (c
if err != nil {
return nil, err
}
if !cfg.IsAccountClient() {
if cfg.ConfigType() == WorkspaceConfig {
logger.Infof(ctx, "Using Google Default Application Credentials for Workspace")
visitor := refreshableVisitor(inner)
return credentials.CredentialsProviderFn(visitor), nil
Expand Down
2 changes: 2 additions & 0 deletions config/auth_u2m_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/databricks/databricks-sdk-go/credentials/u2m/cache"
"github.com/databricks/databricks-sdk-go/internal/env"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)
Expand All @@ -24,6 +25,7 @@ func (m mockU2mTokenSource) Token() (*oauth2.Token, error) {
}

func TestU2MCredentials(t *testing.T) {
env.CleanupEnvironment(t)
tests := []struct {
name string
cfg *Config
Expand Down
115 changes: 108 additions & 7 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ type Loader interface {
Configure(*Config) error
}

// HostType represents the type of API the configured host supports.
type HostType string

const (
// WorkspaceHost supports only workspace-level APIs.
WorkspaceHost HostType = "WORKSPACE_HOST"
// AccountHost supports only account-level APIs.
AccountHost HostType = "ACCOUNT_HOST"
// UnifiedHost supports both workspace-level and account-level APIs.
UnifiedHost HostType = "UNIFIED_HOST"
)

// ConfigType represents the type of API this config is valid for.
type ConfigType string

const (
// WorkspaceConfig is valid for workspace-level API requests.
WorkspaceConfig ConfigType = "WORKSPACE_CONFIG"
// AccountConfig is valid for account-level API requests.
AccountConfig ConfigType = "ACCOUNT_CONFIG"
// InvalidConfig is returned when the config is not valid for either workspace-level or account-level APIs.
InvalidConfig ConfigType = "INVALID_CONFIG"
)
Comment thread
mgyucht marked this conversation as resolved.

// Config represents configuration for Databricks Connectivity
type Config struct {
// Credentials holds an instance of Credentials Strategy to authenticate with Databricks REST APIs.
Expand All @@ -58,6 +82,9 @@ type Config struct {
// Databricks Account ID for Accounts API. This field is used in dependencies.
AccountID string `name:"account_id" env:"DATABRICKS_ACCOUNT_ID"`

// Databricks Workspace ID for Workspace clients when working with unified hosts
WorkspaceId string `name:"workspace_id" env:"DATABRICKS_WORKSPACE_ID"`

Token string `name:"token" env:"DATABRICKS_TOKEN" auth:"pat,sensitive"`
Username string `name:"username" env:"DATABRICKS_USERNAME" auth:"basic"`
Password string `name:"password" env:"DATABRICKS_PASSWORD" auth:"basic,sensitive"`
Expand Down Expand Up @@ -183,6 +210,9 @@ type Config struct {

// Keep track of the source of each attribute
attrSource map[string]Source

// Marker for unified hosts. Will be redundant once we can recognize unified hosts by their hostname.
Experimental_IsUnifiedHost bool `name:"experimental_is_unified_host" env:"DATABRICKS_EXPERIMENTAL_IS_UNIFIED_HOST" auth:"-"`
}

// NewWithWorkspaceHost returns a new instance of the Config with the host set to
Expand All @@ -195,7 +225,7 @@ func (c *Config) NewWithWorkspaceHost(host string) (*Config, error) {
return nil, err
}

var fieldsToSkip = map[string]struct{}{
fieldsToSkip := map[string]struct{}{
"Host": {},
"AzureResourceID": {},
"AccountID": {},
Expand Down Expand Up @@ -289,8 +319,15 @@ func (c *Config) IsAws() bool {
return c.Host != "" && !c.IsAzure() && !c.IsGcp()
}

// IsAccountClient returns true if client is configured for Accounts API
// IsAccountClient returns true if client is configured for Accounts API.
// Panics if the config has the unified host flag set.
Comment thread
jh-db marked this conversation as resolved.
//
// Deprecated: Use HostType() if possible, or ConfigType() if necessary.
func (c *Config) IsAccountClient() bool {
if c.Experimental_IsUnifiedHost {
panic("IsAccountClient cannot be used with unified hosts; use HostType() instead")
}

if c.AccountID != "" && c.isTesting {
return true
}
Expand All @@ -307,6 +344,55 @@ func (c *Config) IsAccountClient() bool {
return false
}

// HostType returns the type of host that the client is configured for.
func (c *Config) HostType() HostType {
if c.Experimental_IsUnifiedHost {
return UnifiedHost
}

// TODO: Refactor tests so that this is not needed.
if c.AccountID != "" && c.isTesting {
return AccountHost
}
Comment on lines +354 to +356

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if c.AccountID != "" && c.isTesting {
return AccountHost
}
// TODO: Refactor tests so that this is not needed.
if c.AccountID != "" && c.isTesting {
return AccountHost
}


accountsPrefixes := []string{
"https://accounts.",
"https://accounts-dod.",
}
for _, prefix := range accountsPrefixes {
if strings.HasPrefix(c.Host, prefix) {
return AccountHost
}
}

return WorkspaceHost
}

// ConfigType returns the type of config that the client is configured for.
// Returns InvalidConfig if the config is invalid.
// Use of this function should be avoided where possible, because we plan
// to remove WorkspaceClient and AccountClient in favor of a single unified
// client in the future.
func (c *Config) ConfigType() ConfigType {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's needed by Terraform. We could use it in azure_msi and gcp_id auth.

switch c.HostType() {
case AccountHost:
return AccountConfig
case WorkspaceHost:
return WorkspaceConfig
case UnifiedHost:
if c.AccountID == "" {
// All unified host configs must have an account ID
return InvalidConfig
}
if c.WorkspaceId != "" {
return WorkspaceConfig
}
return AccountConfig
default:
return InvalidConfig
}
}

func (c *Config) EnsureResolved() error {
if c.resolved {
return nil
Expand All @@ -327,7 +413,6 @@ func (c *Config) EnsureResolved() error {
logger.Tracef(ctx, "Loading config via %s", loader.Name())
err := loader.Configure(c)
if err != nil {

return c.wrapDebug(fmt.Errorf("resolve: %w", err))
}
}
Expand Down Expand Up @@ -475,16 +560,32 @@ func (c *Config) getOidcEndpoints(ctx context.Context) (*u2m.OAuthAuthorizationS
Client: c.refreshClient,
}
host := c.CanonicalHostName()
if c.IsAccountClient() {
switch c.HostType() {
case AccountHost:
return oauthClient.GetAccountOAuthEndpoints(ctx, host, c.AccountID)
case UnifiedHost:
return oauthClient.GetUnifiedOAuthEndpoints(ctx, host, c.AccountID)
case WorkspaceHost:
return oauthClient.GetWorkspaceOAuthEndpoints(ctx, host)
default:
return nil, fmt.Errorf("unknown host type: %v", c.HostType())
}
return oauthClient.GetWorkspaceOAuthEndpoints(ctx, host)
}

func (c *Config) getOAuthArgument() (u2m.OAuthArgument, error) {
err := c.EnsureResolved()
if err != nil {
return nil, err
}
host := c.CanonicalHostName()
if c.IsAccountClient() {
switch c.HostType() {
case AccountHost:
return u2m.NewBasicAccountOAuthArgument(host, c.AccountID)
case UnifiedHost:
return u2m.NewBasicUnifiedOAuthArgument(host, c.AccountID)
case WorkspaceHost:
return u2m.NewBasicWorkspaceOAuthArgument(host)
default:
return nil, fmt.Errorf("unknown host type: %v", c.HostType())
}
return u2m.NewBasicWorkspaceOAuthArgument(host)
}
Loading
Loading