Skip to content

Commit b084eb4

Browse files
authored
Auth validation update (#1700)
Update to auth validation logic NO_CHANGELOG=true --------- Signed-off-by: Simon Faltum <simon.faltum@databricks.com>
1 parent 6da9c9a commit b084eb4

3 files changed

Lines changed: 101 additions & 48 deletions

File tree

credentials/u2m/callback.go

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"html/template"
88
"net/http"
99
"strings"
10-
"sync"
1110

1211
"golang.org/x/text/cases"
1312
"golang.org/x/text/language"
@@ -21,6 +20,7 @@ type oauthResult struct {
2120
ErrorDescription string
2221
State string
2322
Code string
23+
Issuer string
2424
Host string
2525
}
2626

@@ -46,12 +46,6 @@ type callbackServer struct {
4646
// rendering the page.html template.
4747
renderErrCh chan error
4848

49-
// lastIssuer stores the iss (issuer) query parameter from the OAuth
50-
// callback, per RFC 9207. Used by the discovery login flow to identify
51-
// which workspace the user selected. Protected by issuerMu.
52-
issuerMu sync.Mutex
53-
lastIssuer string
54-
5549
// feedbackCh is a channel that receives the result of the authentication
5650
// attempt.
5751
feedbackCh chan oauthResult
@@ -95,14 +89,12 @@ func (cb *callbackServer) Close() error {
9589

9690
// ServeHTTP renders the page.html template.
9791
func (cb *callbackServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
98-
cb.issuerMu.Lock()
99-
cb.lastIssuer = r.FormValue("iss")
100-
cb.issuerMu.Unlock()
10192
res := oauthResult{
10293
Error: r.FormValue("error"),
10394
ErrorDescription: r.FormValue("error_description"),
10495
Code: r.FormValue("code"),
10596
State: r.FormValue("state"),
97+
Issuer: r.FormValue("iss"),
10698
Host: cb.getHost(),
10799
}
108100
if res.Error != "" {
@@ -128,30 +120,37 @@ func (cb *callbackServer) getHost() string {
128120
}
129121
}
130122

131-
// Issuer returns the iss parameter from the last OAuth callback received.
132-
// This is populated during the discovery login flow when login.databricks.com
133-
// redirects back with the workspace issuer.
134-
func (cb *callbackServer) Issuer() string {
135-
cb.issuerMu.Lock()
136-
defer cb.issuerMu.Unlock()
137-
return cb.lastIssuer
138-
}
139-
140-
// Handler opens up a browser waits for redirect to come back from the identity provider
141-
func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) {
123+
func (cb *callbackServer) awaitResult(authCodeURL string) (oauthResult, error) {
142124
err := cb.browser(authCodeURL)
143125
if err != nil {
144126
fmt.Printf("Please continue the authentication process in your browser:\n%s\n", authCodeURL)
145127
}
146128
select {
147129
case <-cb.ctx.Done():
148-
return "", "", cb.ctx.Err()
130+
return oauthResult{}, cb.ctx.Err()
149131
case renderErr := <-cb.renderErrCh:
150-
return "", "", renderErr
132+
return oauthResult{}, renderErr
151133
case res := <-cb.feedbackCh:
152134
if res.Error != "" {
153-
return "", "", fmt.Errorf("%s: %s", res.Error, res.ErrorDescription)
135+
return oauthResult{}, fmt.Errorf("%s: %s", res.Error, res.ErrorDescription)
154136
}
155-
return res.Code, res.State, nil
137+
return res, nil
138+
}
139+
}
140+
141+
// Handler opens up a browser waits for redirect to come back from the identity provider
142+
func (cb *callbackServer) Handler(authCodeURL string) (string, string, error) {
143+
res, err := cb.awaitResult(authCodeURL)
144+
if err != nil {
145+
return "", "", err
146+
}
147+
return res.Code, res.State, nil
148+
}
149+
150+
func (cb *callbackServer) handlerWithIssuer(authCodeURL string) (code, state, issuer string, err error) {
151+
res, err := cb.awaitResult(authCodeURL)
152+
if err != nil {
153+
return "", "", "", err
156154
}
155+
return res.Code, res.State, res.Issuer, nil
157156
}

credentials/u2m/callback_test.go

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,80 @@ package u2m
33
import (
44
"context"
55
"fmt"
6+
"html/template"
67
"net/http"
8+
"net/http/httptest"
9+
"net/url"
710
"testing"
811

912
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
1013
"golang.org/x/oauth2"
1114
)
1215

16+
func TestCallbackServer_HandlerWithIssuerBindsIssuerToResult(t *testing.T) {
17+
browserOpened := make(chan string, 1)
18+
cb := &callbackServer{
19+
ctx: context.Background(),
20+
browser: func(redirect string) error {
21+
browserOpened <- redirect
22+
return nil
23+
},
24+
renderErrCh: make(chan error),
25+
feedbackCh: make(chan oauthResult, 2),
26+
tmpl: template.Must(template.New("page").Parse("")),
27+
}
28+
29+
const authCodeURL = "https://login.databricks.com/?destination_url=%2Foidc%2Fv1%2Fauthorize"
30+
legitimate := struct {
31+
code string
32+
state string
33+
issuer string
34+
}{
35+
code: "legit-code",
36+
state: "legit-state",
37+
issuer: "https://adb-123.azuredatabricks.net/oidc",
38+
}
39+
40+
serveCallback := func(code, state, issuer string) {
41+
callbackURL := fmt.Sprintf("/?code=%s&state=%s&iss=%s",
42+
url.QueryEscape(code), url.QueryEscape(state), url.QueryEscape(issuer))
43+
req := httptest.NewRequest(http.MethodGet, callbackURL, nil)
44+
rec := httptest.NewRecorder()
45+
cb.ServeHTTP(rec, req)
46+
if rec.Code != http.StatusOK {
47+
t.Fatalf("ServeHTTP status = %d, want %d", rec.Code, http.StatusOK)
48+
}
49+
}
50+
51+
// Queue two callbacks with different code/state/iss; the handler must
52+
// return the values from the first result and not interleave.
53+
serveCallback(legitimate.code, legitimate.state, legitimate.issuer)
54+
serveCallback("second-code", "second-state", "https://other.example/oidc")
55+
56+
code, state, issuer, err := cb.handlerWithIssuer(authCodeURL)
57+
if err != nil {
58+
t.Fatalf("handlerWithIssuer(): %v", err)
59+
}
60+
if code != legitimate.code {
61+
t.Errorf("code = %q, want %q", code, legitimate.code)
62+
}
63+
if state != legitimate.state {
64+
t.Errorf("state = %q, want %q", state, legitimate.state)
65+
}
66+
if issuer != legitimate.issuer {
67+
t.Errorf("issuer = %q, want %q", issuer, legitimate.issuer)
68+
}
69+
70+
select {
71+
case got := <-browserOpened:
72+
if got != authCodeURL {
73+
t.Errorf("browser opened %q, want %q", got, authCodeURL)
74+
}
75+
default:
76+
t.Fatal("browser was not opened")
77+
}
78+
}
79+
1380
func TestCallbackServer_ExtractsIssuer(t *testing.T) {
1481
ctx := context.Background()
1582

@@ -64,11 +131,6 @@ func TestCallbackServer_ExtractsIssuer(t *testing.T) {
64131
}
65132
defer cb.Close()
66133

67-
// Verify issuer is empty before any callback.
68-
if got := cb.Issuer(); got != "" {
69-
t.Fatalf("Issuer() before callback: want %q, got %q", "", got)
70-
}
71-
72134
// Fire a callback with iss parameter.
73135
issuerURL := "https://adb-123.azuredatabricks.net/oidc"
74136
resp, err := http.Get(fmt.Sprintf("http://%s?code=xxx&state=yyy&iss=%s", p.redirectAddr, issuerURL))
@@ -77,11 +139,9 @@ func TestCallbackServer_ExtractsIssuer(t *testing.T) {
77139
}
78140
defer resp.Body.Close()
79141

80-
// Drain the feedback channel so ServeHTTP completes.
81-
<-cb.feedbackCh
82-
83-
if got := cb.Issuer(); got != issuerURL {
84-
t.Fatalf("Issuer(): want %q, got %q", issuerURL, got)
142+
res := <-cb.feedbackCh
143+
if got := res.Issuer; got != issuerURL {
144+
t.Fatalf("result Issuer: want %q, got %q", issuerURL, got)
85145
}
86146
}
87147

@@ -128,11 +188,9 @@ func TestCallbackServer_NoIssuer(t *testing.T) {
128188
}
129189
defer resp.Body.Close()
130190

131-
// Drain the feedback channel so ServeHTTP completes.
132-
<-cb.feedbackCh
133-
134-
if got := cb.Issuer(); got != "" {
135-
t.Fatalf("Issuer(): want %q, got %q", "", got)
191+
res := <-cb.feedbackCh
192+
if got := res.Issuer; got != "" {
193+
t.Fatalf("result Issuer: want %q, got %q", "", got)
136194
}
137195
}
138196

@@ -180,10 +238,8 @@ func TestCallbackServer_IssuerWithAccountPath(t *testing.T) {
180238
}
181239
defer resp.Body.Close()
182240

183-
// Drain the feedback channel so ServeHTTP completes.
184-
<-cb.feedbackCh
185-
186-
if got := cb.Issuer(); got != issuerURL {
187-
t.Fatalf("Issuer(): want %q, got %q", issuerURL, got)
241+
res := <-cb.feedbackCh
242+
if got := res.Issuer; got != issuerURL {
243+
t.Fatalf("result Issuer: want %q, got %q", issuerURL, got)
188244
}
189245
}

credentials/u2m/discovery_token_source.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,7 @@ func (d *discoveryTokenSource) challenge() error {
124124
}
125125
authorizeURL := buildDiscoveryAuthorizeURL(host, d.pa.redirectAddr, state, pkce, scopes)
126126

127-
// Use cb.Handler to open the browser and wait for the callback.
128-
code, returnedState, err := cb.Handler(authorizeURL)
127+
code, returnedState, issuer, err := cb.handlerWithIssuer(authorizeURL)
129128
if err != nil {
130129
return fmt.Errorf("authorize: %w", err)
131130
}
@@ -135,7 +134,6 @@ func (d *discoveryTokenSource) challenge() error {
135134
return fmt.Errorf("state mismatch: expected %q, got %q", state, returnedState)
136135
}
137136

138-
issuer := cb.Issuer()
139137
if issuer == "" {
140138
return fmt.Errorf("discovery login failed: callback did not include an issuer (iss) parameter")
141139
}

0 commit comments

Comments
 (0)