From f2e2ddb524e5835c35d663cc26b313af696c61e2 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Tue, 3 Feb 2026 16:09:31 +0000 Subject: [PATCH 01/26] all: WIP client side OAuth support --- .github/workflows/conformance.yml | 5 +- auth/authorization_code.go | 336 ++++++++++++++++++++++++++ auth/client.go | 71 +++++- conformance/baseline.yml | 20 +- conformance/everything-client/main.go | 207 +++++++++++++++- mcp/streamable.go | 99 ++++++-- mcp/streamable_test.go | 27 ++- oauthex/auth_meta.go | 89 +++++-- oauthex/oauth2.go | 12 +- oauthex/resource_meta.go | 98 +++++++- scripts/client-conformance.sh | 13 +- scripts/server-conformance.sh | 2 +- 12 files changed, 885 insertions(+), 94 deletions(-) create mode 100644 auth/authorization_code.go diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index 332975cf..8d9dc46b 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -25,7 +25,8 @@ jobs: go-version: "^1.25" - name: Start everything-server run: | - go run ./conformance/everything-server/main.go -http=":3001" & + # TODO: mcp_go_client_oauth should not be needed here, but the client OAuth abstractions have spilled into the streamable transport. + go run -tags mcp_go_client_oauth ./conformance/everything-server/main.go -http=":3001" & # Wait for the server to be ready. timeout 15 bash -c 'until curl -s http://localhost:3001/mcp; do sleep 0.5; done' - name: "Run conformance tests" @@ -49,6 +50,6 @@ jobs: uses: modelcontextprotocol/conformance@c2f3fdaf781dcd5a862cb0d2f6454c1c210bf0f0 # v0.1.11 with: mode: client - command: go run ./conformance/everything-client/main.go + command: go run -tags mcp_go_client_oauth ./conformance/everything-client/main.go suite: core expected-failures: ./conformance/baseline.yml diff --git a/auth/authorization_code.go b/auth/authorization_code.go new file mode 100644 index 00000000..8ea20772 --- /dev/null +++ b/auth/authorization_code.go @@ -0,0 +1,336 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "log" + "net/http" + "net/url" + "slices" + + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// ErrRedirected is returned when the user was redirected to the authorization server. +var ErrRedirected = errors.New("redirected") + +// ClientIDMetadataDocumentConfig is used to configure the Client ID Metadata Document +// based client registration per +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents. +// See https://client.dev/ for more information. +type ClientIDMetadataDocumentConfig struct { + // URL is the client identifier URL as per + // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-client-id-metadata-document-00#section-3. + URL string +} + +// PreregisteredClientConfig is used to configure a pre-registered client per +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration. +type PreregisteredClientConfig struct { + ClientID string + ClientSecret string + AuthStyle oauth2.AuthStyle +} + +// DynamicClientRegistrationConfig is used to configure dynamic client registration per +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration. +type DynamicClientRegistrationConfig struct { + // Metadata to be used in dynamic client registration request as per + // https://datatracker.ietf.org/doc/html/rfc7591#section-2. + Metadata *oauthex.ClientRegistrationMetadata +} + +type registrationType int + +const ( + registrationTypeClientIDMetadataDocument registrationType = iota + registrationTypePreregistered + registrationTypeDynamic +) + +type resolvedClientConfig struct { + registrationType registrationType + clientID string + clientSecret string + authStyle oauth2.AuthStyle +} + +type AuthorizationCodeOAuthHandler struct { + // Client registration configuration. + // It is attempted in the following order: + // 1. Client ID Metadata Document + // 2. Preregistration + // 3. Dynamic Client Registration + ClientIDMetadataDocumentConfig *ClientIDMetadataDocumentConfig + PreregisteredClientConfig *PreregisteredClientConfig + DynamicClientRegistrationConfig *DynamicClientRegistrationConfig + + // RedirectURL is the URL to redirect to after authorization. + // If Dynamic Client Registration is used, the RedirectURL must be consistent + // with [DynamicClientRegistrationConfig.Metadata.RedirectURIs]. + RedirectURL string + + // AuthorizationURLHandler is called to handle the authorization URL. + // It is responsible for opening the URL in a browser. + // It should return once the redirect has been issued. + // The redirect callback should be handled by the caller and the authorization code + // should be set by calling [SetAuthorizationCode] before retrying the request. + AuthorizationURLHandler func(ctx context.Context, authorizationURL string) error + + // StateProvider is an optional function to generate a state string for authorization + // requests. If not provided, a random string will be generated. + // The state should be validated on the redirect callback. + StateProvider func() string + + // TokenStore is an optional object that allows persistent storage of tokens. + TokenStore TokenStore + + // resolvedClientConfig used during the authorization flow. + resolvedClientConfig *resolvedClientConfig + // tokenSource is the token source to use for authorization. + // It can be prepopulated by calling [SetTokenSource]. + tokenSource oauth2.TokenSource + // codeVerifier is the PKCE code verifier. + codeVerifier string + // authorizationCode is the authorization code obtained from the authorization server. + authorizationCode string + // state is the state string used in the authorization request. + state string +} + +func (h *AuthorizationCodeOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return h.tokenSource, nil +} + +// TODO: extract some logic into helper functions. +// TODO: validate required args +func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + defer resp.Body.Close() + log.Printf("Authorize: %s %s", req.Method, req.URL) + + if h.resolvedClientConfig == nil && h.authorizationCode != "" { + return fmt.Errorf("exchanging authorization code with unregistered client is not allowed") + } + + resourceURL := req.URL.String() + challenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) + if err != nil { + return fmt.Errorf("failed to parse WWW-Authenticate header: %v", err) + } + log.Printf("WWW-Authenticate header: %v", challenges) + var prm *oauthex.ProtectedResourceMetadata + for _, url := range oauthex.ProtectedResourceMetadataURLs(oauthex.ResourceMetadataURL(challenges), resourceURL) { + var err error + log.Printf("Getting protected resource metadata from %q", url) + prm, err = oauthex.GetProtectedResourceMetadata(ctx, url, http.DefaultClient) + if err == nil { + break + } + log.Printf("Failed to get protected resource metadata from %q: %v", url, err) + } + var authServerURL string + if prm != nil && len(prm.AuthorizationServers) > 0 { + // Use the first authorization server, similarly to other SDKs. + authServerURL = prm.AuthorizationServers[0] + } else { + // Fallback to 2025-03-26 spec: MCP server base URL acts as Authorization Server. + authURL, err := url.Parse(resourceURL) + if err != nil { + return fmt.Errorf("failed to parse resource URL: %v", err) + } + authURL.Path = "" + authServerURL = authURL.String() + } + log.Printf("Authorization server URL: %s", authServerURL) + + asm, err := oauthex.GetAuthServerMeta(ctx, authServerURL, http.DefaultClient) + if err != nil { + return fmt.Errorf("failed to get authorization server metadata: %w", err) + } + log.Print("Authorization server medatada fetched") + + if err := h.handleRegistration(ctx, authServerURL, asm); err != nil { + return err + } + + scopes := oauthex.Scopes(challenges) + if len(scopes) == 0 && prm != nil && len(prm.ScopesSupported) > 0 { + scopes = prm.ScopesSupported + } + + var authorizationEndpoint, tokenEndpoint string + if asm != nil { + authorizationEndpoint = asm.AuthorizationEndpoint + tokenEndpoint = asm.TokenEndpoint + } else { + // Fallback to 2025-03-26 spec: predefined endpoints if not provided by AS. + authorizationEndpoint = authServerURL + "/authorize" + tokenEndpoint = authServerURL + "/token" + } + + cfg := &oauth2.Config{ + ClientID: h.resolvedClientConfig.clientID, + ClientSecret: h.resolvedClientConfig.clientSecret, + + Endpoint: oauth2.Endpoint{ + AuthURL: authorizationEndpoint, + TokenURL: tokenEndpoint, + // TODO: validate if the auth style is supported by the AS. + AuthStyle: h.resolvedClientConfig.authStyle, + }, + RedirectURL: h.RedirectURL, + Scopes: scopes, + } + + if h.authorizationCode != "" { + log.Print("Authorization code is available, exchanging for token") + opts := []oauth2.AuthCodeOption{ + oauth2.VerifierOption(h.codeVerifier), + oauth2.SetAuthURLParam("resource", req.URL.String()), + } + token, err := cfg.Exchange(ctx, h.authorizationCode, opts...) + defer func() { + // Authorization code has been consumed, clear it. + h.authorizationCode = "" + }() + if err != nil { + return fmt.Errorf("token exchange failed: %w", err) + } + ts := cfg.TokenSource(ctx, token) + if h.TokenStore != nil { + // Persist the returned tokens to the store if requested. + ts = NewPersistentTokenSource(ctx, ts, h.TokenStore) + } + h.tokenSource = ts + return nil + } + + h.codeVerifier = oauth2.GenerateVerifier() + h.state = rand.Text() + if h.StateProvider != nil { + h.state = h.StateProvider() + } + + authURL := cfg.AuthCodeURL(h.state, + oauth2.S256ChallengeOption(h.codeVerifier), + oauth2.SetAuthURLParam("resource", req.URL.String()), + ) + + log.Print("No authorization code available, opening authorization URL") + if h.AuthorizationURLHandler != nil { + if err := h.AuthorizationURLHandler(ctx, authURL); err != nil { + return fmt.Errorf("authorization URL handler failed: %w", err) + } + } + + return ErrRedirected +} + +func (h *AuthorizationCodeOAuthHandler) SetTokenSource(ts oauth2.TokenSource) { + h.tokenSource = ts +} + +func (h *AuthorizationCodeOAuthHandler) FinalizeAuthorization(code, state string) error { + defer func() { + // State has been used for validation, clear it. + h.state = "" + }() + if state != h.state { + return fmt.Errorf("state mismatch: expected %q, got %q", h.state, state) + } + h.authorizationCode = code + return nil +} + +func (h *AuthorizationCodeOAuthHandler) handleRegistration(ctx context.Context, authServerURL string, asm *oauthex.AuthServerMeta) error { + // 1. Attempt to use Client ID Metadata Document (SEP-991). + cimdCfg := h.ClientIDMetadataDocumentConfig + if cimdCfg != nil { + supportsCIMD := asm != nil && asm.ClientIDMetadataDocumentSupported + if supportsCIMD { + if !isNonRootHTTPSURL(cimdCfg.URL) { + return fmt.Errorf("client ID metadata document URL is not a non-root HTTPS URL") + } + h.resolvedClientConfig = &resolvedClientConfig{ + registrationType: registrationTypeClientIDMetadataDocument, + clientID: cimdCfg.URL, + } + return nil + } + } + // 2. Attempt to use pre-registered client ID. + pCfg := h.PreregisteredClientConfig + if pCfg != nil { + if pCfg.ClientID == "" || pCfg.ClientSecret == "" { + return fmt.Errorf("pre-registered client ID or secret is empty") + } + h.resolvedClientConfig = &resolvedClientConfig{ + registrationType: registrationTypePreregistered, + clientID: pCfg.ClientID, + clientSecret: pCfg.ClientSecret, + authStyle: pCfg.AuthStyle, + } + return nil + } + // 3. Attempt to use dynamic client registration. + dcrCfg := h.DynamicClientRegistrationConfig + if dcrCfg != nil { + if !slices.Contains(dcrCfg.Metadata.RedirectURIs, h.RedirectURL) { + return fmt.Errorf("redirect URI %q is not in the list of allowed redirect URIs for dynamic client registration", h.RedirectURL) + } + var registrationEndpoint string + if asm != nil { + if asm.RegistrationEndpoint == "" { + return fmt.Errorf("authorization server does not support dynamic client registration") + } + registrationEndpoint = asm.RegistrationEndpoint + } else { + // Fallback to 2025-03-26 spec: predefined endpoints if not provided by AS. + registrationEndpoint = authServerURL + "/register" + } + log.Printf("Attempting dynamic client registration at %v", registrationEndpoint) + regResp, err := oauthex.RegisterClient(ctx, registrationEndpoint, dcrCfg.Metadata, http.DefaultClient) + if err != nil { + return fmt.Errorf("failed to register client: %w", err) + } + h.resolvedClientConfig = &resolvedClientConfig{ + registrationType: registrationTypeDynamic, + clientID: regResp.ClientID, + clientSecret: regResp.ClientSecret, + } + switch regResp.TokenEndpointAuthMethod { + case "client_secret_post": + h.resolvedClientConfig.authStyle = oauth2.AuthStyleInParams + case "client_secret_basic": + h.resolvedClientConfig.authStyle = oauth2.AuthStyleInHeader + case "none": + // "none" is equivalent to "client_secret_post" but without sending client secret. + h.resolvedClientConfig.authStyle = oauth2.AuthStyleInParams + h.resolvedClientConfig.clientSecret = "" + default: + // We leave the AuthStyle set to zero value, which is auto-detection. + } + log.Printf("Client registered with client ID: %s", regResp.ClientID) + return nil + } + return fmt.Errorf("no client registration method configured") +} + +func isNonRootHTTPSURL(u string) bool { + pu, err := url.Parse(u) + if err != nil { + return false + } + return pu.Scheme == "https" && pu.Path != "" +} + +var _ OAuthHandler = (*AuthorizationCodeOAuthHandler)(nil) diff --git a/auth/client.go b/auth/client.go index acadc51b..73c5969a 100644 --- a/auth/client.go +++ b/auth/client.go @@ -8,6 +8,7 @@ package auth import ( "bytes" + "context" "errors" "io" "net/http" @@ -16,16 +17,74 @@ import ( "golang.org/x/oauth2" ) -// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization +// An OAuthHandlerLegacy conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization // is approved, or an error if not. // The handler receives the HTTP request and response that triggered the authentication flow. // To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. -type OAuthHandler func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) +// Deprecated: Please use the new OAuthHandler abstraction that is built +// into the streamable transport. +type OAuthHandlerLegacy func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) + +// Error that will be thrown if the call failed due to authorization. +var ErrUnauthorized = errors.New("unauthorized") + +type OAuthHandler interface { + // TokenSource returns a token source to be used for outgoing requests. + TokenSource(context.Context) (oauth2.TokenSource, error) + + // Authorize is called when an HTTP request results in an error that may + // be addressed by the authorization flow (currently 401 Unauthorized and 403 Forbidden). + // It is responsible for initiating the OAuth flow to obtain a token source. + // The arguments are the request that failed and the response that was received for it. + // If the returned error is nil, [TokenSource] is expected to return a non-nil token source. + // After a successful call to [Authorize], the HTTP request should be retried by the transport. + // The function is responsible for closing the response body. + Authorize(context.Context, *http.Request, *http.Response) error +} + +// TokenStore is an interface than can be used by OAuthHandler implementations +// to save tokens to a persistent storage. +type TokenStore interface { + Save(context.Context, *oauth2.Token) error +} + +type persistentTokenSource struct { + wrapped oauth2.TokenSource + store TokenStore + ctx context.Context +} + +// NewPersistentTokenSource returns a [oauth2.TokenSource] that +// persists the token to a given [TokenStore] after every successful +// [oauth2.TokenSource.Token] call. +// It is especially useful when wrapping a [oauth2.TokenSource] +// that automatically refreshes the token when it expires. +// The passed context is used for [TokenStore.Save] calls. +func NewPersistentTokenSource(ctx context.Context, wrapped oauth2.TokenSource, store TokenStore) oauth2.TokenSource { + return &persistentTokenSource{ + wrapped: wrapped, + store: store, + ctx: ctx, + } +} + +func (t *persistentTokenSource) Token() (*oauth2.Token, error) { + token, err := t.wrapped.Token() + if err != nil { + return nil, err + } + if err := t.store.Save(t.ctx, token); err != nil { + return nil, err + } + return token, nil +} // HTTPTransport is an [http.RoundTripper] that follows the MCP // OAuth protocol when it encounters a 401 Unauthorized response. +// Deprecated: Please use the new OAuthHandler abstraction that is built +// into the streamable transport. type HTTPTransport struct { - handler OAuthHandler + handler OAuthHandlerLegacy mu sync.Mutex // protects opts.Base opts HTTPTransportOptions } @@ -34,7 +93,9 @@ type HTTPTransport struct { // The handler is invoked when an HTTP request results in a 401 Unauthorized status. // It is called only once per transport. Once a TokenSource is obtained, it is used // for the lifetime of the transport; subsequent 401s are not processed. -func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTransport, error) { +// Deprecated: Please use the new OAuthHandler abstraction that is built +// into the streamable transport. +func NewHTTPTransport(handler OAuthHandlerLegacy, opts *HTTPTransportOptions) (*HTTPTransport, error) { if handler == nil { return nil, errors.New("handler cannot be nil") } @@ -51,6 +112,8 @@ func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTr } // HTTPTransportOptions are options to [NewHTTPTransport]. +// Deprecated: Please use the new OAuthHandler abstraction that is built +// into the streamable transport. type HTTPTransportOptions struct { // Base is the [http.RoundTripper] to use. // If nil, [http.DefaultTransport] is used. diff --git a/conformance/baseline.yml b/conformance/baseline.yml index f0e15132..7765063e 100644 --- a/conformance/baseline.yml +++ b/conformance/baseline.yml @@ -1,21 +1,3 @@ server: - dns-rebinding-protection -client: -- auth/basic-cimd -- auth/metadata-default -- auth/metadata-var1 -- auth/metadata-var2 -- auth/metadata-var3 -- auth/2025-03-26-oauth-metadata-backcompat -- auth/2025-03-26-oauth-endpoint-fallback -- auth/scope-from-www-authenticate -- auth/scope-from-scopes-supported -- auth/scope-omitted-when-undefined -- auth/scope-step-up -- auth/scope-retry-limit -- auth/token-endpoint-auth-basic -- auth/token-endpoint-auth-post -- auth/token-endpoint-auth-none -- auth/client-credentials-jwt -- auth/client-credentials-basic -- auth/pre-registration +client: [] # All pass! diff --git a/conformance/everything-client/main.go b/conformance/everything-client/main.go index 9674dbbc..f87e153c 100644 --- a/conformance/everything-client/main.go +++ b/conformance/everything-client/main.go @@ -9,14 +9,19 @@ package main import ( "context" + "errors" "fmt" "log" + "net/http" + "net/url" "os" "slices" "sort" "strings" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/modelcontextprotocol/go-sdk/oauthex" ) // scenarioHandler is the function signature for all conformance test scenarios. @@ -42,6 +47,30 @@ func init() { registerScenario("tools_call", runToolsCallClient) registerScenario("elicitation-sep1034-client-defaults", runElicitationDefaultsClient) registerScenario("sse-retry", runSSERetryClient) + + authScenarios := []string{ + "auth/2025-03-26-oauth-metadata-backcompat", + "auth/2025-03-26-oauth-endpoint-fallback", + "auth/basic-cimd", + "auth/metadata-default", + "auth/metadata-var1", + "auth/metadata-var2", + "auth/metadata-var3", + "auth/resource-mismatch", + "auth/scope-from-www-authenticate", + "auth/scope-from-scopes-supported", + "auth/scope-omitted-when-undefined", + "auth/scope-step-up", + "auth/scope-retry-limit", + "auth/token-endpoint-auth-basic", + "auth/token-endpoint-auth-post", + "auth/token-endpoint-auth-none", + } + for _, scenario := range authScenarios { + registerScenario(scenario, runAuthClient) + } + + registerScenario("auth/pre-registration", runPreregisteredClient) } // ============================================================================ @@ -174,6 +203,170 @@ func runSSERetryClient(ctx context.Context, serverURL string) error { return nil } +// ============================================================================ +// Auth scenarios +// ============================================================================ + +func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (code, state string, err error) { + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + req, err := http.NewRequestWithContext(ctx, "GET", authURL, nil) + if err != nil { + return "", "", err + } + + resp, err := client.Do(req) + if err != nil { + return "", "", err + } + defer resp.Body.Close() + + location := resp.Header.Get("Location") + if location == "" { + return "", "", fmt.Errorf("no Location header in redirect") + } + + locURL, err := url.Parse(location) + if err != nil { + return "", "", fmt.Errorf("parse location: %v", err) + } + + code = locURL.Query().Get("code") + if code == "" { + return "", "", fmt.Errorf("no code parameter in redirect URL") + } + state = locURL.Query().Get("state") + if state == "" { + return "", "", fmt.Errorf("no state parameter in redirect URL") + } + + return code, state, nil +} + +func runAuthClient(ctx context.Context, serverURL string) error { + authHandler := &auth.AuthorizationCodeOAuthHandler{ + RedirectURL: "http://localhost:3000/callback", + // Try client ID metadata document based registration. + ClientIDMetadataDocumentConfig: &auth.ClientIDMetadataDocumentConfig{ + URL: "https://conformance-test.local/client-metadata.json", + }, + // Try dynamic client registration. + DynamicClientRegistrationConfig: &auth.DynamicClientRegistrationConfig{ + Metadata: &oauthex.ClientRegistrationMetadata{ + RedirectURIs: []string{"http://localhost:3000/callback"}, + }, + }, + } + + authHandler.AuthorizationURLHandler = func(ctx context.Context, authURL string) error { + // Normally this handler would trigger user browser to be opened. + // Here we query the authorization URL automatically and the AS is configured + // to authorize and redirect immediately. We save the resulting code. + code, state, err := fetchAuthorizationCodeAndState(ctx, authURL) + if err != nil { + return err + } + if err := authHandler.FinalizeAuthorization(code, state); err != nil { + return err + } + return nil + } + + session, err := connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) + if err != nil { + if !errors.Is(err, auth.ErrRedirected) { + return err + } + // Received auth.ErrRedirected. Normally we would wait for the callback triggered + // by the AS redirect to RedirectURL, but here we already have the authorization code + // so we can immediately retry. + session, err = connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) + if err != nil { + return nil + } + } + defer session.Close() + + if _, err := session.ListTools(ctx, nil); err != nil { + // Retry for the scope step-up scenario. + if !errors.Is(err, auth.ErrRedirected) { + return fmt.Errorf("session.ListTools(): %v", err) + } + // Received auth.ErrRedirected. Normally we would wait for the callback triggered + // by the AS redirect to RedirectURL, but here we already have the authorization code + // so we can immediately retry. + _, err = session.ListTools(ctx, nil) + if err != nil { + return fmt.Errorf("session.ListTools(): %v", err) + } + } + + if _, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "test-tool", + Arguments: map[string]any{}, + }); err != nil { + return fmt.Errorf("session.CallTool('test-tool'): %v", err) + } + + return nil +} + +func runPreregisteredClient(ctx context.Context, serverURL string) error { + authHandler := &auth.AuthorizationCodeOAuthHandler{ + RedirectURL: "http://localhost:3000/callback", + // Try preregistered client information. + PreregisteredClientConfig: &auth.PreregisteredClientConfig{ + ClientID: "pre-registered-client", + ClientSecret: "pre-registered-secret", + }, + } + + authHandler.AuthorizationURLHandler = func(ctx context.Context, authURL string) error { + // Normally this handler would trigger user browser to be opened. + // Here we query the authorization URL automatically and the AS is configured + // to authorize and redirect immediately. We save the resulting code. + code, state, err := fetchAuthorizationCodeAndState(ctx, authURL) + if err != nil { + return err + } + if err := authHandler.FinalizeAuthorization(code, state); err != nil { + return err + } + return nil + } + + session, err := connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) + if err != nil { + if !errors.Is(err, auth.ErrRedirected) { + return err + } + // Received auth.ErrRedirected. Normally we would wait for the callback triggered + // by the AS redirect to RedirectURL, but here we already have the authorization code + // so we can immediately retry. + session, err = connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) + if err != nil { + return nil + } + } + defer session.Close() + + if _, err := session.ListTools(ctx, nil); err != nil { + return fmt.Errorf("session.ListTools(): %v", err) + } + + if _, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "test-tool", + Arguments: map[string]any{}, + }); err != nil { + return fmt.Errorf("session.CallTool('test-tool'): %v", err) + } + + return nil +} + // ============================================================================ // Main entry point // ============================================================================ @@ -214,6 +407,7 @@ func printUsageAndExit(format string, args ...any) { type connectConfig struct { clientOptions *mcp.ClientOptions + oauthHandler auth.OAuthHandler } type connectOption func(*connectConfig) @@ -224,6 +418,12 @@ func withClientOptions(opts *mcp.ClientOptions) connectOption { } } +func withOAuthHandler(handler auth.OAuthHandler) connectOption { + return func(c *connectConfig) { + c.oauthHandler = handler + } +} + // connectToServer connects to the MCP server and returns a client session. // The caller is responsible for closing the session. func connectToServer(ctx context.Context, serverURL string, opts ...connectOption) (*mcp.ClientSession, error) { @@ -237,11 +437,14 @@ func connectToServer(ctx context.Context, serverURL string, opts ...connectOptio Version: "1.0.0", }, config.clientOptions) - transport := &mcp.StreamableClientTransport{Endpoint: serverURL} + transport := &mcp.StreamableClientTransport{ + Endpoint: serverURL, + OAuthHandler: config.oauthHandler, + } session, err := client.Connect(ctx, transport, nil) if err != nil { - return nil, fmt.Errorf("client.Connect(): %v", err) + return nil, fmt.Errorf("client.Connect(): %w", err) } return session, nil diff --git a/mcp/streamable.go b/mcp/streamable.go index c6b96bb0..acbb15b4 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -25,13 +25,13 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/xcontext" "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/modelcontextprotocol/go-sdk/oauthex" ) const ( @@ -1425,6 +1425,9 @@ type StreamableClientTransport struct { // - You want to avoid maintaining a persistent connection DisableStandaloneSSE bool + // OAuthHandler is an optional field that, if provided, will be used to authorize the requests. + OAuthHandler auth.OAuthHandler + // TODO(rfindley): propose exporting these. // If strict is set, the transport is in 'strict mode', where any violation // of the MCP spec causes a failure. @@ -1500,6 +1503,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er cancel: cancel, failed: make(chan struct{}), disableStandaloneSSE: t.DisableStandaloneSSE, + oauthHandler: t.OAuthHandler, } return conn, nil } @@ -1518,6 +1522,9 @@ type streamableClientConn struct { // for receiving server-to-client notifications when no request is in flight. disableStandaloneSSE bool // from [StreamableClientTransport.DisableStandaloneSSE] + // oauthHandler is the OAuth handler for the connection. + oauthHandler auth.OAuthHandler // from [StreamableClientTransport.OAuthHandler] + // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once closeErr error @@ -1693,14 +1700,57 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") - c.setMCPHeaders(req) + doRequest := func() (*http.Response, error) { + if err := c.setMCPHeaders(req); err != nil { + // TODO: should we fail the connection here? + return nil, err + } + resp, err := c.client.Do(req) + if err != nil { + // Any error from client.Do means the request didn't reach the server. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + err = fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err) + } + return resp, err + } - resp, err := c.client.Do(req) + resp, err := doRequest() if err != nil { - // Any error from client.Do means the request didn't reach the server. - // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr - // and permanently break the connection. - return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err) + return err + } + + if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { + if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil { + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + // Wrap the authorization error as well for client inspection. + return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) + } + // Retry the request after successful authorization. + resp, err = doRequest() + if err != nil { + return err + } + } + if resp.StatusCode == http.StatusForbidden && c.oauthHandler != nil { + challenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) + if err != nil { + c.logger.Warn("%s: failed to parse WWW-Authenticate header: %v", requestSummary, err) + } else if oauthex.Error(challenges) == "insufficient_scope" { + // Trigger step-up authorization flow. + if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil { + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + // Wrap the authorization error as well for client inspection. + return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) + } + // Retry the request after successful authorization. + resp, err = doRequest() + if err != nil { + return err + } + } } if err := c.checkResponse(requestSummary, resp); err != nil { @@ -1768,23 +1818,30 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } -// testAuth controls whether a fake Authorization header is added to outgoing requests. -// TODO: replace with a better mechanism when client-side auth is in place. -var testAuth atomic.Bool - -func (c *streamableClientConn) setMCPHeaders(req *http.Request) { +func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { c.mu.Lock() defer c.mu.Unlock() + if c.oauthHandler != nil { + ts, err := c.oauthHandler.TokenSource(c.ctx) + if err != nil { + return err + } + if ts != nil { + token, err := ts.Token() + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+token.AccessToken) + } + } if c.initializedResult != nil { req.Header.Set(protocolVersionHeader, c.initializedResult.ProtocolVersion) } if c.sessionID != "" { req.Header.Set(sessionIDHeader, c.sessionID) } - if testAuth.Load() { - req.Header.Set("Authorization", "Bearer foo") - } + return nil } func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) { @@ -2037,7 +2094,10 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin if err != nil { return nil, err } - c.setMCPHeaders(req) + if err := c.setMCPHeaders(req); err != nil { + // TODO: should we fail the connection here? + return nil, err + } if lastEventID != "" { req.Header.Set(lastEventIDHeader, lastEventID) } @@ -2068,8 +2128,11 @@ func (c *streamableClientConn) Close() error { if err != nil { c.closeErr = err } else { - c.setMCPHeaders(req) - if _, err := c.client.Do(req); err != nil { + if err := c.setMCPHeaders(req); err != nil { + // TODO: or setting headers should be best-effort and we should retry + // the request without them? + c.closeErr = err + } else if _, err := c.client.Do(req); err != nil { c.closeErr = err } } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index f1f6200f..79a3c531 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -34,6 +34,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "golang.org/x/oauth2" ) func TestStreamableTransports(t *testing.T) { @@ -1666,10 +1667,20 @@ func textContent(t *testing.T, res *CallToolResult) string { return text.Text } +type testOAuthHandler struct { + token string +} + +func (h *testOAuthHandler) TokenSource(context.Context) (oauth2.TokenSource, error) { + return oauth2.StaticTokenSource(&oauth2.Token{AccessToken: h.token}), nil +} + +func (h *testOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + // 401 resonse is not expected in this test. We can simply fail. + return errors.New("unexpected 401") +} + func TestTokenInfo(t *testing.T) { - oldAuth := testAuth.Load() - defer testAuth.Store(oldAuth) - testAuth.Store(true) ctx := context.Background() // Create a server with a tool that returns TokenInfo. @@ -1680,7 +1691,10 @@ func TestTokenInfo(t *testing.T) { AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) - verifier := func(context.Context, string, *http.Request) (*auth.TokenInfo, error) { + verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { + if token != "test-token" { + return nil, auth.ErrInvalidToken + } return &auth.TokenInfo{ Scopes: []string{"scope"}, // Expiration is far, far in the future. @@ -1691,7 +1705,10 @@ func TestTokenInfo(t *testing.T) { httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() - transport := &StreamableClientTransport{Endpoint: httpServer.URL} + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: &testOAuthHandler{token: "test-token"}, + } client := NewClient(testImpl, nil) session, err := client.Connect(ctx, transport, nil) if err != nil { diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index 9aa0c8d7..a6ee07a8 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -14,6 +14,8 @@ import ( "errors" "fmt" "net/http" + "net/url" + "strings" ) // AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, @@ -113,11 +115,10 @@ type AuthServerMeta struct { // CodeChallengeMethodsSupported is a RECOMMENDED JSON array of strings containing a list of // PKCE code challenge methods supported by this authorization server. CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` -} -var wellKnownPaths = []string{ - "/.well-known/oauth-authorization-server", - "/.well-known/openid-configuration", + // ClientIDMetadataDocumentSupported is a boolean indicating whether the authorization server + // supports client ID metadata documents. + ClientIDMetadataDocumentSupported bool `json:"client_id_metadata_document_supported,omitempty"` } // GetAuthServerMeta issues a GET request to retrieve authorization server metadata @@ -130,34 +131,72 @@ var wellKnownPaths = []string{ // // [RFC 8414]: https://tools.ietf.org/html/rfc8414 func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (*AuthServerMeta, error) { - var errs []error - for _, p := range wellKnownPaths { - u, err := prependToPath(issuerURL, p) + for _, u := range AuthorizationServerMetadataURLs(issuerURL) { + asm, err := getJSON[AuthServerMeta](ctx, c, u, 1<<20) if err != nil { - // issuerURL is bad; no point in continuing. + var httpErr *httpStatusError + if errors.As(err, &httpErr) { + if 400 <= httpErr.StatusCode && httpErr.StatusCode < 500 { + continue + } + return nil, fmt.Errorf("%v", err) // Do not expose wrapped errors. + } + } + // TODO: causes conformance test failure, filed https://github.com/modelcontextprotocol/conformance/issues/140. + // if asm.Issuer != issuerURL { + // // Validate the Issuer field (see RFC 8414, section 3.3). + // return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) + // } + + if len(asm.CodeChallengeMethodsSupported) == 0 { + return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuerURL) + } + + // Validate endpoint URLs to prevent XSS attacks (see #526). + if err := validateAuthServerMetaURLs(asm); err != nil { return nil, err } - asm, err := getJSON[AuthServerMeta](ctx, c, u, 1<<20) - if err == nil { - if asm.Issuer != issuerURL { // section 3.3 - // Security violation; don't keep trying. - return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) - } - if len(asm.CodeChallengeMethodsSupported) == 0 { - return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuerURL) - } + return asm, nil + } + // Authorization server metadata not found. Return nil error to allow a fallback. + return nil, nil +} - // Validate endpoint URLs to prevent XSS attacks (see #526). - if err := validateAuthServerMetaURLs(asm); err != nil { - return nil, err - } +// AuthorizationServerMetadataURLs returns a list of URLs to try when looking for +// authorization server metadata as mandated by the MCP specification. +func AuthorizationServerMetadataURLs(issuerURL string) []string { + var urls []string - return asm, nil - } - errs = append(errs, err) + // Produce candidates per + // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. + + baseURL, err := url.Parse(issuerURL) + if err != nil { + return nil + } + + if baseURL.Path == "" { + // "OAuth 2.0 Authorization Server Metadata". + baseURL.Path = "/.well-known/oauth-authorization-server" + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0". + baseURL.Path = "/.well-known/openid-configuration" + urls = append(urls, baseURL.String()) + return urls } - return nil, fmt.Errorf("failed to get auth server metadata from %q: %w", issuerURL, errors.Join(errs...)) + + originalPath := baseURL.Path + // "OAuth 2.0 Authorization Server Metadata with path insertion". + baseURL.Path = "/.well-known/oauth-authorization-server/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0 with path insertion". + baseURL.Path = "/.well-known/openid-configuration/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0 path appending". + baseURL.Path = "/" + strings.Trim(originalPath, "/") + "/.well-known/openid-configuration" + urls = append(urls, baseURL.String()) + return urls } // validateAuthServerMetaURLs validates all URL fields in AuthServerMeta diff --git a/oauthex/oauth2.go b/oauthex/oauth2.go index cdda695b..5b76116d 100644 --- a/oauthex/oauth2.go +++ b/oauthex/oauth2.go @@ -36,6 +36,14 @@ func prependToPath(urlStr, pre string) (string, error) { return u.String(), nil } +type httpStatusError struct { + StatusCode int +} + +func (e *httpStatusError) Error() string { + return fmt.Sprintf("bad status %d", e.StatusCode) +} + // getJSON retrieves JSON and unmarshals JSON from the URL, as specified in both // RFC 9728 and RFC 8414. // It will not read more than limit bytes from the body. @@ -53,11 +61,9 @@ func getJSON[T any](ctx context.Context, c *http.Client, url string, limit int64 } defer res.Body.Close() - // Specs require a 200. if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("bad status %s", res.Status) + return nil, &httpStatusError{StatusCode: res.StatusCode} } - // Specs require application/json. ct := res.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(ct) if err != nil || mediaType != "application/json" { diff --git a/oauthex/resource_meta.go b/oauthex/resource_meta.go index bb61f797..107c5e63 100644 --- a/oauthex/resource_meta.go +++ b/oauthex/resource_meta.go @@ -13,6 +13,7 @@ import ( "context" "errors" "fmt" + "log" "net/http" "net/url" "path" @@ -38,6 +39,7 @@ const defaultProtectedResourceMetadataURI = "/.well-known/oauth-protected-resour // // It then retrieves the metadata at that location using the given client (or the // default client if nil) and validates its resource field against resourceID. +// Deprecated: Use [GetProtectedResourceMetadata] instead. func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { defer util.Wrapf(&err, "GetProtectedResourceMetadataFromID(%q)", resourceID) @@ -47,7 +49,10 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, } // Insert well-known URI into URL. u.Path = path.Join(defaultProtectedResourceMetadataURI, u.Path) - return getPRM(ctx, u.String(), c, resourceID) + return GetProtectedResourceMetadata(ctx, ProtectedResourceMetadataURL{ + URL: u.String(), + Resource: resourceID, + }, c) } // GetProtectedResourceMetadataFromHeader retrieves protected resource metadata @@ -57,8 +62,8 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, // Per RFC 9728 section 3.3, it validates that the resource field of the resulting metadata // matches the serverURL (the URL that the client used to make the original request to the resource server). // If there is no metadata URL in the header, it returns nil, nil. +// Deprecated: Use [GetProtectedResourceMetadata] instead. func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL string, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) { - defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader") headers := header[http.CanonicalHeaderKey("WWW-Authenticate")] if len(headers) == 0 { return nil, nil @@ -71,22 +76,31 @@ func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL strin if metadataURL == "" { return nil, nil } - return getPRM(ctx, metadataURL, c, serverURL) + return GetProtectedResourceMetadata(ctx, ProtectedResourceMetadataURL{ + URL: metadataURL, + Resource: serverURL, + }, c) } -// getPRM makes a GET request to the given URL, and validates the response. -// As part of the validation, it compares the returned resource field to wantResource. -func getPRM(ctx context.Context, purl string, c *http.Client, wantResource string) (*ProtectedResourceMetadata, error) { - if !strings.HasPrefix(strings.ToUpper(purl), "HTTPS://") { - return nil, fmt.Errorf("resource URL %q does not use HTTPS", purl) - } - prm, err := getJSON[ProtectedResourceMetadata](ctx, c, purl, 1<<20) +// GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource +// metadata from a resource server. +// The metadataURL is typically a URL with a host:port and possibly a path. +// For example: +// +// https://example.com/server +func GetProtectedResourceMetadata(ctx context.Context, metadataURL ProtectedResourceMetadataURL, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadata(%q)", metadataURL) + // TODO: where HTTPS requirement comes from? conformance tests use HTTP. + // if !strings.HasPrefix(strings.ToUpper(purl), "HTTPS://") { + // return nil, fmt.Errorf("resource URL %q does not use HTTPS", purl) + // } + prm, err := getJSON[ProtectedResourceMetadata](ctx, c, metadataURL.URL, 1<<20) if err != nil { return nil, err } // Validate the Resource field (see RFC 9728, section 3.3). - if prm.Resource != wantResource { - return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, wantResource) + if prm.Resource != metadataURL.Resource { + return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, metadataURL.Resource) } // Validate the authorization server URLs to prevent XSS attacks (see #526). for _, u := range prm.AuthorizationServers { @@ -97,6 +111,48 @@ func getPRM(ctx context.Context, purl string, c *http.Client, wantResource strin return prm, nil } +type ProtectedResourceMetadataURL struct { + // URL represents a URL where Protected Resource Metadata may be retrieved. + URL string + // Resource represents the corresponding resource URL for [URL]. + // It is required to perform validation described in RFC 9728, section 3.3. + Resource string +} + +// ProtectedResourceMetadataURLs returns a list of URLs to try when looking for +// protected resource metadata as mandated by the MCP specification. +func ProtectedResourceMetadataURLs(metadataURL, resourceURL string) []ProtectedResourceMetadataURL { + var urls []ProtectedResourceMetadataURL + if metadataURL != "" { + urls = append(urls, ProtectedResourceMetadataURL{ + URL: metadataURL, + Resource: resourceURL, + }) + } + // Produce fallbacks per + // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#protected-resource-metadata-discovery-requirements + ru, err := url.Parse(resourceURL) + if err != nil { + return urls + } + mu := *ru + // "At the path of the server's MCP endpoint". + mu.Path = "/.well-known/oauth-protected-resource/" + strings.TrimLeft(ru.Path, "/") + urls = append(urls, ProtectedResourceMetadataURL{ + URL: mu.String(), + Resource: resourceURL, + }) + // "At the root". + mu.Path = "/.well-known/oauth-protected-resource" + ru.Path = "" + urls = append(urls, ProtectedResourceMetadataURL{ + URL: mu.String(), + Resource: ru.String(), + }) + log.Printf("Resource metadata URLs: %v", urls) + return urls +} + // challenge represents a single authentication challenge from a WWW-Authenticate header. // As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. type challenge struct { @@ -121,6 +177,24 @@ func ResourceMetadataURL(cs []challenge) string { return "" } +func Scopes(cs []challenge) []string { + for _, c := range cs { + if c.Scheme == "bearer" && c.Params["scope"] != "" { + return strings.Fields(c.Params["scope"]) + } + } + return nil +} + +func Error(cs []challenge) string { + for _, c := range cs { + if c.Scheme == "bearer" && c.Params["error"] != "" { + return c.Params["error"] + } + } + return "" +} + // ParseWWWAuthenticate parses a WWW-Authenticate header string. // The header format is defined in RFC 9110, Section 11.6.1, and can contain // one or more challenges, separated by commas. diff --git a/scripts/client-conformance.sh b/scripts/client-conformance.sh index c093c75f..17528450 100755 --- a/scripts/client-conformance.sh +++ b/scripts/client-conformance.sh @@ -10,6 +10,7 @@ set -e RESULT_DIR="" WORKDIR="" CONFORMANCE_REPO="" +SUITE="core" FINAL_EXIT_CODE=0 usage() { @@ -21,9 +22,11 @@ usage() { echo " --result_dir Save results to the specified directory" echo " --conformance_repo Run conformance tests from a local checkout" echo " instead of using the latest npm release" + echo " --suite Which suite to run (default: core)" echo " --help Show this help message" } + # Parse arguments. while [[ $# -gt 0 ]]; do case $1 in @@ -35,6 +38,10 @@ while [[ $# -gt 0 ]]; do CONFORMANCE_REPO="$2" shift 2 ;; + --suite) + SUITE="$2" + shift 2 + ;; --help) usage exit 0 @@ -56,7 +63,7 @@ else fi # Build the conformance server. -go build -o "$WORKDIR/conformance-client" ./conformance/everything-client +go build -tags mcp_go_client_oauth -o "$WORKDIR/conformance-client" ./conformance/everything-client # Run conformance tests from the work directory to avoid writing results to the repo. echo "Running conformance tests..." @@ -65,13 +72,13 @@ if [ -n "$CONFORMANCE_REPO" ]; then (cd "$WORKDIR" && \ npm --prefix "$CONFORMANCE_REPO" run start -- \ client --command "$WORKDIR/conformance-client" \ - --suite core \ + --suite "$SUITE" \ ${RESULT_DIR:+--output-dir "$RESULT_DIR"}) || FINAL_EXIT_CODE=$? else (cd "$WORKDIR" && \ npx @modelcontextprotocol/conformance@latest \ client --command "$WORKDIR/conformance-client" \ - --suite core \ + --suite "$SUITE" \ ${RESULT_DIR:+--output-dir "$RESULT_DIR"}) || FINAL_EXIT_CODE=$? fi diff --git a/scripts/server-conformance.sh b/scripts/server-conformance.sh index 0826086a..8ab4a3bf 100755 --- a/scripts/server-conformance.sh +++ b/scripts/server-conformance.sh @@ -67,7 +67,7 @@ else fi # Build the conformance server. -go build -o "$WORKDIR/conformance-server" ./conformance/everything-server +go build -tags mcp_go_client_oauth -o "$WORKDIR/conformance-server" ./conformance/everything-server # Start the server in the background echo "Starting conformance server on port $PORT..." From 622f4fc3b72d52fefd10f3e52b515c4ffbc5a390 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Mon, 16 Feb 2026 15:25:03 +0000 Subject: [PATCH 02/26] Simplify the client setup by reading the conformance contect. --- conformance/everything-client/main.go | 90 +++++++++------------------ 1 file changed, 29 insertions(+), 61 deletions(-) diff --git a/conformance/everything-client/main.go b/conformance/everything-client/main.go index f87e153c..c6ec1f50 100644 --- a/conformance/everything-client/main.go +++ b/conformance/everything-client/main.go @@ -9,6 +9,7 @@ package main import ( "context" + "encoding/json" "errors" "fmt" "log" @@ -26,7 +27,7 @@ import ( // scenarioHandler is the function signature for all conformance test scenarios. // It takes a context and the server URL to connect to. -type scenarioHandler func(ctx context.Context, serverURL string) error +type scenarioHandler func(ctx context.Context, serverURL string, configCtx map[string]any) error var ( // registry stores all registered scenario handlers. @@ -70,14 +71,14 @@ func init() { registerScenario(scenario, runAuthClient) } - registerScenario("auth/pre-registration", runPreregisteredClient) + registerScenario("auth/pre-registration", runAuthClient) } // ============================================================================ // Basic scenarios // ============================================================================ -func runBasicClient(ctx context.Context, serverURL string) error { +func runBasicClient(ctx context.Context, serverURL string, _ map[string]any) error { session, err := connectToServer(ctx, serverURL) if err != nil { return err @@ -92,7 +93,7 @@ func runBasicClient(ctx context.Context, serverURL string) error { return nil } -func runToolsCallClient(ctx context.Context, serverURL string) error { +func runToolsCallClient(ctx context.Context, serverURL string, _ map[string]any) error { session, err := connectToServer(ctx, serverURL) if err != nil { return err @@ -126,7 +127,7 @@ func runToolsCallClient(ctx context.Context, serverURL string) error { // Elicitation scenarios // ============================================================================ -func runElicitationDefaultsClient(ctx context.Context, serverURL string) error { +func runElicitationDefaultsClient(ctx context.Context, serverURL string, _ map[string]any) error { elicitationHandler := func(ctx context.Context, req *mcp.ElicitRequest) (*mcp.ElicitResult, error) { return &mcp.ElicitResult{ Action: "accept", @@ -170,7 +171,7 @@ func runElicitationDefaultsClient(ctx context.Context, serverURL string) error { // SSE retry scenario // ============================================================================ -func runSSERetryClient(ctx context.Context, serverURL string) error { +func runSSERetryClient(ctx context.Context, serverURL string, _ map[string]any) error { // TODO: this scenario is not passing yet. It requires a fix in the client SSE handling. session, err := connectToServer(ctx, serverURL) if err != nil { @@ -246,7 +247,7 @@ func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (code, return code, state, nil } -func runAuthClient(ctx context.Context, serverURL string) error { +func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]any) error { authHandler := &auth.AuthorizationCodeOAuthHandler{ RedirectURL: "http://localhost:3000/callback", // Try client ID metadata document based registration. @@ -260,6 +261,15 @@ func runAuthClient(ctx context.Context, serverURL string) error { }, }, } + // Try pre-registered client information if provided in the context. + if clientId, ok := configCtx["client_id"].(string); ok { + if clientSecret, ok := configCtx["client_secret"].(string); ok { + authHandler.PreregisteredClientConfig = &auth.PreregisteredClientConfig{ + ClientID: clientId, + ClientSecret: clientSecret, + } + } + } authHandler.AuthorizationURLHandler = func(ctx context.Context, authURL string) error { // Normally this handler would trigger user browser to be opened. @@ -314,59 +324,6 @@ func runAuthClient(ctx context.Context, serverURL string) error { return nil } -func runPreregisteredClient(ctx context.Context, serverURL string) error { - authHandler := &auth.AuthorizationCodeOAuthHandler{ - RedirectURL: "http://localhost:3000/callback", - // Try preregistered client information. - PreregisteredClientConfig: &auth.PreregisteredClientConfig{ - ClientID: "pre-registered-client", - ClientSecret: "pre-registered-secret", - }, - } - - authHandler.AuthorizationURLHandler = func(ctx context.Context, authURL string) error { - // Normally this handler would trigger user browser to be opened. - // Here we query the authorization URL automatically and the AS is configured - // to authorize and redirect immediately. We save the resulting code. - code, state, err := fetchAuthorizationCodeAndState(ctx, authURL) - if err != nil { - return err - } - if err := authHandler.FinalizeAuthorization(code, state); err != nil { - return err - } - return nil - } - - session, err := connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) - if err != nil { - if !errors.Is(err, auth.ErrRedirected) { - return err - } - // Received auth.ErrRedirected. Normally we would wait for the callback triggered - // by the AS redirect to RedirectURL, but here we already have the authorization code - // so we can immediately retry. - session, err = connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) - if err != nil { - return nil - } - } - defer session.Close() - - if _, err := session.ListTools(ctx, nil); err != nil { - return fmt.Errorf("session.ListTools(): %v", err) - } - - if _, err := session.CallTool(ctx, &mcp.CallToolParams{ - Name: "test-tool", - Arguments: map[string]any{}, - }); err != nil { - return fmt.Errorf("session.CallTool('test-tool'): %v", err) - } - - return nil -} - // ============================================================================ // Main entry point // ============================================================================ @@ -378,6 +335,7 @@ func main() { serverURL := os.Args[1] scenarioName := os.Getenv("MCP_CONFORMANCE_SCENARIO") + configCtx := getConformanceContext() if scenarioName == "" { printUsageAndExit("MCP_CONFORMANCE_SCENARIO not set") @@ -389,11 +347,21 @@ func main() { } ctx := context.Background() - if err := handler(ctx, serverURL); err != nil { + if err := handler(ctx, serverURL, configCtx); err != nil { log.Fatalf("Scenario %q failed: %v", scenarioName, err) } } +func getConformanceContext() map[string]any { + ctxStr := os.Getenv("MCP_CONFORMANCE_CONTEXT") + if ctxStr == "" { + return nil + } + var ctx map[string]any + _ = json.Unmarshal([]byte(ctxStr), &ctx) + return ctx +} + func printUsageAndExit(format string, args ...any) { var scenarios []string for name := range registry { From d802adbdd83146e23b5e5b417fad986009101025 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Mon, 16 Feb 2026 15:34:39 +0000 Subject: [PATCH 03/26] Separate objects required by the mcp package to not be protected by the build tag. --- .github/workflows/conformance.yml | 3 +- auth/client.go | 156 ----------------------- auth/client_private.go | 169 +++++++++++++++++++++++++ oauthex/resource_meta.go | 183 --------------------------- oauthex/resource_meta_public.go | 198 ++++++++++++++++++++++++++++++ scripts/server-conformance.sh | 2 +- 6 files changed, 369 insertions(+), 342 deletions(-) create mode 100644 auth/client_private.go create mode 100644 oauthex/resource_meta_public.go diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index 8d9dc46b..6d4b9950 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -25,8 +25,7 @@ jobs: go-version: "^1.25" - name: Start everything-server run: | - # TODO: mcp_go_client_oauth should not be needed here, but the client OAuth abstractions have spilled into the streamable transport. - go run -tags mcp_go_client_oauth ./conformance/everything-server/main.go -http=":3001" & + go run ./conformance/everything-server/main.go -http=":3001" & # Wait for the server to be ready. timeout 15 bash -c 'until curl -s http://localhost:3001/mcp; do sleep 0.5; done' - name: "Run conformance tests" diff --git a/auth/client.go b/auth/client.go index 73c5969a..aa921fbb 100644 --- a/auth/client.go +++ b/auth/client.go @@ -2,29 +2,16 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. -//go:build mcp_go_client_oauth - package auth import ( - "bytes" "context" "errors" - "io" "net/http" - "sync" "golang.org/x/oauth2" ) -// An OAuthHandlerLegacy conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization -// is approved, or an error if not. -// The handler receives the HTTP request and response that triggered the authentication flow. -// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. -// Deprecated: Please use the new OAuthHandler abstraction that is built -// into the streamable transport. -type OAuthHandlerLegacy func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) - // Error that will be thrown if the call failed due to authorization. var ErrUnauthorized = errors.New("unauthorized") @@ -41,146 +28,3 @@ type OAuthHandler interface { // The function is responsible for closing the response body. Authorize(context.Context, *http.Request, *http.Response) error } - -// TokenStore is an interface than can be used by OAuthHandler implementations -// to save tokens to a persistent storage. -type TokenStore interface { - Save(context.Context, *oauth2.Token) error -} - -type persistentTokenSource struct { - wrapped oauth2.TokenSource - store TokenStore - ctx context.Context -} - -// NewPersistentTokenSource returns a [oauth2.TokenSource] that -// persists the token to a given [TokenStore] after every successful -// [oauth2.TokenSource.Token] call. -// It is especially useful when wrapping a [oauth2.TokenSource] -// that automatically refreshes the token when it expires. -// The passed context is used for [TokenStore.Save] calls. -func NewPersistentTokenSource(ctx context.Context, wrapped oauth2.TokenSource, store TokenStore) oauth2.TokenSource { - return &persistentTokenSource{ - wrapped: wrapped, - store: store, - ctx: ctx, - } -} - -func (t *persistentTokenSource) Token() (*oauth2.Token, error) { - token, err := t.wrapped.Token() - if err != nil { - return nil, err - } - if err := t.store.Save(t.ctx, token); err != nil { - return nil, err - } - return token, nil -} - -// HTTPTransport is an [http.RoundTripper] that follows the MCP -// OAuth protocol when it encounters a 401 Unauthorized response. -// Deprecated: Please use the new OAuthHandler abstraction that is built -// into the streamable transport. -type HTTPTransport struct { - handler OAuthHandlerLegacy - mu sync.Mutex // protects opts.Base - opts HTTPTransportOptions -} - -// NewHTTPTransport returns a new [*HTTPTransport]. -// The handler is invoked when an HTTP request results in a 401 Unauthorized status. -// It is called only once per transport. Once a TokenSource is obtained, it is used -// for the lifetime of the transport; subsequent 401s are not processed. -// Deprecated: Please use the new OAuthHandler abstraction that is built -// into the streamable transport. -func NewHTTPTransport(handler OAuthHandlerLegacy, opts *HTTPTransportOptions) (*HTTPTransport, error) { - if handler == nil { - return nil, errors.New("handler cannot be nil") - } - t := &HTTPTransport{ - handler: handler, - } - if opts != nil { - t.opts = *opts - } - if t.opts.Base == nil { - t.opts.Base = http.DefaultTransport - } - return t, nil -} - -// HTTPTransportOptions are options to [NewHTTPTransport]. -// Deprecated: Please use the new OAuthHandler abstraction that is built -// into the streamable transport. -type HTTPTransportOptions struct { - // Base is the [http.RoundTripper] to use. - // If nil, [http.DefaultTransport] is used. - Base http.RoundTripper -} - -func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { - t.mu.Lock() - base := t.opts.Base - t.mu.Unlock() - - var ( - // If haveBody is set, the request has a nontrivial body, and we need avoid - // reading (or closing) it multiple times. In that case, bodyBytes is its - // content. - haveBody bool - bodyBytes []byte - ) - if req.Body != nil && req.Body != http.NoBody { - // if we're setting Body, we must mutate first. - req = req.Clone(req.Context()) - haveBody = true - var err error - bodyBytes, err = io.ReadAll(req.Body) - if err != nil { - return nil, err - } - // Now that we've read the request body, http.RoundTripper requires that we - // close it. - req.Body.Close() // ignore error - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - - resp, err := base.RoundTrip(req) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusUnauthorized { - return resp, nil - } - if _, ok := base.(*oauth2.Transport); ok { - // We failed to authorize even with a token source; give up. - return resp, nil - } - - resp.Body.Close() - // Try to authorize. - t.mu.Lock() - defer t.mu.Unlock() - // If we don't have a token source, get one by following the OAuth flow. - // (We may have obtained one while t.mu was not held above.) - // TODO: We hold the lock for the entire OAuth flow. This could be a long - // time. Is there a better way? - if _, ok := t.opts.Base.(*oauth2.Transport); !ok { - ts, err := t.handler(req, resp) - if err != nil { - return nil, err - } - t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} - } - - // If we don't have a body, the request is reusable, though it will be cloned - // by the base. However, if we've had to read the body, we must clone. - if haveBody { - req = req.Clone(req.Context()) - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - - return t.opts.Base.RoundTrip(req) -} diff --git a/auth/client_private.go b/auth/client_private.go new file mode 100644 index 00000000..d8e633e9 --- /dev/null +++ b/auth/client_private.go @@ -0,0 +1,169 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "sync" + + "golang.org/x/oauth2" +) + +// An OAuthHandlerLegacy conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization +// is approved, or an error if not. +// The handler receives the HTTP request and response that triggered the authentication flow. +// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. +// Deprecated: Please use the new OAuthHandler abstraction that is built +// into the streamable transport. +type OAuthHandlerLegacy func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) + +// TokenStore is an interface than can be used by OAuthHandler implementations +// to save tokens to a persistent storage. +type TokenStore interface { + Save(context.Context, *oauth2.Token) error +} + +type persistentTokenSource struct { + wrapped oauth2.TokenSource + store TokenStore + ctx context.Context +} + +// NewPersistentTokenSource returns a [oauth2.TokenSource] that +// persists the token to a given [TokenStore] after every successful +// [oauth2.TokenSource.Token] call. +// It is especially useful when wrapping a [oauth2.TokenSource] +// that automatically refreshes the token when it expires. +// The passed context is used for [TokenStore.Save] calls. +func NewPersistentTokenSource(ctx context.Context, wrapped oauth2.TokenSource, store TokenStore) oauth2.TokenSource { + return &persistentTokenSource{ + wrapped: wrapped, + store: store, + ctx: ctx, + } +} + +func (t *persistentTokenSource) Token() (*oauth2.Token, error) { + token, err := t.wrapped.Token() + if err != nil { + return nil, err + } + if err := t.store.Save(t.ctx, token); err != nil { + return nil, err + } + return token, nil +} + +// HTTPTransport is an [http.RoundTripper] that follows the MCP +// OAuth protocol when it encounters a 401 Unauthorized response. +// Deprecated: Please use the new OAuthHandler abstraction that is built +// into the streamable transport. +type HTTPTransport struct { + handler OAuthHandlerLegacy + mu sync.Mutex // protects opts.Base + opts HTTPTransportOptions +} + +// NewHTTPTransport returns a new [*HTTPTransport]. +// The handler is invoked when an HTTP request results in a 401 Unauthorized status. +// It is called only once per transport. Once a TokenSource is obtained, it is used +// for the lifetime of the transport; subsequent 401s are not processed. +// Deprecated: Please use the new OAuthHandler abstraction that is built +// into the streamable transport. +func NewHTTPTransport(handler OAuthHandlerLegacy, opts *HTTPTransportOptions) (*HTTPTransport, error) { + if handler == nil { + return nil, errors.New("handler cannot be nil") + } + t := &HTTPTransport{ + handler: handler, + } + if opts != nil { + t.opts = *opts + } + if t.opts.Base == nil { + t.opts.Base = http.DefaultTransport + } + return t, nil +} + +// HTTPTransportOptions are options to [NewHTTPTransport]. +// Deprecated: Please use the new OAuthHandler abstraction that is built +// into the streamable transport. +type HTTPTransportOptions struct { + // Base is the [http.RoundTripper] to use. + // If nil, [http.DefaultTransport] is used. + Base http.RoundTripper +} + +func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.Lock() + base := t.opts.Base + t.mu.Unlock() + + var ( + // If haveBody is set, the request has a nontrivial body, and we need avoid + // reading (or closing) it multiple times. In that case, bodyBytes is its + // content. + haveBody bool + bodyBytes []byte + ) + if req.Body != nil && req.Body != http.NoBody { + // if we're setting Body, we must mutate first. + req = req.Clone(req.Context()) + haveBody = true + var err error + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + // Now that we've read the request body, http.RoundTripper requires that we + // close it. + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + resp, err := base.RoundTrip(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + if _, ok := base.(*oauth2.Transport); ok { + // We failed to authorize even with a token source; give up. + return resp, nil + } + + resp.Body.Close() + // Try to authorize. + t.mu.Lock() + defer t.mu.Unlock() + // If we don't have a token source, get one by following the OAuth flow. + // (We may have obtained one while t.mu was not held above.) + // TODO: We hold the lock for the entire OAuth flow. This could be a long + // time. Is there a better way? + if _, ok := t.opts.Base.(*oauth2.Transport); !ok { + ts, err := t.handler(req, resp) + if err != nil { + return nil, err + } + t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} + } + + // If we don't have a body, the request is reusable, though it will be cloned + // by the base. However, if we've had to read the body, we must clone. + if haveBody { + req = req.Clone(req.Context()) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + return t.opts.Base.RoundTrip(req) +} diff --git a/oauthex/resource_meta.go b/oauthex/resource_meta.go index 107c5e63..9875f54a 100644 --- a/oauthex/resource_meta.go +++ b/oauthex/resource_meta.go @@ -11,14 +11,12 @@ package oauthex import ( "context" - "errors" "fmt" "log" "net/http" "net/url" "path" "strings" - "unicode" "github.com/modelcontextprotocol/go-sdk/internal/util" ) @@ -153,19 +151,6 @@ func ProtectedResourceMetadataURLs(metadataURL, resourceURL string) []ProtectedR return urls } -// challenge represents a single authentication challenge from a WWW-Authenticate header. -// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. -type challenge struct { - // GENERATED BY GEMINI 2.5. - // - // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). - // It is case-insensitive. A parsed value will always be lower-case. - Scheme string - // Params is a map of authentication parameters. - // Keys are case-insensitive. Parsed keys are always lower-case. - Params map[string]string -} - // ResourceMetadataURL returns a resource metadata URL from the given challenges, // or the empty string if there is none. func ResourceMetadataURL(cs []challenge) string { @@ -185,171 +170,3 @@ func Scopes(cs []challenge) []string { } return nil } - -func Error(cs []challenge) string { - for _, c := range cs { - if c.Scheme == "bearer" && c.Params["error"] != "" { - return c.Params["error"] - } - } - return "" -} - -// ParseWWWAuthenticate parses a WWW-Authenticate header string. -// The header format is defined in RFC 9110, Section 11.6.1, and can contain -// one or more challenges, separated by commas. -// It returns a slice of challenges or an error if one of the headers is malformed. -func ParseWWWAuthenticate(headers []string) ([]challenge, error) { - // GENERATED BY GEMINI 2.5 (human-tweaked) - var challenges []challenge - for _, h := range headers { - challengeStrings, err := splitChallenges(h) - if err != nil { - return nil, err - } - for _, cs := range challengeStrings { - if strings.TrimSpace(cs) == "" { - continue - } - challenge, err := parseSingleChallenge(cs) - if err != nil { - return nil, fmt.Errorf("failed to parse challenge %q: %w", cs, err) - } - challenges = append(challenges, challenge) - } - } - return challenges, nil -} - -// splitChallenges splits a header value containing one or more challenges. -// It correctly handles commas within quoted strings and distinguishes between -// commas separating auth-params and commas separating challenges. -func splitChallenges(header string) ([]string, error) { - // GENERATED BY GEMINI 2.5. - var challenges []string - inQuotes := false - start := 0 - for i, r := range header { - if r == '"' { - if i > 0 && header[i-1] != '\\' { - inQuotes = !inQuotes - } else if i == 0 { - // A challenge begins with an auth-scheme, which is a token, which cannot contain - // a quote. - return nil, errors.New(`challenge begins with '"'`) - } - } else if r == ',' && !inQuotes { - // This is a potential challenge separator. - // A new challenge does not start with `key=value`. - // We check if the part after the comma looks like a parameter. - lookahead := strings.TrimSpace(header[i+1:]) - eqPos := strings.Index(lookahead, "=") - - isParam := false - if eqPos > 0 { - // Check if the part before '=' is a single token (no spaces). - token := lookahead[:eqPos] - if strings.IndexFunc(token, unicode.IsSpace) == -1 { - isParam = true - } - } - - if !isParam { - // The part after the comma does not look like a parameter, - // so this comma separates challenges. - challenges = append(challenges, header[start:i]) - start = i + 1 - } - } - } - // Add the last (or only) challenge to the list. - challenges = append(challenges, header[start:]) - return challenges, nil -} - -// parseSingleChallenge parses a string containing exactly one challenge. -// challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] -func parseSingleChallenge(s string) (challenge, error) { - // GENERATED BY GEMINI 2.5, human-tweaked. - s = strings.TrimSpace(s) - if s == "" { - return challenge{}, errors.New("empty challenge string") - } - - scheme, paramsStr, found := strings.Cut(s, " ") - c := challenge{Scheme: strings.ToLower(scheme)} - if !found { - return c, nil - } - - params := make(map[string]string) - - // Parse the key-value parameters. - for paramsStr != "" { - // Find the end of the parameter key. - keyEnd := strings.Index(paramsStr, "=") - if keyEnd <= 0 { - return challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) - } - key := strings.TrimSpace(paramsStr[:keyEnd]) - - // Move the string past the key and the '='. - paramsStr = strings.TrimSpace(paramsStr[keyEnd+1:]) - - var value string - if strings.HasPrefix(paramsStr, "\"") { - // The value is a quoted string. - paramsStr = paramsStr[1:] // Consume the opening quote. - var valBuilder strings.Builder - i := 0 - for ; i < len(paramsStr); i++ { - // Handle escaped characters. - if paramsStr[i] == '\\' && i+1 < len(paramsStr) { - valBuilder.WriteByte(paramsStr[i+1]) - i++ // We've consumed two characters. - } else if paramsStr[i] == '"' { - // End of the quoted string. - break - } else { - valBuilder.WriteByte(paramsStr[i]) - } - } - - // A quoted string must be terminated. - if i == len(paramsStr) { - return challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") - } - - value = valBuilder.String() - // Move the string past the value and the closing quote. - paramsStr = strings.TrimSpace(paramsStr[i+1:]) - } else { - // The value is a token. It ends at the next comma or the end of the string. - commaPos := strings.Index(paramsStr, ",") - if commaPos == -1 { - value = paramsStr - paramsStr = "" - } else { - value = strings.TrimSpace(paramsStr[:commaPos]) - paramsStr = strings.TrimSpace(paramsStr[commaPos:]) // Keep comma for next check - } - } - if value == "" { - return challenge{}, fmt.Errorf("no value for auth param %q", key) - } - - // Per RFC 9110, parameter keys are case-insensitive. - params[strings.ToLower(key)] = value - - // If there is a comma, consume it and continue to the next parameter. - if strings.HasPrefix(paramsStr, ",") { - paramsStr = strings.TrimSpace(paramsStr[1:]) - } else if paramsStr != "" { - // If there's content but it's not a new parameter, the format is wrong. - return challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) - } - } - - // Per RFC 9110, the scheme is case-insensitive. - return challenge{Scheme: strings.ToLower(scheme), Params: params}, nil -} diff --git a/oauthex/resource_meta_public.go b/oauthex/resource_meta_public.go new file mode 100644 index 00000000..44ad3e79 --- /dev/null +++ b/oauthex/resource_meta_public.go @@ -0,0 +1,198 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Protected Resource Metadata. +// See https://www.rfc-editor.org/rfc/rfc9728.html. + +// This is a temporary file to expose the required objects to the main package. + +package oauthex + +import ( + "errors" + "fmt" + "strings" + "unicode" +) + +// challenge represents a single authentication challenge from a WWW-Authenticate header. +// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. +type challenge struct { + // GENERATED BY GEMINI 2.5. + // + // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). + // It is case-insensitive. A parsed value will always be lower-case. + Scheme string + // Params is a map of authentication parameters. + // Keys are case-insensitive. Parsed keys are always lower-case. + Params map[string]string +} + +func Error(cs []challenge) string { + for _, c := range cs { + if c.Scheme == "bearer" && c.Params["error"] != "" { + return c.Params["error"] + } + } + return "" +} + +// ParseWWWAuthenticate parses a WWW-Authenticate header string. +// The header format is defined in RFC 9110, Section 11.6.1, and can contain +// one or more challenges, separated by commas. +// It returns a slice of challenges or an error if one of the headers is malformed. +func ParseWWWAuthenticate(headers []string) ([]challenge, error) { + // GENERATED BY GEMINI 2.5 (human-tweaked) + var challenges []challenge + for _, h := range headers { + challengeStrings, err := splitChallenges(h) + if err != nil { + return nil, err + } + for _, cs := range challengeStrings { + if strings.TrimSpace(cs) == "" { + continue + } + challenge, err := parseSingleChallenge(cs) + if err != nil { + return nil, fmt.Errorf("failed to parse challenge %q: %w", cs, err) + } + challenges = append(challenges, challenge) + } + } + return challenges, nil +} + +// splitChallenges splits a header value containing one or more challenges. +// It correctly handles commas within quoted strings and distinguishes between +// commas separating auth-params and commas separating challenges. +func splitChallenges(header string) ([]string, error) { + // GENERATED BY GEMINI 2.5. + var challenges []string + inQuotes := false + start := 0 + for i, r := range header { + if r == '"' { + if i > 0 && header[i-1] != '\\' { + inQuotes = !inQuotes + } else if i == 0 { + // A challenge begins with an auth-scheme, which is a token, which cannot contain + // a quote. + return nil, errors.New(`challenge begins with '"'`) + } + } else if r == ',' && !inQuotes { + // This is a potential challenge separator. + // A new challenge does not start with `key=value`. + // We check if the part after the comma looks like a parameter. + lookahead := strings.TrimSpace(header[i+1:]) + eqPos := strings.Index(lookahead, "=") + + isParam := false + if eqPos > 0 { + // Check if the part before '=' is a single token (no spaces). + token := lookahead[:eqPos] + if strings.IndexFunc(token, unicode.IsSpace) == -1 { + isParam = true + } + } + + if !isParam { + // The part after the comma does not look like a parameter, + // so this comma separates challenges. + challenges = append(challenges, header[start:i]) + start = i + 1 + } + } + } + // Add the last (or only) challenge to the list. + challenges = append(challenges, header[start:]) + return challenges, nil +} + +// parseSingleChallenge parses a string containing exactly one challenge. +// challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] +func parseSingleChallenge(s string) (challenge, error) { + // GENERATED BY GEMINI 2.5, human-tweaked. + s = strings.TrimSpace(s) + if s == "" { + return challenge{}, errors.New("empty challenge string") + } + + scheme, paramsStr, found := strings.Cut(s, " ") + c := challenge{Scheme: strings.ToLower(scheme)} + if !found { + return c, nil + } + + params := make(map[string]string) + + // Parse the key-value parameters. + for paramsStr != "" { + // Find the end of the parameter key. + keyEnd := strings.Index(paramsStr, "=") + if keyEnd <= 0 { + return challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) + } + key := strings.TrimSpace(paramsStr[:keyEnd]) + + // Move the string past the key and the '='. + paramsStr = strings.TrimSpace(paramsStr[keyEnd+1:]) + + var value string + if strings.HasPrefix(paramsStr, "\"") { + // The value is a quoted string. + paramsStr = paramsStr[1:] // Consume the opening quote. + var valBuilder strings.Builder + i := 0 + for ; i < len(paramsStr); i++ { + // Handle escaped characters. + if paramsStr[i] == '\\' && i+1 < len(paramsStr) { + valBuilder.WriteByte(paramsStr[i+1]) + i++ // We've consumed two characters. + } else if paramsStr[i] == '"' { + // End of the quoted string. + break + } else { + valBuilder.WriteByte(paramsStr[i]) + } + } + + // A quoted string must be terminated. + if i == len(paramsStr) { + return challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") + } + + value = valBuilder.String() + // Move the string past the value and the closing quote. + paramsStr = strings.TrimSpace(paramsStr[i+1:]) + } else { + // The value is a token. It ends at the next comma or the end of the string. + commaPos := strings.Index(paramsStr, ",") + if commaPos == -1 { + value = paramsStr + paramsStr = "" + } else { + value = strings.TrimSpace(paramsStr[:commaPos]) + paramsStr = strings.TrimSpace(paramsStr[commaPos:]) // Keep comma for next check + } + } + if value == "" { + return challenge{}, fmt.Errorf("no value for auth param %q", key) + } + + // Per RFC 9110, parameter keys are case-insensitive. + params[strings.ToLower(key)] = value + + // If there is a comma, consume it and continue to the next parameter. + if strings.HasPrefix(paramsStr, ",") { + paramsStr = strings.TrimSpace(paramsStr[1:]) + } else if paramsStr != "" { + // If there's content but it's not a new parameter, the format is wrong. + return challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) + } + } + + // Per RFC 9110, the scheme is case-insensitive. + return challenge{Scheme: strings.ToLower(scheme), Params: params}, nil +} diff --git a/scripts/server-conformance.sh b/scripts/server-conformance.sh index 8ab4a3bf..0826086a 100755 --- a/scripts/server-conformance.sh +++ b/scripts/server-conformance.sh @@ -67,7 +67,7 @@ else fi # Build the conformance server. -go build -tags mcp_go_client_oauth -o "$WORKDIR/conformance-server" ./conformance/everything-server +go build -o "$WORKDIR/conformance-server" ./conformance/everything-server # Start the server in the background echo "Starting conformance server on port $PORT..." From bd4be14b900bb420f2f98133d16173f6b3601249 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Mon, 16 Feb 2026 15:43:36 +0000 Subject: [PATCH 04/26] Separate conformance client into two files, one protected by the tag. --- .../everything-client/client_private.go | 168 ++++++++++++++++++ conformance/everything-client/main.go | 148 --------------- 2 files changed, 168 insertions(+), 148 deletions(-) create mode 100644 conformance/everything-client/client_private.go diff --git a/conformance/everything-client/client_private.go b/conformance/everything-client/client_private.go new file mode 100644 index 00000000..9fccb164 --- /dev/null +++ b/conformance/everything-client/client_private.go @@ -0,0 +1,168 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The conformance client implements features required for MCP conformance testing. +// It mirrors the functionality of the TypeScript conformance client at +// https://github.com/modelcontextprotocol/typescript-sdk/blob/main/src/conformance/everything-client.ts + +//go:build mcp_go_client_oauth + +package main + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +func init() { + authScenarios := []string{ + "auth/2025-03-26-oauth-metadata-backcompat", + "auth/2025-03-26-oauth-endpoint-fallback", + "auth/basic-cimd", + "auth/metadata-default", + "auth/metadata-var1", + "auth/metadata-var2", + "auth/metadata-var3", + "auth/pre-registration", + "auth/resource-mismatch", + "auth/scope-from-www-authenticate", + "auth/scope-from-scopes-supported", + "auth/scope-omitted-when-undefined", + "auth/scope-step-up", + "auth/scope-retry-limit", + "auth/token-endpoint-auth-basic", + "auth/token-endpoint-auth-post", + "auth/token-endpoint-auth-none", + } + for _, scenario := range authScenarios { + registerScenario(scenario, runAuthClient) + } +} + +// ============================================================================ +// Auth scenarios +// ============================================================================ + +func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (code, state string, err error) { + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + req, err := http.NewRequestWithContext(ctx, "GET", authURL, nil) + if err != nil { + return "", "", err + } + + resp, err := client.Do(req) + if err != nil { + return "", "", err + } + defer resp.Body.Close() + + location := resp.Header.Get("Location") + if location == "" { + return "", "", fmt.Errorf("no Location header in redirect") + } + + locURL, err := url.Parse(location) + if err != nil { + return "", "", fmt.Errorf("parse location: %v", err) + } + + code = locURL.Query().Get("code") + if code == "" { + return "", "", fmt.Errorf("no code parameter in redirect URL") + } + state = locURL.Query().Get("state") + if state == "" { + return "", "", fmt.Errorf("no state parameter in redirect URL") + } + + return code, state, nil +} + +func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]any) error { + authHandler := &auth.AuthorizationCodeOAuthHandler{ + RedirectURL: "http://localhost:3000/callback", + // Try client ID metadata document based registration. + ClientIDMetadataDocumentConfig: &auth.ClientIDMetadataDocumentConfig{ + URL: "https://conformance-test.local/client-metadata.json", + }, + // Try dynamic client registration. + DynamicClientRegistrationConfig: &auth.DynamicClientRegistrationConfig{ + Metadata: &oauthex.ClientRegistrationMetadata{ + RedirectURIs: []string{"http://localhost:3000/callback"}, + }, + }, + } + // Try pre-registered client information if provided in the context. + if clientId, ok := configCtx["client_id"].(string); ok { + if clientSecret, ok := configCtx["client_secret"].(string); ok { + authHandler.PreregisteredClientConfig = &auth.PreregisteredClientConfig{ + ClientID: clientId, + ClientSecret: clientSecret, + } + } + } + + authHandler.AuthorizationURLHandler = func(ctx context.Context, authURL string) error { + // Normally this handler would trigger user browser to be opened. + // Here we query the authorization URL automatically and the AS is configured + // to authorize and redirect immediately. We save the resulting code. + code, state, err := fetchAuthorizationCodeAndState(ctx, authURL) + if err != nil { + return err + } + if err := authHandler.FinalizeAuthorization(code, state); err != nil { + return err + } + return nil + } + + session, err := connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) + if err != nil { + if !errors.Is(err, auth.ErrRedirected) { + return err + } + // Received auth.ErrRedirected. Normally we would wait for the callback triggered + // by the AS redirect to RedirectURL, but here we already have the authorization code + // so we can immediately retry. + session, err = connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) + if err != nil { + return nil + } + } + defer session.Close() + + if _, err := session.ListTools(ctx, nil); err != nil { + // Retry for the scope step-up scenario. + if !errors.Is(err, auth.ErrRedirected) { + return fmt.Errorf("session.ListTools(): %v", err) + } + // Received auth.ErrRedirected. Normally we would wait for the callback triggered + // by the AS redirect to RedirectURL, but here we already have the authorization code + // so we can immediately retry. + _, err = session.ListTools(ctx, nil) + if err != nil { + return fmt.Errorf("session.ListTools(): %v", err) + } + } + + if _, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "test-tool", + Arguments: map[string]any{}, + }); err != nil { + return fmt.Errorf("session.CallTool('test-tool'): %v", err) + } + + return nil +} diff --git a/conformance/everything-client/main.go b/conformance/everything-client/main.go index c6ec1f50..857c9f62 100644 --- a/conformance/everything-client/main.go +++ b/conformance/everything-client/main.go @@ -10,11 +10,8 @@ package main import ( "context" "encoding/json" - "errors" "fmt" "log" - "net/http" - "net/url" "os" "slices" "sort" @@ -22,7 +19,6 @@ import ( "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/mcp" - "github.com/modelcontextprotocol/go-sdk/oauthex" ) // scenarioHandler is the function signature for all conformance test scenarios. @@ -48,30 +44,6 @@ func init() { registerScenario("tools_call", runToolsCallClient) registerScenario("elicitation-sep1034-client-defaults", runElicitationDefaultsClient) registerScenario("sse-retry", runSSERetryClient) - - authScenarios := []string{ - "auth/2025-03-26-oauth-metadata-backcompat", - "auth/2025-03-26-oauth-endpoint-fallback", - "auth/basic-cimd", - "auth/metadata-default", - "auth/metadata-var1", - "auth/metadata-var2", - "auth/metadata-var3", - "auth/resource-mismatch", - "auth/scope-from-www-authenticate", - "auth/scope-from-scopes-supported", - "auth/scope-omitted-when-undefined", - "auth/scope-step-up", - "auth/scope-retry-limit", - "auth/token-endpoint-auth-basic", - "auth/token-endpoint-auth-post", - "auth/token-endpoint-auth-none", - } - for _, scenario := range authScenarios { - registerScenario(scenario, runAuthClient) - } - - registerScenario("auth/pre-registration", runAuthClient) } // ============================================================================ @@ -204,126 +176,6 @@ func runSSERetryClient(ctx context.Context, serverURL string, _ map[string]any) return nil } -// ============================================================================ -// Auth scenarios -// ============================================================================ - -func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (code, state string, err error) { - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - req, err := http.NewRequestWithContext(ctx, "GET", authURL, nil) - if err != nil { - return "", "", err - } - - resp, err := client.Do(req) - if err != nil { - return "", "", err - } - defer resp.Body.Close() - - location := resp.Header.Get("Location") - if location == "" { - return "", "", fmt.Errorf("no Location header in redirect") - } - - locURL, err := url.Parse(location) - if err != nil { - return "", "", fmt.Errorf("parse location: %v", err) - } - - code = locURL.Query().Get("code") - if code == "" { - return "", "", fmt.Errorf("no code parameter in redirect URL") - } - state = locURL.Query().Get("state") - if state == "" { - return "", "", fmt.Errorf("no state parameter in redirect URL") - } - - return code, state, nil -} - -func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]any) error { - authHandler := &auth.AuthorizationCodeOAuthHandler{ - RedirectURL: "http://localhost:3000/callback", - // Try client ID metadata document based registration. - ClientIDMetadataDocumentConfig: &auth.ClientIDMetadataDocumentConfig{ - URL: "https://conformance-test.local/client-metadata.json", - }, - // Try dynamic client registration. - DynamicClientRegistrationConfig: &auth.DynamicClientRegistrationConfig{ - Metadata: &oauthex.ClientRegistrationMetadata{ - RedirectURIs: []string{"http://localhost:3000/callback"}, - }, - }, - } - // Try pre-registered client information if provided in the context. - if clientId, ok := configCtx["client_id"].(string); ok { - if clientSecret, ok := configCtx["client_secret"].(string); ok { - authHandler.PreregisteredClientConfig = &auth.PreregisteredClientConfig{ - ClientID: clientId, - ClientSecret: clientSecret, - } - } - } - - authHandler.AuthorizationURLHandler = func(ctx context.Context, authURL string) error { - // Normally this handler would trigger user browser to be opened. - // Here we query the authorization URL automatically and the AS is configured - // to authorize and redirect immediately. We save the resulting code. - code, state, err := fetchAuthorizationCodeAndState(ctx, authURL) - if err != nil { - return err - } - if err := authHandler.FinalizeAuthorization(code, state); err != nil { - return err - } - return nil - } - - session, err := connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) - if err != nil { - if !errors.Is(err, auth.ErrRedirected) { - return err - } - // Received auth.ErrRedirected. Normally we would wait for the callback triggered - // by the AS redirect to RedirectURL, but here we already have the authorization code - // so we can immediately retry. - session, err = connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) - if err != nil { - return nil - } - } - defer session.Close() - - if _, err := session.ListTools(ctx, nil); err != nil { - // Retry for the scope step-up scenario. - if !errors.Is(err, auth.ErrRedirected) { - return fmt.Errorf("session.ListTools(): %v", err) - } - // Received auth.ErrRedirected. Normally we would wait for the callback triggered - // by the AS redirect to RedirectURL, but here we already have the authorization code - // so we can immediately retry. - _, err = session.ListTools(ctx, nil) - if err != nil { - return fmt.Errorf("session.ListTools(): %v", err) - } - } - - if _, err := session.CallTool(ctx, &mcp.CallToolParams{ - Name: "test-tool", - Arguments: map[string]any{}, - }); err != nil { - return fmt.Errorf("session.CallTool('test-tool'): %v", err) - } - - return nil -} - // ============================================================================ // Main entry point // ============================================================================ From 21bf098f33f360d725114ff9863081cef48dd07b Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Mon, 16 Feb 2026 15:49:07 +0000 Subject: [PATCH 05/26] Fix conformance tests. --- .github/workflows/conformance.yml | 2 +- conformance/everything-client/client_private.go | 6 ++++++ conformance/everything-client/main.go | 6 ------ 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index 6d4b9950..764e21d8 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -49,6 +49,6 @@ jobs: uses: modelcontextprotocol/conformance@c2f3fdaf781dcd5a862cb0d2f6454c1c210bf0f0 # v0.1.11 with: mode: client - command: go run -tags mcp_go_client_oauth ./conformance/everything-client/main.go + command: go run -tags mcp_go_client_oauth ./conformance/everything-client suite: core expected-failures: ./conformance/baseline.yml diff --git a/conformance/everything-client/client_private.go b/conformance/everything-client/client_private.go index 9fccb164..c7ed65f6 100644 --- a/conformance/everything-client/client_private.go +++ b/conformance/everything-client/client_private.go @@ -166,3 +166,9 @@ func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]a return nil } + +func withOAuthHandler(handler auth.OAuthHandler) connectOption { + return func(c *connectConfig) { + c.oauthHandler = handler + } +} diff --git a/conformance/everything-client/main.go b/conformance/everything-client/main.go index 857c9f62..d34e8328 100644 --- a/conformance/everything-client/main.go +++ b/conformance/everything-client/main.go @@ -238,12 +238,6 @@ func withClientOptions(opts *mcp.ClientOptions) connectOption { } } -func withOAuthHandler(handler auth.OAuthHandler) connectOption { - return func(c *connectConfig) { - c.oauthHandler = handler - } -} - // connectToServer connects to the MCP server and returns a client session. // The caller is responsible for closing the session. func connectToServer(ctx context.Context, serverURL string, opts ...connectOption) (*mcp.ClientSession, error) { From ff406be3d122f8793ff444a1bf159096c0e7036c Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Wed, 18 Feb 2026 14:23:13 +0000 Subject: [PATCH 06/26] all: remove token persistence APIs and add unexported method to OAuthHandler. --- auth/authorization_code.go | 16 +++------------- auth/client.go | 2 ++ auth/client_private.go | 38 -------------------------------------- auth/fake.go | 27 +++++++++++++++++++++++++++ mcp/streamable_test.go | 15 +-------------- 5 files changed, 33 insertions(+), 65 deletions(-) create mode 100644 auth/fake.go diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 8ea20772..38e1a327 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -91,9 +91,6 @@ type AuthorizationCodeOAuthHandler struct { // The state should be validated on the redirect callback. StateProvider func() string - // TokenStore is an optional object that allows persistent storage of tokens. - TokenStore TokenStore - // resolvedClientConfig used during the authorization flow. resolvedClientConfig *resolvedClientConfig // tokenSource is the token source to use for authorization. @@ -107,6 +104,8 @@ type AuthorizationCodeOAuthHandler struct { state string } +func (h *AuthorizationCodeOAuthHandler) isOAuthHandler() {} + func (h *AuthorizationCodeOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { return h.tokenSource, nil } @@ -205,12 +204,7 @@ func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http if err != nil { return fmt.Errorf("token exchange failed: %w", err) } - ts := cfg.TokenSource(ctx, token) - if h.TokenStore != nil { - // Persist the returned tokens to the store if requested. - ts = NewPersistentTokenSource(ctx, ts, h.TokenStore) - } - h.tokenSource = ts + h.tokenSource = cfg.TokenSource(ctx, token) return nil } @@ -235,10 +229,6 @@ func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http return ErrRedirected } -func (h *AuthorizationCodeOAuthHandler) SetTokenSource(ts oauth2.TokenSource) { - h.tokenSource = ts -} - func (h *AuthorizationCodeOAuthHandler) FinalizeAuthorization(code, state string) error { defer func() { // State has been used for validation, clear it. diff --git a/auth/client.go b/auth/client.go index aa921fbb..3aa6d29b 100644 --- a/auth/client.go +++ b/auth/client.go @@ -16,6 +16,8 @@ import ( var ErrUnauthorized = errors.New("unauthorized") type OAuthHandler interface { + isOAuthHandler() + // TokenSource returns a token source to be used for outgoing requests. TokenSource(context.Context) (oauth2.TokenSource, error) diff --git a/auth/client_private.go b/auth/client_private.go index d8e633e9..f161bdc6 100644 --- a/auth/client_private.go +++ b/auth/client_private.go @@ -8,7 +8,6 @@ package auth import ( "bytes" - "context" "errors" "io" "net/http" @@ -25,43 +24,6 @@ import ( // into the streamable transport. type OAuthHandlerLegacy func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) -// TokenStore is an interface than can be used by OAuthHandler implementations -// to save tokens to a persistent storage. -type TokenStore interface { - Save(context.Context, *oauth2.Token) error -} - -type persistentTokenSource struct { - wrapped oauth2.TokenSource - store TokenStore - ctx context.Context -} - -// NewPersistentTokenSource returns a [oauth2.TokenSource] that -// persists the token to a given [TokenStore] after every successful -// [oauth2.TokenSource.Token] call. -// It is especially useful when wrapping a [oauth2.TokenSource] -// that automatically refreshes the token when it expires. -// The passed context is used for [TokenStore.Save] calls. -func NewPersistentTokenSource(ctx context.Context, wrapped oauth2.TokenSource, store TokenStore) oauth2.TokenSource { - return &persistentTokenSource{ - wrapped: wrapped, - store: store, - ctx: ctx, - } -} - -func (t *persistentTokenSource) Token() (*oauth2.Token, error) { - token, err := t.wrapped.Token() - if err != nil { - return nil, err - } - if err := t.store.Save(t.ctx, token); err != nil { - return nil, err - } - return token, nil -} - // HTTPTransport is an [http.RoundTripper] that follows the MCP // OAuth protocol when it encounters a 401 Unauthorized response. // Deprecated: Please use the new OAuthHandler abstraction that is built diff --git a/auth/fake.go b/auth/fake.go new file mode 100644 index 00000000..b8d82f33 --- /dev/null +++ b/auth/fake.go @@ -0,0 +1,27 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package auth + +import ( + "context" + "net/http" + + "golang.org/x/oauth2" +) + +type FakeOAuthHandler struct { + Token *oauth2.Token + AuthorizeErr error +} + +func (h *FakeOAuthHandler) isOAuthHandler() {} + +func (h *FakeOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return oauth2.StaticTokenSource(h.Token), nil +} + +func (h *FakeOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + return h.AuthorizeErr +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d9361243..11089535 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1667,19 +1667,6 @@ func textContent(t *testing.T, res *CallToolResult) string { return text.Text } -type testOAuthHandler struct { - token string -} - -func (h *testOAuthHandler) TokenSource(context.Context) (oauth2.TokenSource, error) { - return oauth2.StaticTokenSource(&oauth2.Token{AccessToken: h.token}), nil -} - -func (h *testOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { - // 401 resonse is not expected in this test. We can simply fail. - return errors.New("unexpected 401") -} - func TestTokenInfo(t *testing.T) { ctx := context.Background() @@ -1707,7 +1694,7 @@ func TestTokenInfo(t *testing.T) { transport := &StreamableClientTransport{ Endpoint: httpServer.URL, - OAuthHandler: &testOAuthHandler{token: "test-token"}, + OAuthHandler: &auth.FakeOAuthHandler{Token: &oauth2.Token{AccessToken: "test-token"}}, } client := NewClient(testImpl, nil) session, err := client.Connect(ctx, transport, nil) From 5be070b9a96d58945de6e879349a238135e09e00 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Thu, 19 Feb 2026 17:24:29 +0000 Subject: [PATCH 07/26] auth: improve the structure in authorization_code.go. --- auth/authorization_code.go | 296 ++++++++++++++++++++++--------------- 1 file changed, 173 insertions(+), 123 deletions(-) diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 38e1a327..a27b8f06 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -36,9 +36,12 @@ type ClientIDMetadataDocumentConfig struct { // PreregisteredClientConfig is used to configure a pre-registered client per // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration. type PreregisteredClientConfig struct { + // ClientID and ClientSecret to be used for client authentication. ClientID string ClientSecret string - AuthStyle oauth2.AuthStyle + // AuthStyle is an optional client authentication method. + // See [oauth2.AuthStyleAutoDetect] for the documentation of the zero value. + AuthStyle oauth2.AuthStyle } // DynamicClientRegistrationConfig is used to configure dynamic client registration per @@ -64,26 +67,35 @@ type resolvedClientConfig struct { authStyle oauth2.AuthStyle } +// AuthorizationCodeOAuthHandler is an implementation of [OAuthHandler] that uses +// the authorization code flow to obtain access tokens. +// The handler is stateful and can only handle one authorization flow at a time. type AuthorizationCodeOAuthHandler struct { // Client registration configuration. // It is attempted in the following order: - // 1. Client ID Metadata Document - // 2. Preregistration - // 3. Dynamic Client Registration + // + // 1. Client ID Metadata Document + // 2. Preregistration + // 3. Dynamic Client Registration + // + // At least one method must be configured. ClientIDMetadataDocumentConfig *ClientIDMetadataDocumentConfig PreregisteredClientConfig *PreregisteredClientConfig DynamicClientRegistrationConfig *DynamicClientRegistrationConfig - // RedirectURL is the URL to redirect to after authorization. + // RedirectURL is a required URL to redirect to after authorization. + // The caller is responsible for handling the redirect out of band. // If Dynamic Client Registration is used, the RedirectURL must be consistent // with [DynamicClientRegistrationConfig.Metadata.RedirectURIs]. RedirectURL string - // AuthorizationURLHandler is called to handle the authorization URL. - // It is responsible for opening the URL in a browser. - // It should return once the redirect has been issued. - // The redirect callback should be handled by the caller and the authorization code - // should be set by calling [SetAuthorizationCode] before retrying the request. + // AuthorizationURLHandler is a required function called to handle the authorization URL. + // It is responsible for opening the URL in a browser for the user to start the authorization. + // It should return once the proccess has been initiated and the URL was + // presented to the user successfully. Once the Authorizatin Server redirects back + // to the [AuthorizationCodeOAuthHandler.RedirectURL], the caller should set + // the authorization code by calling [AuthorizationCodeOAuthHandler.SetAuthorizationCode] + // before retrying the request. AuthorizationURLHandler func(ctx context.Context, authorizationURL string) error // StateProvider is an optional function to generate a state string for authorization @@ -94,7 +106,7 @@ type AuthorizationCodeOAuthHandler struct { // resolvedClientConfig used during the authorization flow. resolvedClientConfig *resolvedClientConfig // tokenSource is the token source to use for authorization. - // It can be prepopulated by calling [SetTokenSource]. + // It can be prepopulated by calling [AuthorizationCodeOAuthHandler.SetTokenSource]. tokenSource oauth2.TokenSource // codeVerifier is the PKCE code verifier. codeVerifier string @@ -104,30 +116,37 @@ type AuthorizationCodeOAuthHandler struct { state string } +var _ OAuthHandler = (*AuthorizationCodeOAuthHandler)(nil) + func (h *AuthorizationCodeOAuthHandler) isOAuthHandler() {} func (h *AuthorizationCodeOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { return h.tokenSource, nil } -// TODO: extract some logic into helper functions. -// TODO: validate required args +// Authorize performs the authorization flow. +// It is designed to be reentrant and called in two phases. +// 1. It initiates the Authorization Grant flow by calling [AuthorizationCodeOAuthHandler.AuthorizationURLHandler]. +// It will return [ErrRedirected] if the authorization flow was initiated successfully. +// 2. It exchanges the authorization code for an access token. +// It will return a `nil` error if the authorization flow was completed successfully. +// From this point on, [AuthorizationCodeOAuthHandler.TokenSource] will return a token source with the fetched token. func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { defer resp.Body.Close() log.Printf("Authorize: %s %s", req.Method, req.URL) - - if h.resolvedClientConfig == nil && h.authorizationCode != "" { - return fmt.Errorf("exchanging authorization code with unregistered client is not allowed") + if err := h.validate(); err != nil { + return err } resourceURL := req.URL.String() - challenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) + wwwChallenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) if err != nil { return fmt.Errorf("failed to parse WWW-Authenticate header: %v", err) } - log.Printf("WWW-Authenticate header: %v", challenges) + + log.Printf("WWW-Authenticate header: %v", wwwChallenges) var prm *oauthex.ProtectedResourceMetadata - for _, url := range oauthex.ProtectedResourceMetadataURLs(oauthex.ResourceMetadataURL(challenges), resourceURL) { + for _, url := range oauthex.ProtectedResourceMetadataURLs(oauthex.ResourceMetadataURL(wwwChallenges), resourceURL) { var err error log.Printf("Getting protected resource metadata from %q", url) prm, err = oauthex.GetProtectedResourceMetadata(ctx, url, http.DefaultClient) @@ -136,53 +155,27 @@ func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http } log.Printf("Failed to get protected resource metadata from %q: %v", url, err) } - var authServerURL string - if prm != nil && len(prm.AuthorizationServers) > 0 { - // Use the first authorization server, similarly to other SDKs. - authServerURL = prm.AuthorizationServers[0] - } else { - // Fallback to 2025-03-26 spec: MCP server base URL acts as Authorization Server. - authURL, err := url.Parse(resourceURL) - if err != nil { - return fmt.Errorf("failed to parse resource URL: %v", err) - } - authURL.Path = "" - authServerURL = authURL.String() - } - log.Printf("Authorization server URL: %s", authServerURL) - - asm, err := oauthex.GetAuthServerMeta(ctx, authServerURL, http.DefaultClient) + asm, err := h.getAuthServerMetadata(ctx, prm, resourceURL) if err != nil { - return fmt.Errorf("failed to get authorization server metadata: %w", err) + return err } - log.Print("Authorization server medatada fetched") - if err := h.handleRegistration(ctx, authServerURL, asm); err != nil { + if err := h.handleRegistration(ctx, asm); err != nil { return err } - scopes := oauthex.Scopes(challenges) + scopes := oauthex.Scopes(wwwChallenges) if len(scopes) == 0 && prm != nil && len(prm.ScopesSupported) > 0 { scopes = prm.ScopesSupported } - var authorizationEndpoint, tokenEndpoint string - if asm != nil { - authorizationEndpoint = asm.AuthorizationEndpoint - tokenEndpoint = asm.TokenEndpoint - } else { - // Fallback to 2025-03-26 spec: predefined endpoints if not provided by AS. - authorizationEndpoint = authServerURL + "/authorize" - tokenEndpoint = authServerURL + "/token" - } - cfg := &oauth2.Config{ ClientID: h.resolvedClientConfig.clientID, ClientSecret: h.resolvedClientConfig.clientSecret, Endpoint: oauth2.Endpoint{ - AuthURL: authorizationEndpoint, - TokenURL: tokenEndpoint, + AuthURL: asm.AuthorizationEndpoint, + TokenURL: asm.TokenEndpoint, // TODO: validate if the auth style is supported by the AS. AuthStyle: h.resolvedClientConfig.authStyle, }, @@ -191,42 +184,10 @@ func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http } if h.authorizationCode != "" { - log.Print("Authorization code is available, exchanging for token") - opts := []oauth2.AuthCodeOption{ - oauth2.VerifierOption(h.codeVerifier), - oauth2.SetAuthURLParam("resource", req.URL.String()), - } - token, err := cfg.Exchange(ctx, h.authorizationCode, opts...) - defer func() { - // Authorization code has been consumed, clear it. - h.authorizationCode = "" - }() - if err != nil { - return fmt.Errorf("token exchange failed: %w", err) - } - h.tokenSource = cfg.TokenSource(ctx, token) - return nil + return h.exchangeAuthorizationCode(ctx, cfg, resourceURL) } - h.codeVerifier = oauth2.GenerateVerifier() - h.state = rand.Text() - if h.StateProvider != nil { - h.state = h.StateProvider() - } - - authURL := cfg.AuthCodeURL(h.state, - oauth2.S256ChallengeOption(h.codeVerifier), - oauth2.SetAuthURLParam("resource", req.URL.String()), - ) - - log.Print("No authorization code available, opening authorization URL") - if h.AuthorizationURLHandler != nil { - if err := h.AuthorizationURLHandler(ctx, authURL); err != nil { - return fmt.Errorf("authorization URL handler failed: %w", err) - } - } - - return ErrRedirected + return h.startAuthFlow(ctx, cfg, req.URL.String()) } func (h *AuthorizationCodeOAuthHandler) FinalizeAuthorization(code, state string) error { @@ -241,28 +202,102 @@ func (h *AuthorizationCodeOAuthHandler) FinalizeAuthorization(code, state string return nil } -func (h *AuthorizationCodeOAuthHandler) handleRegistration(ctx context.Context, authServerURL string, asm *oauthex.AuthServerMeta) error { +func (h *AuthorizationCodeOAuthHandler) validate() error { + if h.ClientIDMetadataDocumentConfig == nil && + h.PreregisteredClientConfig == nil && + h.DynamicClientRegistrationConfig == nil { + return errors.New("at least one client registration configuration must be provided") + } + if h.RedirectURL == "" { + return errors.New("field RedirectURL is required") + } + if h.AuthorizationURLHandler == nil { + return errors.New("field AuthorizationURLHandler is required") + } + if h.ClientIDMetadataDocumentConfig != nil && !isNonRootHTTPSURL(h.ClientIDMetadataDocumentConfig.URL) { + return fmt.Errorf("client ID metadata document URL must be a non-root HTTPS URL") + } + if h.PreregisteredClientConfig != nil { + if h.PreregisteredClientConfig.ClientID == "" || h.PreregisteredClientConfig.ClientSecret == "" { + return fmt.Errorf("pre-registered client ID or secret is empty") + } + } + if h.DynamicClientRegistrationConfig != nil { + if h.DynamicClientRegistrationConfig.Metadata == nil { + return errors.New("field Metadata is required for dynamic client registration") + } + if !slices.Contains(h.DynamicClientRegistrationConfig.Metadata.RedirectURIs, h.RedirectURL) { + return fmt.Errorf("redirect URI %q is not in the list of allowed redirect URIs for dynamic client registration", h.RedirectURL) + } + } + if h.resolvedClientConfig == nil && h.authorizationCode != "" { + return fmt.Errorf("exchanging authorization code with unregistered client is not allowed") + } + return nil +} + +func isNonRootHTTPSURL(u string) bool { + pu, err := url.Parse(u) + if err != nil { + return false + } + return pu.Scheme == "https" && pu.Path != "" +} + +// getAuthServerMetadata returns the authorization server metadata. +// If no metadata is available, it returns a minimal set of endpoints +// as a fallback to 2025-03-26 spec. +func (h *AuthorizationCodeOAuthHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata, resourceURL string) (*oauthex.AuthServerMeta, error) { + var authServerURL string + if prm != nil && len(prm.AuthorizationServers) > 0 { + // Use the first authorization server, similarly to other SDKs. + authServerURL = prm.AuthorizationServers[0] + } else { + // Fallback to 2025-03-26 spec: MCP server base URL acts as Authorization Server. + authURL, err := url.Parse(resourceURL) + if err != nil { + return nil, fmt.Errorf("failed to parse resource URL: %v", err) + } + authURL.Path = "" + authServerURL = authURL.String() + } + log.Printf("Authorization server URL: %s", authServerURL) + + asm, err := oauthex.GetAuthServerMeta(ctx, authServerURL, http.DefaultClient) + if err != nil { + return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) + } + log.Print("Authorization server medatada fetched") + if asm == nil { + log.Print("Authorization server metadata not found, using fallback") + // Fallback to 2025-03-26 spec: predefined endpoints. + // https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#fallbacks-for-servers-without-metadata-discovery + asm = &oauthex.AuthServerMeta{ + Issuer: authServerURL, + AuthorizationEndpoint: authServerURL + "/authorize", + TokenEndpoint: authServerURL + "/token", + RegistrationEndpoint: authServerURL + "/register", + } + } + return asm, nil +} + +// handleRegistration handles client registration. +// The provided authorization server metadata must be non-nil. +// It must also have RegistrationEndpoint set if dynamic client registration is supported. +func (h *AuthorizationCodeOAuthHandler) handleRegistration(ctx context.Context, asm *oauthex.AuthServerMeta) error { // 1. Attempt to use Client ID Metadata Document (SEP-991). cimdCfg := h.ClientIDMetadataDocumentConfig - if cimdCfg != nil { - supportsCIMD := asm != nil && asm.ClientIDMetadataDocumentSupported - if supportsCIMD { - if !isNonRootHTTPSURL(cimdCfg.URL) { - return fmt.Errorf("client ID metadata document URL is not a non-root HTTPS URL") - } - h.resolvedClientConfig = &resolvedClientConfig{ - registrationType: registrationTypeClientIDMetadataDocument, - clientID: cimdCfg.URL, - } - return nil + if cimdCfg != nil && asm.ClientIDMetadataDocumentSupported { + h.resolvedClientConfig = &resolvedClientConfig{ + registrationType: registrationTypeClientIDMetadataDocument, + clientID: cimdCfg.URL, } + return nil } - // 2. Attempt to use pre-registered client ID. + // 2. Attempt to use pre-registered client configuration. pCfg := h.PreregisteredClientConfig if pCfg != nil { - if pCfg.ClientID == "" || pCfg.ClientSecret == "" { - return fmt.Errorf("pre-registered client ID or secret is empty") - } h.resolvedClientConfig = &resolvedClientConfig{ registrationType: registrationTypePreregistered, clientID: pCfg.ClientID, @@ -273,22 +308,8 @@ func (h *AuthorizationCodeOAuthHandler) handleRegistration(ctx context.Context, } // 3. Attempt to use dynamic client registration. dcrCfg := h.DynamicClientRegistrationConfig - if dcrCfg != nil { - if !slices.Contains(dcrCfg.Metadata.RedirectURIs, h.RedirectURL) { - return fmt.Errorf("redirect URI %q is not in the list of allowed redirect URIs for dynamic client registration", h.RedirectURL) - } - var registrationEndpoint string - if asm != nil { - if asm.RegistrationEndpoint == "" { - return fmt.Errorf("authorization server does not support dynamic client registration") - } - registrationEndpoint = asm.RegistrationEndpoint - } else { - // Fallback to 2025-03-26 spec: predefined endpoints if not provided by AS. - registrationEndpoint = authServerURL + "/register" - } - log.Printf("Attempting dynamic client registration at %v", registrationEndpoint) - regResp, err := oauthex.RegisterClient(ctx, registrationEndpoint, dcrCfg.Metadata, http.DefaultClient) + if dcrCfg != nil && asm.RegistrationEndpoint != "" { + regResp, err := oauthex.RegisterClient(ctx, asm.RegistrationEndpoint, dcrCfg.Metadata, http.DefaultClient) if err != nil { return fmt.Errorf("failed to register client: %w", err) } @@ -312,15 +333,44 @@ func (h *AuthorizationCodeOAuthHandler) handleRegistration(ctx context.Context, log.Printf("Client registered with client ID: %s", regResp.ClientID) return nil } - return fmt.Errorf("no client registration method configured") + return fmt.Errorf("no configured client registration methods are supported by the authorization server") } -func isNonRootHTTPSURL(u string) bool { - pu, err := url.Parse(u) +// exchangeAuthorizationCode exchanges the authorization code for a token and stores it in a token source. +func (h *AuthorizationCodeOAuthHandler) exchangeAuthorizationCode(ctx context.Context, cfg *oauth2.Config, resourceURL string) error { + log.Print("Authorization code is available, exchanging for token") + opts := []oauth2.AuthCodeOption{ + oauth2.VerifierOption(h.codeVerifier), + oauth2.SetAuthURLParam("resource", resourceURL), + } + defer func() { + // Authorization code has been consumed, clear it. + h.authorizationCode = "" + }() + token, err := cfg.Exchange(ctx, h.authorizationCode, opts...) if err != nil { - return false + return fmt.Errorf("token exchange failed: %w", err) } - return pu.Scheme == "https" && pu.Path != "" + h.tokenSource = cfg.TokenSource(ctx, token) + return nil } -var _ OAuthHandler = (*AuthorizationCodeOAuthHandler)(nil) +// startAuthFlow generates the authorization URL and redirects the user to the authorization server. +func (h *AuthorizationCodeOAuthHandler) startAuthFlow(ctx context.Context, cfg *oauth2.Config, resourceURL string) error { + h.codeVerifier = oauth2.GenerateVerifier() + h.state = rand.Text() + if h.StateProvider != nil { + h.state = h.StateProvider() + } + + authURL := cfg.AuthCodeURL(h.state, + oauth2.S256ChallengeOption(h.codeVerifier), + oauth2.SetAuthURLParam("resource", resourceURL), + ) + + log.Print("No authorization code available, opening authorization URL") + if err := h.AuthorizationURLHandler(ctx, authURL); err != nil { + return fmt.Errorf("authorization URL handler failed: %w", err) + } + return ErrRedirected +} From 18d58384b2b4955dfd273f83d5bebc75b8187b11 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Fri, 20 Feb 2026 12:44:51 +0000 Subject: [PATCH 08/26] examples: add an example to display client OAuth. --- auth/auth.go | 3 + auth/authorization_code.go | 11 ++- examples/auth/client/main.go | 143 ++++++++++++++++++++++++++++++ examples/auth/server/main.go | 167 +++++++++++++++++++++++++++++++++++ oauthex/auth_meta.go | 3 + 5 files changed, 324 insertions(+), 3 deletions(-) create mode 100644 examples/auth/client/main.go create mode 100644 examples/auth/server/main.go diff --git a/auth/auth.go b/auth/auth.go index 87665121..29cca526 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -106,6 +106,9 @@ func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenO } return nil, err.Error(), http.StatusInternalServerError } + if tokenInfo == nil { + return nil, "token validation failed", http.StatusInternalServerError + } // Check scopes. All must be present. if opts != nil { diff --git a/auth/authorization_code.go b/auth/authorization_code.go index a27b8f06..269861be 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -155,13 +155,18 @@ func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http } log.Printf("Failed to get protected resource metadata from %q: %v", url, err) } + // log.Printf("Protected resource metadata: %+v", prm) asm, err := h.getAuthServerMetadata(ctx, prm, resourceURL) if err != nil { return err } + // log.Printf("Authorization server metadata: %+v", asm) - if err := h.handleRegistration(ctx, asm); err != nil { - return err + if h.resolvedClientConfig == nil { + // Client configuration is not resolved yet, try to resolve it. + if err := h.handleRegistration(ctx, asm); err != nil { + return err + } } scopes := oauthex.Scopes(wwwChallenges) @@ -202,6 +207,7 @@ func (h *AuthorizationCodeOAuthHandler) FinalizeAuthorization(code, state string return nil } +// TODO: validate on creation. func (h *AuthorizationCodeOAuthHandler) validate() error { if h.ClientIDMetadataDocumentConfig == nil && h.PreregisteredClientConfig == nil && @@ -267,7 +273,6 @@ func (h *AuthorizationCodeOAuthHandler) getAuthServerMetadata(ctx context.Contex if err != nil { return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) } - log.Print("Authorization server medatada fetched") if asm == nil { log.Print("Authorization server metadata not found, using fallback") // Fallback to 2025-03-26 spec: predefined endpoints. diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go new file mode 100644 index 00000000..d9d4a5e8 --- /dev/null +++ b/examples/auth/client/main.go @@ -0,0 +1,143 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Flags. +var ( + serverURL = flag.String("server_url", "http://localhost:8000/mcp", "Server URL") +) + +type authResult struct { + code string + state string + err error +} + +type codeReceiver struct { + authChan chan authResult + server *http.Server +} + +func (r *codeReceiver) startAuthorizationFlow(ctx context.Context, authorizationURL string) error { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + code := req.URL.Query().Get("code") + state := req.URL.Query().Get("state") + if code == "" { + http.Error(w, "authorization code not found", http.StatusBadRequest) + return + } + + r.authChan <- authResult{ + code: code, + state: state, + } + fmt.Fprint(w, "Authentication successful. You can close this window.") + }) + + r.server = &http.Server{ + Addr: "localhost:3142", + Handler: mux, + } + + go func() { + // We ignore ErrServerClosed as it is returned on Shutdown. + if err := r.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + r.authChan <- authResult{err: fmt.Errorf("server error: %w", err)} + } + }() + + fmt.Printf("Please authorize by visiting: %s\n", authorizationURL) + return nil +} + +func main() { + flag.Parse() + client := mcp.NewClient(&mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, nil) + + receiver := &codeReceiver{ + authChan: make(chan authResult), + } + + authHandler := &auth.AuthorizationCodeOAuthHandler{ + RedirectURL: "http://localhost:3142", + // Uncomment the client configuration you want to use. + // PreregisteredClientConfig: &auth.PreregisteredClientConfig{ + // ClientID: "", + // ClientSecret: "", + // }, + // DynamicClientRegistrationConfig: &auth.DynamicClientRegistrationConfig{ + // Metadata: &oauthex.ClientRegistrationMetadata{ + // ClientName: "Dynamically registered MCP client", + // RedirectURIs: []string{"http://localhost:3142"}, + // Scope: "read", + // }, + // }, + AuthorizationURLHandler: receiver.startAuthorizationFlow, + } + + transport := &mcp.StreamableClientTransport{ + Endpoint: *serverURL, + OAuthHandler: authHandler, + } + + ctx := context.Background() + var session *mcp.ClientSession + var err error + + for { + session, err = client.Connect(ctx, transport, nil) + if err == nil { + break + } + // If the error is ErrRedirected, it means the authorization flow has started + // and we need to wait for the code. + if errors.Is(err, auth.ErrRedirected) { + fmt.Println("Authorization flow started. Waiting for authorization code...") + res := <-receiver.authChan + if res.err != nil { + log.Fatalf("Authorization failed: %v", res.err) + } + + // Shutdown the temporary server + if err := receiver.server.Shutdown(ctx); err != nil { + log.Printf("Failed to shutdown server: %v", err) + } + + if err := authHandler.FinalizeAuthorization(res.code, res.state); err != nil { + log.Fatalf("Failed to finalize authorization: %v", err) + } + continue + } + log.Fatalf("client.Connect(): %v", err) + } + defer session.Close() + + tools, err := session.ListTools(ctx, nil) + if err != nil { + log.Fatalf("session.ListTools(): %v", err) + } + log.Println("Tools:") + for _, tool := range tools.Tools { + log.Printf("- %q", tool.Name) + } +} diff --git a/examples/auth/server/main.go b/examples/auth/server/main.go new file mode 100644 index 00000000..94ad9ae3 --- /dev/null +++ b/examples/auth/server/main.go @@ -0,0 +1,167 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +// Flags. +var ( + port = flag.Int("port", 8000, "Port to listen on") +) + +// Configuration required for this example. +var ( + // Authorization server to return in the protected resource metadata. + authorizationServer = "" + // Introspection endpoint for verifying tokens. + introspectionEndpoint = "" + // Client credentials used in the introspection request. + clientID = "" + clientSecret = "" +) + +func verifyToken(ctx context.Context, token string, _ *http.Request) (*auth.TokenInfo, error) { + data := url.Values{} + data.Set("token", token) + data.Set("token_type_hint", "access_token") + + req, err := http.NewRequestWithContext(ctx, "POST", introspectionEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + req.SetBasicAuth(clientID, clientSecret) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + dump, _ := httputil.DumpResponse(resp, true) + log.Printf("Introspection failed: %s", dump) + return nil, fmt.Errorf("introspection failed with status %d", resp.StatusCode) + } + + var result struct { + Active bool `json:"active"` + Scope string `json:"scope"` + Exp int64 `json:"exp"` + Sub string `json:"sub"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + + if !result.Active { + return nil, auth.ErrInvalidToken + } + + return &auth.TokenInfo{ + Scopes: strings.Fields(result.Scope), + Expiration: time.Unix(result.Exp, 0), + UserID: result.Sub, + }, nil +} + +type args struct { + Input string `json:"input"` +} + +func echo(ctx context.Context, req *mcp.CallToolRequest, args args) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: args.Input}, + }, + }, nil, nil +} + +func main() { + flag.Parse() + metadata := &oauthex.ProtectedResourceMetadata{ + Resource: fmt.Sprintf("http://localhost:%d/mcp", *port), + AuthorizationServers: []string{authorizationServer}, + ScopesSupported: []string{"read"}, + } + http.Handle("/.well-known/oauth-protected-resource", auth.ProtectedResourceMetadataHandler(metadata)) + + server := mcp.NewServer(&mcp.Implementation{ + Name: "test-server", + Version: "1.0.0", + }, nil) + server.AddReceivingMiddleware(createLoggingMiddleware()) + mcp.AddTool(server, &mcp.Tool{Name: "echo"}, echo) + + handler := mcp.NewStreamableHTTPHandler(func(req *http.Request) *mcp.Server { + return server + }, nil) + + authMiddleware := auth.RequireBearerToken(verifyToken, &auth.RequireBearerTokenOptions{ + Scopes: []string{"read"}, + ResourceMetadataURL: fmt.Sprintf("http://localhost:%d/.well-known/oauth-protected-resource", *port), + }) + + http.Handle("/mcp", authMiddleware(handler)) + + log.Printf("Starting server on http://localhost:%d", *port) + log.Fatal(http.ListenAndServe(fmt.Sprintf("localhost:%d", *port), nil)) +} + +// createLoggingMiddleware creates an MCP middleware that logs method calls. +func createLoggingMiddleware() mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func( + ctx context.Context, + method string, + req mcp.Request, + ) (mcp.Result, error) { + start := time.Now() + sessionID := req.GetSession().ID() + + // Log request details. + log.Printf("[REQUEST] Session: %s | Method: %s", + sessionID, + method) + + // Call the actual handler. + result, err := next(ctx, method, req) + + // Log response details. + duration := time.Since(start) + + if err != nil { + log.Printf("[RESPONSE] Session: %s | Method: %s | Status: ERROR | Duration: %v | Error: %v", + sessionID, + method, + duration, + err) + } else { + log.Printf("[RESPONSE] Session: %s | Method: %s | Status: OK | Duration: %v", + sessionID, + method, + duration) + } + + return result, err + } + } +} diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index a6ee07a8..44d7359d 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -13,6 +13,7 @@ import ( "context" "errors" "fmt" + "log" "net/http" "net/url" "strings" @@ -134,6 +135,7 @@ func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (* for _, u := range AuthorizationServerMetadataURLs(issuerURL) { asm, err := getJSON[AuthServerMeta](ctx, c, u, 1<<20) if err != nil { + log.Printf("Failed to get auth server metadata from %q: %v", u, err) var httpErr *httpStatusError if errors.As(err, &httpErr) { if 400 <= httpErr.StatusCode && httpErr.StatusCode < 500 { @@ -156,6 +158,7 @@ func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (* if err := validateAuthServerMetaURLs(asm); err != nil { return nil, err } + log.Printf("Fetched authorization server metadata from %q", u) return asm, nil } From a7c3ab97a55ae69665f49a3897c21b77b8534507 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Mon, 23 Feb 2026 13:03:41 +0000 Subject: [PATCH 09/26] Make AuthorizationURLHandler blocking. --- auth/authorization_code.go | 206 +++++++++--------- .../everything-client/client_private.go | 61 ++---- examples/auth/client/main.go | 97 ++++----- 3 files changed, 158 insertions(+), 206 deletions(-) diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 269861be..d24fc1f8 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -20,9 +20,6 @@ import ( "golang.org/x/oauth2" ) -// ErrRedirected is returned when the user was redirected to the authorization server. -var ErrRedirected = errors.New("redirected") - // ClientIDMetadataDocumentConfig is used to configure the Client ID Metadata Document // based client registration per // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents. @@ -52,24 +49,17 @@ type DynamicClientRegistrationConfig struct { Metadata *oauthex.ClientRegistrationMetadata } -type registrationType int - -const ( - registrationTypeClientIDMetadataDocument registrationType = iota - registrationTypePreregistered - registrationTypeDynamic -) - -type resolvedClientConfig struct { - registrationType registrationType - clientID string - clientSecret string - authStyle oauth2.AuthStyle +// AuthorizationResult is the result of an authorization flow. +// It is returned by [AuthorizationCodeOAuthHandler.AuthorizationURLHandler] implementations. +type AuthorizationResult struct { + // AuthorizationCode is the authorization code obtained from the authorization server. + AuthorizationCode string + // State is the state string returned by the authorization server. + State string } // AuthorizationCodeOAuthHandler is an implementation of [OAuthHandler] that uses // the authorization code flow to obtain access tokens. -// The handler is stateful and can only handle one authorization flow at a time. type AuthorizationCodeOAuthHandler struct { // Client registration configuration. // It is attempted in the following order: @@ -91,29 +81,17 @@ type AuthorizationCodeOAuthHandler struct { // AuthorizationURLHandler is a required function called to handle the authorization URL. // It is responsible for opening the URL in a browser for the user to start the authorization. - // It should return once the proccess has been initiated and the URL was - // presented to the user successfully. Once the Authorizatin Server redirects back - // to the [AuthorizationCodeOAuthHandler.RedirectURL], the caller should set - // the authorization code by calling [AuthorizationCodeOAuthHandler.SetAuthorizationCode] - // before retrying the request. - AuthorizationURLHandler func(ctx context.Context, authorizationURL string) error + // It should return the authorization code and state once the Authorization Server + // redirects back to the [AuthorizationCodeOAuthHandler.RedirectURL]. + AuthorizationURLHandler func(ctx context.Context, authorizationURL string) (*AuthorizationResult, error) // StateProvider is an optional function to generate a state string for authorization // requests. If not provided, a random string will be generated. - // The state should be validated on the redirect callback. + // The state will be validated on the redirect callback. StateProvider func() string - // resolvedClientConfig used during the authorization flow. - resolvedClientConfig *resolvedClientConfig // tokenSource is the token source to use for authorization. - // It can be prepopulated by calling [AuthorizationCodeOAuthHandler.SetTokenSource]. tokenSource oauth2.TokenSource - // codeVerifier is the PKCE code verifier. - codeVerifier string - // authorizationCode is the authorization code obtained from the authorization server. - authorizationCode string - // state is the state string used in the authorization request. - state string } var _ OAuthHandler = (*AuthorizationCodeOAuthHandler)(nil) @@ -125,12 +103,8 @@ func (h *AuthorizationCodeOAuthHandler) TokenSource(ctx context.Context) (oauth2 } // Authorize performs the authorization flow. -// It is designed to be reentrant and called in two phases. -// 1. It initiates the Authorization Grant flow by calling [AuthorizationCodeOAuthHandler.AuthorizationURLHandler]. -// It will return [ErrRedirected] if the authorization flow was initiated successfully. -// 2. It exchanges the authorization code for an access token. -// It will return a `nil` error if the authorization flow was completed successfully. -// From this point on, [AuthorizationCodeOAuthHandler.TokenSource] will return a token source with the fetched token. +// It is designed to perform the whole Authorization Code Grant flow. +// On success, [AuthorizationCodeOAuthHandler.TokenSource] will return a token source with the fetched token. func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { defer resp.Body.Close() log.Printf("Authorize: %s %s", req.Method, req.URL) @@ -162,11 +136,9 @@ func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http } // log.Printf("Authorization server metadata: %+v", asm) - if h.resolvedClientConfig == nil { - // Client configuration is not resolved yet, try to resolve it. - if err := h.handleRegistration(ctx, asm); err != nil { - return err - } + resolvedClientConfig, err := h.handleRegistration(ctx, asm) + if err != nil { + return err } scopes := oauthex.Scopes(wwwChallenges) @@ -175,36 +147,25 @@ func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http } cfg := &oauth2.Config{ - ClientID: h.resolvedClientConfig.clientID, - ClientSecret: h.resolvedClientConfig.clientSecret, + ClientID: resolvedClientConfig.clientID, + ClientSecret: resolvedClientConfig.clientSecret, Endpoint: oauth2.Endpoint{ AuthURL: asm.AuthorizationEndpoint, TokenURL: asm.TokenEndpoint, // TODO: validate if the auth style is supported by the AS. - AuthStyle: h.resolvedClientConfig.authStyle, + AuthStyle: resolvedClientConfig.authStyle, }, RedirectURL: h.RedirectURL, Scopes: scopes, } - if h.authorizationCode != "" { - return h.exchangeAuthorizationCode(ctx, cfg, resourceURL) + authRes, err := h.getAuthorizationCode(ctx, cfg, req.URL.String()) + if err != nil { + return err } - return h.startAuthFlow(ctx, cfg, req.URL.String()) -} - -func (h *AuthorizationCodeOAuthHandler) FinalizeAuthorization(code, state string) error { - defer func() { - // State has been used for validation, clear it. - h.state = "" - }() - if state != h.state { - return fmt.Errorf("state mismatch: expected %q, got %q", h.state, state) - } - h.authorizationCode = code - return nil + return h.exchangeAuthorizationCode(ctx, cfg, authRes, resourceURL) } // TODO: validate on creation. @@ -236,9 +197,6 @@ func (h *AuthorizationCodeOAuthHandler) validate() error { return fmt.Errorf("redirect URI %q is not in the list of allowed redirect URIs for dynamic client registration", h.RedirectURL) } } - if h.resolvedClientConfig == nil && h.authorizationCode != "" { - return fmt.Errorf("exchanging authorization code with unregistered client is not allowed") - } return nil } @@ -287,95 +245,125 @@ func (h *AuthorizationCodeOAuthHandler) getAuthServerMetadata(ctx context.Contex return asm, nil } +type registrationType int + +const ( + registrationTypeClientIDMetadataDocument registrationType = iota + registrationTypePreregistered + registrationTypeDynamic +) + +type resolvedClientConfig struct { + registrationType registrationType + clientID string + clientSecret string + authStyle oauth2.AuthStyle +} + // handleRegistration handles client registration. // The provided authorization server metadata must be non-nil. -// It must also have RegistrationEndpoint set if dynamic client registration is supported. -func (h *AuthorizationCodeOAuthHandler) handleRegistration(ctx context.Context, asm *oauthex.AuthServerMeta) error { +// Support for different registration methods is defined as follows: +// - Client ID Metadata Document: metadata must have +// `ClientIDMetadataDocumentSupported` set to true. +// - Pre-registered client: assumed to be supported. +// - Dynamic client registration: metadata must have +// `RegistrationEndpoint` set to a non-empty value. +func (h *AuthorizationCodeOAuthHandler) handleRegistration(ctx context.Context, asm *oauthex.AuthServerMeta) (*resolvedClientConfig, error) { // 1. Attempt to use Client ID Metadata Document (SEP-991). cimdCfg := h.ClientIDMetadataDocumentConfig if cimdCfg != nil && asm.ClientIDMetadataDocumentSupported { - h.resolvedClientConfig = &resolvedClientConfig{ + return &resolvedClientConfig{ registrationType: registrationTypeClientIDMetadataDocument, clientID: cimdCfg.URL, - } - return nil + }, nil } // 2. Attempt to use pre-registered client configuration. pCfg := h.PreregisteredClientConfig if pCfg != nil { - h.resolvedClientConfig = &resolvedClientConfig{ + return &resolvedClientConfig{ registrationType: registrationTypePreregistered, clientID: pCfg.ClientID, clientSecret: pCfg.ClientSecret, authStyle: pCfg.AuthStyle, - } - return nil + }, nil } // 3. Attempt to use dynamic client registration. dcrCfg := h.DynamicClientRegistrationConfig if dcrCfg != nil && asm.RegistrationEndpoint != "" { regResp, err := oauthex.RegisterClient(ctx, asm.RegistrationEndpoint, dcrCfg.Metadata, http.DefaultClient) if err != nil { - return fmt.Errorf("failed to register client: %w", err) + return nil, fmt.Errorf("failed to register client: %w", err) } - h.resolvedClientConfig = &resolvedClientConfig{ + cfg := &resolvedClientConfig{ registrationType: registrationTypeDynamic, clientID: regResp.ClientID, clientSecret: regResp.ClientSecret, } switch regResp.TokenEndpointAuthMethod { case "client_secret_post": - h.resolvedClientConfig.authStyle = oauth2.AuthStyleInParams + cfg.authStyle = oauth2.AuthStyleInParams case "client_secret_basic": - h.resolvedClientConfig.authStyle = oauth2.AuthStyleInHeader + cfg.authStyle = oauth2.AuthStyleInHeader case "none": // "none" is equivalent to "client_secret_post" but without sending client secret. - h.resolvedClientConfig.authStyle = oauth2.AuthStyleInParams - h.resolvedClientConfig.clientSecret = "" + cfg.authStyle = oauth2.AuthStyleInParams + cfg.clientSecret = "" default: // We leave the AuthStyle set to zero value, which is auto-detection. } log.Printf("Client registered with client ID: %s", regResp.ClientID) - return nil + return cfg, nil } - return fmt.Errorf("no configured client registration methods are supported by the authorization server") + return nil, fmt.Errorf("no configured client registration methods are supported by the authorization server") } -// exchangeAuthorizationCode exchanges the authorization code for a token and stores it in a token source. -func (h *AuthorizationCodeOAuthHandler) exchangeAuthorizationCode(ctx context.Context, cfg *oauth2.Config, resourceURL string) error { - log.Print("Authorization code is available, exchanging for token") - opts := []oauth2.AuthCodeOption{ - oauth2.VerifierOption(h.codeVerifier), - oauth2.SetAuthURLParam("resource", resourceURL), - } - defer func() { - // Authorization code has been consumed, clear it. - h.authorizationCode = "" - }() - token, err := cfg.Exchange(ctx, h.authorizationCode, opts...) - if err != nil { - return fmt.Errorf("token exchange failed: %w", err) - } - h.tokenSource = cfg.TokenSource(ctx, token) - return nil +type authResult struct { + *AuthorizationResult + // usedCodeVerifier is the PKCE code verifier used to obtain the authorization code. + // It is preserved for the token exchange step. + usedCodeVerifier string } -// startAuthFlow generates the authorization URL and redirects the user to the authorization server. -func (h *AuthorizationCodeOAuthHandler) startAuthFlow(ctx context.Context, cfg *oauth2.Config, resourceURL string) error { - h.codeVerifier = oauth2.GenerateVerifier() - h.state = rand.Text() +// getAuthorizationCode uses the [AuthorizationCodeOAuthHandler.AuthorizationURLHandler] +// to obtain an authorization code. +func (h *AuthorizationCodeOAuthHandler) getAuthorizationCode(ctx context.Context, cfg *oauth2.Config, resourceURL string) (*authResult, error) { + codeVerifier := oauth2.GenerateVerifier() + state := rand.Text() if h.StateProvider != nil { - h.state = h.StateProvider() + state = h.StateProvider() } - authURL := cfg.AuthCodeURL(h.state, - oauth2.S256ChallengeOption(h.codeVerifier), + authURL := cfg.AuthCodeURL(state, + oauth2.S256ChallengeOption(codeVerifier), oauth2.SetAuthURLParam("resource", resourceURL), ) - log.Print("No authorization code available, opening authorization URL") - if err := h.AuthorizationURLHandler(ctx, authURL); err != nil { - return fmt.Errorf("authorization URL handler failed: %w", err) + log.Printf("Calling AuthorizationURLHandler: %q", authURL) + authRes, err := h.AuthorizationURLHandler(ctx, authURL) + if err != nil { + return nil, err } - return ErrRedirected + if authRes.State != state { + return nil, fmt.Errorf("state mismatch") + } + return &authResult{ + AuthorizationResult: authRes, + usedCodeVerifier: codeVerifier, + }, nil +} + +// exchangeAuthorizationCode exchanges the authorization code for a token +// and stores it in a token source. +func (h *AuthorizationCodeOAuthHandler) exchangeAuthorizationCode(ctx context.Context, cfg *oauth2.Config, authResult *authResult, resourceURL string) error { + log.Printf("Exchanging authorization code for token") + opts := []oauth2.AuthCodeOption{ + oauth2.VerifierOption(authResult.usedCodeVerifier), + oauth2.SetAuthURLParam("resource", resourceURL), + } + token, err := cfg.Exchange(ctx, authResult.AuthorizationCode, opts...) + if err != nil { + return fmt.Errorf("token exchange failed: %w", err) + } + h.tokenSource = cfg.TokenSource(ctx, token) + return nil } diff --git a/conformance/everything-client/client_private.go b/conformance/everything-client/client_private.go index c7ed65f6..5e5cf725 100644 --- a/conformance/everything-client/client_private.go +++ b/conformance/everything-client/client_private.go @@ -12,7 +12,6 @@ package main import ( "context" - "errors" "fmt" "net/http" "net/url" @@ -51,7 +50,7 @@ func init() { // Auth scenarios // ============================================================================ -func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (code, state string, err error) { +func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (*auth.AuthorizationResult, error) { client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse @@ -59,35 +58,38 @@ func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (code, } req, err := http.NewRequestWithContext(ctx, "GET", authURL, nil) if err != nil { - return "", "", err + return nil, err } resp, err := client.Do(req) if err != nil { - return "", "", err + return nil, err } defer resp.Body.Close() location := resp.Header.Get("Location") if location == "" { - return "", "", fmt.Errorf("no Location header in redirect") + return nil, fmt.Errorf("no Location header in redirect") } locURL, err := url.Parse(location) if err != nil { - return "", "", fmt.Errorf("parse location: %v", err) + return nil, fmt.Errorf("parse location: %v", err) } - code = locURL.Query().Get("code") + code := locURL.Query().Get("code") if code == "" { - return "", "", fmt.Errorf("no code parameter in redirect URL") + return nil, fmt.Errorf("no code parameter in redirect URL") } - state = locURL.Query().Get("state") + state := locURL.Query().Get("state") if state == "" { - return "", "", fmt.Errorf("no state parameter in redirect URL") + return nil, fmt.Errorf("no state parameter in redirect URL") } - return code, state, nil + return &auth.AuthorizationResult{ + AuthorizationCode: code, + State: state, + }, nil } func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]any) error { @@ -114,47 +116,16 @@ func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]a } } - authHandler.AuthorizationURLHandler = func(ctx context.Context, authURL string) error { - // Normally this handler would trigger user browser to be opened. - // Here we query the authorization URL automatically and the AS is configured - // to authorize and redirect immediately. We save the resulting code. - code, state, err := fetchAuthorizationCodeAndState(ctx, authURL) - if err != nil { - return err - } - if err := authHandler.FinalizeAuthorization(code, state); err != nil { - return err - } - return nil - } + authHandler.AuthorizationURLHandler = fetchAuthorizationCodeAndState session, err := connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) if err != nil { - if !errors.Is(err, auth.ErrRedirected) { - return err - } - // Received auth.ErrRedirected. Normally we would wait for the callback triggered - // by the AS redirect to RedirectURL, but here we already have the authorization code - // so we can immediately retry. - session, err = connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) - if err != nil { - return nil - } + return err } defer session.Close() if _, err := session.ListTools(ctx, nil); err != nil { - // Retry for the scope step-up scenario. - if !errors.Is(err, auth.ErrRedirected) { - return fmt.Errorf("session.ListTools(): %v", err) - } - // Received auth.ErrRedirected. Normally we would wait for the callback triggered - // by the AS redirect to RedirectURL, but here we already have the authorization code - // so we can immediately retry. - _, err = session.ListTools(ctx, nil) - if err != nil { - return fmt.Errorf("session.ListTools(): %v", err) - } + return fmt.Errorf("session.ListTools(): %v", err) } if _, err := session.CallTool(ctx, &mcp.CallToolParams{ diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go index d9d4a5e8..972f611c 100644 --- a/examples/auth/client/main.go +++ b/examples/auth/client/main.go @@ -12,29 +12,28 @@ import ( "flag" "fmt" "log" + "net" "net/http" "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/mcp" ) -// Flags. var ( - serverURL = flag.String("server_url", "http://localhost:8000/mcp", "Server URL") + // URL of the MCP server. + serverURL = flag.String("server_url", "http://localhost:8000/mcp", "URL of the MCP server.") + // Port for the local HTTP server that will receive the authorization code. + callbackPort = flag.Int("callback_port", 3142, "Port for the local HTTP server that will receive the authorization code.") ) -type authResult struct { - code string - state string - err error -} - type codeReceiver struct { - authChan chan authResult + authChan chan *auth.AuthorizationResult + errChan chan error + listener net.Listener server *http.Server } -func (r *codeReceiver) startAuthorizationFlow(ctx context.Context, authorizationURL string) error { +func (r *codeReceiver) serveRedirectHandler(listener net.Listener) error { mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { code := req.URL.Query().Get("code") @@ -44,27 +43,38 @@ func (r *codeReceiver) startAuthorizationFlow(ctx context.Context, authorization return } - r.authChan <- authResult{ - code: code, - state: state, + r.authChan <- &auth.AuthorizationResult{ + AuthorizationCode: code, + State: state, } fmt.Fprint(w, "Authentication successful. You can close this window.") }) r.server = &http.Server{ - Addr: "localhost:3142", + Addr: fmt.Sprintf("localhost:%d", *callbackPort), Handler: mux, } + if err := r.server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + r.errChan <- err + } + return nil +} - go func() { - // We ignore ErrServerClosed as it is returned on Shutdown. - if err := r.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - r.authChan <- authResult{err: fmt.Errorf("server error: %w", err)} - } - }() +func (r *codeReceiver) getAuthorizationCode(ctx context.Context, authorizationURL string) (*auth.AuthorizationResult, error) { + select { + case authRes := <-r.authChan: + return authRes, nil + case err := <-r.errChan: + return nil, err + case <-ctx.Done(): + return nil, ctx.Err() + } +} - fmt.Printf("Please authorize by visiting: %s\n", authorizationURL) - return nil +func (r *codeReceiver) close() { + if r.server != nil { + r.server.Close() + } } func main() { @@ -75,11 +85,18 @@ func main() { }, nil) receiver := &codeReceiver{ - authChan: make(chan authResult), + authChan: make(chan *auth.AuthorizationResult), + errChan: make(chan error), + } + listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *callbackPort)) + if err != nil { + log.Fatalf("failed to listen: %v", err) } + go receiver.serveRedirectHandler(listener) + defer receiver.close() authHandler := &auth.AuthorizationCodeOAuthHandler{ - RedirectURL: "http://localhost:3142", + RedirectURL: fmt.Sprintf("http://localhost:%d", *callbackPort), // Uncomment the client configuration you want to use. // PreregisteredClientConfig: &auth.PreregisteredClientConfig{ // ClientID: "", @@ -88,11 +105,11 @@ func main() { // DynamicClientRegistrationConfig: &auth.DynamicClientRegistrationConfig{ // Metadata: &oauthex.ClientRegistrationMetadata{ // ClientName: "Dynamically registered MCP client", - // RedirectURIs: []string{"http://localhost:3142"}, + // RedirectURIs: []string{fmt.Sprintf("http://localhost:%d", *callbackPort)}, // Scope: "read", // }, // }, - AuthorizationURLHandler: receiver.startAuthorizationFlow, + AuthorizationURLHandler: receiver.getAuthorizationCode, } transport := &mcp.StreamableClientTransport{ @@ -101,33 +118,9 @@ func main() { } ctx := context.Background() - var session *mcp.ClientSession - var err error - for { - session, err = client.Connect(ctx, transport, nil) - if err == nil { - break - } - // If the error is ErrRedirected, it means the authorization flow has started - // and we need to wait for the code. - if errors.Is(err, auth.ErrRedirected) { - fmt.Println("Authorization flow started. Waiting for authorization code...") - res := <-receiver.authChan - if res.err != nil { - log.Fatalf("Authorization failed: %v", res.err) - } - - // Shutdown the temporary server - if err := receiver.server.Shutdown(ctx); err != nil { - log.Printf("Failed to shutdown server: %v", err) - } - - if err := authHandler.FinalizeAuthorization(res.code, res.state); err != nil { - log.Fatalf("Failed to finalize authorization: %v", err) - } - continue - } + session, err := client.Connect(ctx, transport, nil) + if err != nil { log.Fatalf("client.Connect(): %v", err) } defer session.Close() From d1b6bfd56f13ec052c1eecf1ad3a6ae885302cd9 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Mon, 23 Feb 2026 15:31:31 +0000 Subject: [PATCH 10/26] Document client-side OAuth. --- docs/protocol.md | 56 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/docs/protocol.md b/docs/protocol.md index 16ba0bfa..a859b5b8 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -306,9 +306,50 @@ The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/ ### Client -Client-side OAuth is implemented by setting -[`StreamableClientTransport.HTTPClient`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk@v0.5.0/mcp#StreamableClientTransport.HTTPClient) to a custom [`http.Client`](https://pkg.go.dev/net/http#Client) -Additional support is forthcoming; see modelcontextprotocol/go-sdk#493. +Client-side authorization is supported via the +[`StreamableClientTransport.OAuthHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableClientTransport.OAuthHandler) +field. If the handler is provided, the transport will automatically use it to +add an `Authorization: Bearer ` header to every request. The transport +will also call the handler's `Authorize` method if the server returns +`401 Unauthorized` or `403 Forbidden` errors to perform the authorization flow +or facilitate scope step-up authorization. + +The SDK implements the Authorization Code flow in +[`auth.AuthorizationCodeOAuthHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationCodeOAuthHandler). +This handler supports: + +- [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) +- Pre-registered clients +- [Dynamic Client Registration](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration) + +To use it, configure the handler and assign it to the transport: + +```go +authHandler := &auth.AuthorizationCodeOAuthHandler{ + RedirectURL: "https://myapp.com/oauth2-callback", + // Configure one of the following: + // ClientIDMetadataDocumentConfig: ... + // PreregisteredClientConfig: ... + // DynamicClientRegistrationConfig: ... + AuthorizationURLHandler: func(ctx context.Context, url string) (*auth.AuthorizationResult, error) { + // Open the URL in a browser and return the resulting code and state. + // See full example in examples/auth/client/main.go. + code := ... + state := ... + return &auth.AuthorizationResult{Code: code, State: state}, nil + }, +} + +transport := &mcp.StreamableClientTransport{ + Endpoint: "https://example.com/mcp", + OAuthHandler: authHandler, +} +client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) +session, err := client.Connect(ctx, transport, nil) +``` + +The handler automatically manages token exchange, token refreshing, and step-up +authentication (when the server returns `insufficient_scope` error). ## Security @@ -317,9 +358,12 @@ the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specifi ### Confused Deputy -The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), obtaining user consent for dynamically registered clients, -happens on the MCP client. At present we don't provide client-side OAuth support. - +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), +obtaining user consent for dynamically registered clients, is mostly the +responsibility of the MCP Proxy server implementation. The SDK client does +generate cryptographically secure random `state` values for each authorization +request by default and validates them when the authorization code is returned. +Mismatched state values will result in an error. ### Token Passthrough From f95953f54cead253033b1054c33e798206d55778 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Mon, 23 Feb 2026 15:34:12 +0000 Subject: [PATCH 11/26] Bring back issuer validation. Conformance tests have been fixed. --- oauthex/auth_meta.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index 44d7359d..e6aab283 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -144,11 +144,10 @@ func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (* return nil, fmt.Errorf("%v", err) // Do not expose wrapped errors. } } - // TODO: causes conformance test failure, filed https://github.com/modelcontextprotocol/conformance/issues/140. - // if asm.Issuer != issuerURL { - // // Validate the Issuer field (see RFC 8414, section 3.3). - // return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) - // } + if asm.Issuer != issuerURL { + // Validate the Issuer field (see RFC 8414, section 3.3). + return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) + } if len(asm.CodeChallengeMethodsSupported) == 0 { return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuerURL) From 2f9d4d5db029cd15cb36bae387055387e3b3a789 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Mon, 23 Feb 2026 15:36:27 +0000 Subject: [PATCH 12/26] Documentation fix. --- docs/protocol.md | 4 +-- internal/docs/protocol.src.md | 56 +++++++++++++++++++++++++++++++---- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/docs/protocol.md b/docs/protocol.md index a859b5b8..d69d252e 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -348,8 +348,8 @@ client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, session, err := client.Connect(ctx, transport, nil) ``` -The handler automatically manages token exchange, token refreshing, and step-up -authentication (when the server returns `insufficient_scope` error). +The `auth.AuthorizationCodeOAuthHandler` automatically manages token refreshing +and step-up authentication (when the server returns `insufficient_scope` error). ## Security diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index ada34371..0e533a49 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -232,9 +232,50 @@ The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/ ### Client -Client-side OAuth is implemented by setting -[`StreamableClientTransport.HTTPClient`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk@v0.5.0/mcp#StreamableClientTransport.HTTPClient) to a custom [`http.Client`](https://pkg.go.dev/net/http#Client) -Additional support is forthcoming; see modelcontextprotocol/go-sdk#493. +Client-side authorization is supported via the +[`StreamableClientTransport.OAuthHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableClientTransport.OAuthHandler) +field. If the handler is provided, the transport will automatically use it to +add an `Authorization: Bearer ` header to every request. The transport +will also call the handler's `Authorize` method if the server returns +`401 Unauthorized` or `403 Forbidden` errors to perform the authorization flow +or facilitate scope step-up authorization. + +The SDK implements the Authorization Code flow in +[`auth.AuthorizationCodeOAuthHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationCodeOAuthHandler). +This handler supports: + +- [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) +- Pre-registered clients +- [Dynamic Client Registration](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration) + +To use it, configure the handler and assign it to the transport: + +```go +authHandler := &auth.AuthorizationCodeOAuthHandler{ + RedirectURL: "https://myapp.com/oauth2-callback", + // Configure one of the following: + // ClientIDMetadataDocumentConfig: ... + // PreregisteredClientConfig: ... + // DynamicClientRegistrationConfig: ... + AuthorizationURLHandler: func(ctx context.Context, url string) (*auth.AuthorizationResult, error) { + // Open the URL in a browser and return the resulting code and state. + // See full example in examples/auth/client/main.go. + code := ... + state := ... + return &auth.AuthorizationResult{Code: code, State: state}, nil + }, +} + +transport := &mcp.StreamableClientTransport{ + Endpoint: "https://example.com/mcp", + OAuthHandler: authHandler, +} +client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) +session, err := client.Connect(ctx, transport, nil) +``` + +The `auth.AuthorizationCodeOAuthHandler` automatically manages token refreshing +and step-up authentication (when the server returns `insufficient_scope` error). ## Security @@ -243,9 +284,12 @@ the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specifi ### Confused Deputy -The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), obtaining user consent for dynamically registered clients, -happens on the MCP client. At present we don't provide client-side OAuth support. - +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), +obtaining user consent for dynamically registered clients, is mostly the +responsibility of the MCP Proxy server implementation. The SDK client does +generate cryptographically secure random `state` values for each authorization +request by default and validates them when the authorization code is returned. +Mismatched state values will result in an error. ### Token Passthrough From 57e5b2d6cf68be96f9f3912cdc632d57ca206951 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Mon, 23 Feb 2026 16:12:41 +0000 Subject: [PATCH 13/26] Introduce a constructor for the handler and validate on creation. --- auth/authorization_code.go | 97 ++++++++++--------- .../everything-client/client_private.go | 12 ++- examples/auth/client/main.go | 9 +- 3 files changed, 66 insertions(+), 52 deletions(-) diff --git a/auth/authorization_code.go b/auth/authorization_code.go index d24fc1f8..e2d0e52b 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -58,9 +58,8 @@ type AuthorizationResult struct { State string } -// AuthorizationCodeOAuthHandler is an implementation of [OAuthHandler] that uses -// the authorization code flow to obtain access tokens. -type AuthorizationCodeOAuthHandler struct { +// AuthorizationCodeHandlerConfig is the configuration for [AuthorizationCodeOAuthHandler]. +type AuthorizationCodeHandlerConfig struct { // Client registration configuration. // It is attempted in the following order: // @@ -89,6 +88,12 @@ type AuthorizationCodeOAuthHandler struct { // requests. If not provided, a random string will be generated. // The state will be validated on the redirect callback. StateProvider func() string +} + +// AuthorizationCodeOAuthHandler is an implementation of [OAuthHandler] that uses +// the authorization code flow to obtain access tokens. +type AuthorizationCodeOAuthHandler struct { + config *AuthorizationCodeHandlerConfig // tokenSource is the token source to use for authorization. tokenSource oauth2.TokenSource @@ -102,15 +107,49 @@ func (h *AuthorizationCodeOAuthHandler) TokenSource(ctx context.Context) (oauth2 return h.tokenSource, nil } +// NewAuthorizationCodeOAuthHandler creates a new AuthorizationCodeOAuthHandler. +// It performs validation of the configuration and returns an error if it is invalid. +// The passed config is consumed by the handler and should not be modified after. +func NewAuthorizationCodeOAuthHandler(config *AuthorizationCodeHandlerConfig) (*AuthorizationCodeOAuthHandler, error) { + if config == nil { + return nil, errors.New("config must be provided") + } + if config.ClientIDMetadataDocumentConfig == nil && + config.PreregisteredClientConfig == nil && + config.DynamicClientRegistrationConfig == nil { + return nil, errors.New("at least one client registration configuration must be provided") + } + if config.RedirectURL == "" { + return nil, errors.New("field RedirectURL is required") + } + if config.AuthorizationURLHandler == nil { + return nil, errors.New("field AuthorizationURLHandler is required") + } + if config.ClientIDMetadataDocumentConfig != nil && !isNonRootHTTPSURL(config.ClientIDMetadataDocumentConfig.URL) { + return nil, fmt.Errorf("client ID metadata document URL must be a non-root HTTPS URL") + } + if config.PreregisteredClientConfig != nil { + if config.PreregisteredClientConfig.ClientID == "" || config.PreregisteredClientConfig.ClientSecret == "" { + return nil, fmt.Errorf("pre-registered client ID or secret is empty") + } + } + if config.DynamicClientRegistrationConfig != nil { + if config.DynamicClientRegistrationConfig.Metadata == nil { + return nil, errors.New("field Metadata is required for dynamic client registration") + } + if !slices.Contains(config.DynamicClientRegistrationConfig.Metadata.RedirectURIs, config.RedirectURL) { + return nil, fmt.Errorf("redirect URI %q is not in the list of allowed redirect URIs for dynamic client registration", config.RedirectURL) + } + } + return &AuthorizationCodeOAuthHandler{config: config}, nil +} + // Authorize performs the authorization flow. // It is designed to perform the whole Authorization Code Grant flow. // On success, [AuthorizationCodeOAuthHandler.TokenSource] will return a token source with the fetched token. func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { defer resp.Body.Close() log.Printf("Authorize: %s %s", req.Method, req.URL) - if err := h.validate(); err != nil { - return err - } resourceURL := req.URL.String() wwwChallenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) @@ -156,7 +195,7 @@ func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http // TODO: validate if the auth style is supported by the AS. AuthStyle: resolvedClientConfig.authStyle, }, - RedirectURL: h.RedirectURL, + RedirectURL: h.config.RedirectURL, Scopes: scopes, } @@ -168,38 +207,6 @@ func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http return h.exchangeAuthorizationCode(ctx, cfg, authRes, resourceURL) } -// TODO: validate on creation. -func (h *AuthorizationCodeOAuthHandler) validate() error { - if h.ClientIDMetadataDocumentConfig == nil && - h.PreregisteredClientConfig == nil && - h.DynamicClientRegistrationConfig == nil { - return errors.New("at least one client registration configuration must be provided") - } - if h.RedirectURL == "" { - return errors.New("field RedirectURL is required") - } - if h.AuthorizationURLHandler == nil { - return errors.New("field AuthorizationURLHandler is required") - } - if h.ClientIDMetadataDocumentConfig != nil && !isNonRootHTTPSURL(h.ClientIDMetadataDocumentConfig.URL) { - return fmt.Errorf("client ID metadata document URL must be a non-root HTTPS URL") - } - if h.PreregisteredClientConfig != nil { - if h.PreregisteredClientConfig.ClientID == "" || h.PreregisteredClientConfig.ClientSecret == "" { - return fmt.Errorf("pre-registered client ID or secret is empty") - } - } - if h.DynamicClientRegistrationConfig != nil { - if h.DynamicClientRegistrationConfig.Metadata == nil { - return errors.New("field Metadata is required for dynamic client registration") - } - if !slices.Contains(h.DynamicClientRegistrationConfig.Metadata.RedirectURIs, h.RedirectURL) { - return fmt.Errorf("redirect URI %q is not in the list of allowed redirect URIs for dynamic client registration", h.RedirectURL) - } - } - return nil -} - func isNonRootHTTPSURL(u string) bool { pu, err := url.Parse(u) if err != nil { @@ -270,7 +277,7 @@ type resolvedClientConfig struct { // `RegistrationEndpoint` set to a non-empty value. func (h *AuthorizationCodeOAuthHandler) handleRegistration(ctx context.Context, asm *oauthex.AuthServerMeta) (*resolvedClientConfig, error) { // 1. Attempt to use Client ID Metadata Document (SEP-991). - cimdCfg := h.ClientIDMetadataDocumentConfig + cimdCfg := h.config.ClientIDMetadataDocumentConfig if cimdCfg != nil && asm.ClientIDMetadataDocumentSupported { return &resolvedClientConfig{ registrationType: registrationTypeClientIDMetadataDocument, @@ -278,7 +285,7 @@ func (h *AuthorizationCodeOAuthHandler) handleRegistration(ctx context.Context, }, nil } // 2. Attempt to use pre-registered client configuration. - pCfg := h.PreregisteredClientConfig + pCfg := h.config.PreregisteredClientConfig if pCfg != nil { return &resolvedClientConfig{ registrationType: registrationTypePreregistered, @@ -288,7 +295,7 @@ func (h *AuthorizationCodeOAuthHandler) handleRegistration(ctx context.Context, }, nil } // 3. Attempt to use dynamic client registration. - dcrCfg := h.DynamicClientRegistrationConfig + dcrCfg := h.config.DynamicClientRegistrationConfig if dcrCfg != nil && asm.RegistrationEndpoint != "" { regResp, err := oauthex.RegisterClient(ctx, asm.RegistrationEndpoint, dcrCfg.Metadata, http.DefaultClient) if err != nil { @@ -329,8 +336,8 @@ type authResult struct { func (h *AuthorizationCodeOAuthHandler) getAuthorizationCode(ctx context.Context, cfg *oauth2.Config, resourceURL string) (*authResult, error) { codeVerifier := oauth2.GenerateVerifier() state := rand.Text() - if h.StateProvider != nil { - state = h.StateProvider() + if h.config.StateProvider != nil { + state = h.config.StateProvider() } authURL := cfg.AuthCodeURL(state, @@ -339,7 +346,7 @@ func (h *AuthorizationCodeOAuthHandler) getAuthorizationCode(ctx context.Context ) log.Printf("Calling AuthorizationURLHandler: %q", authURL) - authRes, err := h.AuthorizationURLHandler(ctx, authURL) + authRes, err := h.config.AuthorizationURLHandler(ctx, authURL) if err != nil { return nil, err } diff --git a/conformance/everything-client/client_private.go b/conformance/everything-client/client_private.go index 5e5cf725..86855197 100644 --- a/conformance/everything-client/client_private.go +++ b/conformance/everything-client/client_private.go @@ -93,8 +93,9 @@ func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (*auth. } func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]any) error { - authHandler := &auth.AuthorizationCodeOAuthHandler{ - RedirectURL: "http://localhost:3000/callback", + authConfig := &auth.AuthorizationCodeHandlerConfig{ + RedirectURL: "http://localhost:3000/callback", + AuthorizationURLHandler: fetchAuthorizationCodeAndState, // Try client ID metadata document based registration. ClientIDMetadataDocumentConfig: &auth.ClientIDMetadataDocumentConfig{ URL: "https://conformance-test.local/client-metadata.json", @@ -109,14 +110,17 @@ func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]a // Try pre-registered client information if provided in the context. if clientId, ok := configCtx["client_id"].(string); ok { if clientSecret, ok := configCtx["client_secret"].(string); ok { - authHandler.PreregisteredClientConfig = &auth.PreregisteredClientConfig{ + authConfig.PreregisteredClientConfig = &auth.PreregisteredClientConfig{ ClientID: clientId, ClientSecret: clientSecret, } } } - authHandler.AuthorizationURLHandler = fetchAuthorizationCodeAndState + authHandler, err := auth.NewAuthorizationCodeOAuthHandler(authConfig) + if err != nil { + return fmt.Errorf("failed to create auth handler: %w", err) + } session, err := connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) if err != nil { diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go index 972f611c..53db7909 100644 --- a/examples/auth/client/main.go +++ b/examples/auth/client/main.go @@ -95,8 +95,9 @@ func main() { go receiver.serveRedirectHandler(listener) defer receiver.close() - authHandler := &auth.AuthorizationCodeOAuthHandler{ - RedirectURL: fmt.Sprintf("http://localhost:%d", *callbackPort), + authHandler, err := auth.NewAuthorizationCodeOAuthHandler(&auth.AuthorizationCodeHandlerConfig{ + RedirectURL: fmt.Sprintf("http://localhost:%d", *callbackPort), + AuthorizationURLHandler: receiver.getAuthorizationCode, // Uncomment the client configuration you want to use. // PreregisteredClientConfig: &auth.PreregisteredClientConfig{ // ClientID: "", @@ -109,7 +110,9 @@ func main() { // Scope: "read", // }, // }, - AuthorizationURLHandler: receiver.getAuthorizationCode, + }) + if err != nil { + log.Fatalf("failed to create auth handler: %v", err) } transport := &mcp.StreamableClientTransport{ From 46231599cb230e5163d004649506740df3b14332 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Tue, 24 Feb 2026 08:59:53 +0000 Subject: [PATCH 14/26] Remove StateProvider for now, rename handler to AuthorizationCodeHandler. --- auth/authorization_code.go | 44 ++++++++----------- .../everything-client/client_private.go | 4 +- docs/protocol.md | 10 ++--- examples/auth/client/main.go | 2 +- internal/docs/protocol.src.md | 10 ++--- 5 files changed, 32 insertions(+), 38 deletions(-) diff --git a/auth/authorization_code.go b/auth/authorization_code.go index e2d0e52b..7b637539 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -50,7 +50,7 @@ type DynamicClientRegistrationConfig struct { } // AuthorizationResult is the result of an authorization flow. -// It is returned by [AuthorizationCodeOAuthHandler.AuthorizationURLHandler] implementations. +// It is returned by [AuthorizationCodeHandler.AuthorizationURLHandler] implementations. type AuthorizationResult struct { // AuthorizationCode is the authorization code obtained from the authorization server. AuthorizationCode string @@ -58,7 +58,7 @@ type AuthorizationResult struct { State string } -// AuthorizationCodeHandlerConfig is the configuration for [AuthorizationCodeOAuthHandler]. +// AuthorizationCodeHandlerConfig is the configuration for [AuthorizationCodeHandler]. type AuthorizationCodeHandlerConfig struct { // Client registration configuration. // It is attempted in the following order: @@ -81,36 +81,31 @@ type AuthorizationCodeHandlerConfig struct { // AuthorizationURLHandler is a required function called to handle the authorization URL. // It is responsible for opening the URL in a browser for the user to start the authorization. // It should return the authorization code and state once the Authorization Server - // redirects back to the [AuthorizationCodeOAuthHandler.RedirectURL]. + // redirects back to the [AuthorizationCodeHandler.RedirectURL]. AuthorizationURLHandler func(ctx context.Context, authorizationURL string) (*AuthorizationResult, error) - - // StateProvider is an optional function to generate a state string for authorization - // requests. If not provided, a random string will be generated. - // The state will be validated on the redirect callback. - StateProvider func() string } -// AuthorizationCodeOAuthHandler is an implementation of [OAuthHandler] that uses +// AuthorizationCodeHandler is an implementation of [OAuthHandler] that uses // the authorization code flow to obtain access tokens. -type AuthorizationCodeOAuthHandler struct { +type AuthorizationCodeHandler struct { config *AuthorizationCodeHandlerConfig // tokenSource is the token source to use for authorization. tokenSource oauth2.TokenSource } -var _ OAuthHandler = (*AuthorizationCodeOAuthHandler)(nil) +var _ OAuthHandler = (*AuthorizationCodeHandler)(nil) -func (h *AuthorizationCodeOAuthHandler) isOAuthHandler() {} +func (h *AuthorizationCodeHandler) isOAuthHandler() {} -func (h *AuthorizationCodeOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { +func (h *AuthorizationCodeHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { return h.tokenSource, nil } -// NewAuthorizationCodeOAuthHandler creates a new AuthorizationCodeOAuthHandler. +// NewAuthorizationCodeHandler creates a new AuthorizationCodeHandler. // It performs validation of the configuration and returns an error if it is invalid. // The passed config is consumed by the handler and should not be modified after. -func NewAuthorizationCodeOAuthHandler(config *AuthorizationCodeHandlerConfig) (*AuthorizationCodeOAuthHandler, error) { +func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*AuthorizationCodeHandler, error) { if config == nil { return nil, errors.New("config must be provided") } @@ -141,13 +136,13 @@ func NewAuthorizationCodeOAuthHandler(config *AuthorizationCodeHandlerConfig) (* return nil, fmt.Errorf("redirect URI %q is not in the list of allowed redirect URIs for dynamic client registration", config.RedirectURL) } } - return &AuthorizationCodeOAuthHandler{config: config}, nil + return &AuthorizationCodeHandler{config: config}, nil } // Authorize performs the authorization flow. // It is designed to perform the whole Authorization Code Grant flow. -// On success, [AuthorizationCodeOAuthHandler.TokenSource] will return a token source with the fetched token. -func (h *AuthorizationCodeOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { +// On success, [AuthorizationCodeHandler.TokenSource] will return a token source with the fetched token. +func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { defer resp.Body.Close() log.Printf("Authorize: %s %s", req.Method, req.URL) @@ -218,7 +213,7 @@ func isNonRootHTTPSURL(u string) bool { // getAuthServerMetadata returns the authorization server metadata. // If no metadata is available, it returns a minimal set of endpoints // as a fallback to 2025-03-26 spec. -func (h *AuthorizationCodeOAuthHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata, resourceURL string) (*oauthex.AuthServerMeta, error) { +func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata, resourceURL string) (*oauthex.AuthServerMeta, error) { var authServerURL string if prm != nil && len(prm.AuthorizationServers) > 0 { // Use the first authorization server, similarly to other SDKs. @@ -275,7 +270,7 @@ type resolvedClientConfig struct { // - Pre-registered client: assumed to be supported. // - Dynamic client registration: metadata must have // `RegistrationEndpoint` set to a non-empty value. -func (h *AuthorizationCodeOAuthHandler) handleRegistration(ctx context.Context, asm *oauthex.AuthServerMeta) (*resolvedClientConfig, error) { +func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm *oauthex.AuthServerMeta) (*resolvedClientConfig, error) { // 1. Attempt to use Client ID Metadata Document (SEP-991). cimdCfg := h.config.ClientIDMetadataDocumentConfig if cimdCfg != nil && asm.ClientIDMetadataDocumentSupported { @@ -331,14 +326,11 @@ type authResult struct { usedCodeVerifier string } -// getAuthorizationCode uses the [AuthorizationCodeOAuthHandler.AuthorizationURLHandler] +// getAuthorizationCode uses the [AuthorizationCodeHandler.AuthorizationURLHandler] // to obtain an authorization code. -func (h *AuthorizationCodeOAuthHandler) getAuthorizationCode(ctx context.Context, cfg *oauth2.Config, resourceURL string) (*authResult, error) { +func (h *AuthorizationCodeHandler) getAuthorizationCode(ctx context.Context, cfg *oauth2.Config, resourceURL string) (*authResult, error) { codeVerifier := oauth2.GenerateVerifier() state := rand.Text() - if h.config.StateProvider != nil { - state = h.config.StateProvider() - } authURL := cfg.AuthCodeURL(state, oauth2.S256ChallengeOption(codeVerifier), @@ -361,7 +353,7 @@ func (h *AuthorizationCodeOAuthHandler) getAuthorizationCode(ctx context.Context // exchangeAuthorizationCode exchanges the authorization code for a token // and stores it in a token source. -func (h *AuthorizationCodeOAuthHandler) exchangeAuthorizationCode(ctx context.Context, cfg *oauth2.Config, authResult *authResult, resourceURL string) error { +func (h *AuthorizationCodeHandler) exchangeAuthorizationCode(ctx context.Context, cfg *oauth2.Config, authResult *authResult, resourceURL string) error { log.Printf("Exchanging authorization code for token") opts := []oauth2.AuthCodeOption{ oauth2.VerifierOption(authResult.usedCodeVerifier), diff --git a/conformance/everything-client/client_private.go b/conformance/everything-client/client_private.go index 86855197..f771cbdb 100644 --- a/conformance/everything-client/client_private.go +++ b/conformance/everything-client/client_private.go @@ -67,6 +67,8 @@ func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (*auth. } defer resp.Body.Close() + // In conformance tests the authorization server immediately redirects + // to the callback URL with the authorization code and state. location := resp.Header.Get("Location") if location == "" { return nil, fmt.Errorf("no Location header in redirect") @@ -117,7 +119,7 @@ func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]a } } - authHandler, err := auth.NewAuthorizationCodeOAuthHandler(authConfig) + authHandler, err := auth.NewAuthorizationCodeHandler(authConfig) if err != nil { return fmt.Errorf("failed to create auth handler: %w", err) } diff --git a/docs/protocol.md b/docs/protocol.md index d69d252e..34fc45d6 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -315,7 +315,7 @@ will also call the handler's `Authorize` method if the server returns or facilitate scope step-up authorization. The SDK implements the Authorization Code flow in -[`auth.AuthorizationCodeOAuthHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationCodeOAuthHandler). +[`auth.AuthorizationCodeHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationCodeHandler). This handler supports: - [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) @@ -325,7 +325,7 @@ This handler supports: To use it, configure the handler and assign it to the transport: ```go -authHandler := &auth.AuthorizationCodeOAuthHandler{ +authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandlerConfig{ RedirectURL: "https://myapp.com/oauth2-callback", // Configure one of the following: // ClientIDMetadataDocumentConfig: ... @@ -336,9 +336,9 @@ authHandler := &auth.AuthorizationCodeOAuthHandler{ // See full example in examples/auth/client/main.go. code := ... state := ... - return &auth.AuthorizationResult{Code: code, State: state}, nil + return &auth.AuthorizationResult{AuthorizationCode: code, State: state}, nil }, -} +}) transport := &mcp.StreamableClientTransport{ Endpoint: "https://example.com/mcp", @@ -348,7 +348,7 @@ client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, session, err := client.Connect(ctx, transport, nil) ``` -The `auth.AuthorizationCodeOAuthHandler` automatically manages token refreshing +The `auth.AuthorizationCodeHandler` automatically manages token refreshing and step-up authentication (when the server returns `insufficient_scope` error). ## Security diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go index 53db7909..91865a21 100644 --- a/examples/auth/client/main.go +++ b/examples/auth/client/main.go @@ -95,7 +95,7 @@ func main() { go receiver.serveRedirectHandler(listener) defer receiver.close() - authHandler, err := auth.NewAuthorizationCodeOAuthHandler(&auth.AuthorizationCodeHandlerConfig{ + authHandler, err := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandlerConfig{ RedirectURL: fmt.Sprintf("http://localhost:%d", *callbackPort), AuthorizationURLHandler: receiver.getAuthorizationCode, // Uncomment the client configuration you want to use. diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 0e533a49..dd2b23af 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -241,7 +241,7 @@ will also call the handler's `Authorize` method if the server returns or facilitate scope step-up authorization. The SDK implements the Authorization Code flow in -[`auth.AuthorizationCodeOAuthHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationCodeOAuthHandler). +[`auth.AuthorizationCodeHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationCodeHandler). This handler supports: - [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) @@ -251,7 +251,7 @@ This handler supports: To use it, configure the handler and assign it to the transport: ```go -authHandler := &auth.AuthorizationCodeOAuthHandler{ +authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandlerConfig{ RedirectURL: "https://myapp.com/oauth2-callback", // Configure one of the following: // ClientIDMetadataDocumentConfig: ... @@ -262,9 +262,9 @@ authHandler := &auth.AuthorizationCodeOAuthHandler{ // See full example in examples/auth/client/main.go. code := ... state := ... - return &auth.AuthorizationResult{Code: code, State: state}, nil + return &auth.AuthorizationResult{AuthorizationCode: code, State: state}, nil }, -} +}) transport := &mcp.StreamableClientTransport{ Endpoint: "https://example.com/mcp", @@ -274,7 +274,7 @@ client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, session, err := client.Connect(ctx, transport, nil) ``` -The `auth.AuthorizationCodeOAuthHandler` automatically manages token refreshing +The `auth.AuthorizationCodeHandler` automatically manages token refreshing and step-up authentication (when the server returns `insufficient_scope` error). ## Security From ae5e5b32570cdc7d9018a011d27180b1d9a79604 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Tue, 24 Feb 2026 10:18:28 +0000 Subject: [PATCH 15/26] Self review adjustments. --- auth/authorization_code.go | 82 ++++++++----- .../everything-client/client_private.go | 20 +--- docs/protocol.md | 2 +- examples/auth/client/main.go | 20 +--- internal/docs/protocol.src.md | 2 +- oauthex/auth_meta.go | 106 ++++++++++------- oauthex/auth_meta_test.go | 7 +- oauthex/oauth2_test.go | 18 +-- oauthex/oauthex.go | 86 -------------- oauthex/resource_meta.go | 4 +- oauthex/resource_meta_public.go | 112 ++++++++++++++++-- oauthex/url_scheme_test.go | 6 +- 12 files changed, 244 insertions(+), 221 deletions(-) diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 7b637539..be6470d4 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -54,7 +54,7 @@ type DynamicClientRegistrationConfig struct { type AuthorizationResult struct { // AuthorizationCode is the authorization code obtained from the authorization server. AuthorizationCode string - // State is the state string returned by the authorization server. + // State string returned by the authorization server. State string } @@ -78,8 +78,8 @@ type AuthorizationCodeHandlerConfig struct { // with [DynamicClientRegistrationConfig.Metadata.RedirectURIs]. RedirectURL string - // AuthorizationURLHandler is a required function called to handle the authorization URL. - // It is responsible for opening the URL in a browser for the user to start the authorization. + // AuthorizationURLHandler is a required function called to handle the authorization request. + // It is responsible for opening the URL in a browser for the user to start the authorization process. // It should return the authorization code and state once the Authorization Server // redirects back to the [AuthorizationCodeHandler.RedirectURL]. AuthorizationURLHandler func(ctx context.Context, authorizationURL string) (*AuthorizationResult, error) @@ -139,6 +139,14 @@ func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*Autho return &AuthorizationCodeHandler{config: config}, nil } +func isNonRootHTTPSURL(u string) bool { + pu, err := url.Parse(u) + if err != nil { + return false + } + return pu.Scheme == "https" && pu.Path != "" +} + // Authorize performs the authorization flow. // It is designed to perform the whole Authorization Code Grant flow. // On success, [AuthorizationCodeHandler.TokenSource] will return a token source with the fetched token. @@ -151,19 +159,14 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ if err != nil { return fmt.Errorf("failed to parse WWW-Authenticate header: %v", err) } - log.Printf("WWW-Authenticate header: %v", wwwChallenges) - var prm *oauthex.ProtectedResourceMetadata - for _, url := range oauthex.ProtectedResourceMetadataURLs(oauthex.ResourceMetadataURL(wwwChallenges), resourceURL) { - var err error - log.Printf("Getting protected resource metadata from %q", url) - prm, err = oauthex.GetProtectedResourceMetadata(ctx, url, http.DefaultClient) - if err == nil { - break - } - log.Printf("Failed to get protected resource metadata from %q: %v", url, err) + + prm, err := h.getProtectedResourceMetadata(ctx, wwwChallenges, resourceURL) + if err != nil { + return err } // log.Printf("Protected resource metadata: %+v", prm) + asm, err := h.getAuthServerMetadata(ctx, prm, resourceURL) if err != nil { return err @@ -202,16 +205,27 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ return h.exchangeAuthorizationCode(ctx, cfg, authRes, resourceURL) } -func isNonRootHTTPSURL(u string) bool { - pu, err := url.Parse(u) - if err != nil { - return false +// getProtectedResourceMetadata returns the protected resource metadata. +// If no metadata was found or the fetched metadata fails security checks, +// it returns an error. +func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Context, wwwChallenges []oauthex.Challenge, resourceURL string) (*oauthex.ProtectedResourceMetadata, error) { + var errs []error + for _, url := range oauthex.ProtectedResourceMetadataURLs(oauthex.ResourceMetadataURL(wwwChallenges), resourceURL) { + log.Printf("Getting protected resource metadata from %q", url) + prm, err := oauthex.GetProtectedResourceMetadata(ctx, url, http.DefaultClient) + if err != nil { + errs = append(errs, err) + continue + } + return prm, nil } - return pu.Scheme == "https" && pu.Path != "" + return nil, fmt.Errorf("failed to get protected resource metadata: %v", errors.Join(errs...)) } // getAuthServerMetadata returns the authorization server metadata. -// If no metadata is available, it returns a minimal set of endpoints +// It returns an error if the metadata request fails with non-4xx HTTP status code +// or the fetched metadata fails security checks. +// If no metadata was found, it returns a minimal set of endpoints // as a fallback to 2025-03-26 spec. func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata, resourceURL string) (*oauthex.AuthServerMeta, error) { var authServerURL string @@ -229,20 +243,24 @@ func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, pr } log.Printf("Authorization server URL: %s", authServerURL) - asm, err := oauthex.GetAuthServerMeta(ctx, authServerURL, http.DefaultClient) - if err != nil { - return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) - } - if asm == nil { - log.Print("Authorization server metadata not found, using fallback") - // Fallback to 2025-03-26 spec: predefined endpoints. - // https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#fallbacks-for-servers-without-metadata-discovery - asm = &oauthex.AuthServerMeta{ - Issuer: authServerURL, - AuthorizationEndpoint: authServerURL + "/authorize", - TokenEndpoint: authServerURL + "/token", - RegistrationEndpoint: authServerURL + "/register", + for _, u := range oauthex.AuthorizationServerMetadataURLs(authServerURL) { + asm, err := oauthex.GetAuthServerMeta(ctx, u, http.DefaultClient) + if err != nil { + return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) } + if asm != nil { + return asm, nil + } + } + + log.Print("Authorization server metadata not found, using fallback") + // Fallback to 2025-03-26 spec: predefined endpoints. + // https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#fallbacks-for-servers-without-metadata-discovery + asm := &oauthex.AuthServerMeta{ + Issuer: authServerURL, + AuthorizationEndpoint: authServerURL + "/authorize", + TokenEndpoint: authServerURL + "/token", + RegistrationEndpoint: authServerURL + "/register", } return asm, nil } diff --git a/conformance/everything-client/client_private.go b/conformance/everything-client/client_private.go index f771cbdb..11412a18 100644 --- a/conformance/everything-client/client_private.go +++ b/conformance/everything-client/client_private.go @@ -69,28 +69,14 @@ func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (*auth. // In conformance tests the authorization server immediately redirects // to the callback URL with the authorization code and state. - location := resp.Header.Get("Location") - if location == "" { - return nil, fmt.Errorf("no Location header in redirect") - } - - locURL, err := url.Parse(location) + locURL, err := url.Parse(resp.Header.Get("Location")) if err != nil { return nil, fmt.Errorf("parse location: %v", err) } - code := locURL.Query().Get("code") - if code == "" { - return nil, fmt.Errorf("no code parameter in redirect URL") - } - state := locURL.Query().Get("state") - if state == "" { - return nil, fmt.Errorf("no state parameter in redirect URL") - } - return &auth.AuthorizationResult{ - AuthorizationCode: code, - State: state, + AuthorizationCode: locURL.Query().Get("code"), + State: locURL.Query().Get("state"), }, nil } diff --git a/docs/protocol.md b/docs/protocol.md index 34fc45d6..0ed6b3af 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -319,7 +319,7 @@ The SDK implements the Authorization Code flow in This handler supports: - [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) -- Pre-registered clients +- [Pre-registered clients](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration) - [Dynamic Client Registration](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration) To use it, configure the handler and assign it to the transport: diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go index 91865a21..dfc3d102 100644 --- a/examples/auth/client/main.go +++ b/examples/auth/client/main.go @@ -36,16 +36,9 @@ type codeReceiver struct { func (r *codeReceiver) serveRedirectHandler(listener net.Listener) error { mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { - code := req.URL.Query().Get("code") - state := req.URL.Query().Get("state") - if code == "" { - http.Error(w, "authorization code not found", http.StatusBadRequest) - return - } - r.authChan <- &auth.AuthorizationResult{ - AuthorizationCode: code, - State: state, + AuthorizationCode: req.URL.Query().Get("code"), + State: req.URL.Query().Get("state"), } fmt.Fprint(w, "Authentication successful. You can close this window.") }) @@ -79,11 +72,6 @@ func (r *codeReceiver) close() { func main() { flag.Parse() - client := mcp.NewClient(&mcp.Implementation{ - Name: "test-client", - Version: "1.0.0", - }, nil) - receiver := &codeReceiver{ authChan: make(chan *auth.AuthorizationResult), errChan: make(chan error), @@ -121,6 +109,10 @@ func main() { } ctx := context.Background() + client := mcp.NewClient(&mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, nil) session, err := client.Connect(ctx, transport, nil) if err != nil { diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index dd2b23af..22a39f90 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -245,7 +245,7 @@ The SDK implements the Authorization Code flow in This handler supports: - [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) -- Pre-registered clients +- [Pre-registered clients](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration) - [Dynamic Client Registration](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration) To use it, configure the handler and assign it to the transport: diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index e6aab283..d9fcc9d8 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -123,55 +123,60 @@ type AuthServerMeta struct { } // GetAuthServerMeta issues a GET request to retrieve authorization server metadata -// from an OAuth authorization server with the given issuerURL. +// from an OAuth authorization server with the given metadataURL. // // It follows [RFC 8414]: -// - The well-known paths specified there are inserted into the URL's path, one at time. -// The first to succeed is used. -// - The Issuer field is checked against issuerURL. +// - The Issuer field is checked against metadataURL.Issuer. +// +// It also verifies that the authorization server supports PKCE and that the URLs +// in the metadata don't use dangerous schemes. +// +// It returns an error if the request fails with a non-4xx status code or the fetched +// metadata doesn't pass security validations. // // [RFC 8414]: https://tools.ietf.org/html/rfc8414 -func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (*AuthServerMeta, error) { - for _, u := range AuthorizationServerMetadataURLs(issuerURL) { - asm, err := getJSON[AuthServerMeta](ctx, c, u, 1<<20) - if err != nil { - log.Printf("Failed to get auth server metadata from %q: %v", u, err) - var httpErr *httpStatusError - if errors.As(err, &httpErr) { - if 400 <= httpErr.StatusCode && httpErr.StatusCode < 500 { - continue - } - return nil, fmt.Errorf("%v", err) // Do not expose wrapped errors. +func GetAuthServerMeta(ctx context.Context, metadataURL AuthorizationServerMetadataURL, c *http.Client) (*AuthServerMeta, error) { + asm, err := getJSON[AuthServerMeta](ctx, c, metadataURL.URL, 1<<20) + if err != nil { + log.Printf("Failed to get auth server metadata from %q: %v", metadataURL.URL, err) + var httpErr *httpStatusError + if errors.As(err, &httpErr) { + if 400 <= httpErr.StatusCode && httpErr.StatusCode < 500 { + return nil, nil } + return nil, fmt.Errorf("%v", err) // Do not expose error types. } - if asm.Issuer != issuerURL { - // Validate the Issuer field (see RFC 8414, section 3.3). - return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) - } - - if len(asm.CodeChallengeMethodsSupported) == 0 { - return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuerURL) - } + } + if asm.Issuer != metadataURL.Issuer { + // Validate the Issuer field (see RFC 8414, section 3.3). + return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, metadataURL.Issuer) + } - // Validate endpoint URLs to prevent XSS attacks (see #526). - if err := validateAuthServerMetaURLs(asm); err != nil { - return nil, err - } - log.Printf("Fetched authorization server metadata from %q", u) + if len(asm.CodeChallengeMethodsSupported) == 0 { + return nil, fmt.Errorf("authorization server at %s does not implement PKCE", metadataURL.Issuer) + } - return asm, nil + // Validate endpoint URLs to prevent XSS attacks (see #526). + if err := validateAuthServerMetaURLs(asm); err != nil { + return nil, err } - // Authorization server metadata not found. Return nil error to allow a fallback. - return nil, nil + log.Printf("Fetched authorization server metadata from %q", metadataURL.URL) + + return asm, nil } -// AuthorizationServerMetadataURLs returns a list of URLs to try when looking for -// authorization server metadata as mandated by the MCP specification. -func AuthorizationServerMetadataURLs(issuerURL string) []string { - var urls []string +type AuthorizationServerMetadataURL struct { + // URL where the Authorization Server Metadata may be retrieved. + URL string + // Issuer that was used to construct the [URL]. + Issuer string +} - // Produce candidates per - // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. +// AuthorizationServerMetadataURLs returns a list of URLs to try when looking for +// authorization server metadata as mandated by the MCP specification: +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. +func AuthorizationServerMetadataURLs(issuerURL string) []AuthorizationServerMetadataURL { + var urls []AuthorizationServerMetadataURL baseURL, err := url.Parse(issuerURL) if err != nil { @@ -181,23 +186,38 @@ func AuthorizationServerMetadataURLs(issuerURL string) []string { if baseURL.Path == "" { // "OAuth 2.0 Authorization Server Metadata". baseURL.Path = "/.well-known/oauth-authorization-server" - urls = append(urls, baseURL.String()) + urls = append(urls, AuthorizationServerMetadataURL{ + URL: baseURL.String(), + Issuer: issuerURL, + }) // "OpenID Connect Discovery 1.0". baseURL.Path = "/.well-known/openid-configuration" - urls = append(urls, baseURL.String()) + urls = append(urls, AuthorizationServerMetadataURL{ + URL: baseURL.String(), + Issuer: issuerURL, + }) return urls } originalPath := baseURL.Path // "OAuth 2.0 Authorization Server Metadata with path insertion". baseURL.Path = "/.well-known/oauth-authorization-server/" + strings.TrimLeft(originalPath, "/") - urls = append(urls, baseURL.String()) + urls = append(urls, AuthorizationServerMetadataURL{ + URL: baseURL.String(), + Issuer: issuerURL, + }) // "OpenID Connect Discovery 1.0 with path insertion". baseURL.Path = "/.well-known/openid-configuration/" + strings.TrimLeft(originalPath, "/") - urls = append(urls, baseURL.String()) - // "OpenID Connect Discovery 1.0 path appending". + urls = append(urls, AuthorizationServerMetadataURL{ + URL: baseURL.String(), + Issuer: issuerURL, + }) + // "OpenID Connect Discovery 1.0 with path appending". baseURL.Path = "/" + strings.Trim(originalPath, "/") + "/.well-known/openid-configuration" - urls = append(urls, baseURL.String()) + urls = append(urls, AuthorizationServerMetadataURL{ + URL: baseURL.String(), + Issuer: issuerURL, + }) return urls } diff --git a/oauthex/auth_meta_test.go b/oauthex/auth_meta_test.go index 1e608824..6363e098 100644 --- a/oauthex/auth_meta_test.go +++ b/oauthex/auth_meta_test.go @@ -85,7 +85,10 @@ func TestGetAuthServerMetaPKCESupport(t *testing.T) { // The fake server sets issuer to https://localhost:, so compute that issuer. u, _ := url.Parse(ts.URL) - issuer := "https://localhost:" + u.Port() + metadataURL := AuthorizationServerMetadataURL{ + URL: "https://localhost:" + u.Port() + "/.well-known/oauth-authorization-server", + Issuer: "https://localhost:" + u.Port(), + } // The fake server presents a cert for example.com; set ServerName accordingly. httpClient := ts.Client() @@ -95,7 +98,7 @@ func TestGetAuthServerMetaPKCESupport(t *testing.T) { httpClient.Transport = clone } - meta, err := GetAuthServerMeta(ctx, issuer, httpClient) + meta, err := GetAuthServerMeta(ctx, metadataURL, httpClient) if tt.wantError != "" { if err == nil { t.Fatal("wanted error but got none") diff --git a/oauthex/oauth2_test.go b/oauthex/oauth2_test.go index 08d2d314..36f732e8 100644 --- a/oauthex/oauth2_test.go +++ b/oauthex/oauth2_test.go @@ -82,13 +82,13 @@ func TestParseSingleChallenge(t *testing.T) { tests := []struct { name string input string - want challenge + want Challenge wantErr bool }{ { name: "scheme only", input: "Basic", - want: challenge{ + want: Challenge{ Scheme: "basic", }, wantErr: false, @@ -96,7 +96,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "scheme with one quoted param", input: `Bearer realm="example.com"`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": "example.com"}, }, @@ -105,7 +105,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "scheme with one unquoted param", input: `Bearer realm=example.com`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": "example.com"}, }, @@ -114,7 +114,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "scheme with multiple params", input: `Bearer realm="example", error="invalid_token", error_description="The token expired"`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{ "realm": "example", @@ -127,7 +127,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "scheme with multiple unquoted params", input: `Bearer realm=example, error=invalid_token, error_description=The token expired`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{ "realm": "example", @@ -140,7 +140,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "case-insensitive scheme and keys", input: `BEARER ReAlM="example"`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": "example"}, }, @@ -149,7 +149,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "param with escaped quote", input: `Bearer realm="example \"foo\" bar"`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": `example "foo" bar`}, }, @@ -158,7 +158,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "param without quotes (token)", input: "Bearer realm=example.com", - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": "example.com"}, }, diff --git a/oauthex/oauthex.go b/oauthex/oauthex.go index 34ed55b5..151da7e5 100644 --- a/oauthex/oauthex.go +++ b/oauthex/oauthex.go @@ -4,89 +4,3 @@ // Package oauthex implements extensions to OAuth2. package oauthex - -// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, -// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. -// -// The following features are not supported: -// - additional keys (§2, last sentence) -// - human-readable metadata (§2.1) -// - signed metadata (§2.2) -type ProtectedResourceMetadata struct { - // GENERATED BY GEMINI 2.5. - - // Resource (resource) is the protected resource's resource identifier. - // Required. - Resource string `json:"resource"` - - // AuthorizationServers (authorization_servers) is an optional slice containing a list of - // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be - // used with this protected resource. - AuthorizationServers []string `json:"authorization_servers,omitempty"` - - // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set - // document. This contains public keys belonging to the protected resource, such as - // signing key(s) that the resource server uses to sign resource responses. - JWKSURI string `json:"jwks_uri,omitempty"` - - // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope - // values (as defined in RFC 6749) used in authorization requests to request access - // to this protected resource. - ScopesSupported []string `json:"scopes_supported,omitempty"` - - // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing - // a list of the supported methods of sending an OAuth 2.0 bearer token to the - // protected resource. Defined values are "header", "body", and "query". - BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` - - // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional - // slice of JWS signing algorithms (alg values) supported by the protected - // resource for signing resource responses. - ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` - - // ResourceName (resource_name) is a human-readable name of the protected resource - // intended for display to the end user. It is RECOMMENDED that this field be included. - // This value may be internationalized. - ResourceName string `json:"resource_name,omitempty"` - - // ResourceDocumentation (resource_documentation) is an optional URL of a page containing - // human-readable information for developers using the protected resource. - // This value may be internationalized. - ResourceDocumentation string `json:"resource_documentation,omitempty"` - - // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing - // human-readable policy information on how a client can use the data provided. - // This value may be internationalized. - ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` - - // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected - // resource's human-readable terms of service. This value may be internationalized. - ResourceTOSURI string `json:"resource_tos_uri,omitempty"` - - // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an - // optional boolean indicating support for mutual-TLS client certificate-bound - // access tokens (RFC 8705). Defaults to false if omitted. - TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` - - // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional - // slice of 'type' values supported by the resource server for the - // 'authorization_details' parameter (RFC 9396). - AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` - - // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional - // slice of JWS signing algorithms supported by the resource server for validating - // DPoP proof JWTs (RFC 9449). - DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` - - // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean - // specifying whether the protected resource always requires the use of DPoP-bound - // access tokens (RFC 9449). Defaults to false if omitted. - DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` - - // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters - // about the protected resource as claims. If present, these values take precedence - // over values conveyed in plain JSON. - // TODO:implement. - // Note that §2.2 says it's okay to ignore this. - // SignedMetadata string `json:"signed_metadata,omitempty"` -} diff --git a/oauthex/resource_meta.go b/oauthex/resource_meta.go index 9875f54a..bd869fa0 100644 --- a/oauthex/resource_meta.go +++ b/oauthex/resource_meta.go @@ -153,7 +153,7 @@ func ProtectedResourceMetadataURLs(metadataURL, resourceURL string) []ProtectedR // ResourceMetadataURL returns a resource metadata URL from the given challenges, // or the empty string if there is none. -func ResourceMetadataURL(cs []challenge) string { +func ResourceMetadataURL(cs []Challenge) string { for _, c := range cs { if u := c.Params["resource_metadata"]; u != "" { return u @@ -162,7 +162,7 @@ func ResourceMetadataURL(cs []challenge) string { return "" } -func Scopes(cs []challenge) []string { +func Scopes(cs []Challenge) []string { for _, c := range cs { if c.Scheme == "bearer" && c.Params["scope"] != "" { return strings.Fields(c.Params["scope"]) diff --git a/oauthex/resource_meta_public.go b/oauthex/resource_meta_public.go index 44ad3e79..443d5ba8 100644 --- a/oauthex/resource_meta_public.go +++ b/oauthex/resource_meta_public.go @@ -16,9 +16,95 @@ import ( "unicode" ) -// challenge represents a single authentication challenge from a WWW-Authenticate header. +// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, +// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. +// +// The following features are not supported: +// - additional keys (§2, last sentence) +// - human-readable metadata (§2.1) +// - signed metadata (§2.2) +type ProtectedResourceMetadata struct { + // GENERATED BY GEMINI 2.5. + + // Resource (resource) is the protected resource's resource identifier. + // Required. + Resource string `json:"resource"` + + // AuthorizationServers (authorization_servers) is an optional slice containing a list of + // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be + // used with this protected resource. + AuthorizationServers []string `json:"authorization_servers,omitempty"` + + // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set + // document. This contains public keys belonging to the protected resource, such as + // signing key(s) that the resource server uses to sign resource responses. + JWKSURI string `json:"jwks_uri,omitempty"` + + // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope + // values (as defined in RFC 6749) used in authorization requests to request access + // to this protected resource. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing + // a list of the supported methods of sending an OAuth 2.0 bearer token to the + // protected resource. Defined values are "header", "body", and "query". + BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` + + // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms (alg values) supported by the protected + // resource for signing resource responses. + ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` + + // ResourceName (resource_name) is a human-readable name of the protected resource + // intended for display to the end user. It is RECOMMENDED that this field be included. + // This value may be internationalized. + ResourceName string `json:"resource_name,omitempty"` + + // ResourceDocumentation (resource_documentation) is an optional URL of a page containing + // human-readable information for developers using the protected resource. + // This value may be internationalized. + ResourceDocumentation string `json:"resource_documentation,omitempty"` + + // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing + // human-readable policy information on how a client can use the data provided. + // This value may be internationalized. + ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` + + // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected + // resource's human-readable terms of service. This value may be internationalized. + ResourceTOSURI string `json:"resource_tos_uri,omitempty"` + + // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an + // optional boolean indicating support for mutual-TLS client certificate-bound + // access tokens (RFC 8705). Defaults to false if omitted. + TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` + + // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional + // slice of 'type' values supported by the resource server for the + // 'authorization_details' parameter (RFC 9396). + AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` + + // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms supported by the resource server for validating + // DPoP proof JWTs (RFC 9449). + DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` + + // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean + // specifying whether the protected resource always requires the use of DPoP-bound + // access tokens (RFC 9449). Defaults to false if omitted. + DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` + + // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters + // about the protected resource as claims. If present, these values take precedence + // over values conveyed in plain JSON. + // TODO:implement. + // Note that §2.2 says it's okay to ignore this. + // SignedMetadata string `json:"signed_metadata,omitempty"` +} + +// Challenge represents a single authentication challenge from a WWW-Authenticate header. // As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. -type challenge struct { +type Challenge struct { // GENERATED BY GEMINI 2.5. // // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). @@ -29,7 +115,7 @@ type challenge struct { Params map[string]string } -func Error(cs []challenge) string { +func Error(cs []Challenge) string { for _, c := range cs { if c.Scheme == "bearer" && c.Params["error"] != "" { return c.Params["error"] @@ -42,9 +128,9 @@ func Error(cs []challenge) string { // The header format is defined in RFC 9110, Section 11.6.1, and can contain // one or more challenges, separated by commas. // It returns a slice of challenges or an error if one of the headers is malformed. -func ParseWWWAuthenticate(headers []string) ([]challenge, error) { +func ParseWWWAuthenticate(headers []string) ([]Challenge, error) { // GENERATED BY GEMINI 2.5 (human-tweaked) - var challenges []challenge + var challenges []Challenge for _, h := range headers { challengeStrings, err := splitChallenges(h) if err != nil { @@ -112,15 +198,15 @@ func splitChallenges(header string) ([]string, error) { // parseSingleChallenge parses a string containing exactly one challenge. // challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] -func parseSingleChallenge(s string) (challenge, error) { +func parseSingleChallenge(s string) (Challenge, error) { // GENERATED BY GEMINI 2.5, human-tweaked. s = strings.TrimSpace(s) if s == "" { - return challenge{}, errors.New("empty challenge string") + return Challenge{}, errors.New("empty challenge string") } scheme, paramsStr, found := strings.Cut(s, " ") - c := challenge{Scheme: strings.ToLower(scheme)} + c := Challenge{Scheme: strings.ToLower(scheme)} if !found { return c, nil } @@ -132,7 +218,7 @@ func parseSingleChallenge(s string) (challenge, error) { // Find the end of the parameter key. keyEnd := strings.Index(paramsStr, "=") if keyEnd <= 0 { - return challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) + return Challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) } key := strings.TrimSpace(paramsStr[:keyEnd]) @@ -160,7 +246,7 @@ func parseSingleChallenge(s string) (challenge, error) { // A quoted string must be terminated. if i == len(paramsStr) { - return challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") + return Challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") } value = valBuilder.String() @@ -178,7 +264,7 @@ func parseSingleChallenge(s string) (challenge, error) { } } if value == "" { - return challenge{}, fmt.Errorf("no value for auth param %q", key) + return Challenge{}, fmt.Errorf("no value for auth param %q", key) } // Per RFC 9110, parameter keys are case-insensitive. @@ -189,10 +275,10 @@ func parseSingleChallenge(s string) (challenge, error) { paramsStr = strings.TrimSpace(paramsStr[1:]) } else if paramsStr != "" { // If there's content but it's not a new parameter, the format is wrong. - return challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) + return Challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) } } // Per RFC 9110, the scheme is case-insensitive. - return challenge{Scheme: strings.ToLower(scheme), Params: params}, nil + return Challenge{Scheme: strings.ToLower(scheme), Params: params}, nil } diff --git a/oauthex/url_scheme_test.go b/oauthex/url_scheme_test.go index 531a1f9c..c13bc1d5 100644 --- a/oauthex/url_scheme_test.go +++ b/oauthex/url_scheme_test.go @@ -226,7 +226,11 @@ func TestGetAuthServerMetaRejectsDangerousURLs(t *testing.T) { defer server.Close() ctx := context.Background() - _, err := GetAuthServerMeta(ctx, server.URL, server.Client()) + metadataURL := AuthorizationServerMetadataURL{ + URL: server.URL, + Issuer: server.URL, + } + _, err := GetAuthServerMeta(ctx, metadataURL, server.Client()) if err == nil { t.Fatal("GetAuthServerMeta(): got nil error, want error") } From 86cdd9fc01b98accaa870735c85dba4bce712650 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Tue, 24 Feb 2026 10:44:14 +0000 Subject: [PATCH 16/26] Finalize error handling in streaming.go. --- auth/client.go | 8 +++----- mcp/streamable.go | 17 +++++++++-------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/auth/client.go b/auth/client.go index 3aa6d29b..6ddd4a29 100644 --- a/auth/client.go +++ b/auth/client.go @@ -6,24 +6,22 @@ package auth import ( "context" - "errors" "net/http" "golang.org/x/oauth2" ) -// Error that will be thrown if the call failed due to authorization. -var ErrUnauthorized = errors.New("unauthorized") - type OAuthHandler interface { isOAuthHandler() // TokenSource returns a token source to be used for outgoing requests. + // Returned token source might be nil. In that case, the transport will not + // add any authorization headers to the request. TokenSource(context.Context) (oauth2.TokenSource, error) // Authorize is called when an HTTP request results in an error that may // be addressed by the authorization flow (currently 401 Unauthorized and 403 Forbidden). - // It is responsible for initiating the OAuth flow to obtain a token source. + // It is responsible for performing the OAuth flow to obtain an access token. // The arguments are the request that failed and the response that was received for it. // If the returned error is nil, [TokenSource] is expected to return a non-nil token source. // After a successful call to [Authorize], the HTTP request should be retried by the transport. diff --git a/mcp/streamable.go b/mcp/streamable.go index 337ebcf3..9155f7bf 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1733,15 +1733,17 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e req.Header.Set("Accept", "application/json, text/event-stream") doRequest := func() (*http.Response, error) { if err := c.setMCPHeaders(req); err != nil { - // TODO: should we fail the connection here? - return nil, err + // Failure to set headers means that the request was not sent. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + return nil, fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) } resp, err := c.client.Do(req) if err != nil { // Any error from client.Do means the request didn't reach the server. // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr // and permanently break the connection. - err = fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err) + err = fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) } return resp, err } @@ -1774,7 +1776,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr // and permanently break the connection. // Wrap the authorization error as well for client inspection. - return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) + return fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) } // Retry the request after successful authorization. resp, err = doRequest() @@ -1863,7 +1865,9 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { if err != nil { return err } - req.Header.Set("Authorization", "Bearer "+token.AccessToken) + if token != nil { + req.Header.Set("Authorization", "Bearer "+token.AccessToken) + } } } if c.initializedResult != nil { @@ -2126,7 +2130,6 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin return nil, err } if err := c.setMCPHeaders(req); err != nil { - // TODO: should we fail the connection here? return nil, err } if lastEventID != "" { @@ -2160,8 +2163,6 @@ func (c *streamableClientConn) Close() error { c.closeErr = err } else { if err := c.setMCPHeaders(req); err != nil { - // TODO: or setting headers should be best-effort and we should retry - // the request without them? c.closeErr = err } else if _, err := c.client.Do(req); err != nil { c.closeErr = err From 69a0853beabd988a78f184754c4054b463cd52fc Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Tue, 24 Feb 2026 13:13:06 +0000 Subject: [PATCH 17/26] Refactor client_secret handling and authentication methods selection. --- auth/authorization_code.go | 116 ++++++++++++++---- .../everything-client/client_private.go | 8 +- 2 files changed, 95 insertions(+), 29 deletions(-) diff --git a/auth/authorization_code.go b/auth/authorization_code.go index be6470d4..4600ea27 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -20,6 +20,46 @@ import ( "golang.org/x/oauth2" ) +// ClientSecretAuthMethod defines "client_secret_*" authentication methods per +// https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml#token-endpoint-auth-method. +// "client_secret_jwt" is not currently supported. +type ClientSecretAuthMethod int + +const ( + // ClientSecretAuthMethodBasic uses the "client_secret_basic" authentication method. + ClientSecretAuthMethodBasic ClientSecretAuthMethod = iota + // ClientSecretAuthMethodPost uses the "client_secret_post" authentication method. + ClientSecretAuthMethodPost +) + +func (m ClientSecretAuthMethod) String() string { + switch m { + case ClientSecretAuthMethodBasic: + return "client_secret_basic" + case ClientSecretAuthMethodPost: + return "client_secret_post" + default: + return "" + } +} + +// ClientSecretAuthConfig is used to configure client authentication using client_secret. +type ClientSecretAuthConfig struct { + // ClientID is the client ID to be used for client authentication. + ClientID string + // ClientSecret is the client secret to be used for client authentication. + ClientSecret string + // PreferredClientSecretAuthMethod to be used for client authentication. + // If not specified or unsupported by the authorization server, the method + // will be selected based on the authorization server's supported methods, + // according to the following preference order: + // + // 1. "client_secret_post" + // 2. "client_secret_basic" + // + PreferredClientSecretAuthMethod ClientSecretAuthMethod +} + // ClientIDMetadataDocumentConfig is used to configure the Client ID Metadata Document // based client registration per // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents. @@ -32,13 +72,10 @@ type ClientIDMetadataDocumentConfig struct { // PreregisteredClientConfig is used to configure a pre-registered client per // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration. +// Currently only "client_secret_basic" and "client_secret_post" authentication methods are supported. type PreregisteredClientConfig struct { - // ClientID and ClientSecret to be used for client authentication. - ClientID string - ClientSecret string - // AuthStyle is an optional client authentication method. - // See [oauth2.AuthStyleAutoDetect] for the documentation of the zero value. - AuthStyle oauth2.AuthStyle + // ClientSecretAuthConfig is the client_secret based configuration to be used for client authentication. + ClientSecretAuthConfig *ClientSecretAuthConfig } // DynamicClientRegistrationConfig is used to configure dynamic client registration per @@ -123,8 +160,12 @@ func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*Autho if config.ClientIDMetadataDocumentConfig != nil && !isNonRootHTTPSURL(config.ClientIDMetadataDocumentConfig.URL) { return nil, fmt.Errorf("client ID metadata document URL must be a non-root HTTPS URL") } - if config.PreregisteredClientConfig != nil { - if config.PreregisteredClientConfig.ClientID == "" || config.PreregisteredClientConfig.ClientSecret == "" { + preCfg := config.PreregisteredClientConfig + if preCfg != nil { + if preCfg.ClientSecretAuthConfig == nil { + return nil, errors.New("field ClientSecretAuthConfig is required for pre-registered client") + } + if preCfg.ClientSecretAuthConfig.ClientID == "" || preCfg.ClientSecretAuthConfig.ClientSecret == "" { return nil, fmt.Errorf("pre-registered client ID or secret is empty") } } @@ -188,9 +229,8 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ ClientSecret: resolvedClientConfig.clientSecret, Endpoint: oauth2.Endpoint{ - AuthURL: asm.AuthorizationEndpoint, - TokenURL: asm.TokenEndpoint, - // TODO: validate if the auth style is supported by the AS. + AuthURL: asm.AuthorizationEndpoint, + TokenURL: asm.TokenEndpoint, AuthStyle: resolvedClientConfig.authStyle, }, RedirectURL: h.config.RedirectURL, @@ -199,6 +239,7 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ authRes, err := h.getAuthorizationCode(ctx, cfg, req.URL.String()) if err != nil { + // Purposefully leaving the error unwrappable so it can be handled by the caller. return err } @@ -280,6 +321,38 @@ type resolvedClientConfig struct { authStyle oauth2.AuthStyle } +func selectTokenAuthMethod(supported []string, preferred ClientSecretAuthMethod) oauth2.AuthStyle { + if slices.Contains(supported, preferred.String()) { + return authMethodToStyle(preferred.String()) + } + prefOrder := []string{ + // Preferred in OAuth 2.1 draft: https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-14.html#name-client-secret. + "client_secret_post", + "client_secret_basic", + } + for _, method := range prefOrder { + if slices.Contains(supported, method) { + return authMethodToStyle(method) + } + } + return oauth2.AuthStyleAutoDetect +} + +func authMethodToStyle(method string) oauth2.AuthStyle { + switch method { + case "client_secret_post": + return oauth2.AuthStyleInParams + case "client_secret_basic": + return oauth2.AuthStyleInHeader + case "none": + // "none" is equivalent to "client_secret_post" but without sending client secret. + return oauth2.AuthStyleInParams + default: + // "client_secret_basic" is the default per https://datatracker.ietf.org/doc/html/rfc7591#section-2. + return oauth2.AuthStyleInHeader + } +} + // handleRegistration handles client registration. // The provided authorization server metadata must be non-nil. // Support for different registration methods is defined as follows: @@ -300,11 +373,12 @@ func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm * // 2. Attempt to use pre-registered client configuration. pCfg := h.config.PreregisteredClientConfig if pCfg != nil { + authStyle := selectTokenAuthMethod(asm.TokenEndpointAuthMethodsSupported, pCfg.ClientSecretAuthConfig.PreferredClientSecretAuthMethod) return &resolvedClientConfig{ registrationType: registrationTypePreregistered, - clientID: pCfg.ClientID, - clientSecret: pCfg.ClientSecret, - authStyle: pCfg.AuthStyle, + clientID: pCfg.ClientSecretAuthConfig.ClientID, + clientSecret: pCfg.ClientSecretAuthConfig.ClientSecret, + authStyle: authStyle, }, nil } // 3. Attempt to use dynamic client registration. @@ -318,18 +392,7 @@ func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm * registrationType: registrationTypeDynamic, clientID: regResp.ClientID, clientSecret: regResp.ClientSecret, - } - switch regResp.TokenEndpointAuthMethod { - case "client_secret_post": - cfg.authStyle = oauth2.AuthStyleInParams - case "client_secret_basic": - cfg.authStyle = oauth2.AuthStyleInHeader - case "none": - // "none" is equivalent to "client_secret_post" but without sending client secret. - cfg.authStyle = oauth2.AuthStyleInParams - cfg.clientSecret = "" - default: - // We leave the AuthStyle set to zero value, which is auto-detection. + authStyle: authMethodToStyle(regResp.TokenEndpointAuthMethod), } log.Printf("Client registered with client ID: %s", regResp.ClientID) return cfg, nil @@ -358,6 +421,7 @@ func (h *AuthorizationCodeHandler) getAuthorizationCode(ctx context.Context, cfg log.Printf("Calling AuthorizationURLHandler: %q", authURL) authRes, err := h.config.AuthorizationURLHandler(ctx, authURL) if err != nil { + // Purposefully leaving the error unwrappable so it can be handled by the caller. return nil, err } if authRes.State != state { diff --git a/conformance/everything-client/client_private.go b/conformance/everything-client/client_private.go index 11412a18..6b4654ef 100644 --- a/conformance/everything-client/client_private.go +++ b/conformance/everything-client/client_private.go @@ -96,11 +96,13 @@ func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]a }, } // Try pre-registered client information if provided in the context. - if clientId, ok := configCtx["client_id"].(string); ok { + if clientID, ok := configCtx["client_id"].(string); ok { if clientSecret, ok := configCtx["client_secret"].(string); ok { authConfig.PreregisteredClientConfig = &auth.PreregisteredClientConfig{ - ClientID: clientId, - ClientSecret: clientSecret, + ClientSecretAuthConfig: &auth.ClientSecretAuthConfig{ + ClientID: clientID, + ClientSecret: clientSecret, + }, } } } From 857a3516fc499b5b71bfa3753a49740c260568a6 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Wed, 25 Feb 2026 10:11:31 +0000 Subject: [PATCH 18/26] Add unit tests for authorization_code.go. --- auth/authorization_code.go | 24 +- auth/authorization_code_test.go | 589 ++++++++++++++++++ .../oauthtest/fake_authorization_server.go | 304 +++++++++ oauthex/auth_meta.go | 3 +- 4 files changed, 910 insertions(+), 10 deletions(-) create mode 100644 auth/authorization_code_test.go create mode 100644 internal/oauthtest/fake_authorization_server.go diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 4600ea27..8e181c81 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -195,20 +195,19 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ defer resp.Body.Close() log.Printf("Authorize: %s %s", req.Method, req.URL) - resourceURL := req.URL.String() wwwChallenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) if err != nil { return fmt.Errorf("failed to parse WWW-Authenticate header: %v", err) } log.Printf("WWW-Authenticate header: %v", wwwChallenges) - prm, err := h.getProtectedResourceMetadata(ctx, wwwChallenges, resourceURL) + prm, err := h.getProtectedResourceMetadata(ctx, wwwChallenges, req.URL.String()) if err != nil { return err } // log.Printf("Protected resource metadata: %+v", prm) - asm, err := h.getAuthServerMetadata(ctx, prm, resourceURL) + asm, err := h.getAuthServerMetadata(ctx, prm) if err != nil { return err } @@ -243,39 +242,46 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ return err } - return h.exchangeAuthorizationCode(ctx, cfg, authRes, resourceURL) + return h.exchangeAuthorizationCode(ctx, cfg, authRes, prm.Resource) } // getProtectedResourceMetadata returns the protected resource metadata. // If no metadata was found or the fetched metadata fails security checks, // it returns an error. -func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Context, wwwChallenges []oauthex.Challenge, resourceURL string) (*oauthex.ProtectedResourceMetadata, error) { +func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Context, wwwChallenges []oauthex.Challenge, mcpServerURL string) (*oauthex.ProtectedResourceMetadata, error) { var errs []error - for _, url := range oauthex.ProtectedResourceMetadataURLs(oauthex.ResourceMetadataURL(wwwChallenges), resourceURL) { + // Use MCP server URL as the resource URI per + // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#canonical-server-uri. + for _, url := range oauthex.ProtectedResourceMetadataURLs(oauthex.ResourceMetadataURL(wwwChallenges), mcpServerURL) { log.Printf("Getting protected resource metadata from %q", url) prm, err := oauthex.GetProtectedResourceMetadata(ctx, url, http.DefaultClient) if err != nil { errs = append(errs, err) continue } + if prm == nil { + errs = append(errs, fmt.Errorf("protected resource metadata is nil")) + continue + } return prm, nil } return nil, fmt.Errorf("failed to get protected resource metadata: %v", errors.Join(errs...)) } // getAuthServerMetadata returns the authorization server metadata. +// The provided Protected Resource Metadata must not be nil. // It returns an error if the metadata request fails with non-4xx HTTP status code // or the fetched metadata fails security checks. // If no metadata was found, it returns a minimal set of endpoints // as a fallback to 2025-03-26 spec. -func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata, resourceURL string) (*oauthex.AuthServerMeta, error) { +func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata) (*oauthex.AuthServerMeta, error) { var authServerURL string - if prm != nil && len(prm.AuthorizationServers) > 0 { + if len(prm.AuthorizationServers) > 0 { // Use the first authorization server, similarly to other SDKs. authServerURL = prm.AuthorizationServers[0] } else { // Fallback to 2025-03-26 spec: MCP server base URL acts as Authorization Server. - authURL, err := url.Parse(resourceURL) + authURL, err := url.Parse(prm.Resource) if err != nil { return nil, fmt.Errorf("failed to parse resource URL: %v", err) } diff --git a/auth/authorization_code_test.go b/auth/authorization_code_test.go new file mode 100644 index 00000000..8b4dc0f1 --- /dev/null +++ b/auth/authorization_code_test.go @@ -0,0 +1,589 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "net/http/httputil" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/internal/oauthtest" + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +func TestAuthorize(t *testing.T) { + authServer := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ + RegistrationConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "test_client_id": { + Secret: "test_client_secret", + RedirectURIs: []string{"http://localhost:12345/callback"}, + }, + }, + }, + }) + authServer.Start(t) + + resourceMux := http.NewServeMux() + resourceServer := httptest.NewServer(resourceMux) + t.Cleanup(resourceServer.Close) + resourceURL := resourceServer.URL + "/resource" + + resourceMux.Handle("/.well-known/oauth-protected-resource/resource", ProtectedResourceMetadataHandler(&oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{authServer.URL()}, + })) + + handler, err := NewAuthorizationCodeHandler(&AuthorizationCodeHandlerConfig{ + RedirectURL: "http://localhost:12345/callback", + PreregisteredClientConfig: &PreregisteredClientConfig{ + ClientSecretAuthConfig: &ClientSecretAuthConfig{ + ClientID: "test_client_id", + ClientSecret: "test_client_secret", + }, + }, + AuthorizationURLHandler: func(ctx context.Context, authURL string) (*AuthorizationResult, error) { + // The fake authorization server will redirect to an URL with code and state. + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := client.Get(authURL) + if err != nil { + return nil, fmt.Errorf("failed to visit auth URL: %v", err) + } + defer resp.Body.Close() + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + t.Fatalf("failed to dump response: %v", err) + } + t.Log(string(dump)) + + location, err := resp.Location() + if err != nil { + return nil, fmt.Errorf("failed to get location header: %v", err) + } + return &AuthorizationResult{ + AuthorizationCode: location.Query().Get("code"), + State: location.Query().Get("state"), + }, nil + }, + }) + if err != nil { + t.Fatalf("NewAuthorizationCodeHandler failed: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, resourceURL, nil) + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: make(http.Header), + Body: http.NoBody, + Request: req, + } + resp.Header.Set( + "WWW-Authenticate", + "Bearer resource_metadata="+resourceServer.URL+"/.well-known/oauth-protected-resource/resource", + ) + + if err := handler.Authorize(context.Background(), req, resp); err != nil { + t.Fatalf("Authorize failed: %v", err) + } + + tokenSource, err := handler.TokenSource(t.Context()) + if err != nil { + t.Fatalf("Failed to get token source: %v", err) + } + token, err := tokenSource.Token() + if err != nil { + t.Fatalf("Failed to get token: %v", err) + } + if token.AccessToken != "test_access_token" { + t.Errorf("Expected access token 'test_access_token', got '%s'", token.AccessToken) + } +} + +func TestNewAuthorizationCodeHandler(t *testing.T) { + validConfig := func() *AuthorizationCodeHandlerConfig { + return &AuthorizationCodeHandlerConfig{ + ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://example.com/client"}, + RedirectURL: "https://example.com/callback", + AuthorizationURLHandler: func(ctx context.Context, authURL string) (*AuthorizationResult, error) { + return nil, nil + }, + } + } + // Ensure the base config is valid + if _, err := NewAuthorizationCodeHandler(validConfig()); err != nil { + t.Fatalf("NewAuthorizationCodeHandler failed: %v", err) + } + + tests := []struct { + name string + config func() *AuthorizationCodeHandlerConfig + }{ + { + name: "NilConfig", + config: func() *AuthorizationCodeHandlerConfig { + return nil + }, + }, + { + name: "NoRegistrationConfig", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.ClientIDMetadataDocumentConfig = nil + cfg.PreregisteredClientConfig = nil + cfg.DynamicClientRegistrationConfig = nil + return cfg + }, + }, + { + name: "MissingRedirectURL", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.RedirectURL = "" + return cfg + }, + }, + { + name: "MissingAuthorizationURLHandler", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.AuthorizationURLHandler = nil + return cfg + }, + }, + { + name: "InvalidMetadataURL", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.ClientIDMetadataDocumentConfig.URL = "https://example.com" + return cfg + }, + }, + { + name: "InvalidPreregistered_MissingSecretConfig", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.PreregisteredClientConfig = &PreregisteredClientConfig{} + return cfg + }, + }, + { + name: "InvalidPreregistered_EmptyID", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.PreregisteredClientConfig = &PreregisteredClientConfig{ + ClientSecretAuthConfig: &ClientSecretAuthConfig{ + ClientSecret: "secret", + }, + } + return cfg + }, + }, + { + name: "InvalidPreregistered_EmptySecret", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.PreregisteredClientConfig = &PreregisteredClientConfig{ + ClientSecretAuthConfig: &ClientSecretAuthConfig{ + ClientID: "test_client_id", + }, + } + return cfg + }, + }, + { + name: "InvalidDynamic_MissingMetadata", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.DynamicClientRegistrationConfig = &DynamicClientRegistrationConfig{} + return cfg + }, + }, + { + name: "InvalidDynamic_InconsistentRedirectURI", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.DynamicClientRegistrationConfig = &DynamicClientRegistrationConfig{ + Metadata: &oauthex.ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback1"}, + }, + } + cfg.RedirectURL = "https://example.com/callback2" + return cfg + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewAuthorizationCodeHandler(tt.config()) + if err == nil { + t.Errorf("NewAuthorizationCodeHandler() = nil, want error") + } + }) + } +} + +func TestGetProtectedResourceMetadata(t *testing.T) { + handler := &AuthorizationCodeHandler{} // No config needed for this method + pathForChallenge := "/protected-resource" + + tests := []struct { + name string + challengesProvided bool + prmPath string + resourcePath string + wantError bool + }{ + { + name: "FromChallenges", + challengesProvided: true, + prmPath: pathForChallenge, + resourcePath: "/resource", + wantError: false, + }, + { + name: "FallbackToEndpoint", + challengesProvided: false, + prmPath: "/.well-known/oauth-protected-resource/resource", + resourcePath: "/resource", + wantError: false, + }, + { + name: "FallbackToRoot", + challengesProvided: false, + prmPath: "/.well-known/oauth-protected-resource", + resourcePath: "", + wantError: false, + }, + { + name: "NoMetadata", + challengesProvided: false, + prmPath: "/incorrect", + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + resourceURL := server.URL + tt.resourcePath + metadata := &oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + ScopesSupported: []string{"read", "write"}, + } + mux.Handle(tt.prmPath, ProtectedResourceMetadataHandler(metadata)) + var challenges []oauthex.Challenge + if tt.challengesProvided { + challenges = []oauthex.Challenge{ + { + Scheme: "Bearer", + Params: map[string]string{ + "resource_metadata": server.URL + pathForChallenge, + }, + }, + } + } + + got, err := handler.getProtectedResourceMetadata(t.Context(), challenges, resourceURL) + if err != nil { + if !tt.wantError { + t.Fatalf("getProtectedResourceMetadata() error = %v, want nil", err) + } + return + } + if got == nil { + t.Fatal("getProtectedResourceMetadata() got nil, want metadata") + } + if diff := cmp.Diff(metadata, got); diff != "" { + t.Errorf("getProtectedResourceMetadata() metadata mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestGetAuthServerMetadata(t *testing.T) { + handler := &AuthorizationCodeHandler{} // No config needed for this method + + tests := []struct { + name string + authorizationAtMCPServer bool + issuerPath string + endpointConfig *oauthtest.MetadataEndpointConfig + }{ + { + name: "OAuthEndpoint_Root", + authorizationAtMCPServer: false, + issuerPath: "", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOAuthInsertedEndpoint: true, + }, + }, + { + name: "OpenIDEndpoint_Root", + authorizationAtMCPServer: false, + issuerPath: "", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOpenIDInsertedEndpoint: true, + }, + }, + { + name: "OAuthEndpoint_Path", + authorizationAtMCPServer: false, + issuerPath: "/oauth", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOAuthInsertedEndpoint: true, + }, + }, + { + name: "OpenIDEndpoint_Path", + authorizationAtMCPServer: false, + issuerPath: "/openid", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOpenIDInsertedEndpoint: true, + }, + }, + { + name: "OpenIDAppendedEndpoint_Path", + authorizationAtMCPServer: false, + issuerPath: "/openid", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOpenIDAppendedEndpoint: true, + }, + }, + { + name: "FallbackToMCPServer", + authorizationAtMCPServer: true, + }, + { + name: "NoMetadata", + issuerPath: "", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + // All metadata endpoints disabled. + ServeOAuthInsertedEndpoint: false, + ServeOpenIDInsertedEndpoint: false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ + IssuerPath: tt.issuerPath, + MetadataEndpointConfig: tt.endpointConfig, + }) + s.Start(t) + issuerURL := s.URL() + tt.issuerPath + resourceURL := "https://example.com/resource" + authServers := []string{issuerURL} + if tt.authorizationAtMCPServer { + resourceURL = issuerURL + authServers = nil + } + prm := &oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: authServers, + } + + got, err := handler.getAuthServerMetadata(t.Context(), prm) + if err != nil { + t.Fatalf("getAuthServerMetadata() error = %v, want nil", err) + } + if got == nil { + t.Fatal("getAuthServerMetadata() got nil, want metadata") + } + if got.Issuer != issuerURL { + t.Errorf("getAuthServerMetadata() issuer = %q, want %q", got.Issuer, issuerURL) + } + }) + } +} + +func TestSelectTokenAuthMethod(t *testing.T) { + tests := []struct { + name string + supported []string + preferred ClientSecretAuthMethod + want oauth2.AuthStyle + }{ + { + name: "PreferredBasic_Supported", + supported: []string{"client_secret_basic", "client_secret_post"}, + preferred: ClientSecretAuthMethodBasic, + want: oauth2.AuthStyleInHeader, + }, + { + name: "PreferredPost_Supported", + supported: []string{"client_secret_basic", "client_secret_post"}, + preferred: ClientSecretAuthMethodPost, + want: oauth2.AuthStyleInParams, + }, + { + name: "PreferredBasic_NotSupported", + supported: []string{"client_secret_post"}, + preferred: ClientSecretAuthMethodBasic, + want: oauth2.AuthStyleInParams, + }, + { + name: "PreferredPost_NotSupported", + supported: []string{"client_secret_basic"}, + preferred: ClientSecretAuthMethodPost, + want: oauth2.AuthStyleInHeader, + }, + { + name: "NoneSupported", + supported: []string{"private_key_jwt"}, + preferred: ClientSecretAuthMethodBasic, + want: oauth2.AuthStyleAutoDetect, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := selectTokenAuthMethod(tt.supported, tt.preferred) + if got != tt.want { + t.Errorf("selectTokenAuthMethod() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHandleRegistration(t *testing.T) { + tests := []struct { + name string + serverConfig *oauthtest.RegistrationConfig + handlerConfig *AuthorizationCodeHandlerConfig + asm *oauthex.AuthServerMeta + want *resolvedClientConfig + wantError bool + }{ + { + name: "ClientIDMetadataDocument", + serverConfig: &oauthtest.RegistrationConfig{ + ClientIDMetadataDocumentSupported: true, + }, + handlerConfig: &AuthorizationCodeHandlerConfig{ + ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://client.example.com"}, + }, + want: &resolvedClientConfig{ + registrationType: registrationTypeClientIDMetadataDocument, + clientID: "https://client.example.com", + }, + }, + { + name: "Preregistered", + serverConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "pre_client_id": { + Secret: "pre_client_secret", + }, + }, + }, + handlerConfig: &AuthorizationCodeHandlerConfig{ + PreregisteredClientConfig: &PreregisteredClientConfig{ + ClientSecretAuthConfig: &ClientSecretAuthConfig{ + ClientID: "pre_client_id", + ClientSecret: "pre_client_secret", + PreferredClientSecretAuthMethod: ClientSecretAuthMethodBasic, + }, + }, + }, + want: &resolvedClientConfig{ + registrationType: registrationTypePreregistered, + clientID: "pre_client_id", + clientSecret: "pre_client_secret", + authStyle: oauth2.AuthStyleInHeader, + }, + }, + { + name: "NoneSupported", + handlerConfig: &AuthorizationCodeHandlerConfig{ + ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://client.example.com"}, + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{RegistrationConfig: tt.serverConfig}) + s.Start(t) + handler := &AuthorizationCodeHandler{config: tt.handlerConfig} + asm, err := handler.getAuthServerMetadata(t.Context(), &oauthex.ProtectedResourceMetadata{ + AuthorizationServers: []string{s.URL()}, + }) + if err != nil { + t.Fatalf("getAuthServerMetadata() error = %v, want nil", err) + } + got, err := handler.handleRegistration(t.Context(), asm) + if err != nil { + if !tt.wantError { + t.Fatalf("handleRegistration() unexpected error = %v", err) + } + return + } + if got.registrationType != tt.want.registrationType { + t.Errorf("handleRegistration() registrationType = %v, want %v", got.registrationType, tt.want.registrationType) + } + if got.clientID != tt.want.clientID { + t.Errorf("handleRegistration() clientID = %q, want %q", got.clientID, tt.want.clientID) + } + if got.clientSecret != tt.want.clientSecret { + t.Errorf("handleRegistration() clientSecret = %q, want %q", got.clientSecret, tt.want.clientSecret) + } + if got.authStyle != tt.want.authStyle { + t.Errorf("handleRegistration() authStyle = %v, want %v", got.authStyle, tt.want.authStyle) + } + }) + } +} + +func TestDynamicRegistration(t *testing.T) { + s := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ + RegistrationConfig: &oauthtest.RegistrationConfig{ + DynamicClientRegistrationEnabled: true, + }, + }) + s.Start(t) + handler := &AuthorizationCodeHandler{config: &AuthorizationCodeHandlerConfig{ + DynamicClientRegistrationConfig: &DynamicClientRegistrationConfig{ + Metadata: &oauthex.ClientRegistrationMetadata{}, + }, + }} + asm, err := handler.getAuthServerMetadata(t.Context(), &oauthex.ProtectedResourceMetadata{ + AuthorizationServers: []string{s.URL()}, + }) + if err != nil { + t.Fatalf("getAuthServerMetadata() error = %v, want nil", err) + } + got, err := handler.handleRegistration(t.Context(), asm) + if err != nil { + t.Fatalf("handleRegistration() error = %v, want nil", err) + } + if got.registrationType != registrationTypeDynamic { + t.Errorf("handleRegistration() registrationType = %v, want %v", got.registrationType, registrationTypeDynamic) + } + if got.clientID == "" { + t.Errorf("handleRegistration() clientID = %q, want non-empty", got.clientID) + } + if got.clientSecret == "" { + t.Errorf("handleRegistration() clientSecret = %q, want non-empty", got.clientSecret) + } + if got.authStyle != oauth2.AuthStyleInHeader { + t.Errorf("handleRegistration() authStyle = %v, want %v", got.authStyle, oauth2.AuthStyleInHeader) + } +} diff --git a/internal/oauthtest/fake_authorization_server.go b/internal/oauthtest/fake_authorization_server.go new file mode 100644 index 00000000..1e81f102 --- /dev/null +++ b/internal/oauthtest/fake_authorization_server.go @@ -0,0 +1,304 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package oauthtest + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "maps" + "net/http" + "net/http/httptest" + "slices" + "testing" + + internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +type ClientInfo struct { + Secret string + RedirectURIs []string +} + +type MetadataEndpointConfig struct { + // Whether to serve the OAuth Authorization Server Metadata at + // /.well-known/oauth-authorization-server + issuerPath. + ServeOAuthInsertedEndpoint bool + // Whether to serve the OAuth Authorization Server Metadata at + // /.well-known/openid-configuration + issuerPath. + ServeOpenIDInsertedEndpoint bool + // Whether to serve the OAuth Authorization Server Metadata at + // issuerPath + /.well-known/openid-configuration. + // Should be used when issuerPath is not empty. + ServeOpenIDAppendedEndpoint bool +} + +type RegistrationConfig struct { + // Whether the client ID metadata document is supported. + ClientIDMetadataDocumentSupported bool + // PreregisteredClients is a map of valid ClientIDs to ClientSecrets. + PreregisteredClients map[string]ClientInfo + // Whether dynamic client registration is enabled. + DynamicClientRegistrationEnabled bool +} + +// Config holds configuration for FakeAuthorizationServer. +type Config struct { + // The optional path component of the issuer URL. + // If non-empty, it should start with a "/". It should not end with a "/". + // It affects the paths of the server endpoints. + IssuerPath string + // Configuration of the metadata endpoint. + MetadataEndpointConfig *MetadataEndpointConfig + // Configuration for client registration. + RegistrationConfig *RegistrationConfig +} + +// FakeAuthorizationServer is a fake OAuth 2.0 Authorization Server for testing. +type FakeAuthorizationServer struct { + server *httptest.Server + Mux *http.ServeMux + config Config + clients map[string]ClientInfo + codes map[string]codeInfo +} + +type codeInfo struct { + CodeChallenge string +} + +// NewFakeAuthorizationServer creates a new FakeAuthorizationServer. +// The server is simple and should not be used outside of testing. +// It supports: +// - Only the authorization Code Grant +// - PKCE verification +// - Client tracking & dynamic registration +// - Client authentication +func NewFakeAuthorizationServer(config Config) *FakeAuthorizationServer { + s := &FakeAuthorizationServer{ + Mux: http.NewServeMux(), + config: config, + codes: make(map[string]codeInfo), + } + if config.RegistrationConfig != nil { + s.clients = maps.Clone(config.RegistrationConfig.PreregisteredClients) + } + if s.clients == nil { + s.clients = make(map[string]ClientInfo) + } + + s.Mux.HandleFunc(s.config.IssuerPath+"/authorize", s.handleAuthorize) + s.Mux.HandleFunc(s.config.IssuerPath+"/token", s.handleToken) + if config.MetadataEndpointConfig != nil { + if config.MetadataEndpointConfig.ServeOAuthInsertedEndpoint { + s.Mux.HandleFunc("/.well-known/oauth-authorization-server"+s.config.IssuerPath, s.handleMetadata) + } + if config.MetadataEndpointConfig.ServeOpenIDInsertedEndpoint { + s.Mux.HandleFunc("/.well-known/openid-configuration"+s.config.IssuerPath, s.handleMetadata) + } + if config.MetadataEndpointConfig.ServeOpenIDAppendedEndpoint && s.config.IssuerPath != "" { + s.Mux.HandleFunc(s.config.IssuerPath+"/.well-known/openid-configuration", s.handleMetadata) + } + } else { + // Serve the default OAuth endpoint. + s.Mux.HandleFunc("/.well-known/oauth-authorization-server", s.handleMetadata) + } + if config.RegistrationConfig != nil && config.RegistrationConfig.DynamicClientRegistrationEnabled { + s.Mux.HandleFunc(s.config.IssuerPath+"/register", s.handleRegister) + } + s.server = httptest.NewUnstartedServer(s.Mux) + + return s +} + +// Start starts the HTTP server and registers a cleanup function on t to close the server. +func (s *FakeAuthorizationServer) Start(t testing.TB) { + s.server.Start() + t.Cleanup(s.server.Close) +} + +// URL returns the base URL of the server (Issuer). +func (s *FakeAuthorizationServer) URL() string { + return s.server.URL +} + +func (s *FakeAuthorizationServer) handleMetadata(w http.ResponseWriter, r *http.Request) { + cimdSupported := false + var registrationEndpoint string + if s.config.RegistrationConfig != nil { + cimdSupported = s.config.RegistrationConfig.ClientIDMetadataDocumentSupported + if s.config.RegistrationConfig.DynamicClientRegistrationEnabled { + registrationEndpoint = s.URL() + s.config.IssuerPath + "/register" + } + } + meta := &oauthex.AuthServerMeta{ + Issuer: s.URL() + s.config.IssuerPath, + AuthorizationEndpoint: s.URL() + s.config.IssuerPath + "/authorize", + TokenEndpoint: s.URL() + s.config.IssuerPath + "/token", + RegistrationEndpoint: registrationEndpoint, + ResponseTypesSupported: []string{"code"}, + CodeChallengeMethodsSupported: []string{"S256"}, + ClientIDMetadataDocumentSupported: cimdSupported, + TokenEndpointAuthMethodsSupported: []string{"client_secret_post", "client_secret_basic"}, + } + // Set CORS headers for cross-origin client discovery. + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + + // Handle CORS preflight requests + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + // Only GET allowed for metadata retrieval + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(meta); err != nil { + http.Error(w, "Failed to encode metadata", http.StatusInternalServerError) + return + } +} + +func (s *FakeAuthorizationServer) handleRegister(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var metadata oauthex.ClientRegistrationMetadata + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "failed to read request body", http.StatusBadRequest) + return + } + if err := internaljson.Unmarshal(body, &metadata); err != nil { + http.Error(w, "failed to parse request", http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + clientID := rand.Text() + ci := ClientInfo{ + Secret: rand.Text(), + RedirectURIs: metadata.RedirectURIs, + } + s.clients[clientID] = ci + metadata.TokenEndpointAuthMethod = "client_secret_basic" + json.NewEncoder(w).Encode(&oauthex.ClientRegistrationResponse{ + ClientID: clientID, + ClientSecret: ci.Secret, + ClientRegistrationMetadata: metadata, + }) +} + +func (s *FakeAuthorizationServer) handleAuthorize(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + clientID := r.URL.Query().Get("client_id") + clientInfo, ok := s.clients[clientID] + if !ok { + http.Error(w, "unknown client_id", http.StatusBadRequest) + return + } + + redirectURI := r.URL.Query().Get("redirect_uri") + if redirectURI == "" { + http.Error(w, "missing redirect_uri", http.StatusBadRequest) + return + } + if !slices.Contains(clientInfo.RedirectURIs, redirectURI) { + http.Error(w, "invalid redirect_uri", http.StatusBadRequest) + return + } + codeChallenge := r.URL.Query().Get("code_challenge") + if codeChallenge == "" { + http.Error(w, "missing code_challenge", http.StatusBadRequest) + return + } + code := rand.Text() + s.codes[code] = codeInfo{ + CodeChallenge: codeChallenge, + } + + state := r.URL.Query().Get("state") + + redirectURL := fmt.Sprintf("%s?code=%s&state=%s", redirectURI, code, state) + http.Redirect(w, r, redirectURL, http.StatusFound) +} + +func (s *FakeAuthorizationServer) handleToken(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if err := s.authenticateClient(r); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + if r.Form.Get("grant_type") != "authorization_code" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + code := r.Form.Get("code") + if code == "" { + http.Error(w, "missing code", http.StatusBadRequest) + return + } + codeInfo, ok := s.codes[code] + if !ok { + http.Error(w, "unknown authorization code", http.StatusBadRequest) + return + } + verifier := r.Form.Get("code_verifier") + if verifier == "" { + http.Error(w, "missing code_verifier", http.StatusBadRequest) + return + } + sha := sha256.Sum256([]byte(verifier)) + expectedChallenge := base64.RawURLEncoding.EncodeToString(sha[:]) + if expectedChallenge != codeInfo.CodeChallenge { + http.Error(w, "PKCE verification failed", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "test_access_token", + "token_type": "Bearer", + "expires_in": 3600, + }) +} + +func (s *FakeAuthorizationServer) authenticateClient(r *http.Request) error { + clientID, clientSecret, ok := r.BasicAuth() + if !ok { + clientID = r.Form.Get("client_id") + clientSecret = r.Form.Get("client_secret") + } + + clientInfo, ok := s.clients[clientID] + if !ok || clientInfo.Secret != clientSecret { + return errors.New("client not found") + } + return nil +} diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index d9fcc9d8..b2782d7d 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -133,6 +133,7 @@ type AuthServerMeta struct { // // It returns an error if the request fails with a non-4xx status code or the fetched // metadata doesn't pass security validations. +// It returns nil if the request fails with a 4xx status code. // // [RFC 8414]: https://tools.ietf.org/html/rfc8414 func GetAuthServerMeta(ctx context.Context, metadataURL AuthorizationServerMetadataURL, c *http.Client) (*AuthServerMeta, error) { @@ -144,8 +145,8 @@ func GetAuthServerMeta(ctx context.Context, metadataURL AuthorizationServerMetad if 400 <= httpErr.StatusCode && httpErr.StatusCode < 500 { return nil, nil } - return nil, fmt.Errorf("%v", err) // Do not expose error types. } + return nil, fmt.Errorf("%v", err) // Do not expose error types. } if asm.Issuer != metadataURL.Issuer { // Validate the Issuer field (see RFC 8414, section 3.3). From 707173d5d52e28f1b7916a26a720aa2dd5ddaeb6 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Wed, 25 Feb 2026 17:08:10 +0000 Subject: [PATCH 19/26] Add unit tests for streamable.go. --- auth/client.go | 2 + auth/fake.go | 10 +++- mcp/streamable.go | 26 +++++----- mcp/streamable_client_test.go | 89 ++++++++++++++++++++++++++++++++++- 4 files changed, 112 insertions(+), 15 deletions(-) diff --git a/auth/client.go b/auth/client.go index 6ddd4a29..12e3537f 100644 --- a/auth/client.go +++ b/auth/client.go @@ -23,6 +23,8 @@ type OAuthHandler interface { // be addressed by the authorization flow (currently 401 Unauthorized and 403 Forbidden). // It is responsible for performing the OAuth flow to obtain an access token. // The arguments are the request that failed and the response that was received for it. + // Currently the body of the passed request is consumed by the transport + // before [Authorize] is called. Please file an issue if you need the body to be available. // If the returned error is nil, [TokenSource] is expected to return a non-nil token source. // After a successful call to [Authorize], the HTTP request should be retried by the transport. // The function is responsible for closing the response body. diff --git a/auth/fake.go b/auth/fake.go index b8d82f33..e890802c 100644 --- a/auth/fake.go +++ b/auth/fake.go @@ -12,16 +12,24 @@ import ( ) type FakeOAuthHandler struct { - Token *oauth2.Token + // Token to be returned via [TokenSource]. If nil, [TokenSource] returns nil. + Token *oauth2.Token + // AuthorizeErr is an error to be returned from [Authorize]. AuthorizeErr error + // AuthorizeCalled is true if [Authorize] was called. + AuthorizeCalled bool } func (h *FakeOAuthHandler) isOAuthHandler() {} func (h *FakeOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + if h.Token == nil { + return nil, nil + } return oauth2.StaticTokenSource(h.Token), nil } func (h *FakeOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + h.AuthorizeCalled = true return h.AuthorizeErr } diff --git a/mcp/streamable.go b/mcp/streamable.go index 9155f7bf..23310b39 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1725,18 +1725,18 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return fmt.Errorf("%s: %v", requestSummary, err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) - if err != nil { - return err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - doRequest := func() (*http.Response, error) { + doRequest := func() (*http.Request, *http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) + if err != nil { + return nil, nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") if err := c.setMCPHeaders(req); err != nil { // Failure to set headers means that the request was not sent. // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr // and permanently break the connection. - return nil, fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) + return nil, nil, fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) } resp, err := c.client.Do(req) if err != nil { @@ -1745,10 +1745,10 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e // and permanently break the connection. err = fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) } - return resp, err + return req, resp, err } - resp, err := doRequest() + req, resp, err := doRequest() if err != nil { return err } @@ -1761,7 +1761,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) } // Retry the request after successful authorization. - resp, err = doRequest() + req, resp, err = doRequest() if err != nil { return err } @@ -1776,10 +1776,10 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr // and permanently break the connection. // Wrap the authorization error as well for client inspection. - return fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) + return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) } // Retry the request after successful authorization. - resp, err = doRequest() + req, resp, err = doRequest() if err != nil { return err } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index e3adfe98..23e57803 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -17,9 +17,10 @@ import ( "time" "github.com/google/go-cmp/cmp" - + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "golang.org/x/oauth2" ) type streamableRequestKey struct { @@ -909,3 +910,89 @@ func TestStreamableClientDisableStandaloneSSE(t *testing.T) { }) } } + +func TestStreamableClientOAuth_AuthorizationHeader(t *testing.T) { + ctx := context.Background() + token := &oauth2.Token{AccessToken: "test-token"} + oauthHandler := &auth.FakeOAuthHandler{Token: token} + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", "", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + }, + {"DELETE", "123", "", ""}: {}, + }, + } + verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { + if token != "test-token" { + return nil, auth.ErrInvalidToken + } + return &auth.TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil + } + httpServer := httptest.NewServer(auth.RequireBearerToken(verifier, nil)(fake)) + t.Cleanup(httpServer.Close) + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: oauthHandler, + } + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + session.Close() +} + +func TestStreamableClientOAuth_401(t *testing.T) { + ctx := context.Background() + oauthHandler := &auth.FakeOAuthHandler{Token: nil} + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + }, + } + verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { + // Accept any token. + return &auth.TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil + } + httpServer := httptest.NewServer(auth.RequireBearerToken(verifier, nil)(fake)) + t.Cleanup(httpServer.Close) + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: oauthHandler, + } + client := NewClient(testImpl, nil) + _, err := client.Connect(ctx, transport, nil) + if err == nil || !strings.Contains(err.Error(), "Unauthorized") { + t.Fatalf("client.Connect() error does not contain 'Unauthorized': %v", err) + } + + if !oauthHandler.AuthorizeCalled { + t.Errorf("expected Authorize to be called") + } +} From c9d304790f5eccbd3c0cf1e5c66805809dd14e37 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Wed, 25 Feb 2026 21:01:46 +0000 Subject: [PATCH 20/26] First batch of addressing the review comments. --- auth/authorization_code.go | 108 ++++++++------- auth/authorization_code_test.go | 127 ++++++++++++++---- .../everything-client/client_private.go | 8 +- examples/auth/client/main.go | 6 +- mcp/streamable.go | 24 +--- oauthex/auth_meta.go | 50 ++----- oauthex/auth_meta_test.go | 8 +- oauthex/url_scheme_test.go | 8 +- 8 files changed, 180 insertions(+), 159 deletions(-) diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 8e181c81..b3596a4a 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -20,44 +20,17 @@ import ( "golang.org/x/oauth2" ) -// ClientSecretAuthMethod defines "client_secret_*" authentication methods per -// https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml#token-endpoint-auth-method. -// "client_secret_jwt" is not currently supported. -type ClientSecretAuthMethod int - -const ( - // ClientSecretAuthMethodBasic uses the "client_secret_basic" authentication method. - ClientSecretAuthMethodBasic ClientSecretAuthMethod = iota - // ClientSecretAuthMethodPost uses the "client_secret_post" authentication method. - ClientSecretAuthMethodPost -) - -func (m ClientSecretAuthMethod) String() string { - switch m { - case ClientSecretAuthMethodBasic: - return "client_secret_basic" - case ClientSecretAuthMethodPost: - return "client_secret_post" - default: - return "" - } -} - // ClientSecretAuthConfig is used to configure client authentication using client_secret. +// Authentication method will be selected based on the authorization server's supported methods, +// according to the following preference order: +// +// 1. "client_secret_post" +// 2. "client_secret_basic" type ClientSecretAuthConfig struct { // ClientID is the client ID to be used for client authentication. ClientID string // ClientSecret is the client secret to be used for client authentication. ClientSecret string - // PreferredClientSecretAuthMethod to be used for client authentication. - // If not specified or unsupported by the authorization server, the method - // will be selected based on the authorization server's supported methods, - // according to the following preference order: - // - // 1. "client_secret_post" - // 2. "client_secret_basic" - // - PreferredClientSecretAuthMethod ClientSecretAuthMethod } // ClientIDMetadataDocumentConfig is used to configure the Client ID Metadata Document @@ -95,6 +68,12 @@ type AuthorizationResult struct { State string } +// AuthorizationInput is the input to [AuthorizationCodeHandlerConfig.AuthorizationCodeFetcher]. +type AuthorizationInput struct { + // Authorization URL to be opened in a browser for the user to start the authorization process. + URL string +} + // AuthorizationCodeHandlerConfig is the configuration for [AuthorizationCodeHandler]. type AuthorizationCodeHandlerConfig struct { // Client registration configuration. @@ -111,15 +90,20 @@ type AuthorizationCodeHandlerConfig struct { // RedirectURL is a required URL to redirect to after authorization. // The caller is responsible for handling the redirect out of band. - // If Dynamic Client Registration is used, the RedirectURL must be consistent - // with [DynamicClientRegistrationConfig.Metadata.RedirectURIs]. + // + // If Dynamic Client Registration is used: + // - this field is permitted to be empty, in which case it will be set + // to the first redirect URI from + // [DynamicClientRegistrationConfig.Metadata.RedirectURIs]. + // - if the field is not empty, it must be one of the redirect URIs in + // [DynamicClientRegistrationConfig.Metadata.RedirectURIs]. RedirectURL string - // AuthorizationURLHandler is a required function called to handle the authorization request. + // AuthorizationCodeFetcher is a required function called to initiate the authorization flow. // It is responsible for opening the URL in a browser for the user to start the authorization process. // It should return the authorization code and state once the Authorization Server - // redirects back to the [AuthorizationCodeHandler.RedirectURL]. - AuthorizationURLHandler func(ctx context.Context, authorizationURL string) (*AuthorizationResult, error) + // redirects back to the [AuthorizationCodeHandlerConfig.RedirectURL]. + AuthorizationCodeFetcher func(ctx context.Context, authorizationInput *AuthorizationInput) (*AuthorizationResult, error) } // AuthorizationCodeHandler is an implementation of [OAuthHandler] that uses @@ -151,11 +135,8 @@ func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*Autho config.DynamicClientRegistrationConfig == nil { return nil, errors.New("at least one client registration configuration must be provided") } - if config.RedirectURL == "" { - return nil, errors.New("field RedirectURL is required") - } - if config.AuthorizationURLHandler == nil { - return nil, errors.New("field AuthorizationURLHandler is required") + if config.AuthorizationCodeFetcher == nil { + return nil, errors.New("AuthorizationURLHandler is required") } if config.ClientIDMetadataDocumentConfig != nil && !isNonRootHTTPSURL(config.ClientIDMetadataDocumentConfig.URL) { return nil, fmt.Errorf("client ID metadata document URL must be a non-root HTTPS URL") @@ -163,20 +144,31 @@ func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*Autho preCfg := config.PreregisteredClientConfig if preCfg != nil { if preCfg.ClientSecretAuthConfig == nil { - return nil, errors.New("field ClientSecretAuthConfig is required for pre-registered client") + return nil, errors.New("ClientSecretAuthConfig is required for pre-registered client") } if preCfg.ClientSecretAuthConfig.ClientID == "" || preCfg.ClientSecretAuthConfig.ClientSecret == "" { return nil, fmt.Errorf("pre-registered client ID or secret is empty") } } - if config.DynamicClientRegistrationConfig != nil { - if config.DynamicClientRegistrationConfig.Metadata == nil { - return nil, errors.New("field Metadata is required for dynamic client registration") + dCfg := config.DynamicClientRegistrationConfig + if dCfg != nil { + if dCfg.Metadata == nil { + return nil, errors.New("Metadata is required for dynamic client registration") + } + if len(dCfg.Metadata.RedirectURIs) == 0 { + return nil, errors.New("Metadata.RedirectURIs is required for dynamic client registration") } - if !slices.Contains(config.DynamicClientRegistrationConfig.Metadata.RedirectURIs, config.RedirectURL) { - return nil, fmt.Errorf("redirect URI %q is not in the list of allowed redirect URIs for dynamic client registration", config.RedirectURL) + if config.RedirectURL == "" { + config.RedirectURL = dCfg.Metadata.RedirectURIs[0] + } else if !slices.Contains(dCfg.Metadata.RedirectURIs, config.RedirectURL) { + return nil, fmt.Errorf("RedirectURL %q is not in the list of allowed redirect URIs for dynamic client registration", config.RedirectURL) } } + if config.RedirectURL == "" { + // If the RedirectURL was supposed to be set by the dynamic client registration, + // it should have been set by now. Otherwise, it is required. + return nil, errors.New("RedirectURL is required") + } return &AuthorizationCodeHandler{config: config}, nil } @@ -201,6 +193,15 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ } log.Printf("WWW-Authenticate header: %v", wwwChallenges) + if resp.StatusCode == http.StatusForbidden && oauthex.Error(wwwChallenges) != "insufficient_scope" { + // We only want to perform step-up authorization for insufficient_scope errors. + // Returning nil, so that the call is retried immediately and the response + // is handled appropriately by the connection. + // Step-up authorization is defined at + // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#step-up-authorization-flow + return nil + } + prm, err := h.getProtectedResourceMetadata(ctx, wwwChallenges, req.URL.String()) if err != nil { return err @@ -291,7 +292,7 @@ func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, pr log.Printf("Authorization server URL: %s", authServerURL) for _, u := range oauthex.AuthorizationServerMetadataURLs(authServerURL) { - asm, err := oauthex.GetAuthServerMeta(ctx, u, http.DefaultClient) + asm, err := oauthex.GetAuthServerMeta(ctx, u, authServerURL, http.DefaultClient) if err != nil { return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) } @@ -327,10 +328,7 @@ type resolvedClientConfig struct { authStyle oauth2.AuthStyle } -func selectTokenAuthMethod(supported []string, preferred ClientSecretAuthMethod) oauth2.AuthStyle { - if slices.Contains(supported, preferred.String()) { - return authMethodToStyle(preferred.String()) - } +func selectTokenAuthMethod(supported []string) oauth2.AuthStyle { prefOrder := []string{ // Preferred in OAuth 2.1 draft: https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-14.html#name-client-secret. "client_secret_post", @@ -379,7 +377,7 @@ func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm * // 2. Attempt to use pre-registered client configuration. pCfg := h.config.PreregisteredClientConfig if pCfg != nil { - authStyle := selectTokenAuthMethod(asm.TokenEndpointAuthMethodsSupported, pCfg.ClientSecretAuthConfig.PreferredClientSecretAuthMethod) + authStyle := selectTokenAuthMethod(asm.TokenEndpointAuthMethodsSupported) return &resolvedClientConfig{ registrationType: registrationTypePreregistered, clientID: pCfg.ClientSecretAuthConfig.ClientID, @@ -425,7 +423,7 @@ func (h *AuthorizationCodeHandler) getAuthorizationCode(ctx context.Context, cfg ) log.Printf("Calling AuthorizationURLHandler: %q", authURL) - authRes, err := h.config.AuthorizationURLHandler(ctx, authURL) + authRes, err := h.config.AuthorizationCodeFetcher(ctx, &AuthorizationInput{URL: authURL}) if err != nil { // Purposefully leaving the error unwrappable so it can be handled by the caller. return nil, err diff --git a/auth/authorization_code_test.go b/auth/authorization_code_test.go index 8b4dc0f1..5a1e06c3 100644 --- a/auth/authorization_code_test.go +++ b/auth/authorization_code_test.go @@ -51,14 +51,14 @@ func TestAuthorize(t *testing.T) { ClientSecret: "test_client_secret", }, }, - AuthorizationURLHandler: func(ctx context.Context, authURL string) (*AuthorizationResult, error) { + AuthorizationCodeFetcher: func(ctx context.Context, input *AuthorizationInput) (*AuthorizationResult, error) { // The fake authorization server will redirect to an URL with code and state. client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } - resp, err := client.Get(authURL) + resp, err := client.Get(input.URL) if err != nil { return nil, fmt.Errorf("failed to visit auth URL: %v", err) } @@ -112,12 +112,98 @@ func TestAuthorize(t *testing.T) { } } -func TestNewAuthorizationCodeHandler(t *testing.T) { +func TestAuthorize_ForbiddenUnhandledError(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://example.com/resource", nil) + resp := &http.Response{ + StatusCode: http.StatusForbidden, + Header: make(http.Header), + Body: http.NoBody, + Request: req, + } + resp.Header.Set( + "WWW-Authenticate", + "Bearer error=invalid_token", + ) + handler := &AuthorizationCodeHandler{} // No config needed for this test. + err := handler.Authorize(t.Context(), req, resp) + if err != nil { + t.Fatalf("Authorize() failed: %v", err) + } +} + +func TestNewAuthorizationCodeHandler_Success(t *testing.T) { + simpleHandler := func(ctx context.Context, input *AuthorizationInput) (*AuthorizationResult, error) { + return nil, nil + } + tests := []struct { + name string + config *AuthorizationCodeHandlerConfig + }{ + { + name: "ClientIDMetadataDocumentConfig", + config: &AuthorizationCodeHandlerConfig{ + ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://example.com/client"}, + RedirectURL: "https://example.com/callback", + AuthorizationCodeFetcher: simpleHandler, + }, + }, + { + name: "PreregisteredClientConfig", + config: &AuthorizationCodeHandlerConfig{ + PreregisteredClientConfig: &PreregisteredClientConfig{ + ClientSecretAuthConfig: &ClientSecretAuthConfig{ + ClientID: "test_client_id", + ClientSecret: "test_client_secret", + }, + }, + RedirectURL: "https://example.com/callback", + AuthorizationCodeFetcher: simpleHandler, + }, + }, + { + name: "DynamicClientRegistrationConfig_NoRedirectURL", + config: &AuthorizationCodeHandlerConfig{ + DynamicClientRegistrationConfig: &DynamicClientRegistrationConfig{ + Metadata: &oauthex.ClientRegistrationMetadata{ + RedirectURIs: []string{ + "https://example.com/callback", + }, + }, + }, + AuthorizationCodeFetcher: simpleHandler, + }, + }, + { + name: "DynamicClientRegistrationConfig_WithRedirectURL", + config: &AuthorizationCodeHandlerConfig{ + DynamicClientRegistrationConfig: &DynamicClientRegistrationConfig{ + Metadata: &oauthex.ClientRegistrationMetadata{ + RedirectURIs: []string{ + "https://example.com/callback", + }, + }, + }, + RedirectURL: "https://example.com/callback", + AuthorizationCodeFetcher: simpleHandler, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := NewAuthorizationCodeHandler(tt.config); err != nil { + t.Fatalf("NewAuthorizationCodeHandler failed: %v", err) + } + }) + } +} + +func TestNewAuthorizationCodeHandler_Error(t *testing.T) { validConfig := func() *AuthorizationCodeHandlerConfig { return &AuthorizationCodeHandlerConfig{ ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://example.com/client"}, RedirectURL: "https://example.com/callback", - AuthorizationURLHandler: func(ctx context.Context, authURL string) (*AuthorizationResult, error) { + AuthorizationCodeFetcher: func(ctx context.Context, input *AuthorizationInput) (*AuthorizationResult, error) { return nil, nil }, } @@ -159,7 +245,7 @@ func TestNewAuthorizationCodeHandler(t *testing.T) { name: "MissingAuthorizationURLHandler", config: func() *AuthorizationCodeHandlerConfig { cfg := validConfig() - cfg.AuthorizationURLHandler = nil + cfg.AuthorizationCodeFetcher = nil return cfg }, }, @@ -417,44 +503,28 @@ func TestSelectTokenAuthMethod(t *testing.T) { tests := []struct { name string supported []string - preferred ClientSecretAuthMethod want oauth2.AuthStyle }{ { - name: "PreferredBasic_Supported", + name: "PostPreferredOverBasic", supported: []string{"client_secret_basic", "client_secret_post"}, - preferred: ClientSecretAuthMethodBasic, - want: oauth2.AuthStyleInHeader, - }, - { - name: "PreferredPost_Supported", - supported: []string{"client_secret_basic", "client_secret_post"}, - preferred: ClientSecretAuthMethodPost, - want: oauth2.AuthStyleInParams, - }, - { - name: "PreferredBasic_NotSupported", - supported: []string{"client_secret_post"}, - preferred: ClientSecretAuthMethodBasic, want: oauth2.AuthStyleInParams, }, { - name: "PreferredPost_NotSupported", - supported: []string{"client_secret_basic"}, - preferred: ClientSecretAuthMethodPost, + name: "BasicChosenIfPostNotSupported", + supported: []string{"private_key_jwt", "client_secret_basic"}, want: oauth2.AuthStyleInHeader, }, { name: "NoneSupported", supported: []string{"private_key_jwt"}, - preferred: ClientSecretAuthMethodBasic, want: oauth2.AuthStyleAutoDetect, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := selectTokenAuthMethod(tt.supported, tt.preferred) + got := selectTokenAuthMethod(tt.supported) if got != tt.want { t.Errorf("selectTokenAuthMethod() = %v, want %v", got, tt.want) } @@ -496,9 +566,8 @@ func TestHandleRegistration(t *testing.T) { handlerConfig: &AuthorizationCodeHandlerConfig{ PreregisteredClientConfig: &PreregisteredClientConfig{ ClientSecretAuthConfig: &ClientSecretAuthConfig{ - ClientID: "pre_client_id", - ClientSecret: "pre_client_secret", - PreferredClientSecretAuthMethod: ClientSecretAuthMethodBasic, + ClientID: "pre_client_id", + ClientSecret: "pre_client_secret", }, }, }, @@ -506,7 +575,7 @@ func TestHandleRegistration(t *testing.T) { registrationType: registrationTypePreregistered, clientID: "pre_client_id", clientSecret: "pre_client_secret", - authStyle: oauth2.AuthStyleInHeader, + authStyle: oauth2.AuthStyleInParams, }, }, { diff --git a/conformance/everything-client/client_private.go b/conformance/everything-client/client_private.go index 6b4654ef..7ca8d07e 100644 --- a/conformance/everything-client/client_private.go +++ b/conformance/everything-client/client_private.go @@ -50,13 +50,13 @@ func init() { // Auth scenarios // ============================================================================ -func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (*auth.AuthorizationResult, error) { +func fetchAuthorizationCodeAndState(ctx context.Context, input *auth.AuthorizationInput) (*auth.AuthorizationResult, error) { client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } - req, err := http.NewRequestWithContext(ctx, "GET", authURL, nil) + req, err := http.NewRequestWithContext(ctx, "GET", input.URL, nil) if err != nil { return nil, err } @@ -82,8 +82,8 @@ func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (*auth. func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]any) error { authConfig := &auth.AuthorizationCodeHandlerConfig{ - RedirectURL: "http://localhost:3000/callback", - AuthorizationURLHandler: fetchAuthorizationCodeAndState, + RedirectURL: "http://localhost:3000/callback", + AuthorizationCodeFetcher: fetchAuthorizationCodeAndState, // Try client ID metadata document based registration. ClientIDMetadataDocumentConfig: &auth.ClientIDMetadataDocumentConfig{ URL: "https://conformance-test.local/client-metadata.json", diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go index dfc3d102..f9efd198 100644 --- a/examples/auth/client/main.go +++ b/examples/auth/client/main.go @@ -53,7 +53,7 @@ func (r *codeReceiver) serveRedirectHandler(listener net.Listener) error { return nil } -func (r *codeReceiver) getAuthorizationCode(ctx context.Context, authorizationURL string) (*auth.AuthorizationResult, error) { +func (r *codeReceiver) getAuthorizationCode(ctx context.Context, input *auth.AuthorizationInput) (*auth.AuthorizationResult, error) { select { case authRes := <-r.authChan: return authRes, nil @@ -84,8 +84,8 @@ func main() { defer receiver.close() authHandler, err := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandlerConfig{ - RedirectURL: fmt.Sprintf("http://localhost:%d", *callbackPort), - AuthorizationURLHandler: receiver.getAuthorizationCode, + RedirectURL: fmt.Sprintf("http://localhost:%d", *callbackPort), + AuthorizationCodeFetcher: receiver.getAuthorizationCode, // Uncomment the client configuration you want to use. // PreregisteredClientConfig: &auth.PreregisteredClientConfig{ // ClientID: "", diff --git a/mcp/streamable.go b/mcp/streamable.go index 23310b39..14499ffe 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -35,7 +35,6 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/util" "github.com/modelcontextprotocol/go-sdk/internal/xcontext" "github.com/modelcontextprotocol/go-sdk/jsonrpc" - "github.com/modelcontextprotocol/go-sdk/oauthex" ) const ( @@ -1753,7 +1752,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return err } - if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { + if slices.Contains([]int{http.StatusUnauthorized, http.StatusForbidden}, resp.StatusCode) && c.oauthHandler != nil { if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil { // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr // and permanently break the connection. @@ -1761,30 +1760,11 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) } // Retry the request after successful authorization. - req, resp, err = doRequest() + _, resp, err = doRequest() if err != nil { return err } } - if resp.StatusCode == http.StatusForbidden && c.oauthHandler != nil { - challenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) - if err != nil { - c.logger.Warn("%s: failed to parse WWW-Authenticate header: %v", requestSummary, err) - } else if oauthex.Error(challenges) == "insufficient_scope" { - // Trigger step-up authorization flow. - if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil { - // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr - // and permanently break the connection. - // Wrap the authorization error as well for client inspection. - return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) - } - // Retry the request after successful authorization. - req, resp, err = doRequest() - if err != nil { - return err - } - } - } if err := c.checkResponse(requestSummary, resp); err != nil { // Only fail the connection for non-transient errors. diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index b2782d7d..a0973512 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -136,10 +136,10 @@ type AuthServerMeta struct { // It returns nil if the request fails with a 4xx status code. // // [RFC 8414]: https://tools.ietf.org/html/rfc8414 -func GetAuthServerMeta(ctx context.Context, metadataURL AuthorizationServerMetadataURL, c *http.Client) (*AuthServerMeta, error) { - asm, err := getJSON[AuthServerMeta](ctx, c, metadataURL.URL, 1<<20) +func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http.Client) (*AuthServerMeta, error) { + asm, err := getJSON[AuthServerMeta](ctx, c, metadataURL, 1<<20) if err != nil { - log.Printf("Failed to get auth server metadata from %q: %v", metadataURL.URL, err) + log.Printf("Failed to get auth server metadata from %q: %v", metadataURL, err) var httpErr *httpStatusError if errors.As(err, &httpErr) { if 400 <= httpErr.StatusCode && httpErr.StatusCode < 500 { @@ -148,36 +148,29 @@ func GetAuthServerMeta(ctx context.Context, metadataURL AuthorizationServerMetad } return nil, fmt.Errorf("%v", err) // Do not expose error types. } - if asm.Issuer != metadataURL.Issuer { + if asm.Issuer != issuer { // Validate the Issuer field (see RFC 8414, section 3.3). - return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, metadataURL.Issuer) + return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuer) } if len(asm.CodeChallengeMethodsSupported) == 0 { - return nil, fmt.Errorf("authorization server at %s does not implement PKCE", metadataURL.Issuer) + return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuer) } // Validate endpoint URLs to prevent XSS attacks (see #526). if err := validateAuthServerMetaURLs(asm); err != nil { return nil, err } - log.Printf("Fetched authorization server metadata from %q", metadataURL.URL) + log.Printf("Fetched authorization server metadata from %q", metadataURL) return asm, nil } -type AuthorizationServerMetadataURL struct { - // URL where the Authorization Server Metadata may be retrieved. - URL string - // Issuer that was used to construct the [URL]. - Issuer string -} - // AuthorizationServerMetadataURLs returns a list of URLs to try when looking for // authorization server metadata as mandated by the MCP specification: // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. -func AuthorizationServerMetadataURLs(issuerURL string) []AuthorizationServerMetadataURL { - var urls []AuthorizationServerMetadataURL +func AuthorizationServerMetadataURLs(issuerURL string) []string { + var urls []string baseURL, err := url.Parse(issuerURL) if err != nil { @@ -187,38 +180,23 @@ func AuthorizationServerMetadataURLs(issuerURL string) []AuthorizationServerMeta if baseURL.Path == "" { // "OAuth 2.0 Authorization Server Metadata". baseURL.Path = "/.well-known/oauth-authorization-server" - urls = append(urls, AuthorizationServerMetadataURL{ - URL: baseURL.String(), - Issuer: issuerURL, - }) + urls = append(urls, baseURL.String()) // "OpenID Connect Discovery 1.0". baseURL.Path = "/.well-known/openid-configuration" - urls = append(urls, AuthorizationServerMetadataURL{ - URL: baseURL.String(), - Issuer: issuerURL, - }) + urls = append(urls, baseURL.String()) return urls } originalPath := baseURL.Path // "OAuth 2.0 Authorization Server Metadata with path insertion". baseURL.Path = "/.well-known/oauth-authorization-server/" + strings.TrimLeft(originalPath, "/") - urls = append(urls, AuthorizationServerMetadataURL{ - URL: baseURL.String(), - Issuer: issuerURL, - }) + urls = append(urls, baseURL.String()) // "OpenID Connect Discovery 1.0 with path insertion". baseURL.Path = "/.well-known/openid-configuration/" + strings.TrimLeft(originalPath, "/") - urls = append(urls, AuthorizationServerMetadataURL{ - URL: baseURL.String(), - Issuer: issuerURL, - }) + urls = append(urls, baseURL.String()) // "OpenID Connect Discovery 1.0 with path appending". baseURL.Path = "/" + strings.Trim(originalPath, "/") + "/.well-known/openid-configuration" - urls = append(urls, AuthorizationServerMetadataURL{ - URL: baseURL.String(), - Issuer: issuerURL, - }) + urls = append(urls, baseURL.String()) return urls } diff --git a/oauthex/auth_meta_test.go b/oauthex/auth_meta_test.go index 6363e098..8b67e59b 100644 --- a/oauthex/auth_meta_test.go +++ b/oauthex/auth_meta_test.go @@ -85,10 +85,8 @@ func TestGetAuthServerMetaPKCESupport(t *testing.T) { // The fake server sets issuer to https://localhost:, so compute that issuer. u, _ := url.Parse(ts.URL) - metadataURL := AuthorizationServerMetadataURL{ - URL: "https://localhost:" + u.Port() + "/.well-known/oauth-authorization-server", - Issuer: "https://localhost:" + u.Port(), - } + issuer := "https://localhost:" + u.Port() + metadataURL := issuer + "/.well-known/oauth-authorization-server" // The fake server presents a cert for example.com; set ServerName accordingly. httpClient := ts.Client() @@ -98,7 +96,7 @@ func TestGetAuthServerMetaPKCESupport(t *testing.T) { httpClient.Transport = clone } - meta, err := GetAuthServerMeta(ctx, metadataURL, httpClient) + meta, err := GetAuthServerMeta(ctx, metadataURL, issuer, httpClient) if tt.wantError != "" { if err == nil { t.Fatal("wanted error but got none") diff --git a/oauthex/url_scheme_test.go b/oauthex/url_scheme_test.go index c13bc1d5..83eeb5e1 100644 --- a/oauthex/url_scheme_test.go +++ b/oauthex/url_scheme_test.go @@ -226,11 +226,9 @@ func TestGetAuthServerMetaRejectsDangerousURLs(t *testing.T) { defer server.Close() ctx := context.Background() - metadataURL := AuthorizationServerMetadataURL{ - URL: server.URL, - Issuer: server.URL, - } - _, err := GetAuthServerMeta(ctx, metadataURL, server.Client()) + issuer := server.URL + metadataURL := issuer + _, err := GetAuthServerMeta(ctx, metadataURL, issuer, server.Client()) if err == nil { t.Fatal("GetAuthServerMeta(): got nil error, want error") } From f4e4014abb2e7265bc3cc9f811bef7d2e6d2d734 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Thu, 26 Feb 2026 10:05:53 +0000 Subject: [PATCH 21/26] Second batch of addressing the review comments. --- auth/authorization_code.go | 82 ++++++++++++++++++++++++++++++++++++-- oauthex/auth_meta.go | 46 +++++---------------- oauthex/resource_meta.go | 81 +++++++++---------------------------- 3 files changed, 109 insertions(+), 100 deletions(-) diff --git a/auth/authorization_code.go b/auth/authorization_code.go index b3596a4a..99b20440 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -15,6 +15,7 @@ import ( "net/http" "net/url" "slices" + "strings" "github.com/modelcontextprotocol/go-sdk/oauthex" "golang.org/x/oauth2" @@ -253,9 +254,9 @@ func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Cont var errs []error // Use MCP server URL as the resource URI per // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#canonical-server-uri. - for _, url := range oauthex.ProtectedResourceMetadataURLs(oauthex.ResourceMetadataURL(wwwChallenges), mcpServerURL) { + for _, url := range protectedResourceMetadataURLs(oauthex.ResourceMetadataURL(wwwChallenges), mcpServerURL) { log.Printf("Getting protected resource metadata from %q", url) - prm, err := oauthex.GetProtectedResourceMetadata(ctx, url, http.DefaultClient) + prm, err := oauthex.GetProtectedResourceMetadata(ctx, url.URL, url.Resource, http.DefaultClient) if err != nil { errs = append(errs, err) continue @@ -269,6 +270,47 @@ func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Cont return nil, fmt.Errorf("failed to get protected resource metadata: %v", errors.Join(errs...)) } +type prmURL struct { + // URL represents a URL where Protected Resource Metadata may be retrieved. + URL string + // Resource represents the corresponding resource URL for [URL]. + // It is required to perform validation described in RFC 9728, section 3.3. + Resource string +} + +// protectedResourceMetadataURLs returns a list of URLs to try when looking for +// protected resource metadata as mandated by the MCP specification: +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#protected-resource-metadata-discovery-requirements +func protectedResourceMetadataURLs(metadataURL, resourceURL string) []prmURL { + var urls []prmURL + if metadataURL != "" { + urls = append(urls, prmURL{ + URL: metadataURL, + Resource: resourceURL, + }) + } + ru, err := url.Parse(resourceURL) + if err != nil { + return urls + } + mu := *ru + // "At the path of the server's MCP endpoint". + mu.Path = "/.well-known/oauth-protected-resource/" + strings.TrimLeft(ru.Path, "/") + urls = append(urls, prmURL{ + URL: mu.String(), + Resource: resourceURL, + }) + // "At the root". + mu.Path = "/.well-known/oauth-protected-resource" + ru.Path = "" + urls = append(urls, prmURL{ + URL: mu.String(), + Resource: ru.String(), + }) + log.Printf("Resource metadata URLs: %v", urls) + return urls +} + // getAuthServerMetadata returns the authorization server metadata. // The provided Protected Resource Metadata must not be nil. // It returns an error if the metadata request fails with non-4xx HTTP status code @@ -291,7 +333,7 @@ func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, pr } log.Printf("Authorization server URL: %s", authServerURL) - for _, u := range oauthex.AuthorizationServerMetadataURLs(authServerURL) { + for _, u := range authorizationServerMetadataURLs(authServerURL) { asm, err := oauthex.GetAuthServerMeta(ctx, u, authServerURL, http.DefaultClient) if err != nil { return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) @@ -313,6 +355,40 @@ func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, pr return asm, nil } +// authorizationServerMetadataURLs returns a list of URLs to try when looking for +// authorization server metadata as mandated by the MCP specification: +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. +func authorizationServerMetadataURLs(issuerURL string) []string { + var urls []string + + baseURL, err := url.Parse(issuerURL) + if err != nil { + return nil + } + + if baseURL.Path == "" { + // "OAuth 2.0 Authorization Server Metadata". + baseURL.Path = "/.well-known/oauth-authorization-server" + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0". + baseURL.Path = "/.well-known/openid-configuration" + urls = append(urls, baseURL.String()) + return urls + } + + originalPath := baseURL.Path + // "OAuth 2.0 Authorization Server Metadata with path insertion". + baseURL.Path = "/.well-known/oauth-authorization-server/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0 with path insertion". + baseURL.Path = "/.well-known/openid-configuration/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0 with path appending". + baseURL.Path = "/" + strings.Trim(originalPath, "/") + "/.well-known/openid-configuration" + urls = append(urls, baseURL.String()) + return urls +} + type registrationType int const ( diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index a0973512..ff5b05a1 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -16,7 +16,8 @@ import ( "log" "net/http" "net/url" - "strings" + + "github.com/modelcontextprotocol/go-sdk/internal/util" ) // AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, @@ -126,6 +127,7 @@ type AuthServerMeta struct { // from an OAuth authorization server with the given metadataURL. // // It follows [RFC 8414]: +// - The metadataURL must use HTTPS or be a local address. // - The Issuer field is checked against metadataURL.Issuer. // // It also verifies that the authorization server supports PKCE and that the URLs @@ -137,6 +139,14 @@ type AuthServerMeta struct { // // [RFC 8414]: https://tools.ietf.org/html/rfc8414 func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http.Client) (*AuthServerMeta, error) { + u, err := url.Parse(metadataURL) + if err != nil { + return nil, err + } + // Only allow HTTP for local addresses (testing or development purposes). + if !util.IsLoopback(u.Host) && u.Scheme != "https" { + return nil, fmt.Errorf("metadataURL %q does not use HTTPS", metadataURL) + } asm, err := getJSON[AuthServerMeta](ctx, c, metadataURL, 1<<20) if err != nil { log.Printf("Failed to get auth server metadata from %q: %v", metadataURL, err) @@ -166,40 +176,6 @@ func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http. return asm, nil } -// AuthorizationServerMetadataURLs returns a list of URLs to try when looking for -// authorization server metadata as mandated by the MCP specification: -// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. -func AuthorizationServerMetadataURLs(issuerURL string) []string { - var urls []string - - baseURL, err := url.Parse(issuerURL) - if err != nil { - return nil - } - - if baseURL.Path == "" { - // "OAuth 2.0 Authorization Server Metadata". - baseURL.Path = "/.well-known/oauth-authorization-server" - urls = append(urls, baseURL.String()) - // "OpenID Connect Discovery 1.0". - baseURL.Path = "/.well-known/openid-configuration" - urls = append(urls, baseURL.String()) - return urls - } - - originalPath := baseURL.Path - // "OAuth 2.0 Authorization Server Metadata with path insertion". - baseURL.Path = "/.well-known/oauth-authorization-server/" + strings.TrimLeft(originalPath, "/") - urls = append(urls, baseURL.String()) - // "OpenID Connect Discovery 1.0 with path insertion". - baseURL.Path = "/.well-known/openid-configuration/" + strings.TrimLeft(originalPath, "/") - urls = append(urls, baseURL.String()) - // "OpenID Connect Discovery 1.0 with path appending". - baseURL.Path = "/" + strings.Trim(originalPath, "/") + "/.well-known/openid-configuration" - urls = append(urls, baseURL.String()) - return urls -} - // validateAuthServerMetaURLs validates all URL fields in AuthServerMeta // to ensure they don't use dangerous schemes that could enable XSS attacks. func validateAuthServerMetaURLs(asm *AuthServerMeta) error { diff --git a/oauthex/resource_meta.go b/oauthex/resource_meta.go index bd869fa0..b9c3afb7 100644 --- a/oauthex/resource_meta.go +++ b/oauthex/resource_meta.go @@ -12,7 +12,6 @@ package oauthex import ( "context" "fmt" - "log" "net/http" "net/url" "path" @@ -47,10 +46,7 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, } // Insert well-known URI into URL. u.Path = path.Join(defaultProtectedResourceMetadataURI, u.Path) - return GetProtectedResourceMetadata(ctx, ProtectedResourceMetadataURL{ - URL: u.String(), - Resource: resourceID, - }, c) + return GetProtectedResourceMetadata(ctx, u.String(), resourceID, c) } // GetProtectedResourceMetadataFromHeader retrieves protected resource metadata @@ -74,31 +70,34 @@ func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL strin if metadataURL == "" { return nil, nil } - return GetProtectedResourceMetadata(ctx, ProtectedResourceMetadataURL{ - URL: metadataURL, - Resource: serverURL, - }, c) + return GetProtectedResourceMetadata(ctx, metadataURL, serverURL, c) } // GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource // metadata from a resource server. // The metadataURL is typically a URL with a host:port and possibly a path. -// For example: -// -// https://example.com/server -func GetProtectedResourceMetadata(ctx context.Context, metadataURL ProtectedResourceMetadataURL, c *http.Client) (_ *ProtectedResourceMetadata, err error) { +// The resourceURL is the resource URI that the metadataURL is for. +// The following checks are performed: +// - The metadataURL must use HTTPS or be a local address. +// - The resource field of the resulting metadata must match the resourceURL. +// - The authorization_servers field of the resulting metadata is checked for dangerous URL schemes. +func GetProtectedResourceMetadata(ctx context.Context, metadataURL, resourceURL string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { defer util.Wrapf(&err, "GetProtectedResourceMetadata(%q)", metadataURL) - // TODO: where HTTPS requirement comes from? conformance tests use HTTP. - // if !strings.HasPrefix(strings.ToUpper(purl), "HTTPS://") { - // return nil, fmt.Errorf("resource URL %q does not use HTTPS", purl) - // } - prm, err := getJSON[ProtectedResourceMetadata](ctx, c, metadataURL.URL, 1<<20) + u, err := url.Parse(metadataURL) + if err != nil { + return nil, err + } + // Only allow HTTP for local addresses (testing or development purposes). + if !util.IsLoopback(u.Host) && u.Scheme != "https" { + return nil, fmt.Errorf("metadataURL %q does not use HTTPS", metadataURL) + } + prm, err := getJSON[ProtectedResourceMetadata](ctx, c, metadataURL, 1<<20) if err != nil { return nil, err } // Validate the Resource field (see RFC 9728, section 3.3). - if prm.Resource != metadataURL.Resource { - return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, metadataURL.Resource) + if prm.Resource != resourceURL { + return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, resourceURL) } // Validate the authorization server URLs to prevent XSS attacks (see #526). for _, u := range prm.AuthorizationServers { @@ -109,48 +108,6 @@ func GetProtectedResourceMetadata(ctx context.Context, metadataURL ProtectedReso return prm, nil } -type ProtectedResourceMetadataURL struct { - // URL represents a URL where Protected Resource Metadata may be retrieved. - URL string - // Resource represents the corresponding resource URL for [URL]. - // It is required to perform validation described in RFC 9728, section 3.3. - Resource string -} - -// ProtectedResourceMetadataURLs returns a list of URLs to try when looking for -// protected resource metadata as mandated by the MCP specification. -func ProtectedResourceMetadataURLs(metadataURL, resourceURL string) []ProtectedResourceMetadataURL { - var urls []ProtectedResourceMetadataURL - if metadataURL != "" { - urls = append(urls, ProtectedResourceMetadataURL{ - URL: metadataURL, - Resource: resourceURL, - }) - } - // Produce fallbacks per - // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#protected-resource-metadata-discovery-requirements - ru, err := url.Parse(resourceURL) - if err != nil { - return urls - } - mu := *ru - // "At the path of the server's MCP endpoint". - mu.Path = "/.well-known/oauth-protected-resource/" + strings.TrimLeft(ru.Path, "/") - urls = append(urls, ProtectedResourceMetadataURL{ - URL: mu.String(), - Resource: resourceURL, - }) - // "At the root". - mu.Path = "/.well-known/oauth-protected-resource" - ru.Path = "" - urls = append(urls, ProtectedResourceMetadataURL{ - URL: mu.String(), - Resource: ru.String(), - }) - log.Printf("Resource metadata URLs: %v", urls) - return urls -} - // ResourceMetadataURL returns a resource metadata URL from the given challenges, // or the empty string if there is none. func ResourceMetadataURL(cs []Challenge) string { From 80a033e26557ff0fc293b3edd02592c0107de20b Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Thu, 26 Feb 2026 12:13:43 +0000 Subject: [PATCH 22/26] Final cleanup. --- auth/authorization_code.go | 12 ------------ conformance/everything-client/main.go | 1 - oauthex/auth_meta.go | 3 --- 3 files changed, 16 deletions(-) diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 99b20440..06713f4a 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -11,7 +11,6 @@ import ( "crypto/rand" "errors" "fmt" - "log" "net/http" "net/url" "slices" @@ -186,13 +185,11 @@ func isNonRootHTTPSURL(u string) bool { // On success, [AuthorizationCodeHandler.TokenSource] will return a token source with the fetched token. func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { defer resp.Body.Close() - log.Printf("Authorize: %s %s", req.Method, req.URL) wwwChallenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) if err != nil { return fmt.Errorf("failed to parse WWW-Authenticate header: %v", err) } - log.Printf("WWW-Authenticate header: %v", wwwChallenges) if resp.StatusCode == http.StatusForbidden && oauthex.Error(wwwChallenges) != "insufficient_scope" { // We only want to perform step-up authorization for insufficient_scope errors. @@ -207,13 +204,11 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ if err != nil { return err } - // log.Printf("Protected resource metadata: %+v", prm) asm, err := h.getAuthServerMetadata(ctx, prm) if err != nil { return err } - // log.Printf("Authorization server metadata: %+v", asm) resolvedClientConfig, err := h.handleRegistration(ctx, asm) if err != nil { @@ -255,7 +250,6 @@ func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Cont // Use MCP server URL as the resource URI per // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#canonical-server-uri. for _, url := range protectedResourceMetadataURLs(oauthex.ResourceMetadataURL(wwwChallenges), mcpServerURL) { - log.Printf("Getting protected resource metadata from %q", url) prm, err := oauthex.GetProtectedResourceMetadata(ctx, url.URL, url.Resource, http.DefaultClient) if err != nil { errs = append(errs, err) @@ -307,7 +301,6 @@ func protectedResourceMetadataURLs(metadataURL, resourceURL string) []prmURL { URL: mu.String(), Resource: ru.String(), }) - log.Printf("Resource metadata URLs: %v", urls) return urls } @@ -331,7 +324,6 @@ func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, pr authURL.Path = "" authServerURL = authURL.String() } - log.Printf("Authorization server URL: %s", authServerURL) for _, u := range authorizationServerMetadataURLs(authServerURL) { asm, err := oauthex.GetAuthServerMeta(ctx, u, authServerURL, http.DefaultClient) @@ -343,7 +335,6 @@ func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, pr } } - log.Print("Authorization server metadata not found, using fallback") // Fallback to 2025-03-26 spec: predefined endpoints. // https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#fallbacks-for-servers-without-metadata-discovery asm := &oauthex.AuthServerMeta{ @@ -474,7 +465,6 @@ func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm * clientSecret: regResp.ClientSecret, authStyle: authMethodToStyle(regResp.TokenEndpointAuthMethod), } - log.Printf("Client registered with client ID: %s", regResp.ClientID) return cfg, nil } return nil, fmt.Errorf("no configured client registration methods are supported by the authorization server") @@ -498,7 +488,6 @@ func (h *AuthorizationCodeHandler) getAuthorizationCode(ctx context.Context, cfg oauth2.SetAuthURLParam("resource", resourceURL), ) - log.Printf("Calling AuthorizationURLHandler: %q", authURL) authRes, err := h.config.AuthorizationCodeFetcher(ctx, &AuthorizationInput{URL: authURL}) if err != nil { // Purposefully leaving the error unwrappable so it can be handled by the caller. @@ -516,7 +505,6 @@ func (h *AuthorizationCodeHandler) getAuthorizationCode(ctx context.Context, cfg // exchangeAuthorizationCode exchanges the authorization code for a token // and stores it in a token source. func (h *AuthorizationCodeHandler) exchangeAuthorizationCode(ctx context.Context, cfg *oauth2.Config, authResult *authResult, resourceURL string) error { - log.Printf("Exchanging authorization code for token") opts := []oauth2.AuthCodeOption{ oauth2.VerifierOption(authResult.usedCodeVerifier), oauth2.SetAuthURLParam("resource", resourceURL), diff --git a/conformance/everything-client/main.go b/conformance/everything-client/main.go index d34e8328..c05fa6f7 100644 --- a/conformance/everything-client/main.go +++ b/conformance/everything-client/main.go @@ -144,7 +144,6 @@ func runElicitationDefaultsClient(ctx context.Context, serverURL string, _ map[s // ============================================================================ func runSSERetryClient(ctx context.Context, serverURL string, _ map[string]any) error { - // TODO: this scenario is not passing yet. It requires a fix in the client SSE handling. session, err := connectToServer(ctx, serverURL) if err != nil { return err diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index ff5b05a1..c97620eb 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -13,7 +13,6 @@ import ( "context" "errors" "fmt" - "log" "net/http" "net/url" @@ -149,7 +148,6 @@ func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http. } asm, err := getJSON[AuthServerMeta](ctx, c, metadataURL, 1<<20) if err != nil { - log.Printf("Failed to get auth server metadata from %q: %v", metadataURL, err) var httpErr *httpStatusError if errors.As(err, &httpErr) { if 400 <= httpErr.StatusCode && httpErr.StatusCode < 500 { @@ -171,7 +169,6 @@ func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http. if err := validateAuthServerMetaURLs(asm); err != nil { return nil, err } - log.Printf("Fetched authorization server metadata from %q", metadataURL) return asm, nil } From 9a4f3335f779cd2420061be83c3d3424e9cf94dc Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Thu, 26 Feb 2026 13:21:29 +0000 Subject: [PATCH 23/26] Address review feedback. --- auth/client.go | 6 +++--- examples/auth/client/main.go | 3 +-- mcp/streamable.go | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/auth/client.go b/auth/client.go index 12e3537f..de0f37c9 100644 --- a/auth/client.go +++ b/auth/client.go @@ -23,10 +23,10 @@ type OAuthHandler interface { // be addressed by the authorization flow (currently 401 Unauthorized and 403 Forbidden). // It is responsible for performing the OAuth flow to obtain an access token. // The arguments are the request that failed and the response that was received for it. - // Currently the body of the passed request is consumed by the transport - // before [Authorize] is called. Please file an issue if you need the body to be available. + // The headers of the request are available, but the body will have already been consumed + // when Authorize is called. // If the returned error is nil, [TokenSource] is expected to return a non-nil token source. - // After a successful call to [Authorize], the HTTP request should be retried by the transport. + // After a successful call to [Authorize], the HTTP request will be retried by the transport. // The function is responsible for closing the response body. Authorize(context.Context, *http.Request, *http.Response) error } diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go index f9efd198..b0b99e1b 100644 --- a/examples/auth/client/main.go +++ b/examples/auth/client/main.go @@ -33,7 +33,7 @@ type codeReceiver struct { server *http.Server } -func (r *codeReceiver) serveRedirectHandler(listener net.Listener) error { +func (r *codeReceiver) serveRedirectHandler(listener net.Listener) { mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { r.authChan <- &auth.AuthorizationResult{ @@ -50,7 +50,6 @@ func (r *codeReceiver) serveRedirectHandler(listener net.Listener) error { if err := r.server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { r.errChan <- err } - return nil } func (r *codeReceiver) getAuthorizationCode(ctx context.Context, input *auth.AuthorizationInput) (*auth.AuthorizationResult, error) { diff --git a/mcp/streamable.go b/mcp/streamable.go index 14499ffe..e4e7695b 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1752,7 +1752,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return err } - if slices.Contains([]int{http.StatusUnauthorized, http.StatusForbidden}, resp.StatusCode) && c.oauthHandler != nil { + if (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden) && c.oauthHandler != nil { if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil { // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr // and permanently break the connection. From 22bfc54695ca6a0d02eda86f77ee3ace877671b9 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Fri, 27 Feb 2026 10:13:35 +0000 Subject: [PATCH 24/26] Doc comment fixes. --- auth/authorization_code.go | 33 +++++++++++++++------------------ auth/authorization_code_test.go | 2 +- auth/client.go | 14 ++++++++++++-- auth/client_private.go | 20 ++++++++++++-------- auth/fake.go | 6 +++--- docs/protocol.md | 4 ++-- internal/docs/protocol.src.md | 4 ++-- oauthex/resource_meta.go | 18 +++++++++++------- oauthex/resource_meta_public.go | 2 ++ 9 files changed, 60 insertions(+), 43 deletions(-) diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 06713f4a..25b90e03 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -23,9 +23,8 @@ import ( // ClientSecretAuthConfig is used to configure client authentication using client_secret. // Authentication method will be selected based on the authorization server's supported methods, // according to the following preference order: -// -// 1. "client_secret_post" -// 2. "client_secret_basic" +// 1. client_secret_post +// 2. client_secret_basic type ClientSecretAuthConfig struct { // ClientID is the client ID to be used for client authentication. ClientID string @@ -60,7 +59,7 @@ type DynamicClientRegistrationConfig struct { } // AuthorizationResult is the result of an authorization flow. -// It is returned by [AuthorizationCodeHandler.AuthorizationURLHandler] implementations. +// It is returned by [AuthorizationCodeHandler].AuthorizationCodeFetcher implementations. type AuthorizationResult struct { // AuthorizationCode is the authorization code obtained from the authorization server. AuthorizationCode string @@ -68,7 +67,7 @@ type AuthorizationResult struct { State string } -// AuthorizationInput is the input to [AuthorizationCodeHandlerConfig.AuthorizationCodeFetcher]. +// AuthorizationInput is the input to [AuthorizationCodeHandlerConfig].AuthorizationCodeFetcher. type AuthorizationInput struct { // Authorization URL to be opened in a browser for the user to start the authorization process. URL string @@ -78,11 +77,9 @@ type AuthorizationInput struct { type AuthorizationCodeHandlerConfig struct { // Client registration configuration. // It is attempted in the following order: - // - // 1. Client ID Metadata Document - // 2. Preregistration - // 3. Dynamic Client Registration - // + // 1. Client ID Metadata Document + // 2. Preregistration + // 3. Dynamic Client Registration // At least one method must be configured. ClientIDMetadataDocumentConfig *ClientIDMetadataDocumentConfig PreregisteredClientConfig *PreregisteredClientConfig @@ -92,17 +89,17 @@ type AuthorizationCodeHandlerConfig struct { // The caller is responsible for handling the redirect out of band. // // If Dynamic Client Registration is used: - // - this field is permitted to be empty, in which case it will be set - // to the first redirect URI from - // [DynamicClientRegistrationConfig.Metadata.RedirectURIs]. - // - if the field is not empty, it must be one of the redirect URIs in - // [DynamicClientRegistrationConfig.Metadata.RedirectURIs]. + // - this field is permitted to be empty, in which case it will be set + // to the first redirect URI from + // DynamicClientRegistrationConfig.Metadata.RedirectURIs. + // - if the field is not empty, it must be one of the redirect URIs in + // DynamicClientRegistrationConfig.Metadata.RedirectURIs. RedirectURL string // AuthorizationCodeFetcher is a required function called to initiate the authorization flow. // It is responsible for opening the URL in a browser for the user to start the authorization process. // It should return the authorization code and state once the Authorization Server - // redirects back to the [AuthorizationCodeHandlerConfig.RedirectURL]. + // redirects back to the RedirectURL. AuthorizationCodeFetcher func(ctx context.Context, authorizationInput *AuthorizationInput) (*AuthorizationResult, error) } @@ -136,7 +133,7 @@ func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*Autho return nil, errors.New("at least one client registration configuration must be provided") } if config.AuthorizationCodeFetcher == nil { - return nil, errors.New("AuthorizationURLHandler is required") + return nil, errors.New("AuthorizationCodeFetcher is required") } if config.ClientIDMetadataDocumentConfig != nil && !isNonRootHTTPSURL(config.ClientIDMetadataDocumentConfig.URL) { return nil, fmt.Errorf("client ID metadata document URL must be a non-root HTTPS URL") @@ -477,7 +474,7 @@ type authResult struct { usedCodeVerifier string } -// getAuthorizationCode uses the [AuthorizationCodeHandler.AuthorizationURLHandler] +// getAuthorizationCode uses the [AuthorizationCodeHandler.AuthorizationCodeFetcher] // to obtain an authorization code. func (h *AuthorizationCodeHandler) getAuthorizationCode(ctx context.Context, cfg *oauth2.Config, resourceURL string) (*authResult, error) { codeVerifier := oauth2.GenerateVerifier() diff --git a/auth/authorization_code_test.go b/auth/authorization_code_test.go index 5a1e06c3..d8822322 100644 --- a/auth/authorization_code_test.go +++ b/auth/authorization_code_test.go @@ -242,7 +242,7 @@ func TestNewAuthorizationCodeHandler_Error(t *testing.T) { }, }, { - name: "MissingAuthorizationURLHandler", + name: "MissingAuthorizationCodeFetcher", config: func() *AuthorizationCodeHandlerConfig { cfg := validConfig() cfg.AuthorizationCodeFetcher = nil diff --git a/auth/client.go b/auth/client.go index de0f37c9..0af6963f 100644 --- a/auth/client.go +++ b/auth/client.go @@ -11,6 +11,16 @@ import ( "golang.org/x/oauth2" ) +// OAuthHandler is an interface for handling OAuth flows. +// +// If a transport wishes to support OAuth 2 authorization, it should support +// being configured with an OAuthHandler. It should call the handler's +// TokenSource method whenever it sends an HTTP request to set the +// Authorization header. If a request fails with a 401 or 403, it should call +// Authorize, and if that returns nil, it should retry the request. It should +// not call Authorize after the second failure. See +// [github.com/modelcontextprotocol/go-sdk/mcp.StreamableClientTransport] +// for an example. type OAuthHandler interface { isOAuthHandler() @@ -25,8 +35,8 @@ type OAuthHandler interface { // The arguments are the request that failed and the response that was received for it. // The headers of the request are available, but the body will have already been consumed // when Authorize is called. - // If the returned error is nil, [TokenSource] is expected to return a non-nil token source. - // After a successful call to [Authorize], the HTTP request will be retried by the transport. + // If the returned error is nil, TokenSource is expected to return a non-nil token source. + // After a successful call to Authorize, the HTTP request will be retried by the transport. // The function is responsible for closing the response body. Authorize(context.Context, *http.Request, *http.Response) error } diff --git a/auth/client_private.go b/auth/client_private.go index f161bdc6..767c59ee 100644 --- a/auth/client_private.go +++ b/auth/client_private.go @@ -20,14 +20,16 @@ import ( // is approved, or an error if not. // The handler receives the HTTP request and response that triggered the authentication flow. // To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. -// Deprecated: Please use the new OAuthHandler abstraction that is built -// into the streamable transport. +// +// Deprecated: Please use the new [OAuthHandler] abstraction that is built +// into the streamable transport. This struct will be removed in v1.5.0. type OAuthHandlerLegacy func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) // HTTPTransport is an [http.RoundTripper] that follows the MCP // OAuth protocol when it encounters a 401 Unauthorized response. -// Deprecated: Please use the new OAuthHandler abstraction that is built -// into the streamable transport. +// +// Deprecated: Please use the new [OAuthHandler] abstraction that is built +// into the streamable transport. This struct will be removed in v1.5.0. type HTTPTransport struct { handler OAuthHandlerLegacy mu sync.Mutex // protects opts.Base @@ -38,8 +40,9 @@ type HTTPTransport struct { // The handler is invoked when an HTTP request results in a 401 Unauthorized status. // It is called only once per transport. Once a TokenSource is obtained, it is used // for the lifetime of the transport; subsequent 401s are not processed. -// Deprecated: Please use the new OAuthHandler abstraction that is built -// into the streamable transport. +// +// Deprecated: Please use the new [OAuthHandler] abstraction that is built +// into the streamable transport. This struct will be removed in v1.5.0. func NewHTTPTransport(handler OAuthHandlerLegacy, opts *HTTPTransportOptions) (*HTTPTransport, error) { if handler == nil { return nil, errors.New("handler cannot be nil") @@ -57,8 +60,9 @@ func NewHTTPTransport(handler OAuthHandlerLegacy, opts *HTTPTransportOptions) (* } // HTTPTransportOptions are options to [NewHTTPTransport]. -// Deprecated: Please use the new OAuthHandler abstraction that is built -// into the streamable transport. +// +// Deprecated: Please use the new [OAuthHandler] abstraction that is built +// into the streamable transport. This struct will be removed in v1.5.0. type HTTPTransportOptions struct { // Base is the [http.RoundTripper] to use. // If nil, [http.DefaultTransport] is used. diff --git a/auth/fake.go b/auth/fake.go index e890802c..0318527b 100644 --- a/auth/fake.go +++ b/auth/fake.go @@ -12,11 +12,11 @@ import ( ) type FakeOAuthHandler struct { - // Token to be returned via [TokenSource]. If nil, [TokenSource] returns nil. + // Token to be returned from TokenSource. If nil, TokenSource also returns nil. Token *oauth2.Token - // AuthorizeErr is an error to be returned from [Authorize]. + // AuthorizeErr is an error to be returned from Authorize. AuthorizeErr error - // AuthorizeCalled is true if [Authorize] was called. + // AuthorizeCalled is true if Authorize was called. AuthorizeCalled bool } diff --git a/docs/protocol.md b/docs/protocol.md index 0ed6b3af..703f3b2c 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -331,8 +331,8 @@ authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandle // ClientIDMetadataDocumentConfig: ... // PreregisteredClientConfig: ... // DynamicClientRegistrationConfig: ... - AuthorizationURLHandler: func(ctx context.Context, url string) (*auth.AuthorizationResult, error) { - // Open the URL in a browser and return the resulting code and state. + AuthorizationCodeFetcher: func(ctx context.Context, input *auth.AuthorizationInput) (*auth.AuthorizationResult, error) { + // Open the input.URL in a browser and return the resulting code and state. // See full example in examples/auth/client/main.go. code := ... state := ... diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 22a39f90..80bf35cc 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -257,8 +257,8 @@ authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandle // ClientIDMetadataDocumentConfig: ... // PreregisteredClientConfig: ... // DynamicClientRegistrationConfig: ... - AuthorizationURLHandler: func(ctx context.Context, url string) (*auth.AuthorizationResult, error) { - // Open the URL in a browser and return the resulting code and state. + AuthorizationCodeFetcher: func(ctx context.Context, input *auth.AuthorizationInput) (*auth.AuthorizationResult, error) { + // Open the input.URL in a browser and return the resulting code and state. // See full example in examples/auth/client/main.go. code := ... state := ... diff --git a/oauthex/resource_meta.go b/oauthex/resource_meta.go index b9c3afb7..2d865674 100644 --- a/oauthex/resource_meta.go +++ b/oauthex/resource_meta.go @@ -36,7 +36,8 @@ const defaultProtectedResourceMetadataURI = "/.well-known/oauth-protected-resour // // It then retrieves the metadata at that location using the given client (or the // default client if nil) and validates its resource field against resourceID. -// Deprecated: Use [GetProtectedResourceMetadata] instead. +// +// Deprecated: Use [GetProtectedResourceMetadata] instead. This function will be removed in v1.5.0. func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { defer util.Wrapf(&err, "GetProtectedResourceMetadataFromID(%q)", resourceID) @@ -56,7 +57,8 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, // Per RFC 9728 section 3.3, it validates that the resource field of the resulting metadata // matches the serverURL (the URL that the client used to make the original request to the resource server). // If there is no metadata URL in the header, it returns nil, nil. -// Deprecated: Use [GetProtectedResourceMetadata] instead. +// +// Deprecated: Use [GetProtectedResourceMetadata] instead. This function will be removed in v1.5.0. func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL string, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) { headers := header[http.CanonicalHeaderKey("WWW-Authenticate")] if len(headers) == 0 { @@ -76,11 +78,11 @@ func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL strin // GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource // metadata from a resource server. // The metadataURL is typically a URL with a host:port and possibly a path. -// The resourceURL is the resource URI that the metadataURL is for. +// The resourceURL is the resource URI the metadataURL is for. // The following checks are performed: -// - The metadataURL must use HTTPS or be a local address. -// - The resource field of the resulting metadata must match the resourceURL. -// - The authorization_servers field of the resulting metadata is checked for dangerous URL schemes. +// - The metadataURL must use HTTPS or be a local address. +// - The resource field of the resulting metadata must match the resourceURL. +// - The authorization_servers field of the resulting metadata is checked for dangerous URL schemes. func GetProtectedResourceMetadata(ctx context.Context, metadataURL, resourceURL string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { defer util.Wrapf(&err, "GetProtectedResourceMetadata(%q)", metadataURL) u, err := url.Parse(metadataURL) @@ -108,7 +110,7 @@ func GetProtectedResourceMetadata(ctx context.Context, metadataURL, resourceURL return prm, nil } -// ResourceMetadataURL returns a resource metadata URL from the given challenges, +// ResourceMetadataURL returns a resource metadata URL from the given "WWW-Authenticate" header challenges, // or the empty string if there is none. func ResourceMetadataURL(cs []Challenge) string { for _, c := range cs { @@ -119,6 +121,8 @@ func ResourceMetadataURL(cs []Challenge) string { return "" } +// Scopes returns the scopes from the given "WWW-Authenticate" header challenges. +// It only looks at challenges with the "Bearer" scheme. func Scopes(cs []Challenge) []string { for _, c := range cs { if c.Scheme == "bearer" && c.Params["scope"] != "" { diff --git a/oauthex/resource_meta_public.go b/oauthex/resource_meta_public.go index 443d5ba8..75b541d7 100644 --- a/oauthex/resource_meta_public.go +++ b/oauthex/resource_meta_public.go @@ -115,6 +115,8 @@ type Challenge struct { Params map[string]string } +// Error returns the error from the given "WWW-Authenticate" header challenges. +// It only looks at challenges with the "Bearer" scheme. func Error(cs []Challenge) string { for _, c := range cs { if c.Scheme == "bearer" && c.Params["error"] != "" { From 753c0eba04d2cc8f9db83a48fcf65ba8f51f01cf Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Fri, 27 Feb 2026 10:34:45 +0000 Subject: [PATCH 25/26] Explicit note about client OAuth support. --- README.md | 13 +++++++++---- docs/protocol.md | 5 +++++ internal/docs/protocol.src.md | 5 +++++ internal/readme/README.src.md | 13 +++++++++---- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 2b015c6c..ffaf33e1 100644 --- a/README.md +++ b/README.md @@ -34,10 +34,15 @@ contains feature documentation, mapping the MCP spec to the packages above. The following table shows which versions of the Go SDK support which versions of the MCP specification: -| SDK Version | Latest MCP Spec | All Supported MCP Specs | -|-----------------|-------------------|------------------------------------------------| -| v1.2.0+ | 2025-06-18 | 2025-11-25, 2025-06-18, 2025-03-26, 2024-11-05 | -| v1.0.0 - v1.1.0 | 2025-06-18 | 2025-06-18, 2025-03-26, 2024-11-05 | +| SDK Version | Latest MCP Spec | All Supported MCP Specs | +|-----------------|-------------------|----------------------------------------------------| +| v1.4.0+ | 2025-11-25\* | 2025-11-25\*, 2025-06-18, 2025-03-26, 2024-11-05 | +| v1.2.0 - v1.3.1 | 2025-11-25\*\* | 2025-11-25\*\*, 2025-06-18, 2025-03-26, 2024-11-05 | +| v1.0.0 - v1.1.0 | 2025-06-18 | 2025-06-18, 2025-03-26, 2024-11-05 | + +\* Client side OAuth has experimental support. + +\*\* Partial support for 2025-11-25 (client side OAuth and Sampling with tools not available). New releases of the SDK target only supported versions of Go. See https://go.dev/doc/devel/release#policy for more information. diff --git a/docs/protocol.md b/docs/protocol.md index 703f3b2c..3b043ad0 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -306,6 +306,11 @@ The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/ ### Client +> [!IMPORTANT] +> Client-side OAuth support is currently experimental and requires the `mcp_go_client_oauth` build tag to compile. +> API changes may still be made, based on developer feedback. The build tag will be removed in `v1.5.0`, which +> is planned to be released by the end of March 2026. + Client-side authorization is supported via the [`StreamableClientTransport.OAuthHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableClientTransport.OAuthHandler) field. If the handler is provided, the transport will automatically use it to diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 80bf35cc..6a5e6b59 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -232,6 +232,11 @@ The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/ ### Client +> [!IMPORTANT] +> Client-side OAuth support is currently experimental and requires the `mcp_go_client_oauth` build tag to compile. +> API changes may still be made, based on developer feedback. The build tag will be removed in `v1.5.0`, which +> is planned to be released by the end of March 2026. + Client-side authorization is supported via the [`StreamableClientTransport.OAuthHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableClientTransport.OAuthHandler) field. If the handler is provided, the transport will automatically use it to diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index fce0fa44..d419b022 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -33,10 +33,15 @@ contains feature documentation, mapping the MCP spec to the packages above. The following table shows which versions of the Go SDK support which versions of the MCP specification: -| SDK Version | Latest MCP Spec | All Supported MCP Specs | -|-----------------|-------------------|------------------------------------------------| -| v1.2.0+ | 2025-06-18 | 2025-11-25, 2025-06-18, 2025-03-26, 2024-11-05 | -| v1.0.0 - v1.1.0 | 2025-06-18 | 2025-06-18, 2025-03-26, 2024-11-05 | +| SDK Version | Latest MCP Spec | All Supported MCP Specs | +|-----------------|-------------------|----------------------------------------------------| +| v1.4.0+ | 2025-11-25\* | 2025-11-25\*, 2025-06-18, 2025-03-26, 2024-11-05 | +| v1.2.0 - v1.3.1 | 2025-11-25\*\* | 2025-11-25\*\*, 2025-06-18, 2025-03-26, 2024-11-05 | +| v1.0.0 - v1.1.0 | 2025-06-18 | 2025-06-18, 2025-03-26, 2024-11-05 | + +\* Client side OAuth has experimental support. + +\*\* Partial support for 2025-11-25 (client side OAuth and Sampling with tools not available). New releases of the SDK target only supported versions of Go. See https://go.dev/doc/devel/release#policy for more information. From b26dc14d05e4376a5d2a5b0f652a0639d051ea73 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Fri, 27 Feb 2026 14:09:49 +0000 Subject: [PATCH 26/26] Address review feedback. --- auth/auth.go | 3 +- auth/authorization_code.go | 59 ++++-- auth/authorization_code_test.go | 12 +- auth/fake.go | 35 ---- .../everything-client/client_private.go | 8 +- docs/protocol.md | 6 +- examples/auth/client/main.go | 6 +- internal/docs/protocol.src.md | 6 +- mcp/streamable_client_auth_test.go | 179 +++++++++++++++++ mcp/streamable_client_test.go | 88 --------- mcp/streamable_test.go | 53 ----- oauthex/auth_meta.go | 2 - oauthex/oauth2.go | 17 -- oauthex/resource_meta.go | 177 +++++++++++++++-- oauthex/resource_meta_public.go | 181 ------------------ 15 files changed, 407 insertions(+), 425 deletions(-) delete mode 100644 auth/fake.go create mode 100644 mcp/streamable_client_auth_test.go diff --git a/auth/auth.go b/auth/auth.go index 29cca526..36ff259e 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -25,8 +25,7 @@ type TokenInfo struct { // session hijacking by ensuring that all requests for a given session // come from the same user. UserID string - // TODO: add standard JWT fields - Extra map[string]any + Extra map[string]any } // The error that a TokenVerifier should return if the token cannot be verified. diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 25b90e03..2a6ed32b 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -61,14 +61,14 @@ type DynamicClientRegistrationConfig struct { // AuthorizationResult is the result of an authorization flow. // It is returned by [AuthorizationCodeHandler].AuthorizationCodeFetcher implementations. type AuthorizationResult struct { - // AuthorizationCode is the authorization code obtained from the authorization server. - AuthorizationCode string + // Code is the authorization code obtained from the authorization server. + Code string // State string returned by the authorization server. State string } -// AuthorizationInput is the input to [AuthorizationCodeHandlerConfig].AuthorizationCodeFetcher. -type AuthorizationInput struct { +// AuthorizationArgs is the input to [AuthorizationCodeHandlerConfig].AuthorizationCodeFetcher. +type AuthorizationArgs struct { // Authorization URL to be opened in a browser for the user to start the authorization process. URL string } @@ -100,7 +100,7 @@ type AuthorizationCodeHandlerConfig struct { // It is responsible for opening the URL in a browser for the user to start the authorization process. // It should return the authorization code and state once the Authorization Server // redirects back to the RedirectURL. - AuthorizationCodeFetcher func(ctx context.Context, authorizationInput *AuthorizationInput) (*AuthorizationResult, error) + AuthorizationCodeFetcher func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) } // AuthorizationCodeHandler is an implementation of [OAuthHandler] that uses @@ -188,7 +188,7 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ return fmt.Errorf("failed to parse WWW-Authenticate header: %v", err) } - if resp.StatusCode == http.StatusForbidden && oauthex.Error(wwwChallenges) != "insufficient_scope" { + if resp.StatusCode == http.StatusForbidden && errorFromChallenges(wwwChallenges) != "insufficient_scope" { // We only want to perform step-up authorization for insufficient_scope errors. // Returning nil, so that the call is retried immediately and the response // is handled appropriately by the connection. @@ -212,9 +212,9 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ return err } - scopes := oauthex.Scopes(wwwChallenges) - if len(scopes) == 0 && prm != nil && len(prm.ScopesSupported) > 0 { - scopes = prm.ScopesSupported + scps := scopesFromChallenges(wwwChallenges) + if len(scps) == 0 && len(prm.ScopesSupported) > 0 { + scps = prm.ScopesSupported } cfg := &oauth2.Config{ @@ -227,7 +227,7 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ AuthStyle: resolvedClientConfig.authStyle, }, RedirectURL: h.config.RedirectURL, - Scopes: scopes, + Scopes: scps, } authRes, err := h.getAuthorizationCode(ctx, cfg, req.URL.String()) @@ -239,6 +239,39 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ return h.exchangeAuthorizationCode(ctx, cfg, authRes, prm.Resource) } +// resourceMetadataURLFromChallenges returns a resource metadata URL from the given "WWW-Authenticate" header challenges, +// or the empty string if there is none. +func resourceMetadataURLFromChallenges(cs []oauthex.Challenge) string { + for _, c := range cs { + if u := c.Params["resource_metadata"]; u != "" { + return u + } + } + return "" +} + +// scopesFromChallenges returns the scopes from the given "WWW-Authenticate" header challenges. +// It only looks at challenges with the "Bearer" scheme. +func scopesFromChallenges(cs []oauthex.Challenge) []string { + for _, c := range cs { + if c.Scheme == "bearer" && c.Params["scope"] != "" { + return strings.Fields(c.Params["scope"]) + } + } + return nil +} + +// errorFromChallenges returns the error from the given "WWW-Authenticate" header challenges. +// It only looks at challenges with the "Bearer" scheme. +func errorFromChallenges(cs []oauthex.Challenge) string { + for _, c := range cs { + if c.Scheme == "bearer" && c.Params["error"] != "" { + return c.Params["error"] + } + } + return "" +} + // getProtectedResourceMetadata returns the protected resource metadata. // If no metadata was found or the fetched metadata fails security checks, // it returns an error. @@ -246,7 +279,7 @@ func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Cont var errs []error // Use MCP server URL as the resource URI per // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#canonical-server-uri. - for _, url := range protectedResourceMetadataURLs(oauthex.ResourceMetadataURL(wwwChallenges), mcpServerURL) { + for _, url := range protectedResourceMetadataURLs(resourceMetadataURLFromChallenges(wwwChallenges), mcpServerURL) { prm, err := oauthex.GetProtectedResourceMetadata(ctx, url.URL, url.Resource, http.DefaultClient) if err != nil { errs = append(errs, err) @@ -485,7 +518,7 @@ func (h *AuthorizationCodeHandler) getAuthorizationCode(ctx context.Context, cfg oauth2.SetAuthURLParam("resource", resourceURL), ) - authRes, err := h.config.AuthorizationCodeFetcher(ctx, &AuthorizationInput{URL: authURL}) + authRes, err := h.config.AuthorizationCodeFetcher(ctx, &AuthorizationArgs{URL: authURL}) if err != nil { // Purposefully leaving the error unwrappable so it can be handled by the caller. return nil, err @@ -506,7 +539,7 @@ func (h *AuthorizationCodeHandler) exchangeAuthorizationCode(ctx context.Context oauth2.VerifierOption(authResult.usedCodeVerifier), oauth2.SetAuthURLParam("resource", resourceURL), } - token, err := cfg.Exchange(ctx, authResult.AuthorizationCode, opts...) + token, err := cfg.Exchange(ctx, authResult.Code, opts...) if err != nil { return fmt.Errorf("token exchange failed: %w", err) } diff --git a/auth/authorization_code_test.go b/auth/authorization_code_test.go index d8822322..77214cd9 100644 --- a/auth/authorization_code_test.go +++ b/auth/authorization_code_test.go @@ -51,14 +51,14 @@ func TestAuthorize(t *testing.T) { ClientSecret: "test_client_secret", }, }, - AuthorizationCodeFetcher: func(ctx context.Context, input *AuthorizationInput) (*AuthorizationResult, error) { + AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { // The fake authorization server will redirect to an URL with code and state. client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } - resp, err := client.Get(input.URL) + resp, err := client.Get(args.URL) if err != nil { return nil, fmt.Errorf("failed to visit auth URL: %v", err) } @@ -74,8 +74,8 @@ func TestAuthorize(t *testing.T) { return nil, fmt.Errorf("failed to get location header: %v", err) } return &AuthorizationResult{ - AuthorizationCode: location.Query().Get("code"), - State: location.Query().Get("state"), + Code: location.Query().Get("code"), + State: location.Query().Get("state"), }, nil }, }) @@ -132,7 +132,7 @@ func TestAuthorize_ForbiddenUnhandledError(t *testing.T) { } func TestNewAuthorizationCodeHandler_Success(t *testing.T) { - simpleHandler := func(ctx context.Context, input *AuthorizationInput) (*AuthorizationResult, error) { + simpleHandler := func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { return nil, nil } tests := []struct { @@ -203,7 +203,7 @@ func TestNewAuthorizationCodeHandler_Error(t *testing.T) { return &AuthorizationCodeHandlerConfig{ ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://example.com/client"}, RedirectURL: "https://example.com/callback", - AuthorizationCodeFetcher: func(ctx context.Context, input *AuthorizationInput) (*AuthorizationResult, error) { + AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { return nil, nil }, } diff --git a/auth/fake.go b/auth/fake.go deleted file mode 100644 index 0318527b..00000000 --- a/auth/fake.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2026 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by the license -// that can be found in the LICENSE file. - -package auth - -import ( - "context" - "net/http" - - "golang.org/x/oauth2" -) - -type FakeOAuthHandler struct { - // Token to be returned from TokenSource. If nil, TokenSource also returns nil. - Token *oauth2.Token - // AuthorizeErr is an error to be returned from Authorize. - AuthorizeErr error - // AuthorizeCalled is true if Authorize was called. - AuthorizeCalled bool -} - -func (h *FakeOAuthHandler) isOAuthHandler() {} - -func (h *FakeOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { - if h.Token == nil { - return nil, nil - } - return oauth2.StaticTokenSource(h.Token), nil -} - -func (h *FakeOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { - h.AuthorizeCalled = true - return h.AuthorizeErr -} diff --git a/conformance/everything-client/client_private.go b/conformance/everything-client/client_private.go index 7ca8d07e..3b0c6592 100644 --- a/conformance/everything-client/client_private.go +++ b/conformance/everything-client/client_private.go @@ -50,13 +50,13 @@ func init() { // Auth scenarios // ============================================================================ -func fetchAuthorizationCodeAndState(ctx context.Context, input *auth.AuthorizationInput) (*auth.AuthorizationResult, error) { +func fetchAuthorizationCodeAndState(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } - req, err := http.NewRequestWithContext(ctx, "GET", input.URL, nil) + req, err := http.NewRequestWithContext(ctx, "GET", args.URL, nil) if err != nil { return nil, err } @@ -75,8 +75,8 @@ func fetchAuthorizationCodeAndState(ctx context.Context, input *auth.Authorizati } return &auth.AuthorizationResult{ - AuthorizationCode: locURL.Query().Get("code"), - State: locURL.Query().Get("state"), + Code: locURL.Query().Get("code"), + State: locURL.Query().Get("state"), }, nil } diff --git a/docs/protocol.md b/docs/protocol.md index 3b043ad0..abdf50fa 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -336,12 +336,12 @@ authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandle // ClientIDMetadataDocumentConfig: ... // PreregisteredClientConfig: ... // DynamicClientRegistrationConfig: ... - AuthorizationCodeFetcher: func(ctx context.Context, input *auth.AuthorizationInput) (*auth.AuthorizationResult, error) { - // Open the input.URL in a browser and return the resulting code and state. + AuthorizationCodeFetcher: func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + // Open the args.URL in a browser and return the resulting code and state. // See full example in examples/auth/client/main.go. code := ... state := ... - return &auth.AuthorizationResult{AuthorizationCode: code, State: state}, nil + return &auth.AuthorizationResult{Code: code, State: state}, nil }, }) diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go index b0b99e1b..32de488e 100644 --- a/examples/auth/client/main.go +++ b/examples/auth/client/main.go @@ -37,8 +37,8 @@ func (r *codeReceiver) serveRedirectHandler(listener net.Listener) { mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { r.authChan <- &auth.AuthorizationResult{ - AuthorizationCode: req.URL.Query().Get("code"), - State: req.URL.Query().Get("state"), + Code: req.URL.Query().Get("code"), + State: req.URL.Query().Get("state"), } fmt.Fprint(w, "Authentication successful. You can close this window.") }) @@ -52,7 +52,7 @@ func (r *codeReceiver) serveRedirectHandler(listener net.Listener) { } } -func (r *codeReceiver) getAuthorizationCode(ctx context.Context, input *auth.AuthorizationInput) (*auth.AuthorizationResult, error) { +func (r *codeReceiver) getAuthorizationCode(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { select { case authRes := <-r.authChan: return authRes, nil diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 6a5e6b59..3771b581 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -262,12 +262,12 @@ authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandle // ClientIDMetadataDocumentConfig: ... // PreregisteredClientConfig: ... // DynamicClientRegistrationConfig: ... - AuthorizationCodeFetcher: func(ctx context.Context, input *auth.AuthorizationInput) (*auth.AuthorizationResult, error) { - // Open the input.URL in a browser and return the resulting code and state. + AuthorizationCodeFetcher: func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + // Open the args.URL in a browser and return the resulting code and state. // See full example in examples/auth/client/main.go. code := ... state := ... - return &auth.AuthorizationResult{AuthorizationCode: code, State: state}, nil + return &auth.AuthorizationResult{Code: code, State: state}, nil }, }) diff --git a/mcp/streamable_client_auth_test.go b/mcp/streamable_client_auth_test.go new file mode 100644 index 00000000..a1211e48 --- /dev/null +++ b/mcp/streamable_client_auth_test.go @@ -0,0 +1,179 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package mcp + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "golang.org/x/oauth2" +) + +type mockOAuthHandler struct { + // Embed to satisfy the interface. + auth.AuthorizationCodeHandler + + token *oauth2.Token + authorizeErr error + authorizeCalled bool +} + +func (h *mockOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + if h.token == nil { + return nil, nil + } + return oauth2.StaticTokenSource(h.token), nil +} + +func (h *mockOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + h.authorizeCalled = true + return h.authorizeErr +} + +func TestStreamableClientOAuth_AuthorizationHeader(t *testing.T) { + ctx := context.Background() + token := &oauth2.Token{AccessToken: "test-token"} + oauthHandler := &mockOAuthHandler{token: token} + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", "", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + }, + {"DELETE", "123", "", ""}: {}, + }, + } + verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { + if token != "test-token" { + return nil, auth.ErrInvalidToken + } + return &auth.TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil + } + httpServer := httptest.NewServer(auth.RequireBearerToken(verifier, nil)(fake)) + t.Cleanup(httpServer.Close) + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: oauthHandler, + } + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + session.Close() +} + +func TestStreamableClientOAuth_401(t *testing.T) { + ctx := context.Background() + oauthHandler := &mockOAuthHandler{token: nil} + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + }, + } + verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { + // Accept any token. + return &auth.TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil + } + httpServer := httptest.NewServer(auth.RequireBearerToken(verifier, nil)(fake)) + t.Cleanup(httpServer.Close) + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: oauthHandler, + } + client := NewClient(testImpl, nil) + _, err := client.Connect(ctx, transport, nil) + if err == nil || !strings.Contains(err.Error(), "Unauthorized") { + t.Fatalf("client.Connect() error does not contain 'Unauthorized': %v", err) + } + + if !oauthHandler.authorizeCalled { + t.Errorf("expected Authorize to be called") + } +} + +func TestTokenInfo(t *testing.T) { + ctx := context.Background() + + // Create a server with a tool that returns TokenInfo. + tokenInfo := func(ctx context.Context, req *CallToolRequest, _ struct{}) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil, nil + } + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) + + streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { + if token != "test-token" { + return nil, auth.ErrInvalidToken + } + return &auth.TokenInfo{ + Scopes: []string{"scope"}, + // Expiration is far, far in the future. + Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC), + }, nil + } + handler := auth.RequireBearerToken(verifier, nil)(streamHandler) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: &mockOAuthHandler{token: &oauth2.Token{AccessToken: "test-token"}}, + } + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + + res, err := session.CallTool(ctx, &CallToolParams{Name: "tokenInfo"}) + if err != nil { + t.Fatal(err) + } + if len(res.Content) == 0 { + t.Fatal("missing content") + } + tc, ok := res.Content[0].(*TextContent) + if !ok { + t.Fatal("not TextContent") + } + if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w { + t.Errorf("got %q, want %q", g, w) + } +} diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 23e57803..d189ca41 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -17,10 +17,8 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" - "golang.org/x/oauth2" ) type streamableRequestKey struct { @@ -910,89 +908,3 @@ func TestStreamableClientDisableStandaloneSSE(t *testing.T) { }) } } - -func TestStreamableClientOAuth_AuthorizationHeader(t *testing.T) { - ctx := context.Background() - token := &oauth2.Token{AccessToken: "test-token"} - oauthHandler := &auth.FakeOAuthHandler{Token: token} - - fake := &fakeStreamableServer{ - t: t, - responses: fakeResponses{ - {"POST", "", methodInitialize, ""}: { - header: header{ - "Content-Type": "application/json", - sessionIDHeader: "123", - }, - body: jsonBody(t, initResp), - }, - {"POST", "123", notificationInitialized, ""}: { - status: http.StatusAccepted, - wantProtocolVersion: latestProtocolVersion, - }, - {"GET", "123", "", ""}: { - header: header{ - "Content-Type": "text/event-stream", - }, - }, - {"DELETE", "123", "", ""}: {}, - }, - } - verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { - if token != "test-token" { - return nil, auth.ErrInvalidToken - } - return &auth.TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil - } - httpServer := httptest.NewServer(auth.RequireBearerToken(verifier, nil)(fake)) - t.Cleanup(httpServer.Close) - - transport := &StreamableClientTransport{ - Endpoint: httpServer.URL, - OAuthHandler: oauthHandler, - } - client := NewClient(testImpl, nil) - session, err := client.Connect(ctx, transport, nil) - if err != nil { - t.Fatalf("client.Connect() failed: %v", err) - } - session.Close() -} - -func TestStreamableClientOAuth_401(t *testing.T) { - ctx := context.Background() - oauthHandler := &auth.FakeOAuthHandler{Token: nil} - - fake := &fakeStreamableServer{ - t: t, - responses: fakeResponses{ - {"POST", "", methodInitialize, ""}: { - header: header{ - "Content-Type": "application/json", - sessionIDHeader: "123", - }, - body: jsonBody(t, initResp), - }, - }, - } - verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { - // Accept any token. - return &auth.TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil - } - httpServer := httptest.NewServer(auth.RequireBearerToken(verifier, nil)(fake)) - t.Cleanup(httpServer.Close) - - transport := &StreamableClientTransport{ - Endpoint: httpServer.URL, - OAuthHandler: oauthHandler, - } - client := NewClient(testImpl, nil) - _, err := client.Connect(ctx, transport, nil) - if err == nil || !strings.Contains(err.Error(), "Unauthorized") { - t.Fatalf("client.Connect() error does not contain 'Unauthorized': %v", err) - } - - if !oauthHandler.AuthorizeCalled { - t.Errorf("expected Authorize to be called") - } -} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 11089535..22d0d1c6 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -34,7 +34,6 @@ import ( "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" - "golang.org/x/oauth2" ) func TestStreamableTransports(t *testing.T) { @@ -1667,58 +1666,6 @@ func textContent(t *testing.T, res *CallToolResult) string { return text.Text } -func TestTokenInfo(t *testing.T) { - ctx := context.Background() - - // Create a server with a tool that returns TokenInfo. - tokenInfo := func(ctx context.Context, req *CallToolRequest, _ struct{}) (*CallToolResult, any, error) { - return &CallToolResult{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil, nil - } - server := NewServer(testImpl, nil) - AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) - - streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) - verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { - if token != "test-token" { - return nil, auth.ErrInvalidToken - } - return &auth.TokenInfo{ - Scopes: []string{"scope"}, - // Expiration is far, far in the future. - Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC), - }, nil - } - handler := auth.RequireBearerToken(verifier, nil)(streamHandler) - httpServer := httptest.NewServer(mustNotPanic(t, handler)) - defer httpServer.Close() - - transport := &StreamableClientTransport{ - Endpoint: httpServer.URL, - OAuthHandler: &auth.FakeOAuthHandler{Token: &oauth2.Token{AccessToken: "test-token"}}, - } - client := NewClient(testImpl, nil) - session, err := client.Connect(ctx, transport, nil) - if err != nil { - t.Fatalf("client.Connect() failed: %v", err) - } - defer session.Close() - - res, err := session.CallTool(ctx, &CallToolParams{Name: "tokenInfo"}) - if err != nil { - t.Fatal(err) - } - if len(res.Content) == 0 { - t.Fatal("missing content") - } - tc, ok := res.Content[0].(*TextContent) - if !ok { - t.Fatal("not TextContent") - } - if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w { - t.Errorf("got %q, want %q", g, w) - } -} - func TestSessionHijackingPrevention(t *testing.T) { // This test verifies that sessions bound to a user ID cannot be accessed // by a different user (session hijacking prevention). diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index c97620eb..b05d80b6 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -31,8 +31,6 @@ import ( // // [RFC 8414]: https://tools.ietf.org/html/rfc8414) type AuthServerMeta struct { - // GENERATED BY GEMINI 2.5. - // Issuer is the REQUIRED URL identifying the authorization server. Issuer string `json:"issuer"` diff --git a/oauthex/oauth2.go b/oauthex/oauth2.go index 5b76116d..836a4201 100644 --- a/oauthex/oauth2.go +++ b/oauthex/oauth2.go @@ -19,23 +19,6 @@ import ( "strings" ) -// prependToPath prepends pre to the path of urlStr. -// When pre is the well-known path, this is the algorithm specified in both RFC 9728 -// section 3.1 and RFC 8414 section 3.1. -func prependToPath(urlStr, pre string) (string, error) { - u, err := url.Parse(urlStr) - if err != nil { - return "", err - } - p := "/" + strings.Trim(pre, "/") - if u.Path != "" { - p += "/" - } - - u.Path = p + strings.TrimLeft(u.Path, "/") - return u.String(), nil -} - type httpStatusError struct { StatusCode int } diff --git a/oauthex/resource_meta.go b/oauthex/resource_meta.go index 2d865674..8b911cad 100644 --- a/oauthex/resource_meta.go +++ b/oauthex/resource_meta.go @@ -11,11 +11,13 @@ package oauthex import ( "context" + "errors" "fmt" "net/http" "net/url" "path" "strings" + "unicode" "github.com/modelcontextprotocol/go-sdk/internal/util" ) @@ -68,13 +70,24 @@ func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL strin if err != nil { return nil, err } - metadataURL := ResourceMetadataURL(cs) + metadataURL := resourceMetadataURL(cs) if metadataURL == "" { return nil, nil } return GetProtectedResourceMetadata(ctx, metadataURL, serverURL, c) } +// resourceMetadataURL returns a resource metadata URL from the given "WWW-Authenticate" header challenges, +// or the empty string if there is none. +func resourceMetadataURL(cs []Challenge) string { + for _, c := range cs { + if u := c.Params["resource_metadata"]; u != "" { + return u + } + } + return "" +} + // GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource // metadata from a resource server. // The metadataURL is typically a URL with a host:port and possibly a path. @@ -110,24 +123,158 @@ func GetProtectedResourceMetadata(ctx context.Context, metadataURL, resourceURL return prm, nil } -// ResourceMetadataURL returns a resource metadata URL from the given "WWW-Authenticate" header challenges, -// or the empty string if there is none. -func ResourceMetadataURL(cs []Challenge) string { - for _, c := range cs { - if u := c.Params["resource_metadata"]; u != "" { - return u +// ParseWWWAuthenticate parses a WWW-Authenticate header string. +// The header format is defined in RFC 9110, Section 11.6.1, and can contain +// one or more challenges, separated by commas. +// It returns a slice of challenges or an error if one of the headers is malformed. +func ParseWWWAuthenticate(headers []string) ([]Challenge, error) { + var challenges []Challenge + for _, h := range headers { + challengeStrings, err := splitChallenges(h) + if err != nil { + return nil, err + } + for _, cs := range challengeStrings { + if strings.TrimSpace(cs) == "" { + continue + } + challenge, err := parseSingleChallenge(cs) + if err != nil { + return nil, fmt.Errorf("failed to parse challenge %q: %w", cs, err) + } + challenges = append(challenges, challenge) } } - return "" + return challenges, nil } -// Scopes returns the scopes from the given "WWW-Authenticate" header challenges. -// It only looks at challenges with the "Bearer" scheme. -func Scopes(cs []Challenge) []string { - for _, c := range cs { - if c.Scheme == "bearer" && c.Params["scope"] != "" { - return strings.Fields(c.Params["scope"]) +// splitChallenges splits a header value containing one or more challenges. +// It correctly handles commas within quoted strings and distinguishes between +// commas separating auth-params and commas separating challenges. +func splitChallenges(header string) ([]string, error) { + var challenges []string + inQuotes := false + start := 0 + for i, r := range header { + if r == '"' { + if i > 0 && header[i-1] != '\\' { + inQuotes = !inQuotes + } else if i == 0 { + // A challenge begins with an auth-scheme, which is a token, which cannot contain + // a quote. + return nil, errors.New(`challenge begins with '"'`) + } + } else if r == ',' && !inQuotes { + // This is a potential challenge separator. + // A new challenge does not start with `key=value`. + // We check if the part after the comma looks like a parameter. + lookahead := strings.TrimSpace(header[i+1:]) + eqPos := strings.Index(lookahead, "=") + + isParam := false + if eqPos > 0 { + // Check if the part before '=' is a single token (no spaces). + token := lookahead[:eqPos] + if strings.IndexFunc(token, unicode.IsSpace) == -1 { + isParam = true + } + } + + if !isParam { + // The part after the comma does not look like a parameter, + // so this comma separates challenges. + challenges = append(challenges, header[start:i]) + start = i + 1 + } } } - return nil + // Add the last (or only) challenge to the list. + challenges = append(challenges, header[start:]) + return challenges, nil +} + +// parseSingleChallenge parses a string containing exactly one challenge. +// challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] +func parseSingleChallenge(s string) (Challenge, error) { + s = strings.TrimSpace(s) + if s == "" { + return Challenge{}, errors.New("empty challenge string") + } + + scheme, paramsStr, found := strings.Cut(s, " ") + c := Challenge{Scheme: strings.ToLower(scheme)} + if !found { + return c, nil + } + + params := make(map[string]string) + + // Parse the key-value parameters. + for paramsStr != "" { + // Find the end of the parameter key. + keyEnd := strings.Index(paramsStr, "=") + if keyEnd <= 0 { + return Challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) + } + key := strings.TrimSpace(paramsStr[:keyEnd]) + + // Move the string past the key and the '='. + paramsStr = strings.TrimSpace(paramsStr[keyEnd+1:]) + + var value string + if strings.HasPrefix(paramsStr, "\"") { + // The value is a quoted string. + paramsStr = paramsStr[1:] // Consume the opening quote. + var valBuilder strings.Builder + i := 0 + for ; i < len(paramsStr); i++ { + // Handle escaped characters. + if paramsStr[i] == '\\' && i+1 < len(paramsStr) { + valBuilder.WriteByte(paramsStr[i+1]) + i++ // We've consumed two characters. + } else if paramsStr[i] == '"' { + // End of the quoted string. + break + } else { + valBuilder.WriteByte(paramsStr[i]) + } + } + + // A quoted string must be terminated. + if i == len(paramsStr) { + return Challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") + } + + value = valBuilder.String() + // Move the string past the value and the closing quote. + paramsStr = strings.TrimSpace(paramsStr[i+1:]) + } else { + // The value is a token. It ends at the next comma or the end of the string. + commaPos := strings.Index(paramsStr, ",") + if commaPos == -1 { + value = paramsStr + paramsStr = "" + } else { + value = strings.TrimSpace(paramsStr[:commaPos]) + paramsStr = strings.TrimSpace(paramsStr[commaPos:]) // Keep comma for next check + } + } + if value == "" { + return Challenge{}, fmt.Errorf("no value for auth param %q", key) + } + + // Per RFC 9110, parameter keys are case-insensitive. + params[strings.ToLower(key)] = value + + // If there is a comma, consume it and continue to the next parameter. + if strings.HasPrefix(paramsStr, ",") { + paramsStr = strings.TrimSpace(paramsStr[1:]) + } else if paramsStr != "" { + // If there's content but it's not a new parameter, the format is wrong. + return Challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) + } + } + + // Per RFC 9110, the scheme is case-insensitive. + return Challenge{Scheme: strings.ToLower(scheme), Params: params}, nil } diff --git a/oauthex/resource_meta_public.go b/oauthex/resource_meta_public.go index 75b541d7..3bf7d9ac 100644 --- a/oauthex/resource_meta_public.go +++ b/oauthex/resource_meta_public.go @@ -9,13 +9,6 @@ package oauthex -import ( - "errors" - "fmt" - "strings" - "unicode" -) - // ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, // as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. // @@ -24,8 +17,6 @@ import ( // - human-readable metadata (§2.1) // - signed metadata (§2.2) type ProtectedResourceMetadata struct { - // GENERATED BY GEMINI 2.5. - // Resource (resource) is the protected resource's resource identifier. // Required. Resource string `json:"resource"` @@ -105,8 +96,6 @@ type ProtectedResourceMetadata struct { // Challenge represents a single authentication challenge from a WWW-Authenticate header. // As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. type Challenge struct { - // GENERATED BY GEMINI 2.5. - // // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). // It is case-insensitive. A parsed value will always be lower-case. Scheme string @@ -114,173 +103,3 @@ type Challenge struct { // Keys are case-insensitive. Parsed keys are always lower-case. Params map[string]string } - -// Error returns the error from the given "WWW-Authenticate" header challenges. -// It only looks at challenges with the "Bearer" scheme. -func Error(cs []Challenge) string { - for _, c := range cs { - if c.Scheme == "bearer" && c.Params["error"] != "" { - return c.Params["error"] - } - } - return "" -} - -// ParseWWWAuthenticate parses a WWW-Authenticate header string. -// The header format is defined in RFC 9110, Section 11.6.1, and can contain -// one or more challenges, separated by commas. -// It returns a slice of challenges or an error if one of the headers is malformed. -func ParseWWWAuthenticate(headers []string) ([]Challenge, error) { - // GENERATED BY GEMINI 2.5 (human-tweaked) - var challenges []Challenge - for _, h := range headers { - challengeStrings, err := splitChallenges(h) - if err != nil { - return nil, err - } - for _, cs := range challengeStrings { - if strings.TrimSpace(cs) == "" { - continue - } - challenge, err := parseSingleChallenge(cs) - if err != nil { - return nil, fmt.Errorf("failed to parse challenge %q: %w", cs, err) - } - challenges = append(challenges, challenge) - } - } - return challenges, nil -} - -// splitChallenges splits a header value containing one or more challenges. -// It correctly handles commas within quoted strings and distinguishes between -// commas separating auth-params and commas separating challenges. -func splitChallenges(header string) ([]string, error) { - // GENERATED BY GEMINI 2.5. - var challenges []string - inQuotes := false - start := 0 - for i, r := range header { - if r == '"' { - if i > 0 && header[i-1] != '\\' { - inQuotes = !inQuotes - } else if i == 0 { - // A challenge begins with an auth-scheme, which is a token, which cannot contain - // a quote. - return nil, errors.New(`challenge begins with '"'`) - } - } else if r == ',' && !inQuotes { - // This is a potential challenge separator. - // A new challenge does not start with `key=value`. - // We check if the part after the comma looks like a parameter. - lookahead := strings.TrimSpace(header[i+1:]) - eqPos := strings.Index(lookahead, "=") - - isParam := false - if eqPos > 0 { - // Check if the part before '=' is a single token (no spaces). - token := lookahead[:eqPos] - if strings.IndexFunc(token, unicode.IsSpace) == -1 { - isParam = true - } - } - - if !isParam { - // The part after the comma does not look like a parameter, - // so this comma separates challenges. - challenges = append(challenges, header[start:i]) - start = i + 1 - } - } - } - // Add the last (or only) challenge to the list. - challenges = append(challenges, header[start:]) - return challenges, nil -} - -// parseSingleChallenge parses a string containing exactly one challenge. -// challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] -func parseSingleChallenge(s string) (Challenge, error) { - // GENERATED BY GEMINI 2.5, human-tweaked. - s = strings.TrimSpace(s) - if s == "" { - return Challenge{}, errors.New("empty challenge string") - } - - scheme, paramsStr, found := strings.Cut(s, " ") - c := Challenge{Scheme: strings.ToLower(scheme)} - if !found { - return c, nil - } - - params := make(map[string]string) - - // Parse the key-value parameters. - for paramsStr != "" { - // Find the end of the parameter key. - keyEnd := strings.Index(paramsStr, "=") - if keyEnd <= 0 { - return Challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) - } - key := strings.TrimSpace(paramsStr[:keyEnd]) - - // Move the string past the key and the '='. - paramsStr = strings.TrimSpace(paramsStr[keyEnd+1:]) - - var value string - if strings.HasPrefix(paramsStr, "\"") { - // The value is a quoted string. - paramsStr = paramsStr[1:] // Consume the opening quote. - var valBuilder strings.Builder - i := 0 - for ; i < len(paramsStr); i++ { - // Handle escaped characters. - if paramsStr[i] == '\\' && i+1 < len(paramsStr) { - valBuilder.WriteByte(paramsStr[i+1]) - i++ // We've consumed two characters. - } else if paramsStr[i] == '"' { - // End of the quoted string. - break - } else { - valBuilder.WriteByte(paramsStr[i]) - } - } - - // A quoted string must be terminated. - if i == len(paramsStr) { - return Challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") - } - - value = valBuilder.String() - // Move the string past the value and the closing quote. - paramsStr = strings.TrimSpace(paramsStr[i+1:]) - } else { - // The value is a token. It ends at the next comma or the end of the string. - commaPos := strings.Index(paramsStr, ",") - if commaPos == -1 { - value = paramsStr - paramsStr = "" - } else { - value = strings.TrimSpace(paramsStr[:commaPos]) - paramsStr = strings.TrimSpace(paramsStr[commaPos:]) // Keep comma for next check - } - } - if value == "" { - return Challenge{}, fmt.Errorf("no value for auth param %q", key) - } - - // Per RFC 9110, parameter keys are case-insensitive. - params[strings.ToLower(key)] = value - - // If there is a comma, consume it and continue to the next parameter. - if strings.HasPrefix(paramsStr, ",") { - paramsStr = strings.TrimSpace(paramsStr[1:]) - } else if paramsStr != "" { - // If there's content but it's not a new parameter, the format is wrong. - return Challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) - } - } - - // Per RFC 9110, the scheme is case-insensitive. - return Challenge{Scheme: strings.ToLower(scheme), Params: params}, nil -}