Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions connector/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,17 @@ func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Ser
})
})

mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) {
url := fmt.Sprintf("http://%s", r.Host)

json.NewEncoder(w).Encode(&map[string]string{
"issuer": url,
"token_endpoint": fmt.Sprintf("%s/token", url),
"authorization_endpoint": fmt.Sprintf("%s/authorize", url),
"jwks_uri": fmt.Sprintf("%s/keys", url),
})
})

return httptest.NewServer(mux), nil
}

Expand Down
2 changes: 1 addition & 1 deletion server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func (d dexAPI) GetVersion(ctx context.Context, req *api.VersionReq) (*api.Versi
}

func (d dexAPI) GetDiscovery(ctx context.Context, req *api.DiscoveryReq) (*api.DiscoveryResp, error) {
discoveryDoc := d.server.constructDiscovery()
discoveryDoc := d.server.constructDiscoveryOIDC()
data, err := json.Marshal(discoveryDoc)
if err != nil {
return nil, fmt.Errorf("failed to marshal discovery data: %v", err)
Expand Down
63 changes: 58 additions & 5 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
w.Write(data)
}

type discovery struct {
type discoveryOIDC struct {
Issuer string `json:"issuer"`
Auth string `json:"authorization_endpoint"`
Token string `json:"token_endpoint"`
Expand All @@ -90,8 +90,36 @@ type discovery struct {
Claims []string `json:"claims_supported"`
}

func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
d := s.constructDiscovery()
type discoveryOAuth2 struct {
Issuer string `json:"issuer"`
Auth string `json:"authorization_endpoint"`
Token string `json:"token_endpoint"`
Keys string `json:"jwks_uri"`
DeviceEndpoint string `json:"device_authorization_endpoint,omitempty"`
Introspect string `json:"introspection_endpoint,omitempty"`
GrantTypes []string `json:"grant_types_supported"`
ResponseTypes []string `json:"response_types_supported"`
CodeChallengeAlgs []string `json:"code_challenge_methods_supported,omitempty"`
Scopes []string `json:"scopes_supported,omitempty"`
AuthMethods []string `json:"token_endpoint_auth_methods_supported,omitempty"`
}

type DiscoveryType int

const (
DiscoveryOIDC DiscoveryType = iota
DiscoveryOAuth2
)

func (s *Server) discoveryHandler(t DiscoveryType) (http.HandlerFunc, error) {
var d interface{}

switch t {
case DiscoveryOAuth2:
d = s.constructDiscoveryOAuth2()
default:
d = s.constructDiscoveryOIDC()
}

data, err := json.MarshalIndent(d, "", " ")
if err != nil {
Expand All @@ -105,8 +133,8 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
}), nil
}

func (s *Server) constructDiscovery() discovery {
d := discovery{
func (s *Server) constructDiscoveryOIDC() discoveryOIDC {
d := discoveryOIDC{
Issuer: s.issuerURL.String(),
Auth: s.absURL("/auth"),
Token: s.absURL("/token"),
Expand Down Expand Up @@ -134,6 +162,31 @@ func (s *Server) constructDiscovery() discovery {
return d
}

func (s *Server) constructDiscoveryOAuth2() discoveryOAuth2 {
d := discoveryOAuth2{
Issuer: s.issuerURL.String(),
Auth: s.absURL("/auth"),
Token: s.absURL("/token"),
Keys: s.absURL("/keys"),
DeviceEndpoint: s.absURL("/device/code"),
Introspect: s.absURL("/token/introspect"),
CodeChallengeAlgs: []string{codeChallengeMethodS256, codeChallengeMethodPlain},
Scopes: []string{"offline_access"},
AuthMethods: []string{"client_secret_basic", "client_secret_post"},
}

// response_types_supported
for responseType := range s.supportedResponseTypes {
d.ResponseTypes = append(d.ResponseTypes, responseType)
}
sort.Strings(d.ResponseTypes)

// grant_types_supported
d.GrantTypes = s.supportedGrantTypes

return d
}

// handleAuthorization handles the OAuth2 auth endpoint.
func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
Expand Down
51 changes: 48 additions & 3 deletions server/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestHandleHealth(t *testing.T) {
}
}

func TestHandleDiscovery(t *testing.T) {
func TestHandleDiscoveryOIDC(t *testing.T) {
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()

Expand All @@ -44,10 +44,10 @@ func TestHandleDiscovery(t *testing.T) {
t.Errorf("expected 200 got %d", rr.Code)
}

var res discovery
var res discoveryOIDC
err := json.NewDecoder(rr.Result().Body).Decode(&res)
require.NoError(t, err)
require.Equal(t, discovery{
require.Equal(t, discoveryOIDC{
Issuer: httpServer.URL,
Auth: fmt.Sprintf("%s/auth", httpServer.URL),
Token: fmt.Sprintf("%s/token", httpServer.URL),
Expand Down Expand Up @@ -101,6 +101,51 @@ func TestHandleDiscovery(t *testing.T) {
}, res)
}

func TestHandleDiscoveryOAuth2(t *testing.T) {
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()

rr := httptest.NewRecorder()
server.ServeHTTP(rr, httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil))

if rr.Code != http.StatusOK {
t.Errorf("expected 200 got %d", rr.Code)
}

var res discoveryOAuth2
err := json.NewDecoder(rr.Result().Body).Decode(&res)
require.NoError(t, err)

require.Equal(t, discoveryOAuth2{
Issuer: httpServer.URL,
Auth: fmt.Sprintf("%s/auth", httpServer.URL),
Token: fmt.Sprintf("%s/token", httpServer.URL),
Keys: fmt.Sprintf("%s/keys", httpServer.URL),
DeviceEndpoint: fmt.Sprintf("%s/device/code", httpServer.URL),
Introspect: fmt.Sprintf("%s/token/introspect", httpServer.URL),
GrantTypes: []string{
"authorization_code",
"refresh_token",
"urn:ietf:params:oauth:grant-type:device_code",
"urn:ietf:params:oauth:grant-type:token-exchange",
},
ResponseTypes: []string{
"code",
},
CodeChallengeAlgs: []string{
"S256",
"plain",
},
Scopes: []string{
"offline_access",
},
AuthMethods: []string{
"client_secret_basic",
"client_secret_post",
},
}, res)
}

func TestHandleHealthFailure(t *testing.T) {
httpServer, server := newTestServer(t, func(c *Config) {
c.HealthChecker = gosundheit.New()
Expand Down
10 changes: 8 additions & 2 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,17 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
}
r.NotFoundHandler = http.NotFoundHandler()

discoveryHandler, err := s.discoveryHandler()
oidcHandler, err := s.discoveryHandler(DiscoveryOIDC)
if err != nil {
return nil, err
}
handleWithCORS("/.well-known/openid-configuration", discoveryHandler)
handleWithCORS("/.well-known/openid-configuration", oidcHandler)

oauthHandler, err := s.discoveryHandler(DiscoveryOAuth2)
if err != nil {
return nil, err
}
handleWithCORS("/.well-known/oauth-authorization-server", oauthHandler)
// Handle the root path for the better user experience.
handleWithCORS("/", func(w http.ResponseWriter, r *http.Request) {
_, err := fmt.Fprintf(w, `<!DOCTYPE html>
Expand Down