diff --git a/providers/openidConnect/openidConnect.go b/providers/openidConnect/openidConnect.go index dea40d2d..c22017fd 100644 --- a/providers/openidConnect/openidConnect.go +++ b/providers/openidConnect/openidConnect.go @@ -70,6 +70,11 @@ type Provider struct { LocationClaims []string SkipUserInfoRequest bool + + // PKCEMethod is the code challenge method to use for PKCE. It is automatically + // selected from the discovery document's code_challenge_methods_supported field. + // Supported values are "S256" and "plain". Empty string means PKCE is disabled. + PKCEMethod string } type OpenIDConfig struct { @@ -82,6 +87,10 @@ type OpenIDConfig struct { // https://openid.net/specs/openid-connect-session-1_0-17.html#OPMetadata EndSessionEndpoint string `json:"end_session_endpoint,omitempty"` Issuer string `json:"issuer"` + + // CodeChallengeMethodsSupported lists PKCE code challenge methods supported by the provider. + // See https://www.rfc-editor.org/rfc/rfc7636 + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` } type RefreshTokenResponse struct { @@ -142,6 +151,7 @@ func NewNamed(name, clientKey, secret, callbackURL, openIDAutoDiscoveryURL strin return nil, err } p.OpenIDConfig = openIDConfig + p.PKCEMethod = selectPKCEMethod(openIDConfig.CodeChallengeMethodsSupported) p.config = newConfig(p, scopes, openIDConfig) return p, nil @@ -204,10 +214,24 @@ func (p *Provider) Debug(debug bool) {} // BeginAuth asks the OpenID Connect provider for an authentication end-point. func (p *Provider) BeginAuth(state string) (goth.Session, error) { - url := p.config.AuthCodeURL(state, p.authCodeOptions...) - session := &Session{ - AuthURL: url, + authCodeOptions := p.authCodeOptions + session := &Session{} + + if p.PKCEMethod != "" { + verifier := oauth2.GenerateVerifier() + switch p.PKCEMethod { + case "S256": + authCodeOptions = append(authCodeOptions, oauth2.S256ChallengeOption(verifier)) + case "plain": + authCodeOptions = append(authCodeOptions, + oauth2.SetAuthURLParam("code_challenge", verifier), + oauth2.SetAuthURLParam("code_challenge_method", "plain"), + ) + } + session.CodeVerifier = verifier } + + session.AuthURL = p.config.AuthCodeURL(state, authCodeOptions...) return session, nil } @@ -527,3 +551,28 @@ func unMarshal(payload []byte) (map[string]interface{}, error) { return data, json.NewDecoder(bytes.NewBuffer(payload)).Decode(&data) } + +// selectPKCEMethod selects the best PKCE code challenge method from the list +// advertised by the provider. S256 is preferred over plain per RFC 7636 §4.2. +// Returns an empty string if neither method is supported. +func selectPKCEMethod(methods []string) string { + hasS256 := false + hasPlain := false + for _, m := range methods { + switch m { + case "S256": + hasS256 = true + case "plain": + hasPlain = true + } + } + if hasS256 { + return "S256" + } + if hasPlain { + return "plain" + } + return "" +} + + diff --git a/providers/openidConnect/openidConnect_test.go b/providers/openidConnect/openidConnect_test.go index 7dd76e04..613f42ce 100644 --- a/providers/openidConnect/openidConnect_test.go +++ b/providers/openidConnect/openidConnect_test.go @@ -9,6 +9,7 @@ import ( "github.com/markbates/goth" "github.com/stretchr/testify/assert" + "golang.org/x/oauth2" ) var ( @@ -76,6 +77,125 @@ func Test_BeginAuth(t *testing.T) { a.Contains(s.AuthURL, "state=test_state") a.Contains(s.AuthURL, "redirect_uri=http%3A%2F%2Flocalhost%2Ffoo") a.Contains(s.AuthURL, "scope=openid") + + // The mock server advertises ["plain","S256"] so PKCE must be used with S256. + a.Equal("S256", provider.PKCEMethod) + a.NotEmpty(s.CodeVerifier) + a.Contains(s.AuthURL, "code_challenge=") + a.Contains(s.AuthURL, "code_challenge_method=S256") +} + +func Test_BeginAuth_PKCE_S256_Challenge(t *testing.T) { + t.Parallel() + a := assert.New(t) + + provider := openidConnectProvider() + session, err := provider.BeginAuth("test_state") + a.NoError(err) + s := session.(*Session) + + // Verify that the code_challenge in the URL matches the S256 of the stored verifier. + expected := oauth2.S256ChallengeFromVerifier(s.CodeVerifier) + a.Contains(s.AuthURL, "code_challenge="+expected) +} + +func Test_BeginAuth_NoPKCE_WhenNotAdvertised(t *testing.T) { + t.Parallel() + a := assert.New(t) + + // Spin up a server that does NOT advertise code_challenge_methods_supported. + noPKCEServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"issuer":"https://accounts.google.com","authorization_endpoint":"https://accounts.google.com/o/oauth2/v2/auth","token_endpoint":"https://www.googleapis.com/oauth2/v4/token","userinfo_endpoint":"https://www.googleapis.com/oauth2/v3/userinfo"}`) + })) + defer noPKCEServer.Close() + + provider, err := New(os.Getenv("OPENID_CONNECT_KEY"), os.Getenv("OPENID_CONNECT_SECRET"), "http://localhost/foo", noPKCEServer.URL) + a.NoError(err) + a.Equal("", provider.PKCEMethod) + + session, err := provider.BeginAuth("test_state") + a.NoError(err) + s := session.(*Session) + a.Empty(s.CodeVerifier) + a.NotContains(s.AuthURL, "code_challenge") + a.NotContains(s.AuthURL, "code_challenge_method") +} + +func Test_SelectPKCEMethod(t *testing.T) { + t.Parallel() + a := assert.New(t) + + a.Equal("S256", selectPKCEMethod([]string{"plain", "S256"})) + a.Equal("S256", selectPKCEMethod([]string{"S256"})) + a.Equal("plain", selectPKCEMethod([]string{"plain"})) + a.Equal("", selectPKCEMethod([]string{})) + a.Equal("", selectPKCEMethod(nil)) + a.Equal("", selectPKCEMethod([]string{"other"})) +} + +func Test_GenerateCodeVerifier(t *testing.T) { + t.Parallel() + a := assert.New(t) + + v := oauth2.GenerateVerifier() + // RFC 7636 §4.1: verifier must be between 43 and 128 chars + a.GreaterOrEqual(len(v), 43) + a.LessOrEqual(len(v), 128) + // All chars must be URL-safe base64 alphabet + for _, c := range v { + a.Contains("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_", string(c)) + } + + // Two verifiers must differ (with overwhelming probability) + v2 := oauth2.GenerateVerifier() + a.NotEqual(v, v2) +} + +func Test_GenerateS256Challenge(t *testing.T) { + t.Parallel() + a := assert.New(t) + + // Known test vector from RFC 7636 Appendix B: + // code_verifier = dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk + // code_challenge = E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM + a.Equal("E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", oauth2.S256ChallengeFromVerifier("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk")) +} + +func Test_New_PKCE_MethodSelectedFromDiscovery(t *testing.T) { + t.Parallel() + a := assert.New(t) + + // Mock server advertises only "plain" + plainOnlyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"issuer":"https://accounts.google.com","authorization_endpoint":"https://accounts.google.com/o/oauth2/v2/auth","token_endpoint":"https://www.googleapis.com/oauth2/v4/token","userinfo_endpoint":"https://www.googleapis.com/oauth2/v3/userinfo","code_challenge_methods_supported":["plain"]}`) + })) + defer plainOnlyServer.Close() + + provider, err := New(os.Getenv("OPENID_CONNECT_KEY"), os.Getenv("OPENID_CONNECT_SECRET"), "http://localhost/foo", plainOnlyServer.URL) + a.NoError(err) + a.Equal("plain", provider.PKCEMethod) +} + +func Test_BeginAuth_PKCE_PlainMethod(t *testing.T) { + t.Parallel() + a := assert.New(t) + + // Mock server advertises only "plain" + plainOnlyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"issuer":"https://accounts.google.com","authorization_endpoint":"https://accounts.google.com/o/oauth2/v2/auth","token_endpoint":"https://www.googleapis.com/oauth2/v4/token","userinfo_endpoint":"https://www.googleapis.com/oauth2/v3/userinfo","code_challenge_methods_supported":["plain"]}`) + })) + defer plainOnlyServer.Close() + + provider, err := New(os.Getenv("OPENID_CONNECT_KEY"), os.Getenv("OPENID_CONNECT_SECRET"), "http://localhost/foo", plainOnlyServer.URL) + a.NoError(err) + + session, err := provider.BeginAuth("test_state") + a.NoError(err) + s := session.(*Session) + a.NotEmpty(s.CodeVerifier) + // For "plain", challenge == verifier + a.Contains(s.AuthURL, "code_challenge="+s.CodeVerifier) + a.Contains(s.AuthURL, "code_challenge_method=plain") } func Test_BeginAuth_AuthCodeOptions(t *testing.T) { diff --git a/providers/openidConnect/session.go b/providers/openidConnect/session.go index 84b577c3..523b8348 100644 --- a/providers/openidConnect/session.go +++ b/providers/openidConnect/session.go @@ -17,6 +17,9 @@ type Session struct { RefreshToken string ExpiresAt time.Time IDToken string + // CodeVerifier holds the PKCE code verifier generated during BeginAuth. + // It is used at token exchange time to prove possession of the original verifier. + CodeVerifier string `json:",omitempty"` } // GetAuthURL will return the URL set by calling the `BeginAuth` function on the OpenID Connect provider. @@ -39,10 +42,15 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, authParams = append(authParams, oauth2.SetAuthURLParam("redirect_uri", redirectURL)) } - // set code_verifier if passed as param - codeVerifier := params.Get("code_verifier") + // set code_verifier for PKCE: prefer the verifier stored in the session + // (generated automatically during BeginAuth), fall back to one passed as + // a callback parameter for backward compatibility. + codeVerifier := s.CodeVerifier + if codeVerifier == "" { + codeVerifier = params.Get("code_verifier") + } if codeVerifier != "" { - authParams = append(authParams, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) + authParams = append(authParams, oauth2.VerifierOption(codeVerifier)) } token, err := p.config.Exchange(goth.ContextForClient(p.Client()), params.Get("code"), authParams...)