Skip to content

Commit 60a1691

Browse files
committed
resolve comments
1 parent a8e001c commit 60a1691

2 files changed

Lines changed: 204 additions & 95 deletions

File tree

internal/auth/generic/generic.go

Lines changed: 88 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,12 @@ func (cfg Config) AuthServiceConfigType() string {
5353

5454
// Initialize a generic auth service
5555
func (cfg Config) Initialize() (auth.AuthService, error) {
56-
// Discover the JWKS URL from the OIDC configuration endpoint
57-
jwksURL, err := discoverJWKSURL(cfg.AuthorizationServer)
56+
httpClient := newSecureHTTPClient()
57+
58+
// Discover OIDC endpoints
59+
jwksURL, introspectionURL, err := discoverOIDCConfig(httpClient, cfg.AuthorizationServer)
5860
if err != nil {
59-
return nil, fmt.Errorf("failed to discover JWKS URL: %w", err)
61+
return nil, fmt.Errorf("failed to discover OIDC config: %w", err)
6062
}
6163

6264
// Create the keyfunc to fetch and cache the JWKS in the background
@@ -66,8 +68,10 @@ func (cfg Config) Initialize() (auth.AuthService, error) {
6668
}
6769

6870
a := &AuthService{
69-
Config: cfg,
70-
kf: kf,
71+
Config: cfg,
72+
kf: kf,
73+
client: httpClient,
74+
introspectionURL: introspectionURL,
7175
}
7276
return a, nil
7377
}
@@ -88,68 +92,68 @@ func newSecureHTTPClient() *http.Client {
8892
}
8993
}
9094

91-
func discoverJWKSURL(AuthorizationServer string) (string, error) {
95+
func discoverOIDCConfig(client *http.Client, AuthorizationServer string) (jwksURI string, introspectionEndpoint string, err error) {
9296
u, err := url.Parse(AuthorizationServer)
9397
if err != nil {
94-
return "", fmt.Errorf("invalid auth URL")
98+
return "", "", fmt.Errorf("invalid auth URL")
9599
}
96100
if u.Scheme != "https" {
97101
log.Printf("WARNING: HTTP instead of HTTPS is being used for AuthorizationServer: %s", AuthorizationServer)
98102
}
99103

100104
oidcConfigURL, err := url.JoinPath(AuthorizationServer, ".well-known/openid-configuration")
101105
if err != nil {
102-
return "", err
106+
return "", "", err
103107
}
104108

105-
// HTTP Client
106-
client := newSecureHTTPClient()
107-
108109
resp, err := client.Get(oidcConfigURL)
109110
if err != nil {
110-
return "", fmt.Errorf("failed to fetch OIDC config: %w", err)
111+
return "", "", fmt.Errorf("failed to fetch OIDC config: %w", err)
111112
}
112113
defer resp.Body.Close()
113114

114115
if resp.StatusCode != http.StatusOK {
115-
return "", fmt.Errorf("unexpected status: %d", resp.StatusCode)
116+
return "", "", fmt.Errorf("unexpected status: %d", resp.StatusCode)
116117
}
117118

118119
// Limit read size to 1MB to prevent memory exhaustion
119120
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
120121
if err != nil {
121-
return "", err
122+
return "", "", err
122123
}
123124

124125
var config struct {
125-
JWKSURI string `json:"jwks_uri"`
126+
JwksUri string `json:"jwks_uri"`
127+
IntrospectionEndpoint string `json:"introspection_endpoint"`
126128
}
127129
if err := json.Unmarshal(body, &config); err != nil {
128-
return "", err
130+
return "", "", err
129131
}
130132

131-
if config.JWKSURI == "" {
132-
return "", fmt.Errorf("jwks_uri not found in config")
133+
if config.JwksUri == "" {
134+
return "", "", fmt.Errorf("jwks_uri not found in config")
133135
}
134136

135137
// Sanitize the resulting JWKS URI before returning it
136-
parsedJWKS, err := url.Parse(config.JWKSURI)
138+
parsedJWKS, err := url.Parse(config.JwksUri)
137139
if err != nil {
138-
return "", fmt.Errorf("invalid jwks_uri detected")
140+
return "", "", fmt.Errorf("invalid jwks_uri detected")
139141
}
140142
if parsedJWKS.Scheme != "https" {
141-
log.Printf("WARNING: HTTP instead of HTTPS is being used for JWKS URI: %s", config.JWKSURI)
143+
log.Printf("WARNING: HTTP instead of HTTPS is being used for JWKS URI: %s", config.JwksUri)
142144
}
143145

144-
return config.JWKSURI, nil
146+
return config.JwksUri, config.IntrospectionEndpoint, nil
145147
}
146148

147149
var _ auth.AuthService = AuthService{}
148150

149151
// struct used to store auth service info
150152
type AuthService struct {
151153
Config
152-
kf keyfunc.Keyfunc
154+
kf keyfunc.Keyfunc
155+
client *http.Client
156+
introspectionURL string
153157
}
154158

155159
// Returns the auth service type
@@ -246,6 +250,7 @@ func isJWTFormat(token string) bool {
246250
return strings.Count(token, ".") == 2
247251
}
248252

253+
// validateJwtToken validates a JWT token locally
249254
func (a AuthService) validateJwtToken(ctx context.Context, tokenStr string) error {
250255
token, err := jwt.Parse(tokenStr, a.kf.Keyfunc)
251256
if err != nil || !token.Valid {
@@ -263,50 +268,24 @@ func (a AuthService) validateJwtToken(ctx context.Context, tokenStr string) erro
263268
return &MCPAuthError{Code: http.StatusUnauthorized, Message: "could not parse audience from token", ScopesRequired: a.ScopesRequired}
264269
}
265270

266-
isAudValid := false
267-
for _, audItem := range aud {
268-
if audItem == a.Audience {
269-
isAudValid = true
270-
break
271-
}
272-
}
273-
274-
if !isAudValid {
275-
return &MCPAuthError{Code: http.StatusUnauthorized, Message: "audience validation failed", ScopesRequired: a.ScopesRequired}
276-
}
277-
278-
// Check scopes
279-
if len(a.ScopesRequired) > 0 {
280-
scopeClaim, ok := claims["scope"].(string)
281-
if !ok {
282-
return &MCPAuthError{Code: http.StatusForbidden, Message: "insufficient scopes", ScopesRequired: a.ScopesRequired}
283-
}
284-
285-
tokenScopes := strings.Split(scopeClaim, " ")
286-
scopeMap := make(map[string]bool)
287-
for _, s := range tokenScopes {
288-
scopeMap[s] = true
289-
}
290-
291-
for _, requiredScope := range a.ScopesRequired {
292-
if !scopeMap[requiredScope] {
293-
return &MCPAuthError{Code: http.StatusForbidden, Message: "insufficient scopes", ScopesRequired: a.ScopesRequired}
294-
}
295-
}
296-
}
271+
scopeClaim, _ := claims["scope"].(string)
297272

298-
return nil
273+
return a.validateClaims(ctx, aud, scopeClaim)
299274
}
300275

276+
// validateOpaqueToken validates an opaque token by calling the introspection endpoint
301277
func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) error {
302278
logger, err := util.LoggerFromContext(ctx)
303279
if err != nil {
304280
return fmt.Errorf("failed to get logger from context: %w", err)
305281
}
306282

307-
introspectionURL, err := url.JoinPath(a.AuthorizationServer, "introspect")
308-
if err != nil {
309-
return fmt.Errorf("failed to construct introspection URL: %w", err)
283+
introspectionURL := a.introspectionURL
284+
if introspectionURL == "" {
285+
introspectionURL, err = url.JoinPath(a.AuthorizationServer, "introspect")
286+
if err != nil {
287+
return fmt.Errorf("failed to construct introspection URL: %w", err)
288+
}
310289
}
311290

312291
data := url.Values{}
@@ -320,9 +299,7 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e
320299
req.Header.Set("Accept", "application/json")
321300

322301
// Send request to auth server's introspection endpoint
323-
client := newSecureHTTPClient()
324-
325-
resp, err := client.Do(req)
302+
resp, err := a.client.Do(req)
326303
if err != nil {
327304
logger.ErrorContext(ctx, "failed to call introspection endpoint: %v", err)
328305
return &MCPAuthError{Code: http.StatusInternalServerError, Message: fmt.Sprintf("failed to call introspection endpoint: %v", err), ScopesRequired: a.ScopesRequired}
@@ -340,10 +317,10 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e
340317
}
341318

342319
var introspectResp struct {
343-
Active bool `json:"active"`
344-
Scope string `json:"scope"`
345-
ClientId string `json:"client_id"`
346-
Exp int64 `json:"exp"`
320+
Active bool `json:"active"`
321+
Scope string `json:"scope"`
322+
Aud json.RawMessage `json:"aud"`
323+
Exp int64 `json:"exp"`
347324
}
348325

349326
if err := json.Unmarshal(body, &introspectResp); err != nil {
@@ -355,29 +332,66 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e
355332
return &MCPAuthError{Code: http.StatusUnauthorized, Message: "token is not active", ScopesRequired: a.ScopesRequired}
356333
}
357334

358-
// Verify audience (client_id)
359-
if a.Audience != "" && introspectResp.ClientId != a.Audience {
360-
logger.WarnContext(ctx, "audience validation failed: expected %s, got %s", a.Audience, introspectResp.ClientId)
361-
return &MCPAuthError{Code: http.StatusUnauthorized, Message: "audience validation failed", ScopesRequired: a.ScopesRequired}
362-
}
363-
364-
// Verify expiration (with 1 minute leeway) to account for potential time difference between Toolbox and the auth server
335+
// Verify expiration (with 1 minute leeway)
365336
const leeway = 60
366337
if introspectResp.Exp > 0 && time.Now().Unix() > (introspectResp.Exp+leeway) {
367338
logger.WarnContext(ctx, "token has expired: exp=%d, now=%d", introspectResp.Exp, time.Now().Unix())
368339
return &MCPAuthError{Code: http.StatusUnauthorized, Message: "token has expired", ScopesRequired: a.ScopesRequired}
369340
}
370341

371-
// Verify scopes
342+
// Extract audience
343+
// According to RFC 7662, the aud claim can be a string or an array of strings
344+
var aud []string
345+
if len(introspectResp.Aud) > 0 {
346+
var audStr string
347+
var audArr []string
348+
if err := json.Unmarshal(introspectResp.Aud, &audStr); err == nil {
349+
aud = []string{audStr}
350+
} else if err := json.Unmarshal(introspectResp.Aud, &audArr); err == nil {
351+
aud = audArr
352+
} else {
353+
logger.WarnContext(ctx, "failed to parse aud claim in introspection response")
354+
return &MCPAuthError{Code: http.StatusUnauthorized, Message: "invalid aud claim", ScopesRequired: a.ScopesRequired}
355+
}
356+
}
357+
358+
return a.validateClaims(ctx, aud, introspectResp.Scope)
359+
}
360+
361+
// validateClaims validates the audience and scopes of a token
362+
func (a AuthService) validateClaims(ctx context.Context, aud []string, scopeStr string) error {
363+
logger, err := util.LoggerFromContext(ctx)
364+
if err != nil {
365+
return fmt.Errorf("failed to get logger from context: %w", err)
366+
}
367+
368+
// Validate audience
369+
if a.Audience != "" {
370+
isAudValid := false
371+
for _, audItem := range aud {
372+
if audItem == a.Audience {
373+
isAudValid = true
374+
break
375+
}
376+
}
377+
378+
if !isAudValid {
379+
logger.WarnContext(ctx, "audience validation failed: expected %s", a.Audience)
380+
return &MCPAuthError{Code: http.StatusUnauthorized, Message: "audience validation failed", ScopesRequired: a.ScopesRequired}
381+
}
382+
}
383+
384+
// Check scopes
372385
if len(a.ScopesRequired) > 0 {
373-
tokenScopes := strings.Split(introspectResp.Scope, " ")
386+
tokenScopes := strings.Split(scopeStr, " ")
374387
scopeMap := make(map[string]bool)
375388
for _, s := range tokenScopes {
376389
scopeMap[s] = true
377390
}
378391

379392
for _, requiredScope := range a.ScopesRequired {
380393
if !scopeMap[requiredScope] {
394+
logger.WarnContext(ctx, "insufficient scopes: missing %s", requiredScope)
381395
return &MCPAuthError{Code: http.StatusForbidden, Message: "insufficient scopes", ScopesRequired: a.ScopesRequired}
382396
}
383397
}

0 commit comments

Comments
 (0)