Skip to content

Commit e558398

Browse files
committed
fix: address blind spots — remove unused import, add discovery default, enforce PKCE at /auth, add tests
1 parent c46073a commit e558398

4 files changed

Lines changed: 172 additions & 1 deletion

File tree

connector/oauth/oauth.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9-
"io"
109
"log/slog"
1110
"net/http"
1211
"strings"

server/handlers.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ func (s *Server) discoveryHandler(ctx context.Context, t DiscoveryType) (http.Ha
120120
d = s.constructDiscoveryOAuth2()
121121
case DiscoveryOIDC:
122122
d = s.constructDiscoveryOIDC(ctx)
123+
default:
124+
d = s.constructDiscoveryOIDC(ctx)
123125
}
124126

125127
data, err := json.MarshalIndent(d, "", " ")

server/oauth2.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,11 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
477477
return nil, newRedirectedErr(errInvalidRequest, description)
478478
}
479479

480+
// Public clients MUST use PKCE — reject early at /auth if code_challenge is missing.
481+
if client.Public && codeChallenge == "" {
482+
return nil, newRedirectedErr(errInvalidRequest, "Public clients must use PKCE (code_challenge required).")
483+
}
484+
480485
var (
481486
unrecognized []string
482487
invalidScopes []string

server/server_test.go

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,171 @@ func TestOAuth2CodeFlow(t *testing.T) {
897897
}
898898
}
899899

900+
// TestPublicClientPKCE tests that public clients:
901+
// 1. Can exchange auth codes WITHOUT a client_secret when using PKCE
902+
// 2. Are REJECTED when they don't use PKCE at all
903+
func TestPublicClientPKCE(t *testing.T) {
904+
publicClientID := "public-test-client"
905+
906+
t0 := time.Now()
907+
now := func() time.Time { return t0 }
908+
idTokensValidFor := time.Second * 30
909+
910+
oidcConfig := &oidc.Config{SkipClientIDCheck: true}
911+
basicIDTokenVerify := func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
912+
idToken, ok := token.Extra("id_token").(string)
913+
if !ok {
914+
return fmt.Errorf("no id token found")
915+
}
916+
if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil {
917+
return fmt.Errorf("failed to verify id token: %v", err)
918+
}
919+
return nil
920+
}
921+
922+
tests := []test{
923+
{
924+
// Public client with plain PKCE should succeed without client_secret
925+
name: "public client with PKCE succeeds",
926+
authCodeOptions: []oauth2.AuthCodeOption{
927+
oauth2.SetAuthURLParam("code_challenge", "challenge123"),
928+
},
929+
retrieveTokenOptions: []oauth2.AuthCodeOption{
930+
oauth2.SetAuthURLParam("code_verifier", "challenge123"),
931+
},
932+
handleToken: basicIDTokenVerify,
933+
},
934+
{
935+
// Public client with S256 PKCE should succeed
936+
name: "public client with S256 PKCE succeeds",
937+
authCodeOptions: []oauth2.AuthCodeOption{
938+
oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"),
939+
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
940+
},
941+
retrieveTokenOptions: []oauth2.AuthCodeOption{
942+
oauth2.SetAuthURLParam("code_verifier", "challenge123"),
943+
},
944+
handleToken: basicIDTokenVerify,
945+
},
946+
{
947+
// Public client WITHOUT PKCE should be rejected at /auth (redirect error)
948+
name: "public client without PKCE is rejected at auth",
949+
handleToken: basicIDTokenVerify,
950+
authError: &OAuth2ErrorResponse{
951+
Error: errInvalidRequest,
952+
ErrorDescription: "Public clients must use PKCE (code_challenge required).",
953+
},
954+
},
955+
}
956+
957+
for _, tc := range tests {
958+
t.Run(tc.name, func(t *testing.T) {
959+
ctx := t.Context()
960+
961+
httpServer, s := newTestServer(t, func(c *Config) {
962+
c.Issuer += "/non-root-path"
963+
c.Now = now
964+
c.IDTokensValidFor = idTokensValidFor
965+
})
966+
defer httpServer.Close()
967+
968+
p, err := oidc.NewProvider(ctx, httpServer.URL)
969+
if err != nil {
970+
t.Fatalf("failed to get provider: %v", err)
971+
}
972+
973+
var (
974+
gotCode bool
975+
reqDump, respDump []byte
976+
state = "a_state"
977+
)
978+
defer func() {
979+
if !gotCode && tc.authError == nil {
980+
t.Errorf("never got a code in callback\n%s\n%s", reqDump, respDump)
981+
}
982+
}()
983+
984+
var oauth2Config *oauth2.Config
985+
oauth2Client := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
986+
if r.URL.Path != "/callback" {
987+
http.Redirect(w, r, oauth2Config.AuthCodeURL(state, tc.authCodeOptions...), http.StatusSeeOther)
988+
return
989+
}
990+
991+
q := r.URL.Query()
992+
if errType := q.Get("error"); errType != "" {
993+
if tc.authError != nil {
994+
if errType != tc.authError.Error {
995+
t.Errorf("expected auth error %q, got %q", tc.authError.Error, errType)
996+
}
997+
gotCode = true // prevent the deferred "never got a code" error
998+
return
999+
}
1000+
t.Errorf("got error from server %s: %s", errType, q.Get("error_description"))
1001+
w.WriteHeader(http.StatusInternalServerError)
1002+
return
1003+
}
1004+
1005+
if code := q.Get("code"); code != "" {
1006+
gotCode = true
1007+
token, err := oauth2Config.Exchange(ctx, code, tc.retrieveTokenOptions...)
1008+
if tc.tokenError.StatusCode != 0 {
1009+
checkErrorResponse(err, t, tc)
1010+
return
1011+
}
1012+
if err != nil {
1013+
t.Errorf("failed to exchange code for token: %v", err)
1014+
return
1015+
}
1016+
err = tc.handleToken(ctx, p, oauth2Config, token, nil)
1017+
if err != nil {
1018+
t.Errorf("%s: %v", tc.name, err)
1019+
}
1020+
return
1021+
}
1022+
1023+
if gotState := q.Get("state"); gotState != state {
1024+
t.Errorf("state did not match, want=%q got=%q", state, gotState)
1025+
}
1026+
w.WriteHeader(http.StatusOK)
1027+
}))
1028+
defer oauth2Client.Close()
1029+
1030+
// Register a PUBLIC client (no secret, public=true)
1031+
redirectURL := oauth2Client.URL + "/callback"
1032+
client := storage.Client{
1033+
ID: publicClientID,
1034+
Public: true,
1035+
RedirectURIs: []string{redirectURL},
1036+
}
1037+
if err := s.storage.CreateClient(ctx, client); err != nil {
1038+
t.Fatalf("failed to create client: %v", err)
1039+
}
1040+
1041+
// OAuth2 config with NO client secret (public client)
1042+
oauth2Config = &oauth2.Config{
1043+
ClientID: client.ID,
1044+
Endpoint: p.Endpoint(),
1045+
Scopes: []string{oidc.ScopeOpenID, "email", "profile"},
1046+
RedirectURL: redirectURL,
1047+
}
1048+
1049+
resp, err := http.Get(oauth2Client.URL + "/login")
1050+
if err != nil {
1051+
t.Fatalf("get failed: %v", err)
1052+
}
1053+
defer resp.Body.Close()
1054+
1055+
if reqDump, err = httputil.DumpRequest(resp.Request, false); err != nil {
1056+
t.Fatal(err)
1057+
}
1058+
if respDump, err = httputil.DumpResponse(resp, true); err != nil {
1059+
t.Fatal(err)
1060+
}
1061+
})
1062+
}
1063+
}
1064+
9001065
func TestOAuth2ImplicitFlow(t *testing.T) {
9011066
ctx := t.Context()
9021067

0 commit comments

Comments
 (0)