From 61ba8ed44a996d4928690f19dc8c1e9547bea309 Mon Sep 17 00:00:00 2001 From: Morgan PEYRE Date: Thu, 23 Apr 2026 12:08:48 +0000 Subject: [PATCH 1/2] feat(openidConnect): auto-detect and apply PKCE from OIDC discovery MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the OIDC discovery document advertises code_challenge_methods_supported, the provider now automatically: - Selects the best method (S256 preferred over plain per RFC 7636 §4.2) - Generates a cryptographically random code_verifier on each BeginAuth call - Computes the code_challenge and injects it into the authorization URL - Stores the code_verifier in the Session for use at token exchange - Uses the stored verifier in Authorize(), falling back to the legacy code_verifier query param for backward compatibility If code_challenge_methods_supported is absent or empty, PKCE is not applied. --- providers/openidConnect/openidConnect.go | 78 +++++++++++- providers/openidConnect/openidConnect_test.go | 119 ++++++++++++++++++ providers/openidConnect/session.go | 12 +- 3 files changed, 204 insertions(+), 5 deletions(-) diff --git a/providers/openidConnect/openidConnect.go b/providers/openidConnect/openidConnect.go index dea40d2d..dc81a6aa 100644 --- a/providers/openidConnect/openidConnect.go +++ b/providers/openidConnect/openidConnect.go @@ -2,6 +2,8 @@ package openidConnect import ( "bytes" + "crypto/rand" + "crypto/sha256" "encoding/base64" "encoding/json" "errors" @@ -70,6 +72,11 @@ type Provider struct { LocationClaims []string SkipUserInfoRequest bool + + // PKCEMethod is the code challenge method to use for PKCE. It is automatically + // selected from the discovery document's code_challenge_methods_supported field. + // Supported values are "S256" and "plain". Empty string means PKCE is disabled. + PKCEMethod string } type OpenIDConfig struct { @@ -82,6 +89,10 @@ type OpenIDConfig struct { // https://openid.net/specs/openid-connect-session-1_0-17.html#OPMetadata EndSessionEndpoint string `json:"end_session_endpoint,omitempty"` Issuer string `json:"issuer"` + + // CodeChallengeMethodsSupported lists PKCE code challenge methods supported by the provider. + // See https://www.rfc-editor.org/rfc/rfc7636 + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` } type RefreshTokenResponse struct { @@ -142,6 +153,7 @@ func NewNamed(name, clientKey, secret, callbackURL, openIDAutoDiscoveryURL strin return nil, err } p.OpenIDConfig = openIDConfig + p.PKCEMethod = selectPKCEMethod(openIDConfig.CodeChallengeMethodsSupported) p.config = newConfig(p, scopes, openIDConfig) return p, nil @@ -204,10 +216,29 @@ func (p *Provider) Debug(debug bool) {} // BeginAuth asks the OpenID Connect provider for an authentication end-point. func (p *Provider) BeginAuth(state string) (goth.Session, error) { - url := p.config.AuthCodeURL(state, p.authCodeOptions...) - session := &Session{ - AuthURL: url, + authCodeOptions := p.authCodeOptions + session := &Session{} + + if p.PKCEMethod != "" { + verifier, err := generateCodeVerifier() + if err != nil { + return nil, fmt.Errorf("openidConnect: failed to generate PKCE code verifier: %w", err) + } + var challenge string + switch p.PKCEMethod { + case "S256": + challenge = generateS256Challenge(verifier) + case "plain": + challenge = verifier + } + authCodeOptions = append(authCodeOptions, + oauth2.SetAuthURLParam("code_challenge", challenge), + oauth2.SetAuthURLParam("code_challenge_method", p.PKCEMethod), + ) + session.CodeVerifier = verifier } + + session.AuthURL = p.config.AuthCodeURL(state, authCodeOptions...) return session, nil } @@ -527,3 +558,44 @@ func unMarshal(payload []byte) (map[string]interface{}, error) { return data, json.NewDecoder(bytes.NewBuffer(payload)).Decode(&data) } + +// selectPKCEMethod selects the best PKCE code challenge method from the list +// advertised by the provider. S256 is preferred over plain per RFC 7636 §4.2. +// Returns an empty string if neither method is supported. +func selectPKCEMethod(methods []string) string { + hasS256 := false + hasPlain := false + for _, m := range methods { + switch m { + case "S256": + hasS256 = true + case "plain": + hasPlain = true + } + } + if hasS256 { + return "S256" + } + if hasPlain { + return "plain" + } + return "" +} + +// generateCodeVerifier creates a cryptographically random PKCE code verifier +// of 43 URL-safe characters (32 random bytes, base64url-encoded without padding) +// as specified in RFC 7636 §4.1. +func generateCodeVerifier() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b), nil +} + +// generateS256Challenge computes the S256 PKCE code challenge from a verifier: +// BASE64URL-ENCODE(SHA256(ASCII(code_verifier))) per RFC 7636 §4.2. +func generateS256Challenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(h[:]) +} diff --git a/providers/openidConnect/openidConnect_test.go b/providers/openidConnect/openidConnect_test.go index 7dd76e04..47c5ec43 100644 --- a/providers/openidConnect/openidConnect_test.go +++ b/providers/openidConnect/openidConnect_test.go @@ -76,6 +76,125 @@ func Test_BeginAuth(t *testing.T) { a.Contains(s.AuthURL, "state=test_state") a.Contains(s.AuthURL, "redirect_uri=http%3A%2F%2Flocalhost%2Ffoo") a.Contains(s.AuthURL, "scope=openid") + + // The mock server advertises ["plain","S256"] so PKCE must be used with S256. + a.Equal("S256", provider.PKCEMethod) + a.NotEmpty(s.CodeVerifier) + a.Contains(s.AuthURL, "code_challenge=") + a.Contains(s.AuthURL, "code_challenge_method=S256") +} + +func Test_BeginAuth_PKCE_S256_Challenge(t *testing.T) { + t.Parallel() + a := assert.New(t) + + provider := openidConnectProvider() + session, err := provider.BeginAuth("test_state") + a.NoError(err) + s := session.(*Session) + + // Verify that the code_challenge in the URL matches the S256 of the stored verifier. + expected := generateS256Challenge(s.CodeVerifier) + a.Contains(s.AuthURL, "code_challenge="+expected) +} + +func Test_BeginAuth_NoPKCE_WhenNotAdvertised(t *testing.T) { + t.Parallel() + a := assert.New(t) + + // Spin up a server that does NOT advertise code_challenge_methods_supported. + noPKCEServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"issuer":"https://accounts.google.com","authorization_endpoint":"https://accounts.google.com/o/oauth2/v2/auth","token_endpoint":"https://www.googleapis.com/oauth2/v4/token","userinfo_endpoint":"https://www.googleapis.com/oauth2/v3/userinfo"}`) + })) + defer noPKCEServer.Close() + + provider, err := New(os.Getenv("OPENID_CONNECT_KEY"), os.Getenv("OPENID_CONNECT_SECRET"), "http://localhost/foo", noPKCEServer.URL) + a.NoError(err) + a.Equal("", provider.PKCEMethod) + + session, err := provider.BeginAuth("test_state") + a.NoError(err) + s := session.(*Session) + a.Empty(s.CodeVerifier) + a.NotContains(s.AuthURL, "code_challenge") + a.NotContains(s.AuthURL, "code_challenge_method") +} + +func Test_SelectPKCEMethod(t *testing.T) { + t.Parallel() + a := assert.New(t) + + a.Equal("S256", selectPKCEMethod([]string{"plain", "S256"})) + a.Equal("S256", selectPKCEMethod([]string{"S256"})) + a.Equal("plain", selectPKCEMethod([]string{"plain"})) + a.Equal("", selectPKCEMethod([]string{})) + a.Equal("", selectPKCEMethod(nil)) + a.Equal("", selectPKCEMethod([]string{"other"})) +} + +func Test_GenerateCodeVerifier(t *testing.T) { + t.Parallel() + a := assert.New(t) + + v, err := generateCodeVerifier() + a.NoError(err) + // 32 bytes base64url-encoded without padding = 43 chars + a.Equal(43, len(v)) + // All chars must be URL-safe base64 alphabet + for _, c := range v { + a.Contains("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_", string(c)) + } + + // Two verifiers must differ (with overwhelming probability) + v2, _ := generateCodeVerifier() + a.NotEqual(v, v2) +} + +func Test_GenerateS256Challenge(t *testing.T) { + t.Parallel() + a := assert.New(t) + + // Known test vector from RFC 7636 Appendix B: + // code_verifier = dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk + // code_challenge = E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM + a.Equal("E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", generateS256Challenge("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk")) +} + +func Test_New_PKCE_MethodSelectedFromDiscovery(t *testing.T) { + t.Parallel() + a := assert.New(t) + + // Mock server advertises only "plain" + plainOnlyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"issuer":"https://accounts.google.com","authorization_endpoint":"https://accounts.google.com/o/oauth2/v2/auth","token_endpoint":"https://www.googleapis.com/oauth2/v4/token","userinfo_endpoint":"https://www.googleapis.com/oauth2/v3/userinfo","code_challenge_methods_supported":["plain"]}`) + })) + defer plainOnlyServer.Close() + + provider, err := New(os.Getenv("OPENID_CONNECT_KEY"), os.Getenv("OPENID_CONNECT_SECRET"), "http://localhost/foo", plainOnlyServer.URL) + a.NoError(err) + a.Equal("plain", provider.PKCEMethod) +} + +func Test_BeginAuth_PKCE_PlainMethod(t *testing.T) { + t.Parallel() + a := assert.New(t) + + // Mock server advertises only "plain" + plainOnlyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"issuer":"https://accounts.google.com","authorization_endpoint":"https://accounts.google.com/o/oauth2/v2/auth","token_endpoint":"https://www.googleapis.com/oauth2/v4/token","userinfo_endpoint":"https://www.googleapis.com/oauth2/v3/userinfo","code_challenge_methods_supported":["plain"]}`) + })) + defer plainOnlyServer.Close() + + provider, err := New(os.Getenv("OPENID_CONNECT_KEY"), os.Getenv("OPENID_CONNECT_SECRET"), "http://localhost/foo", plainOnlyServer.URL) + a.NoError(err) + + session, err := provider.BeginAuth("test_state") + a.NoError(err) + s := session.(*Session) + a.NotEmpty(s.CodeVerifier) + // For "plain", challenge == verifier + a.Contains(s.AuthURL, "code_challenge="+s.CodeVerifier) + a.Contains(s.AuthURL, "code_challenge_method=plain") } func Test_BeginAuth_AuthCodeOptions(t *testing.T) { diff --git a/providers/openidConnect/session.go b/providers/openidConnect/session.go index 84b577c3..d3ee6645 100644 --- a/providers/openidConnect/session.go +++ b/providers/openidConnect/session.go @@ -17,6 +17,9 @@ type Session struct { RefreshToken string ExpiresAt time.Time IDToken string + // CodeVerifier holds the PKCE code verifier generated during BeginAuth. + // It is used at token exchange time to prove possession of the original verifier. + CodeVerifier string `json:",omitempty"` } // GetAuthURL will return the URL set by calling the `BeginAuth` function on the OpenID Connect provider. @@ -39,8 +42,13 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, authParams = append(authParams, oauth2.SetAuthURLParam("redirect_uri", redirectURL)) } - // set code_verifier if passed as param - codeVerifier := params.Get("code_verifier") + // set code_verifier for PKCE: prefer the verifier stored in the session + // (generated automatically during BeginAuth), fall back to one passed as + // a callback parameter for backward compatibility. + codeVerifier := s.CodeVerifier + if codeVerifier == "" { + codeVerifier = params.Get("code_verifier") + } if codeVerifier != "" { authParams = append(authParams, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) } From efb752a117518537292a22f35ef99ed9785d3a68 Mon Sep 17 00:00:00 2001 From: Morgan PEYRE Date: Thu, 23 Apr 2026 13:09:24 +0000 Subject: [PATCH 2/2] refactor(openidConnect): use golang.org/x/oauth2 built-in PKCE functions Replace custom generateCodeVerifier and generateS256Challenge helpers with oauth2.GenerateVerifier(), oauth2.S256ChallengeOption(), oauth2.VerifierOption(), and oauth2.S256ChallengeFromVerifier() from golang.org/x/oauth2 v0.30.0. --- providers/openidConnect/openidConnect.go | 35 ++++--------------- providers/openidConnect/openidConnect_test.go | 15 ++++---- providers/openidConnect/session.go | 2 +- 3 files changed, 15 insertions(+), 37 deletions(-) diff --git a/providers/openidConnect/openidConnect.go b/providers/openidConnect/openidConnect.go index dc81a6aa..c22017fd 100644 --- a/providers/openidConnect/openidConnect.go +++ b/providers/openidConnect/openidConnect.go @@ -2,8 +2,6 @@ package openidConnect import ( "bytes" - "crypto/rand" - "crypto/sha256" "encoding/base64" "encoding/json" "errors" @@ -220,21 +218,16 @@ func (p *Provider) BeginAuth(state string) (goth.Session, error) { session := &Session{} if p.PKCEMethod != "" { - verifier, err := generateCodeVerifier() - if err != nil { - return nil, fmt.Errorf("openidConnect: failed to generate PKCE code verifier: %w", err) - } - var challenge string + verifier := oauth2.GenerateVerifier() switch p.PKCEMethod { case "S256": - challenge = generateS256Challenge(verifier) + authCodeOptions = append(authCodeOptions, oauth2.S256ChallengeOption(verifier)) case "plain": - challenge = verifier + authCodeOptions = append(authCodeOptions, + oauth2.SetAuthURLParam("code_challenge", verifier), + oauth2.SetAuthURLParam("code_challenge_method", "plain"), + ) } - authCodeOptions = append(authCodeOptions, - oauth2.SetAuthURLParam("code_challenge", challenge), - oauth2.SetAuthURLParam("code_challenge_method", p.PKCEMethod), - ) session.CodeVerifier = verifier } @@ -582,20 +575,4 @@ func selectPKCEMethod(methods []string) string { return "" } -// generateCodeVerifier creates a cryptographically random PKCE code verifier -// of 43 URL-safe characters (32 random bytes, base64url-encoded without padding) -// as specified in RFC 7636 §4.1. -func generateCodeVerifier() (string, error) { - b := make([]byte, 32) - if _, err := rand.Read(b); err != nil { - return "", err - } - return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b), nil -} -// generateS256Challenge computes the S256 PKCE code challenge from a verifier: -// BASE64URL-ENCODE(SHA256(ASCII(code_verifier))) per RFC 7636 §4.2. -func generateS256Challenge(verifier string) string { - h := sha256.Sum256([]byte(verifier)) - return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(h[:]) -} diff --git a/providers/openidConnect/openidConnect_test.go b/providers/openidConnect/openidConnect_test.go index 47c5ec43..613f42ce 100644 --- a/providers/openidConnect/openidConnect_test.go +++ b/providers/openidConnect/openidConnect_test.go @@ -9,6 +9,7 @@ import ( "github.com/markbates/goth" "github.com/stretchr/testify/assert" + "golang.org/x/oauth2" ) var ( @@ -94,7 +95,7 @@ func Test_BeginAuth_PKCE_S256_Challenge(t *testing.T) { s := session.(*Session) // Verify that the code_challenge in the URL matches the S256 of the stored verifier. - expected := generateS256Challenge(s.CodeVerifier) + expected := oauth2.S256ChallengeFromVerifier(s.CodeVerifier) a.Contains(s.AuthURL, "code_challenge="+expected) } @@ -136,17 +137,17 @@ func Test_GenerateCodeVerifier(t *testing.T) { t.Parallel() a := assert.New(t) - v, err := generateCodeVerifier() - a.NoError(err) - // 32 bytes base64url-encoded without padding = 43 chars - a.Equal(43, len(v)) + v := oauth2.GenerateVerifier() + // RFC 7636 §4.1: verifier must be between 43 and 128 chars + a.GreaterOrEqual(len(v), 43) + a.LessOrEqual(len(v), 128) // All chars must be URL-safe base64 alphabet for _, c := range v { a.Contains("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_", string(c)) } // Two verifiers must differ (with overwhelming probability) - v2, _ := generateCodeVerifier() + v2 := oauth2.GenerateVerifier() a.NotEqual(v, v2) } @@ -157,7 +158,7 @@ func Test_GenerateS256Challenge(t *testing.T) { // Known test vector from RFC 7636 Appendix B: // code_verifier = dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk // code_challenge = E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM - a.Equal("E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", generateS256Challenge("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk")) + a.Equal("E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", oauth2.S256ChallengeFromVerifier("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk")) } func Test_New_PKCE_MethodSelectedFromDiscovery(t *testing.T) { diff --git a/providers/openidConnect/session.go b/providers/openidConnect/session.go index d3ee6645..523b8348 100644 --- a/providers/openidConnect/session.go +++ b/providers/openidConnect/session.go @@ -50,7 +50,7 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, codeVerifier = params.Get("code_verifier") } if codeVerifier != "" { - authParams = append(authParams, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) + authParams = append(authParams, oauth2.VerifierOption(codeVerifier)) } token, err := p.config.Exchange(goth.ContextForClient(p.Client()), params.Get("code"), authParams...)