Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 24 additions & 25 deletions credentials/u2m/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"html/template"
"net/http"
"strings"
"sync"

"golang.org/x/text/cases"
"golang.org/x/text/language"
Expand All @@ -21,6 +20,7 @@ type oauthResult struct {
ErrorDescription string
State string
Code string
Issuer string
Host string
}

Expand All @@ -46,12 +46,6 @@ type callbackServer struct {
// rendering the page.html template.
renderErrCh chan error

// lastIssuer stores the iss (issuer) query parameter from the OAuth
// callback, per RFC 9207. Used by the discovery login flow to identify
// which workspace the user selected. Protected by issuerMu.
issuerMu sync.Mutex
lastIssuer string

// feedbackCh is a channel that receives the result of the authentication
// attempt.
feedbackCh chan oauthResult
Expand Down Expand Up @@ -95,14 +89,12 @@ func (cb *callbackServer) Close() error {

// ServeHTTP renders the page.html template.
func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
cb.issuerMu.Lock()
cb.lastIssuer = r.FormValue("iss")
cb.issuerMu.Unlock()
res := oauthResult{
Error: r.FormValue("error"),
ErrorDescription: r.FormValue("error_description"),
Code: r.FormValue("code"),
State: r.FormValue("state"),
Issuer: r.FormValue("iss"),
Host: cb.getHost(),
}
if res.Error != "" {
Expand All @@ -128,30 +120,37 @@ func (cb *callbackServer) getHost() string {
}
}

// Issuer returns the iss parameter from the last OAuth callback received.
// This is populated during the discovery login flow when login.databricks.com
// redirects back with the workspace issuer.
func (cb *callbackServer) Issuer() string {
cb.issuerMu.Lock()
defer cb.issuerMu.Unlock()
return cb.lastIssuer
}

// Handler opens up a browser waits for redirect to come back from the identity provider
func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) {
func (cb *callbackServer) awaitResult(authCodeURL string) (oauthResult, error) {
err := cb.browser(authCodeURL)
if err != nil {
fmt.Printf("Please continue the authentication process in your browser:\n%s\n", authCodeURL)
}
select {
case <-cb.ctx.Done():
return "", "", cb.ctx.Err()
return oauthResult{}, cb.ctx.Err()
case renderErr := <-cb.renderErrCh:
return "", "", renderErr
return oauthResult{}, renderErr
case res := <-cb.feedbackCh:
if res.Error != "" {
return "", "", fmt.Errorf("%s: %s", res.Error, res.ErrorDescription)
return oauthResult{}, fmt.Errorf("%s: %s", res.Error, res.ErrorDescription)
}
return res.Code, res.State, nil
return res, nil
}
}

// Handler opens up a browser waits for redirect to come back from the identity provider
func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) {
res, err := cb.awaitResult(authCodeURL)
if err != nil {
return "", "", err
}
return res.Code, res.State, nil
}

func (cb *callbackServer) handlerWithIssuer(authCodeURL string) (code, state, issuer string, err error) {
res, err := cb.awaitResult(authCodeURL)
if err != nil {
return "", "", "", err
}
return res.Code, res.State, res.Issuer, nil
}
96 changes: 76 additions & 20 deletions credentials/u2m/callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,80 @@ package u2m
import (
"context"
"fmt"
"html/template"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
"golang.org/x/oauth2"
)

func TestCallbackServer_HandlerWithIssuerBindsIssuerToResult(t *testing.T) {
browserOpened := make(chan string, 1)
cb := &callbackServer{
ctx: context.Background(),
browser: func(redirect string) error {
browserOpened <- redirect
return nil
},
renderErrCh: make(chan error),
feedbackCh: make(chan oauthResult, 2),
tmpl: template.Must(template.New("page").Parse("")),
}

const authCodeURL = "https://login.databricks.com/?destination_url=%2Foidc%2Fv1%2Fauthorize"
legitimate := struct {
code string
state string
issuer string
}{
code: "legit-code",
state: "legit-state",
issuer: "https://adb-123.azuredatabricks.net/oidc",
}

serveCallback := func(code, state, issuer string) {
callbackURL := fmt.Sprintf("/?code=%s&state=%s&iss=%s",
url.QueryEscape(code), url.QueryEscape(state), url.QueryEscape(issuer))
req := httptest.NewRequest(http.MethodGet, callbackURL, nil)
rec := httptest.NewRecorder()
cb.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("ServeHTTP status = %d, want %d", rec.Code, http.StatusOK)
}
}

// Queue two callbacks with different code/state/iss; the handler must
// return the values from the first result and not interleave.
serveCallback(legitimate.code, legitimate.state, legitimate.issuer)
serveCallback("second-code", "second-state", "https://other.example/oidc")

code, state, issuer, err := cb.handlerWithIssuer(authCodeURL)
if err != nil {
t.Fatalf("handlerWithIssuer(): %v", err)
}
if code != legitimate.code {
t.Errorf("code = %q, want %q", code, legitimate.code)
}
if state != legitimate.state {
t.Errorf("state = %q, want %q", state, legitimate.state)
}
if issuer != legitimate.issuer {
t.Errorf("issuer = %q, want %q", issuer, legitimate.issuer)
}

select {
case got := <-browserOpened:
if got != authCodeURL {
t.Errorf("browser opened %q, want %q", got, authCodeURL)
}
default:
t.Fatal("browser was not opened")
}
}

func TestCallbackServer_ExtractsIssuer(t *testing.T) {
ctx := context.Background()

Expand Down Expand Up @@ -64,11 +131,6 @@ func TestCallbackServer_ExtractsIssuer(t *testing.T) {
}
defer cb.Close()

// Verify issuer is empty before any callback.
if got := cb.Issuer(); got != "" {
t.Fatalf("Issuer() before callback: want %q, got %q", "", got)
}

// Fire a callback with iss parameter.
issuerURL := "https://adb-123.azuredatabricks.net/oidc"
resp, err := http.Get(fmt.Sprintf("http://%s?code=xxx&state=yyy&iss=%s", p.redirectAddr, issuerURL))
Expand All @@ -77,11 +139,9 @@ func TestCallbackServer_ExtractsIssuer(t *testing.T) {
}
defer resp.Body.Close()

// Drain the feedback channel so ServeHTTP completes.
<-cb.feedbackCh

if got := cb.Issuer(); got != issuerURL {
t.Fatalf("Issuer(): want %q, got %q", issuerURL, got)
res := <-cb.feedbackCh
if got := res.Issuer; got != issuerURL {
t.Fatalf("result Issuer: want %q, got %q", issuerURL, got)
}
}

Expand Down Expand Up @@ -128,11 +188,9 @@ func TestCallbackServer_NoIssuer(t *testing.T) {
}
defer resp.Body.Close()

// Drain the feedback channel so ServeHTTP completes.
<-cb.feedbackCh

if got := cb.Issuer(); got != "" {
t.Fatalf("Issuer(): want %q, got %q", "", got)
res := <-cb.feedbackCh
if got := res.Issuer; got != "" {
t.Fatalf("result Issuer: want %q, got %q", "", got)
}
}

Expand Down Expand Up @@ -180,10 +238,8 @@ func TestCallbackServer_IssuerWithAccountPath(t *testing.T) {
}
defer resp.Body.Close()

// Drain the feedback channel so ServeHTTP completes.
<-cb.feedbackCh

if got := cb.Issuer(); got != issuerURL {
t.Fatalf("Issuer(): want %q, got %q", issuerURL, got)
res := <-cb.feedbackCh
if got := res.Issuer; got != issuerURL {
t.Fatalf("result Issuer: want %q, got %q", issuerURL, got)
}
}
4 changes: 1 addition & 3 deletions credentials/u2m/discovery_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ func (d *discoveryTokenSource) challenge() error {
}
authorizeURL := buildDiscoveryAuthorizeURL(host, d.pa.redirectAddr, state, pkce, scopes)

// Use cb.Handler to open the browser and wait for the callback.
code, returnedState, err := cb.Handler(authorizeURL)
code, returnedState, issuer, err := cb.handlerWithIssuer(authorizeURL)
if err != nil {
return fmt.Errorf("authorize: %w", err)
}
Expand All @@ -135,7 +134,6 @@ func (d *discoveryTokenSource) challenge() error {
return fmt.Errorf("state mismatch: expected %q, got %q", state, returnedState)
}

issuer := cb.Issuer()
if issuer == "" {
return fmt.Errorf("discovery login failed: callback did not include an issuer (iss) parameter")
}
Expand Down
Loading