Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions providers/openidConnect/openidConnect.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ type Provider struct {
LocationClaims []string

SkipUserInfoRequest bool

// PKCEMethod is the code challenge method to use for PKCE. It is automatically
// selected from the discovery document's code_challenge_methods_supported field.
// Supported values are "S256" and "plain". Empty string means PKCE is disabled.
PKCEMethod string
}

type OpenIDConfig struct {
Expand All @@ -82,6 +87,10 @@ type OpenIDConfig struct {
// https://openid.net/specs/openid-connect-session-1_0-17.html#OPMetadata
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
Issuer string `json:"issuer"`

// CodeChallengeMethodsSupported lists PKCE code challenge methods supported by the provider.
// See https://www.rfc-editor.org/rfc/rfc7636
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"`
}

type RefreshTokenResponse struct {
Expand Down Expand Up @@ -142,6 +151,7 @@ func NewNamed(name, clientKey, secret, callbackURL, openIDAutoDiscoveryURL strin
return nil, err
}
p.OpenIDConfig = openIDConfig
p.PKCEMethod = selectPKCEMethod(openIDConfig.CodeChallengeMethodsSupported)

p.config = newConfig(p, scopes, openIDConfig)
return p, nil
Expand Down Expand Up @@ -204,10 +214,24 @@ func (p *Provider) Debug(debug bool) {}

// BeginAuth asks the OpenID Connect provider for an authentication end-point.
func (p *Provider) BeginAuth(state string) (goth.Session, error) {
url := p.config.AuthCodeURL(state, p.authCodeOptions...)
session := &Session{
AuthURL: url,
authCodeOptions := p.authCodeOptions
session := &Session{}

if p.PKCEMethod != "" {
verifier := oauth2.GenerateVerifier()
switch p.PKCEMethod {
case "S256":
authCodeOptions = append(authCodeOptions, oauth2.S256ChallengeOption(verifier))
case "plain":
authCodeOptions = append(authCodeOptions,
oauth2.SetAuthURLParam("code_challenge", verifier),
oauth2.SetAuthURLParam("code_challenge_method", "plain"),
)
}
session.CodeVerifier = verifier
}

session.AuthURL = p.config.AuthCodeURL(state, authCodeOptions...)
return session, nil
}

Expand Down Expand Up @@ -527,3 +551,28 @@ func unMarshal(payload []byte) (map[string]interface{}, error) {

return data, json.NewDecoder(bytes.NewBuffer(payload)).Decode(&data)
}

// selectPKCEMethod selects the best PKCE code challenge method from the list
// advertised by the provider. S256 is preferred over plain per RFC 7636 §4.2.
// Returns an empty string if neither method is supported.
func selectPKCEMethod(methods []string) string {
hasS256 := false
hasPlain := false
for _, m := range methods {
switch m {
case "S256":
hasS256 = true
case "plain":
hasPlain = true
}
}
if hasS256 {
return "S256"
}
if hasPlain {
return "plain"
}
return ""
}


120 changes: 120 additions & 0 deletions providers/openidConnect/openidConnect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/markbates/goth"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
)

var (
Expand Down Expand Up @@ -76,6 +77,125 @@ func Test_BeginAuth(t *testing.T) {
a.Contains(s.AuthURL, "state=test_state")
a.Contains(s.AuthURL, "redirect_uri=http%3A%2F%2Flocalhost%2Ffoo")
a.Contains(s.AuthURL, "scope=openid")

// The mock server advertises ["plain","S256"] so PKCE must be used with S256.
a.Equal("S256", provider.PKCEMethod)
a.NotEmpty(s.CodeVerifier)
a.Contains(s.AuthURL, "code_challenge=")
a.Contains(s.AuthURL, "code_challenge_method=S256")
}

func Test_BeginAuth_PKCE_S256_Challenge(t *testing.T) {
t.Parallel()
a := assert.New(t)

provider := openidConnectProvider()
session, err := provider.BeginAuth("test_state")
a.NoError(err)
s := session.(*Session)

// Verify that the code_challenge in the URL matches the S256 of the stored verifier.
expected := oauth2.S256ChallengeFromVerifier(s.CodeVerifier)
a.Contains(s.AuthURL, "code_challenge="+expected)
}

func Test_BeginAuth_NoPKCE_WhenNotAdvertised(t *testing.T) {
t.Parallel()
a := assert.New(t)

// Spin up a server that does NOT advertise code_challenge_methods_supported.
noPKCEServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, `{"issuer":"https://accounts.google.com","authorization_endpoint":"https://accounts.google.com/o/oauth2/v2/auth","token_endpoint":"https://www.googleapis.com/oauth2/v4/token","userinfo_endpoint":"https://www.googleapis.com/oauth2/v3/userinfo"}`)
}))
defer noPKCEServer.Close()

provider, err := New(os.Getenv("OPENID_CONNECT_KEY"), os.Getenv("OPENID_CONNECT_SECRET"), "http://localhost/foo", noPKCEServer.URL)
a.NoError(err)
a.Equal("", provider.PKCEMethod)

session, err := provider.BeginAuth("test_state")
a.NoError(err)
s := session.(*Session)
a.Empty(s.CodeVerifier)
a.NotContains(s.AuthURL, "code_challenge")
a.NotContains(s.AuthURL, "code_challenge_method")
}

func Test_SelectPKCEMethod(t *testing.T) {
t.Parallel()
a := assert.New(t)

a.Equal("S256", selectPKCEMethod([]string{"plain", "S256"}))
a.Equal("S256", selectPKCEMethod([]string{"S256"}))
a.Equal("plain", selectPKCEMethod([]string{"plain"}))
a.Equal("", selectPKCEMethod([]string{}))
a.Equal("", selectPKCEMethod(nil))
a.Equal("", selectPKCEMethod([]string{"other"}))
}

func Test_GenerateCodeVerifier(t *testing.T) {
t.Parallel()
a := assert.New(t)

v := oauth2.GenerateVerifier()
// RFC 7636 §4.1: verifier must be between 43 and 128 chars
a.GreaterOrEqual(len(v), 43)
a.LessOrEqual(len(v), 128)
// All chars must be URL-safe base64 alphabet
for _, c := range v {
a.Contains("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_", string(c))
}

// Two verifiers must differ (with overwhelming probability)
v2 := oauth2.GenerateVerifier()
a.NotEqual(v, v2)
}

func Test_GenerateS256Challenge(t *testing.T) {
t.Parallel()
a := assert.New(t)

// Known test vector from RFC 7636 Appendix B:
// code_verifier = dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk
// code_challenge = E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM
a.Equal("E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", oauth2.S256ChallengeFromVerifier("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"))
}

func Test_New_PKCE_MethodSelectedFromDiscovery(t *testing.T) {
t.Parallel()
a := assert.New(t)

// Mock server advertises only "plain"
plainOnlyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, `{"issuer":"https://accounts.google.com","authorization_endpoint":"https://accounts.google.com/o/oauth2/v2/auth","token_endpoint":"https://www.googleapis.com/oauth2/v4/token","userinfo_endpoint":"https://www.googleapis.com/oauth2/v3/userinfo","code_challenge_methods_supported":["plain"]}`)
}))
defer plainOnlyServer.Close()

provider, err := New(os.Getenv("OPENID_CONNECT_KEY"), os.Getenv("OPENID_CONNECT_SECRET"), "http://localhost/foo", plainOnlyServer.URL)
a.NoError(err)
a.Equal("plain", provider.PKCEMethod)
}

func Test_BeginAuth_PKCE_PlainMethod(t *testing.T) {
t.Parallel()
a := assert.New(t)

// Mock server advertises only "plain"
plainOnlyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, `{"issuer":"https://accounts.google.com","authorization_endpoint":"https://accounts.google.com/o/oauth2/v2/auth","token_endpoint":"https://www.googleapis.com/oauth2/v4/token","userinfo_endpoint":"https://www.googleapis.com/oauth2/v3/userinfo","code_challenge_methods_supported":["plain"]}`)
}))
defer plainOnlyServer.Close()

provider, err := New(os.Getenv("OPENID_CONNECT_KEY"), os.Getenv("OPENID_CONNECT_SECRET"), "http://localhost/foo", plainOnlyServer.URL)
a.NoError(err)

session, err := provider.BeginAuth("test_state")
a.NoError(err)
s := session.(*Session)
a.NotEmpty(s.CodeVerifier)
// For "plain", challenge == verifier
a.Contains(s.AuthURL, "code_challenge="+s.CodeVerifier)
a.Contains(s.AuthURL, "code_challenge_method=plain")
}

func Test_BeginAuth_AuthCodeOptions(t *testing.T) {
Expand Down
14 changes: 11 additions & 3 deletions providers/openidConnect/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ type Session struct {
RefreshToken string
ExpiresAt time.Time
IDToken string
// CodeVerifier holds the PKCE code verifier generated during BeginAuth.
// It is used at token exchange time to prove possession of the original verifier.
CodeVerifier string `json:",omitempty"`
}

// GetAuthURL will return the URL set by calling the `BeginAuth` function on the OpenID Connect provider.
Expand All @@ -39,10 +42,15 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string,
authParams = append(authParams, oauth2.SetAuthURLParam("redirect_uri", redirectURL))
}

// set code_verifier if passed as param
codeVerifier := params.Get("code_verifier")
// set code_verifier for PKCE: prefer the verifier stored in the session
// (generated automatically during BeginAuth), fall back to one passed as
// a callback parameter for backward compatibility.
codeVerifier := s.CodeVerifier
if codeVerifier == "" {
codeVerifier = params.Get("code_verifier")
}
if codeVerifier != "" {
authParams = append(authParams, oauth2.SetAuthURLParam("code_verifier", codeVerifier))
authParams = append(authParams, oauth2.VerifierOption(codeVerifier))
}

token, err := p.config.Exchange(goth.ContextForClient(p.Client()), params.Get("code"), authParams...)
Expand Down