@@ -53,10 +53,12 @@ func (cfg Config) AuthServiceConfigType() string {
5353
5454// Initialize a generic auth service
5555func (cfg Config ) Initialize () (auth.AuthService , error ) {
56- // Discover the JWKS URL from the OIDC configuration endpoint
57- jwksURL , err := discoverJWKSURL (cfg .AuthorizationServer )
56+ httpClient := newSecureHTTPClient ()
57+
58+ // Discover OIDC endpoints
59+ jwksURL , introspectionURL , err := discoverOIDCConfig (httpClient , cfg .AuthorizationServer )
5860 if err != nil {
59- return nil , fmt .Errorf ("failed to discover JWKS URL : %w" , err )
61+ return nil , fmt .Errorf ("failed to discover OIDC config : %w" , err )
6062 }
6163
6264 // Create the keyfunc to fetch and cache the JWKS in the background
@@ -66,8 +68,10 @@ func (cfg Config) Initialize() (auth.AuthService, error) {
6668 }
6769
6870 a := & AuthService {
69- Config : cfg ,
70- kf : kf ,
71+ Config : cfg ,
72+ kf : kf ,
73+ client : httpClient ,
74+ introspectionURL : introspectionURL ,
7175 }
7276 return a , nil
7377}
@@ -88,68 +92,68 @@ func newSecureHTTPClient() *http.Client {
8892 }
8993}
9094
91- func discoverJWKSURL ( AuthorizationServer string ) (string , error ) {
95+ func discoverOIDCConfig ( client * http. Client , AuthorizationServer string ) (jwksURI string , introspectionEndpoint string , err error ) {
9296 u , err := url .Parse (AuthorizationServer )
9397 if err != nil {
94- return "" , fmt .Errorf ("invalid auth URL" )
98+ return "" , "" , fmt .Errorf ("invalid auth URL" )
9599 }
96100 if u .Scheme != "https" {
97101 log .Printf ("WARNING: HTTP instead of HTTPS is being used for AuthorizationServer: %s" , AuthorizationServer )
98102 }
99103
100104 oidcConfigURL , err := url .JoinPath (AuthorizationServer , ".well-known/openid-configuration" )
101105 if err != nil {
102- return "" , err
106+ return "" , "" , err
103107 }
104108
105- // HTTP Client
106- client := newSecureHTTPClient ()
107-
108109 resp , err := client .Get (oidcConfigURL )
109110 if err != nil {
110- return "" , fmt .Errorf ("failed to fetch OIDC config: %w" , err )
111+ return "" , "" , fmt .Errorf ("failed to fetch OIDC config: %w" , err )
111112 }
112113 defer resp .Body .Close ()
113114
114115 if resp .StatusCode != http .StatusOK {
115- return "" , fmt .Errorf ("unexpected status: %d" , resp .StatusCode )
116+ return "" , "" , fmt .Errorf ("unexpected status: %d" , resp .StatusCode )
116117 }
117118
118119 // Limit read size to 1MB to prevent memory exhaustion
119120 body , err := io .ReadAll (io .LimitReader (resp .Body , 1 << 20 ))
120121 if err != nil {
121- return "" , err
122+ return "" , "" , err
122123 }
123124
124125 var config struct {
125- JWKSURI string `json:"jwks_uri"`
126+ JwksUri string `json:"jwks_uri"`
127+ IntrospectionEndpoint string `json:"introspection_endpoint"`
126128 }
127129 if err := json .Unmarshal (body , & config ); err != nil {
128- return "" , err
130+ return "" , "" , err
129131 }
130132
131- if config .JWKSURI == "" {
132- return "" , fmt .Errorf ("jwks_uri not found in config" )
133+ if config .JwksUri == "" {
134+ return "" , "" , fmt .Errorf ("jwks_uri not found in config" )
133135 }
134136
135137 // Sanitize the resulting JWKS URI before returning it
136- parsedJWKS , err := url .Parse (config .JWKSURI )
138+ parsedJWKS , err := url .Parse (config .JwksUri )
137139 if err != nil {
138- return "" , fmt .Errorf ("invalid jwks_uri detected" )
140+ return "" , "" , fmt .Errorf ("invalid jwks_uri detected" )
139141 }
140142 if parsedJWKS .Scheme != "https" {
141- log .Printf ("WARNING: HTTP instead of HTTPS is being used for JWKS URI: %s" , config .JWKSURI )
143+ log .Printf ("WARNING: HTTP instead of HTTPS is being used for JWKS URI: %s" , config .JwksUri )
142144 }
143145
144- return config .JWKSURI , nil
146+ return config .JwksUri , config . IntrospectionEndpoint , nil
145147}
146148
147149var _ auth.AuthService = AuthService {}
148150
149151// struct used to store auth service info
150152type AuthService struct {
151153 Config
152- kf keyfunc.Keyfunc
154+ kf keyfunc.Keyfunc
155+ client * http.Client
156+ introspectionURL string
153157}
154158
155159// Returns the auth service type
@@ -246,6 +250,7 @@ func isJWTFormat(token string) bool {
246250 return strings .Count (token , "." ) == 2
247251}
248252
253+ // validateJwtToken validates a JWT token locally
249254func (a AuthService ) validateJwtToken (ctx context.Context , tokenStr string ) error {
250255 token , err := jwt .Parse (tokenStr , a .kf .Keyfunc )
251256 if err != nil || ! token .Valid {
@@ -263,50 +268,24 @@ func (a AuthService) validateJwtToken(ctx context.Context, tokenStr string) erro
263268 return & MCPAuthError {Code : http .StatusUnauthorized , Message : "could not parse audience from token" , ScopesRequired : a .ScopesRequired }
264269 }
265270
266- isAudValid := false
267- for _ , audItem := range aud {
268- if audItem == a .Audience {
269- isAudValid = true
270- break
271- }
272- }
273-
274- if ! isAudValid {
275- return & MCPAuthError {Code : http .StatusUnauthorized , Message : "audience validation failed" , ScopesRequired : a .ScopesRequired }
276- }
277-
278- // Check scopes
279- if len (a .ScopesRequired ) > 0 {
280- scopeClaim , ok := claims ["scope" ].(string )
281- if ! ok {
282- return & MCPAuthError {Code : http .StatusForbidden , Message : "insufficient scopes" , ScopesRequired : a .ScopesRequired }
283- }
284-
285- tokenScopes := strings .Split (scopeClaim , " " )
286- scopeMap := make (map [string ]bool )
287- for _ , s := range tokenScopes {
288- scopeMap [s ] = true
289- }
290-
291- for _ , requiredScope := range a .ScopesRequired {
292- if ! scopeMap [requiredScope ] {
293- return & MCPAuthError {Code : http .StatusForbidden , Message : "insufficient scopes" , ScopesRequired : a .ScopesRequired }
294- }
295- }
296- }
271+ scopeClaim , _ := claims ["scope" ].(string )
297272
298- return nil
273+ return a . validateClaims ( ctx , aud , scopeClaim )
299274}
300275
276+ // validateOpaqueToken validates an opaque token by calling the introspection endpoint
301277func (a AuthService ) validateOpaqueToken (ctx context.Context , tokenStr string ) error {
302278 logger , err := util .LoggerFromContext (ctx )
303279 if err != nil {
304280 return fmt .Errorf ("failed to get logger from context: %w" , err )
305281 }
306282
307- introspectionURL , err := url .JoinPath (a .AuthorizationServer , "introspect" )
308- if err != nil {
309- return fmt .Errorf ("failed to construct introspection URL: %w" , err )
283+ introspectionURL := a .introspectionURL
284+ if introspectionURL == "" {
285+ introspectionURL , err = url .JoinPath (a .AuthorizationServer , "introspect" )
286+ if err != nil {
287+ return fmt .Errorf ("failed to construct introspection URL: %w" , err )
288+ }
310289 }
311290
312291 data := url.Values {}
@@ -320,9 +299,7 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e
320299 req .Header .Set ("Accept" , "application/json" )
321300
322301 // Send request to auth server's introspection endpoint
323- client := newSecureHTTPClient ()
324-
325- resp , err := client .Do (req )
302+ resp , err := a .client .Do (req )
326303 if err != nil {
327304 logger .ErrorContext (ctx , "failed to call introspection endpoint: %v" , err )
328305 return & MCPAuthError {Code : http .StatusInternalServerError , Message : fmt .Sprintf ("failed to call introspection endpoint: %v" , err ), ScopesRequired : a .ScopesRequired }
@@ -340,10 +317,10 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e
340317 }
341318
342319 var introspectResp struct {
343- Active bool `json:"active"`
344- Scope string `json:"scope"`
345- ClientId string `json:"client_id "`
346- Exp int64 `json:"exp"`
320+ Active bool `json:"active"`
321+ Scope string `json:"scope"`
322+ Aud json. RawMessage `json:"aud "`
323+ Exp int64 `json:"exp"`
347324 }
348325
349326 if err := json .Unmarshal (body , & introspectResp ); err != nil {
@@ -355,29 +332,66 @@ func (a AuthService) validateOpaqueToken(ctx context.Context, tokenStr string) e
355332 return & MCPAuthError {Code : http .StatusUnauthorized , Message : "token is not active" , ScopesRequired : a .ScopesRequired }
356333 }
357334
358- // Verify audience (client_id)
359- if a .Audience != "" && introspectResp .ClientId != a .Audience {
360- logger .WarnContext (ctx , "audience validation failed: expected %s, got %s" , a .Audience , introspectResp .ClientId )
361- return & MCPAuthError {Code : http .StatusUnauthorized , Message : "audience validation failed" , ScopesRequired : a .ScopesRequired }
362- }
363-
364- // Verify expiration (with 1 minute leeway) to account for potential time difference between Toolbox and the auth server
335+ // Verify expiration (with 1 minute leeway)
365336 const leeway = 60
366337 if introspectResp .Exp > 0 && time .Now ().Unix () > (introspectResp .Exp + leeway ) {
367338 logger .WarnContext (ctx , "token has expired: exp=%d, now=%d" , introspectResp .Exp , time .Now ().Unix ())
368339 return & MCPAuthError {Code : http .StatusUnauthorized , Message : "token has expired" , ScopesRequired : a .ScopesRequired }
369340 }
370341
371- // Verify scopes
342+ // Extract audience
343+ // According to RFC 7662, the aud claim can be a string or an array of strings
344+ var aud []string
345+ if len (introspectResp .Aud ) > 0 {
346+ var audStr string
347+ var audArr []string
348+ if err := json .Unmarshal (introspectResp .Aud , & audStr ); err == nil {
349+ aud = []string {audStr }
350+ } else if err := json .Unmarshal (introspectResp .Aud , & audArr ); err == nil {
351+ aud = audArr
352+ } else {
353+ logger .WarnContext (ctx , "failed to parse aud claim in introspection response" )
354+ return & MCPAuthError {Code : http .StatusUnauthorized , Message : "invalid aud claim" , ScopesRequired : a .ScopesRequired }
355+ }
356+ }
357+
358+ return a .validateClaims (ctx , aud , introspectResp .Scope )
359+ }
360+
361+ // validateClaims validates the audience and scopes of a token
362+ func (a AuthService ) validateClaims (ctx context.Context , aud []string , scopeStr string ) error {
363+ logger , err := util .LoggerFromContext (ctx )
364+ if err != nil {
365+ return fmt .Errorf ("failed to get logger from context: %w" , err )
366+ }
367+
368+ // Validate audience
369+ if a .Audience != "" {
370+ isAudValid := false
371+ for _ , audItem := range aud {
372+ if audItem == a .Audience {
373+ isAudValid = true
374+ break
375+ }
376+ }
377+
378+ if ! isAudValid {
379+ logger .WarnContext (ctx , "audience validation failed: expected %s" , a .Audience )
380+ return & MCPAuthError {Code : http .StatusUnauthorized , Message : "audience validation failed" , ScopesRequired : a .ScopesRequired }
381+ }
382+ }
383+
384+ // Check scopes
372385 if len (a .ScopesRequired ) > 0 {
373- tokenScopes := strings .Split (introspectResp . Scope , " " )
386+ tokenScopes := strings .Split (scopeStr , " " )
374387 scopeMap := make (map [string ]bool )
375388 for _ , s := range tokenScopes {
376389 scopeMap [s ] = true
377390 }
378391
379392 for _ , requiredScope := range a .ScopesRequired {
380393 if ! scopeMap [requiredScope ] {
394+ logger .WarnContext (ctx , "insufficient scopes: missing %s" , requiredScope )
381395 return & MCPAuthError {Code : http .StatusForbidden , Message : "insufficient scopes" , ScopesRequired : a .ScopesRequired }
382396 }
383397 }
0 commit comments