From 485d260b12d053ce9c0d28df00b3ddb541d6b1ac Mon Sep 17 00:00:00 2001 From: Simon Faltum Date: Wed, 27 May 2026 13:11:51 +0200 Subject: [PATCH 1/2] Consolidate issuer handling in callback result Carry the OAuth callback's iss parameter on the same oauthResult that carries the code and state, so all three originate from the same callback request. Removes the parallel lastIssuer field and its mutex. Co-authored-by: Isaac Signed-off-by: Simon Faltum --- credentials/u2m/callback.go | 49 +++++++++++------------ credentials/u2m/callback_test.go | 29 +++++--------- credentials/u2m/discovery_token_source.go | 4 +- 3 files changed, 34 insertions(+), 48 deletions(-) diff --git a/credentials/u2m/callback.go b/credentials/u2m/callback.go index 1d634a128..5db0eeeac 100644 --- a/credentials/u2m/callback.go +++ b/credentials/u2m/callback.go @@ -7,7 +7,6 @@ import ( "html/template" "net/http" "strings" - "sync" "golang.org/x/text/cases" "golang.org/x/text/language" @@ -21,6 +20,7 @@ type oauthResult struct { ErrorDescription string State string Code string + Issuer string Host string } @@ -46,12 +46,6 @@ type callbackServer struct { // rendering the page.html template. renderErrCh chan error - // lastIssuer stores the iss (issuer) query parameter from the OAuth - // callback, per RFC 9207. Used by the discovery login flow to identify - // which workspace the user selected. Protected by issuerMu. - issuerMu sync.Mutex - lastIssuer string - // feedbackCh is a channel that receives the result of the authentication // attempt. feedbackCh chan oauthResult @@ -95,14 +89,12 @@ func (cb *callbackServer) Close() error { // ServeHTTP renders the page.html template. func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - cb.issuerMu.Lock() - cb.lastIssuer = r.FormValue("iss") - cb.issuerMu.Unlock() res := oauthResult{ Error: r.FormValue("error"), ErrorDescription: r.FormValue("error_description"), Code: r.FormValue("code"), State: r.FormValue("state"), + Issuer: r.FormValue("iss"), Host: cb.getHost(), } if res.Error != "" { @@ -128,30 +120,37 @@ func (cb *callbackServer) getHost() string { } } -// Issuer returns the iss parameter from the last OAuth callback received. -// This is populated during the discovery login flow when login.databricks.com -// redirects back with the workspace issuer. -func (cb *callbackServer) Issuer() string { - cb.issuerMu.Lock() - defer cb.issuerMu.Unlock() - return cb.lastIssuer -} - -// Handler opens up a browser waits for redirect to come back from the identity provider -func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) { +func (cb *callbackServer) awaitResult(authCodeURL string) (oauthResult, error) { err := cb.browser(authCodeURL) if err != nil { fmt.Printf("Please continue the authentication process in your browser:\n%s\n", authCodeURL) } select { case <-cb.ctx.Done(): - return "", "", cb.ctx.Err() + return oauthResult{}, cb.ctx.Err() case renderErr := <-cb.renderErrCh: - return "", "", renderErr + return oauthResult{}, renderErr case res := <-cb.feedbackCh: if res.Error != "" { - return "", "", fmt.Errorf("%s: %s", res.Error, res.ErrorDescription) + return oauthResult{}, fmt.Errorf("%s: %s", res.Error, res.ErrorDescription) } - return res.Code, res.State, nil + return res, nil + } +} + +// Handler opens up a browser waits for redirect to come back from the identity provider +func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) { + res, err := cb.awaitResult(authCodeURL) + if err != nil { + return "", "", err + } + return res.Code, res.State, nil +} + +func (cb *callbackServer) handlerWithIssuer(authCodeURL string) (code, state, issuer string, err error) { + res, err := cb.awaitResult(authCodeURL) + if err != nil { + return "", "", "", err } + return res.Code, res.State, res.Issuer, nil } diff --git a/credentials/u2m/callback_test.go b/credentials/u2m/callback_test.go index 3fac8c89e..2000365bb 100644 --- a/credentials/u2m/callback_test.go +++ b/credentials/u2m/callback_test.go @@ -64,11 +64,6 @@ func TestCallbackServer_ExtractsIssuer(t *testing.T) { } defer cb.Close() - // Verify issuer is empty before any callback. - if got := cb.Issuer(); got != "" { - t.Fatalf("Issuer() before callback: want %q, got %q", "", got) - } - // Fire a callback with iss parameter. issuerURL := "https://adb-123.azuredatabricks.net/oidc" resp, err := http.Get(fmt.Sprintf("http://%s?code=xxx&state=yyy&iss=%s", p.redirectAddr, issuerURL)) @@ -77,11 +72,9 @@ func TestCallbackServer_ExtractsIssuer(t *testing.T) { } defer resp.Body.Close() - // Drain the feedback channel so ServeHTTP completes. - <-cb.feedbackCh - - if got := cb.Issuer(); got != issuerURL { - t.Fatalf("Issuer(): want %q, got %q", issuerURL, got) + res := <-cb.feedbackCh + if got := res.Issuer; got != issuerURL { + t.Fatalf("result Issuer: want %q, got %q", issuerURL, got) } } @@ -128,11 +121,9 @@ func TestCallbackServer_NoIssuer(t *testing.T) { } defer resp.Body.Close() - // Drain the feedback channel so ServeHTTP completes. - <-cb.feedbackCh - - if got := cb.Issuer(); got != "" { - t.Fatalf("Issuer(): want %q, got %q", "", got) + res := <-cb.feedbackCh + if got := res.Issuer; got != "" { + t.Fatalf("result Issuer: want %q, got %q", "", got) } } @@ -180,10 +171,8 @@ func TestCallbackServer_IssuerWithAccountPath(t *testing.T) { } defer resp.Body.Close() - // Drain the feedback channel so ServeHTTP completes. - <-cb.feedbackCh - - if got := cb.Issuer(); got != issuerURL { - t.Fatalf("Issuer(): want %q, got %q", issuerURL, got) + res := <-cb.feedbackCh + if got := res.Issuer; got != issuerURL { + t.Fatalf("result Issuer: want %q, got %q", issuerURL, got) } } diff --git a/credentials/u2m/discovery_token_source.go b/credentials/u2m/discovery_token_source.go index 8106cee32..e96891dec 100644 --- a/credentials/u2m/discovery_token_source.go +++ b/credentials/u2m/discovery_token_source.go @@ -124,8 +124,7 @@ func (d *discoveryTokenSource) challenge() error { } authorizeURL := buildDiscoveryAuthorizeURL(host, d.pa.redirectAddr, state, pkce, scopes) - // Use cb.Handler to open the browser and wait for the callback. - code, returnedState, err := cb.Handler(authorizeURL) + code, returnedState, issuer, err := cb.handlerWithIssuer(authorizeURL) if err != nil { return fmt.Errorf("authorize: %w", err) } @@ -135,7 +134,6 @@ func (d *discoveryTokenSource) challenge() error { return fmt.Errorf("state mismatch: expected %q, got %q", state, returnedState) } - issuer := cb.Issuer() if issuer == "" { return fmt.Errorf("discovery login failed: callback did not include an issuer (iss) parameter") } From 0580c278e9bd7fc8474696df6db5379cbabb713b Mon Sep 17 00:00:00 2001 From: simon Date: Wed, 27 May 2026 14:04:21 +0200 Subject: [PATCH 2/2] update --- credentials/u2m/callback_test.go | 67 ++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/credentials/u2m/callback_test.go b/credentials/u2m/callback_test.go index 2000365bb..8d37f5cb7 100644 --- a/credentials/u2m/callback_test.go +++ b/credentials/u2m/callback_test.go @@ -3,13 +3,80 @@ package u2m import ( "context" "fmt" + "html/template" "net/http" + "net/http/httptest" + "net/url" "testing" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "golang.org/x/oauth2" ) +func TestCallbackServer_HandlerWithIssuerBindsIssuerToResult(t *testing.T) { + browserOpened := make(chan string, 1) + cb := &callbackServer{ + ctx: context.Background(), + browser: func(redirect string) error { + browserOpened <- redirect + return nil + }, + renderErrCh: make(chan error), + feedbackCh: make(chan oauthResult, 2), + tmpl: template.Must(template.New("page").Parse("")), + } + + const authCodeURL = "https://login.databricks.com/?destination_url=%2Foidc%2Fv1%2Fauthorize" + legitimate := struct { + code string + state string + issuer string + }{ + code: "legit-code", + state: "legit-state", + issuer: "https://adb-123.azuredatabricks.net/oidc", + } + + serveCallback := func(code, state, issuer string) { + callbackURL := fmt.Sprintf("/?code=%s&state=%s&iss=%s", + url.QueryEscape(code), url.QueryEscape(state), url.QueryEscape(issuer)) + req := httptest.NewRequest(http.MethodGet, callbackURL, nil) + rec := httptest.NewRecorder() + cb.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("ServeHTTP status = %d, want %d", rec.Code, http.StatusOK) + } + } + + // Queue two callbacks with different code/state/iss; the handler must + // return the values from the first result and not interleave. + serveCallback(legitimate.code, legitimate.state, legitimate.issuer) + serveCallback("second-code", "second-state", "https://other.example/oidc") + + code, state, issuer, err := cb.handlerWithIssuer(authCodeURL) + if err != nil { + t.Fatalf("handlerWithIssuer(): %v", err) + } + if code != legitimate.code { + t.Errorf("code = %q, want %q", code, legitimate.code) + } + if state != legitimate.state { + t.Errorf("state = %q, want %q", state, legitimate.state) + } + if issuer != legitimate.issuer { + t.Errorf("issuer = %q, want %q", issuer, legitimate.issuer) + } + + select { + case got := <-browserOpened: + if got != authCodeURL { + t.Errorf("browser opened %q, want %q", got, authCodeURL) + } + default: + t.Fatal("browser was not opened") + } +} + func TestCallbackServer_ExtractsIssuer(t *testing.T) { ctx := context.Background()