Skip to content

Commit 0ec7e06

Browse files
committed
Fix feature flags and U2M OAuth for SPOG support
Feature flags: - Fix endpoint path: /api/2.0/feature-flags -> /api/2.0/connector-service/feature-flags/GOLANG/{version} - Fix response parsing: map format -> array of {name, value} entries - Add extraHeaders for SPOG routing (x-databricks-org-id) - Extract ?o=<workspaceId> from httpPath in connector U2M OAuth: - Don't set ClientSecret for public apps (PKCE) - Force AuthStyleInParams to prevent Basic auth with empty password - Server rejects "Public app should not use a client secret" otherwise Signed-off-by: Madhavendra Rathore <madhavendra.rathore@databricks.com> Co-authored-by: Isaac Signed-off-by: Madhavendra Rathore <madhavendra.rathore@databricks.com>
1 parent 23697e5 commit 0ec7e06

7 files changed

Lines changed: 113 additions & 87 deletions

File tree

auth/oauth/u2m/u2m.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@ func GetConfig(ctx context.Context, hostName, clientID, clientSecret, callbackUR
3030
RedirectURL: callbackURL,
3131
Scopes: scopes,
3232
}
33-
// Only set ClientSecret if non-empty. For U2M (public apps using PKCE),
34-
// sending an empty client_secret causes the server to reject with
35-
// "Public app should not use a client secret".
3633
if clientSecret != "" {
3734
config.ClientSecret = clientSecret
35+
} else {
36+
// For U2M (public apps using PKCE), force AuthStyleInParams to avoid
37+
// sending Basic auth with empty password. AuthStyleInHeader sends
38+
// "Authorization: Basic base64(clientID:)" which the server rejects
39+
// with "Public app should not use a client secret".
40+
config.Endpoint.AuthStyle = oauth2.AuthStyleInParams
3841
}
3942

4043
return config, nil

connector.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"database/sql/driver"
77
"fmt"
88
"net/http"
9+
"net/url"
910
"strings"
1011
"time"
1112

@@ -76,12 +77,16 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
7677
}
7778
log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "")
7879

80+
// Extract SPOG routing headers from ?o= in HTTPPath
81+
spogHeaders := extractSpogHeaders(c.cfg.HTTPPath)
82+
7983
// Initialize telemetry: client config overlay decides; if unset, feature flags decide
8084
conn.telemetry = telemetry.InitializeForConnection(
8185
ctx,
8286
c.cfg.Host,
8387
c.client,
8488
c.cfg.EnableTelemetry,
89+
spogHeaders,
8590
)
8691
if conn.telemetry != nil {
8792
log.Debug().Msg("telemetry initialized for connection")
@@ -107,6 +112,7 @@ func NewConnector(options ...ConnOption) (driver.Connector, error) {
107112
// config with default options
108113
cfg := config.WithDefaults()
109114
cfg.DriverVersion = DriverVersion
115+
telemetry.SetDriverVersion(DriverVersion)
110116

111117
for _, opt := range options {
112118
opt(cfg)
@@ -117,6 +123,25 @@ func NewConnector(options ...ConnOption) (driver.Connector, error) {
117123
return &connector{cfg: cfg, client: client}, nil
118124
}
119125

126+
// extractSpogHeaders extracts ?o=<workspaceId> from httpPath and returns
127+
// an x-databricks-org-id header for SPOG routing.
128+
func extractSpogHeaders(httpPath string) map[string]string {
129+
if !strings.Contains(httpPath, "?") {
130+
return nil
131+
}
132+
// Parse query string from httpPath
133+
parts := strings.SplitN(httpPath, "?", 2)
134+
params, err := url.ParseQuery(parts[1])
135+
if err != nil {
136+
return nil
137+
}
138+
orgID := params.Get("o")
139+
if orgID == "" {
140+
return nil
141+
}
142+
return map[string]string{"x-databricks-org-id": orgID}
143+
}
144+
120145
func withUserConfig(ucfg config.UserConfig) ConnOption {
121146
return func(c *config.Config) {
122147
c.UserConfig = ucfg

telemetry/config.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func ParseTelemetryConfig(params map[string]string) *Config {
102102
//
103103
// Returns:
104104
// - bool: true if telemetry should be enabled, false otherwise
105-
func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, httpClient *http.Client) bool {
105+
func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, httpClient *http.Client, extraHeaders map[string]string) bool {
106106
// Priority 1: Client explicitly set (overrides server)
107107
if cfg.EnableTelemetry.IsSet() {
108108
val, _ := cfg.EnableTelemetry.Get()
@@ -111,7 +111,7 @@ func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, httpClien
111111

112112
// Priority 2: Check server-side feature flag
113113
flagCache := getFeatureFlagCache()
114-
serverEnabled, err := flagCache.isTelemetryEnabled(ctx, host, httpClient)
114+
serverEnabled, err := flagCache.isTelemetryEnabled(ctx, host, httpClient, extraHeaders)
115115
if err != nil {
116116
// Priority 3: Fail-safe default (disabled)
117117
return false

telemetry/config_test.go

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package telemetry
22

33
import (
44
"context"
5-
"encoding/json"
5+
66
"net/http"
77
"net/http/httptest"
88
"testing"
@@ -206,12 +206,8 @@ func TestIsTelemetryEnabled_ClientOverrideEnabled(t *testing.T) {
206206
// Setup: Create a server that returns disabled
207207
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
208208
// Server says disabled, but client override should win
209-
resp := map[string]interface{}{
210-
"flags": map[string]bool{
211-
"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": false,
212-
},
213-
}
214-
_ = json.NewEncoder(w).Encode(resp)
209+
w.Header().Set("Content-Type", "application/json")
210+
_, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "false"}]}`))
215211
}))
216212
defer server.Close()
217213

@@ -228,7 +224,7 @@ func TestIsTelemetryEnabled_ClientOverrideEnabled(t *testing.T) {
228224
defer flagCache.releaseContext(server.URL)
229225

230226
// Client override should bypass server check
231-
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient)
227+
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil)
232228

233229
if !result {
234230
t.Error("Expected telemetry to be enabled when client explicitly sets enableTelemetry=true, got disabled")
@@ -240,12 +236,8 @@ func TestIsTelemetryEnabled_ClientOverrideDisabled(t *testing.T) {
240236
// Setup: Create a server that returns enabled
241237
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
242238
// Server says enabled, but client override should win
243-
resp := map[string]interface{}{
244-
"flags": map[string]bool{
245-
"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true,
246-
},
247-
}
248-
_ = json.NewEncoder(w).Encode(resp)
239+
w.Header().Set("Content-Type", "application/json")
240+
_, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}]}`))
249241
}))
250242
defer server.Close()
251243

@@ -261,7 +253,7 @@ func TestIsTelemetryEnabled_ClientOverrideDisabled(t *testing.T) {
261253
flagCache.getOrCreateContext(server.URL)
262254
defer flagCache.releaseContext(server.URL)
263255

264-
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient)
256+
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil)
265257

266258
if result {
267259
t.Error("Expected telemetry to be disabled when client explicitly sets enableTelemetry=false, got enabled")
@@ -272,12 +264,8 @@ func TestIsTelemetryEnabled_ClientOverrideDisabled(t *testing.T) {
272264
func TestIsTelemetryEnabled_ServerEnabled(t *testing.T) {
273265
// Setup: Create a server that returns enabled
274266
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
275-
resp := map[string]interface{}{
276-
"flags": map[string]bool{
277-
"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true,
278-
},
279-
}
280-
_ = json.NewEncoder(w).Encode(resp)
267+
w.Header().Set("Content-Type", "application/json")
268+
_, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}]}`))
281269
}))
282270
defer server.Close()
283271

@@ -293,7 +281,7 @@ func TestIsTelemetryEnabled_ServerEnabled(t *testing.T) {
293281
flagCache.getOrCreateContext(server.URL)
294282
defer flagCache.releaseContext(server.URL)
295283

296-
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient)
284+
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil)
297285

298286
if !result {
299287
t.Error("Expected telemetry to be enabled when server flag is true, got disabled")
@@ -304,12 +292,8 @@ func TestIsTelemetryEnabled_ServerEnabled(t *testing.T) {
304292
func TestIsTelemetryEnabled_ServerDisabled(t *testing.T) {
305293
// Setup: Create a server that returns disabled
306294
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
307-
resp := map[string]interface{}{
308-
"flags": map[string]bool{
309-
"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": false,
310-
},
311-
}
312-
_ = json.NewEncoder(w).Encode(resp)
295+
w.Header().Set("Content-Type", "application/json")
296+
_, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "false"}]}`))
313297
}))
314298
defer server.Close()
315299

@@ -325,7 +309,7 @@ func TestIsTelemetryEnabled_ServerDisabled(t *testing.T) {
325309
flagCache.getOrCreateContext(server.URL)
326310
defer flagCache.releaseContext(server.URL)
327311

328-
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient)
312+
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil)
329313

330314
if result {
331315
t.Error("Expected telemetry to be disabled when server flag is false, got enabled")
@@ -340,7 +324,7 @@ func TestIsTelemetryEnabled_FailSafeDefault(t *testing.T) {
340324
httpClient := &http.Client{Timeout: 5 * time.Second}
341325

342326
// No server available, should default to disabled (fail-safe)
343-
result := isTelemetryEnabled(ctx, cfg, "nonexistent-host", httpClient)
327+
result := isTelemetryEnabled(ctx, cfg, "nonexistent-host", httpClient, nil)
344328

345329
if result {
346330
t.Error("Expected telemetry to be disabled when server unavailable (fail-safe), got enabled")
@@ -367,7 +351,7 @@ func TestIsTelemetryEnabled_ServerError(t *testing.T) {
367351
flagCache.getOrCreateContext(server.URL)
368352
defer flagCache.releaseContext(server.URL)
369353

370-
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient)
354+
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil)
371355

372356
// On error, should default to disabled (fail-safe)
373357
if result {
@@ -390,7 +374,7 @@ func TestIsTelemetryEnabled_ServerUnreachable(t *testing.T) {
390374
flagCache.getOrCreateContext(unreachableHost)
391375
defer flagCache.releaseContext(unreachableHost)
392376

393-
result := isTelemetryEnabled(ctx, cfg, unreachableHost, httpClient)
377+
result := isTelemetryEnabled(ctx, cfg, unreachableHost, httpClient, nil)
394378

395379
// On error, should default to disabled (fail-safe)
396380
if result {
@@ -418,7 +402,7 @@ func TestIsTelemetryEnabled_ClientOverridesServerError(t *testing.T) {
418402
flagCache.getOrCreateContext(server.URL)
419403
defer flagCache.releaseContext(server.URL)
420404

421-
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient)
405+
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil)
422406

423407
// Client override should work even when server errors
424408
if !result {

telemetry/driver_integration.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
// - host: Databricks host
1717
// - httpClient: HTTP client for making requests
1818
// - enableTelemetry: Client config overlay (unset = check server flag, true/false = override server)
19+
// - extraHeaders: Additional HTTP headers for SPOG routing (e.g. x-databricks-org-id)
1920
//
2021
// Returns:
2122
// - *Interceptor: Telemetry interceptor if enabled, nil otherwise
@@ -24,13 +25,14 @@ func InitializeForConnection(
2425
host string,
2526
httpClient *http.Client,
2627
enableTelemetry config.ConfigValue[bool],
28+
extraHeaders map[string]string,
2729
) *Interceptor {
2830
// Create telemetry config and apply client overlay
2931
cfg := DefaultConfig()
3032
cfg.EnableTelemetry = enableTelemetry
3133

3234
// Check if telemetry should be enabled
33-
if !isTelemetryEnabled(ctx, cfg, host, httpClient) {
35+
if !isTelemetryEnabled(ctx, cfg, host, httpClient, extraHeaders) {
3436
return nil
3537
}
3638

telemetry/featureflag.go

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ const (
2424
// flagEnableNewFeature = "databricks.partnerplatform.clientConfigsFeatureFlags.enableNewFeatureForGoDriver"
2525
)
2626

27+
// driverVersion is set during initialization from the connector package.
28+
var driverVersion = "unknown"
29+
30+
// SetDriverVersion sets the driver version used in feature flag endpoint paths.
31+
func SetDriverVersion(version string) {
32+
driverVersion = version
33+
}
34+
2735
// featureFlagCache manages feature flag state per host with reference counting.
2836
// This prevents rate limiting by caching feature flag responses.
2937
type featureFlagCache struct {
@@ -90,7 +98,7 @@ func (c *featureFlagCache) releaseContext(host string) {
9098
// getFeatureFlag retrieves a specific feature flag value for the host.
9199
// This is the generic method that handles caching and fetching for any flag.
92100
// Uses cached value if available and not expired.
93-
func (c *featureFlagCache) getFeatureFlag(ctx context.Context, host string, httpClient *http.Client, flagName string) (bool, error) {
101+
func (c *featureFlagCache) getFeatureFlag(ctx context.Context, host string, httpClient *http.Client, flagName string, extraHeaders map[string]string) (bool, error) {
94102
c.mu.RLock()
95103
flagCtx, exists := c.contexts[host]
96104
c.mu.RUnlock()
@@ -111,7 +119,7 @@ func (c *featureFlagCache) getFeatureFlag(ctx context.Context, host string, http
111119

112120
// If we just created the context, make the initial blocking fetch
113121
if !exists {
114-
flags, err := fetchFeatureFlags(ctx, host, httpClient)
122+
flags, err := fetchFeatureFlags(ctx, host, httpClient, extraHeaders)
115123

116124
flagCtx.mu.Lock()
117125
flagCtx.fetching = false
@@ -155,7 +163,7 @@ func (c *featureFlagCache) getFeatureFlag(ctx context.Context, host string, http
155163
flagCtx.mu.RUnlock()
156164

157165
// Fetch fresh values for all flags
158-
flags, err := fetchFeatureFlags(ctx, host, httpClient)
166+
flags, err := fetchFeatureFlags(ctx, host, httpClient, extraHeaders)
159167

160168
// Update cache (with proper locking)
161169
flagCtx.mu.Lock()
@@ -184,8 +192,8 @@ func (c *featureFlagCache) getFeatureFlag(ctx context.Context, host string, http
184192

185193
// isTelemetryEnabled checks if telemetry is enabled for the host.
186194
// Uses cached value if available and not expired.
187-
func (c *featureFlagCache) isTelemetryEnabled(ctx context.Context, host string, httpClient *http.Client) (bool, error) {
188-
return c.getFeatureFlag(ctx, host, httpClient, flagEnableTelemetry)
195+
func (c *featureFlagCache) isTelemetryEnabled(ctx context.Context, host string, httpClient *http.Client, extraHeaders map[string]string) (bool, error) {
196+
return c.getFeatureFlag(ctx, host, httpClient, flagEnableTelemetry, extraHeaders)
189197
}
190198

191199
// isExpired returns true if the cache has expired.
@@ -203,35 +211,46 @@ func getAllFeatureFlags() []string {
203211
}
204212
}
205213

206-
// fetchFeatureFlags fetches multiple feature flag values from Databricks in a single request.
214+
// featureFlagEntry represents a single flag from the connector-service response.
215+
type featureFlagEntry struct {
216+
Name string `json:"name"`
217+
Value string `json:"value"`
218+
}
219+
220+
// featureFlagResponse represents the response from the connector-service endpoint.
221+
type featureFlagResponse struct {
222+
Flags []featureFlagEntry `json:"flags"`
223+
TTLSeconds int `json:"ttl_seconds,omitempty"`
224+
}
225+
226+
// fetchFeatureFlags fetches feature flag values from the connector-service endpoint.
227+
// Endpoint: GET /api/2.0/connector-service/feature-flags/{CLIENT_TYPE}/{VERSION}
207228
// Returns a map of flag names to their boolean values.
208-
func fetchFeatureFlags(ctx context.Context, host string, httpClient *http.Client) (map[string]bool, error) {
229+
func fetchFeatureFlags(ctx context.Context, host string, httpClient *http.Client, extraHeaders map[string]string) (map[string]bool, error) {
209230
// Add timeout to context if it doesn't have a deadline
210231
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
211232
var cancel context.CancelFunc
212233
ctx, cancel = context.WithTimeout(ctx, featureFlagHTTPTimeout)
213234
defer cancel()
214235
}
215236

216-
// Construct endpoint URL, adding https:// if not already present
237+
// Construct endpoint URL using the connector-service path
217238
var endpoint string
218239
if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") {
219-
endpoint = fmt.Sprintf("%s/api/2.0/feature-flags", host)
240+
endpoint = fmt.Sprintf("%s/api/2.0/connector-service/feature-flags/GOLANG/%s", host, driverVersion)
220241
} else {
221-
endpoint = fmt.Sprintf("https://%s/api/2.0/feature-flags", host)
242+
endpoint = fmt.Sprintf("https://%s/api/2.0/connector-service/feature-flags/GOLANG/%s", host, driverVersion)
222243
}
223244

224245
req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil)
225246
if err != nil {
226247
return nil, fmt.Errorf("failed to create feature flag request: %w", err)
227248
}
228249

229-
// Add query parameter with comma-separated list of feature flags
230-
// This fetches all flags in a single request for efficiency
231-
allFlags := getAllFeatureFlags()
232-
q := req.URL.Query()
233-
q.Add("flags", strings.Join(allFlags, ","))
234-
req.URL.RawQuery = q.Encode()
250+
// Add extra headers (e.g. x-databricks-org-id for SPOG routing)
251+
for k, v := range extraHeaders {
252+
req.Header.Set(k, v)
253+
}
235254

236255
resp, err := httpClient.Do(req)
237256
if err != nil {
@@ -245,18 +264,16 @@ func fetchFeatureFlags(ctx context.Context, host string, httpClient *http.Client
245264
return nil, fmt.Errorf("feature flag check failed: %d", resp.StatusCode)
246265
}
247266

248-
var result struct {
249-
Flags map[string]bool `json:"flags"`
250-
}
267+
var result featureFlagResponse
251268
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
252269
return nil, fmt.Errorf("failed to decode feature flag response: %w", err)
253270
}
254271

255-
// Return the full map of flags
256-
// Flags not present in the response will have false value when accessed
257-
if result.Flags == nil {
258-
return make(map[string]bool), nil
272+
// Convert array of {name, value} entries to a map of name -> bool
273+
flags := make(map[string]bool, len(result.Flags))
274+
for _, f := range result.Flags {
275+
flags[f.Name] = strings.EqualFold(f.Value, "true")
259276
}
260277

261-
return result.Flags, nil
278+
return flags, nil
262279
}

0 commit comments

Comments
 (0)