diff --git a/credentials/u2m/callback.go b/credentials/u2m/callback.go index 1d634a128..5db0eeeac 100644 --- a/credentials/u2m/callback.go +++ b/credentials/u2m/callback.go @@ -7,7 +7,6 @@ import ( "html/template" "net/http" "strings" - "sync" "golang.org/x/text/cases" "golang.org/x/text/language" @@ -21,6 +20,7 @@ type oauthResult struct { ErrorDescription string State string Code string + Issuer string Host string } @@ -46,12 +46,6 @@ type callbackServer struct { // rendering the page.html template. renderErrCh chan error - // lastIssuer stores the iss (issuer) query parameter from the OAuth - // callback, per RFC 9207. Used by the discovery login flow to identify - // which workspace the user selected. Protected by issuerMu. - issuerMu sync.Mutex - lastIssuer string - // feedbackCh is a channel that receives the result of the authentication // attempt. feedbackCh chan oauthResult @@ -95,14 +89,12 @@ func (cb *callbackServer) Close() error { // ServeHTTP renders the page.html template. func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - cb.issuerMu.Lock() - cb.lastIssuer = r.FormValue("iss") - cb.issuerMu.Unlock() res := oauthResult{ Error: r.FormValue("error"), ErrorDescription: r.FormValue("error_description"), Code: r.FormValue("code"), State: r.FormValue("state"), + Issuer: r.FormValue("iss"), Host: cb.getHost(), } if res.Error != "" { @@ -128,30 +120,37 @@ func (cb *callbackServer) getHost() string { } } -// Issuer returns the iss parameter from the last OAuth callback received. -// This is populated during the discovery login flow when login.databricks.com -// redirects back with the workspace issuer. -func (cb *callbackServer) Issuer() string { - cb.issuerMu.Lock() - defer cb.issuerMu.Unlock() - return cb.lastIssuer -} - -// Handler opens up a browser waits for redirect to come back from the identity provider -func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) { +func (cb *callbackServer) awaitResult(authCodeURL string) (oauthResult, error) { err := cb.browser(authCodeURL) if err != nil { fmt.Printf("Please continue the authentication process in your browser:\n%s\n", authCodeURL) } select { case <-cb.ctx.Done(): - return "", "", cb.ctx.Err() + return oauthResult{}, cb.ctx.Err() case renderErr := <-cb.renderErrCh: - return "", "", renderErr + return oauthResult{}, renderErr case res := <-cb.feedbackCh: if res.Error != "" { - return "", "", fmt.Errorf("%s: %s", res.Error, res.ErrorDescription) + return oauthResult{}, fmt.Errorf("%s: %s", res.Error, res.ErrorDescription) } - return res.Code, res.State, nil + return res, nil + } +} + +// Handler opens up a browser waits for redirect to come back from the identity provider +func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) { + res, err := cb.awaitResult(authCodeURL) + if err != nil { + return "", "", err + } + return res.Code, res.State, nil +} + +func (cb *callbackServer) handlerWithIssuer(authCodeURL string) (code, state, issuer string, err error) { + res, err := cb.awaitResult(authCodeURL) + if err != nil { + return "", "", "", err } + return res.Code, res.State, res.Issuer, nil } diff --git a/credentials/u2m/callback_test.go b/credentials/u2m/callback_test.go index 3fac8c89e..8d37f5cb7 100644 --- a/credentials/u2m/callback_test.go +++ b/credentials/u2m/callback_test.go @@ -3,13 +3,80 @@ package u2m import ( "context" "fmt" + "html/template" "net/http" + "net/http/httptest" + "net/url" "testing" "github.com/databricks/databricks-sdk-go/httpclient/fixtures" "golang.org/x/oauth2" ) +func TestCallbackServer_HandlerWithIssuerBindsIssuerToResult(t *testing.T) { + browserOpened := make(chan string, 1) + cb := &callbackServer{ + ctx: context.Background(), + browser: func(redirect string) error { + browserOpened <- redirect + return nil + }, + renderErrCh: make(chan error), + feedbackCh: make(chan oauthResult, 2), + tmpl: template.Must(template.New("page").Parse("")), + } + + const authCodeURL = "https://login.databricks.com/?destination_url=%2Foidc%2Fv1%2Fauthorize" + legitimate := struct { + code string + state string + issuer string + }{ + code: "legit-code", + state: "legit-state", + issuer: "https://adb-123.azuredatabricks.net/oidc", + } + + serveCallback := func(code, state, issuer string) { + callbackURL := fmt.Sprintf("/?code=%s&state=%s&iss=%s", + url.QueryEscape(code), url.QueryEscape(state), url.QueryEscape(issuer)) + req := httptest.NewRequest(http.MethodGet, callbackURL, nil) + rec := httptest.NewRecorder() + cb.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("ServeHTTP status = %d, want %d", rec.Code, http.StatusOK) + } + } + + // Queue two callbacks with different code/state/iss; the handler must + // return the values from the first result and not interleave. + serveCallback(legitimate.code, legitimate.state, legitimate.issuer) + serveCallback("second-code", "second-state", "https://other.example/oidc") + + code, state, issuer, err := cb.handlerWithIssuer(authCodeURL) + if err != nil { + t.Fatalf("handlerWithIssuer(): %v", err) + } + if code != legitimate.code { + t.Errorf("code = %q, want %q", code, legitimate.code) + } + if state != legitimate.state { + t.Errorf("state = %q, want %q", state, legitimate.state) + } + if issuer != legitimate.issuer { + t.Errorf("issuer = %q, want %q", issuer, legitimate.issuer) + } + + select { + case got := <-browserOpened: + if got != authCodeURL { + t.Errorf("browser opened %q, want %q", got, authCodeURL) + } + default: + t.Fatal("browser was not opened") + } +} + func TestCallbackServer_ExtractsIssuer(t *testing.T) { ctx := context.Background() @@ -64,11 +131,6 @@ func TestCallbackServer_ExtractsIssuer(t *testing.T) { } defer cb.Close() - // Verify issuer is empty before any callback. - if got := cb.Issuer(); got != "" { - t.Fatalf("Issuer() before callback: want %q, got %q", "", got) - } - // Fire a callback with iss parameter. issuerURL := "https://adb-123.azuredatabricks.net/oidc" resp, err := http.Get(fmt.Sprintf("http://%s?code=xxx&state=yyy&iss=%s", p.redirectAddr, issuerURL)) @@ -77,11 +139,9 @@ func TestCallbackServer_ExtractsIssuer(t *testing.T) { } defer resp.Body.Close() - // Drain the feedback channel so ServeHTTP completes. - <-cb.feedbackCh - - if got := cb.Issuer(); got != issuerURL { - t.Fatalf("Issuer(): want %q, got %q", issuerURL, got) + res := <-cb.feedbackCh + if got := res.Issuer; got != issuerURL { + t.Fatalf("result Issuer: want %q, got %q", issuerURL, got) } } @@ -128,11 +188,9 @@ func TestCallbackServer_NoIssuer(t *testing.T) { } defer resp.Body.Close() - // Drain the feedback channel so ServeHTTP completes. - <-cb.feedbackCh - - if got := cb.Issuer(); got != "" { - t.Fatalf("Issuer(): want %q, got %q", "", got) + res := <-cb.feedbackCh + if got := res.Issuer; got != "" { + t.Fatalf("result Issuer: want %q, got %q", "", got) } } @@ -180,10 +238,8 @@ func TestCallbackServer_IssuerWithAccountPath(t *testing.T) { } defer resp.Body.Close() - // Drain the feedback channel so ServeHTTP completes. - <-cb.feedbackCh - - if got := cb.Issuer(); got != issuerURL { - t.Fatalf("Issuer(): want %q, got %q", issuerURL, got) + res := <-cb.feedbackCh + if got := res.Issuer; got != issuerURL { + t.Fatalf("result Issuer: want %q, got %q", issuerURL, got) } } diff --git a/credentials/u2m/discovery_token_source.go b/credentials/u2m/discovery_token_source.go index 8106cee32..e96891dec 100644 --- a/credentials/u2m/discovery_token_source.go +++ b/credentials/u2m/discovery_token_source.go @@ -124,8 +124,7 @@ func (d *discoveryTokenSource) challenge() error { } authorizeURL := buildDiscoveryAuthorizeURL(host, d.pa.redirectAddr, state, pkce, scopes) - // Use cb.Handler to open the browser and wait for the callback. - code, returnedState, err := cb.Handler(authorizeURL) + code, returnedState, issuer, err := cb.handlerWithIssuer(authorizeURL) if err != nil { return fmt.Errorf("authorize: %w", err) } @@ -135,7 +134,6 @@ func (d *discoveryTokenSource) challenge() error { return fmt.Errorf("state mismatch: expected %q, got %q", state, returnedState) } - issuer := cb.Issuer() if issuer == "" { return fmt.Errorf("discovery login failed: callback did not include an issuer (iss) parameter") }