Skip to content

Commit 017e0fc

Browse files
Merge branch 'main' into guglielmoc/SEP-2243_http_standardization
2 parents 57659c0 + db50910 commit 017e0fc

9 files changed

Lines changed: 159 additions & 124 deletions

File tree

docs/rough_edges.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,7 @@ v2.
5959
wrapper) we need to first unmarshal into a `map[string]any` in order to do
6060
server-side validation of required fields. CallToolParams could have just had
6161
a map[string]any.
62+
63+
- `StreamableHTTPOptions.CrossOriginProtection` should not have been part of
64+
the SDK API. Cross-origin protection is a general HTTP concern, not specific
65+
to MCP, and can be applied as standard HTTP middleware.

examples/http/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ The example implements:
1515
```bash
1616
go run . server
1717
```
18-
This starts an MCP server on `http://localhost:8080` (default) that provides a `cityTime` tool.
18+
This starts an MCP server on `http://localhost:8000` (default) that provides a `cityTime` tool.
1919

2020
To run a client in another terminal:
2121

internal/docs/rough_edges.src.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,7 @@ v2.
5858
wrapper) we need to first unmarshal into a `map[string]any` in order to do
5959
server-side validation of required fields. CallToolParams could have just had
6060
a map[string]any.
61+
62+
- `StreamableHTTPOptions.CrossOriginProtection` should not have been part of
63+
the SDK API. Cross-origin protection is a general HTTP concern, not specific
64+
to MCP, and can be applied as standard HTTP middleware.

mcp/sse.go

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"crypto/rand"
1111
"fmt"
1212
"io"
13+
"mime"
1314
"net"
1415
"net/http"
1516
"net/url"
@@ -64,14 +65,6 @@ type SSEOptions struct {
6465
// Only disable this if you understand the security implications.
6566
// See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise
6667
DisableLocalhostProtection bool
67-
68-
// CrossOriginProtection allows to customize cross-origin protection.
69-
// The deny handler set in the CrossOriginProtection through SetDenyHandler
70-
// is ignored.
71-
// If nil, default (zero-value) cross-origin protection will be used.
72-
// Use `disablecrossoriginprotection` MCPGODEBUG compatibility parameter
73-
// to disable the default protection until v1.7.0.
74-
CrossOriginProtection *http.CrossOriginProtection
7568
}
7669

7770
// NewSSEHandler returns a new [SSEHandler] that creates and manages MCP
@@ -97,10 +90,6 @@ func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptio
9790
s.opts = *opts
9891
}
9992

100-
if s.opts.CrossOriginProtection == nil {
101-
s.opts.CrossOriginProtection = &http.CrossOriginProtection{}
102-
}
103-
10493
return s
10594
}
10695

@@ -212,20 +201,13 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
212201
}
213202
}
214203

215-
if disablecrossoriginprotection != "1" {
216-
// Verify the 'Origin' header to protect against CSRF attacks.
217-
if err := h.opts.CrossOriginProtection.Check(req); err != nil {
218-
http.Error(w, err.Error(), http.StatusForbidden)
204+
// Validate 'Content-Type' header.
205+
if req.Method == http.MethodPost {
206+
mediaType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
207+
if err != nil || mediaType != "application/json" {
208+
http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType)
219209
return
220210
}
221-
// Validate 'Content-Type' header.
222-
if req.Method == http.MethodPost {
223-
contentType := req.Header.Get("Content-Type")
224-
if contentType != "application/json" {
225-
http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType)
226-
return
227-
}
228-
}
229211
}
230212

231213
sessionID := req.URL.Query().Get("sessionid")

mcp/sse_test.go

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"net"
1313
"net/http"
1414
"net/http/httptest"
15-
"strings"
1615
"sync/atomic"
1716
"testing"
1817

@@ -320,77 +319,3 @@ func TestSSELocalhostProtection(t *testing.T) {
320319
})
321320
}
322321
}
323-
324-
func TestSSEOriginProtection(t *testing.T) {
325-
server := NewServer(testImpl, nil)
326-
327-
tests := []struct {
328-
name string
329-
protection *http.CrossOriginProtection
330-
requestOrigin string
331-
wantStatusCode int
332-
}{
333-
{
334-
name: "default protection with Origin header",
335-
protection: nil,
336-
requestOrigin: "https://example.com",
337-
wantStatusCode: http.StatusForbidden,
338-
},
339-
{
340-
name: "custom protection with trusted origin and same Origin",
341-
protection: func() *http.CrossOriginProtection {
342-
p := http.NewCrossOriginProtection()
343-
if err := p.AddTrustedOrigin("https://example.com"); err != nil {
344-
t.Fatal(err)
345-
}
346-
return p
347-
}(),
348-
requestOrigin: "https://example.com",
349-
wantStatusCode: http.StatusNotFound, // origin accepted; session not found
350-
},
351-
{
352-
name: "custom protection with trusted origin and different Origin",
353-
protection: func() *http.CrossOriginProtection {
354-
p := http.NewCrossOriginProtection()
355-
if err := p.AddTrustedOrigin("https://example.com"); err != nil {
356-
t.Fatal(err)
357-
}
358-
return p
359-
}(),
360-
requestOrigin: "https://malicious.com",
361-
wantStatusCode: http.StatusForbidden,
362-
},
363-
}
364-
365-
for _, tt := range tests {
366-
t.Run(tt.name, func(t *testing.T) {
367-
opts := &SSEOptions{
368-
CrossOriginProtection: tt.protection,
369-
}
370-
handler := NewSSEHandler(func(req *http.Request) *Server { return server }, opts)
371-
httpServer := httptest.NewServer(handler)
372-
defer httpServer.Close()
373-
374-
// Use POST with a valid session-like URL to test origin protection
375-
// without creating a hanging GET connection.
376-
reqReader := strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping"}`)
377-
req, err := http.NewRequest(http.MethodPost, httpServer.URL+"?sessionid=nonexistent", reqReader)
378-
if err != nil {
379-
t.Fatal(err)
380-
}
381-
req.Header.Set("Content-Type", "application/json")
382-
req.Header.Set("Origin", tt.requestOrigin)
383-
384-
resp, err := http.DefaultClient.Do(req)
385-
if err != nil {
386-
t.Fatal(err)
387-
}
388-
defer resp.Body.Close()
389-
390-
if got := resp.StatusCode; got != tt.wantStatusCode {
391-
body, _ := io.ReadAll(resp.Body)
392-
t.Errorf("Status code: got %d, want %d (body: %s)", got, tt.wantStatusCode, body)
393-
}
394-
})
395-
}
396-
}

mcp/streamable.go

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,15 @@ type StreamableHTTPOptions struct {
174174
// CrossOriginProtection allows to customize cross-origin protection.
175175
// The deny handler set in the CrossOriginProtection through SetDenyHandler
176176
// is ignored.
177-
// If nil, default (zero-value) cross-origin protection will be used.
178-
// Use `disablecrossoriginprotection` MCPGODEBUG compatibility parameter
179-
// to disable the default protection until v1.7.0.
177+
// If nil, no cross-origin protection is applied. Use the `enableoriginverification`
178+
// MCPGODEBUG compatibility parameter to enable the default protection until v1.8.0.
179+
//
180+
// Deprecated: wrap the handler with cross-origin protection middleware
181+
// instead. For example:
182+
//
183+
// handler := mcp.NewStreamableHTTPHandler(...)
184+
// protection := http.NewCrossOriginProtection()
185+
// protectedHandler := protection.Handler(handler)
180186
CrossOriginProtection *http.CrossOriginProtection
181187
}
182188

@@ -196,7 +202,7 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea
196202

197203
h.opts.Logger = ensureLogger(h.opts.Logger)
198204

199-
if h.opts.CrossOriginProtection == nil {
205+
if h.opts.CrossOriginProtection == nil && enableoriginverification == "1" {
200206
h.opts.CrossOriginProtection = &http.CrossOriginProtection{}
201207
}
202208

@@ -229,15 +235,16 @@ func (h *StreamableHTTPHandler) closeAll() {
229235
// disablelocalhostprotection is a compatibility parameter that allows to disable
230236
// DNS rebinding protection, which was added in the 1.4.0 version of the SDK.
231237
// See the documentation for the mcpgodebug package for instructions how to enable it.
232-
// The option will be removed in the 1.7.0 version of the SDK.
238+
// The option will be removed in the 1.6.0 version of the SDK.
233239
var disablelocalhostprotection = mcpgodebug.Value("disablelocalhostprotection")
234240

235-
// disablecrossoriginprotection is a compatibility parameter that allows to disable
236-
// the verification of the 'Origin' and 'Content-Type' headers, which was added in
237-
// the 1.4.1 version of the SDK. See the documentation for the mcpgodebug package
238-
// for instructions how to enable it.
239-
// The option will be removed in the 1.7.0 version of the SDK.
240-
var disablecrossoriginprotection = mcpgodebug.Value("disablecrossoriginprotection")
241+
// enableoriginverification is a compatibility parameter that restores the
242+
// default cross-origin protection behavior from v1.4.1-v1.5.0. When set to
243+
// "1", a zero-value CrossOriginProtection will be applied if none is
244+
// explicitly provided in StreamableHTTPOptions.
245+
// See the documentation for the mcpgodebug package for instructions how to enable it.
246+
// The option will be removed in the 1.8.0 version of the SDK.
247+
var enableoriginverification = mcpgodebug.Value("enableoriginverification")
241248

242249
func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
243250
// DNS rebinding protection: auto-enabled for localhost servers.
@@ -251,17 +258,18 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
251258
}
252259
}
253260

254-
if disablecrossoriginprotection != "1" {
261+
if h.opts.CrossOriginProtection != nil {
255262
// Verify the 'Origin' header to protect against CSRF attacks.
256263
if err := h.opts.CrossOriginProtection.Check(req); err != nil {
257264
http.Error(w, err.Error(), http.StatusForbidden)
258265
return
259266
}
260-
// Validate 'Content-Type' header.
261-
if req.Method == http.MethodPost && baseMediaType(req.Header.Get("Content-Type")) != "application/json" {
262-
http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType)
263-
return
264-
}
267+
}
268+
269+
// Validate 'Content-Type' header.
270+
if req.Method == http.MethodPost && baseMediaType(req.Header.Get("Content-Type")) != "application/json" {
271+
http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType)
272+
return
265273
}
266274

267275
// Allow multiple 'Accept' headers.
@@ -1799,14 +1807,14 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
17991807
// Failure to set headers means that the request was not sent.
18001808
// Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr
18011809
// and permanently break the connection.
1802-
return nil, nil, fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err)
1810+
return nil, nil, fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err)
18031811
}
18041812
resp, err := c.client.Do(req)
18051813
if err != nil {
18061814
// Any error from client.Do means the request didn't reach the server.
18071815
// Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr
18081816
// and permanently break the connection.
1809-
err = fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err)
1817+
err = fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err)
18101818
}
18111819
return req, resp, err
18121820
}
@@ -1818,6 +1826,22 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
18181826

18191827
if (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden) && c.oauthHandler != nil {
18201828
if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil {
1829+
// If the caller's context was cancelled while we were running the
1830+
// authorization flow, treat the connection as failed so subsequent
1831+
// operations on it (e.g. the cancellation notify the call layer
1832+
// sends in response to ctx cancellation) short-circuit instead of
1833+
// re-invoking the OAuth handler. Otherwise the user gets prompted
1834+
// to authorize a request they have already abandoned. See #882.
1835+
//
1836+
// We check ctx.Err() rather than the error returned by Authorize,
1837+
// because the handler is user-implemented and may return an error
1838+
// that does not wrap context.Canceled (e.g. a custom sentinel or
1839+
// a fmt.Errorf with %v). The context itself is the authoritative
1840+
// source for whether the caller abandoned the request.
1841+
ctxErr := ctx.Err()
1842+
if errors.Is(ctxErr, context.Canceled) || errors.Is(ctxErr, context.DeadlineExceeded) {
1843+
c.fail(fmt.Errorf("%s: authorization cancelled: %w", requestSummary, err))
1844+
}
18211845
// Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr
18221846
// and permanently break the connection.
18231847
// Wrap the authorization error as well for client inspection.

mcp/streamable_client_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,94 @@ func TestStreamableClientOAuth_401(t *testing.T) {
10171017
}
10181018
}
10191019

1020+
// blockingCountingOAuthHandler is an OAuthHandler that blocks inside
1021+
// Authorize until the caller's context is cancelled, then returns a custom
1022+
// error that does NOT wrap context.Canceled. This mirrors real-world OAuth
1023+
// handlers that catch the cancellation internally and surface their own
1024+
// error type. The fix for #882 checks ctx.Err() directly rather than
1025+
// relying on the error from Authorize, so this must still trigger c.fail().
1026+
// It records how many times Authorize is invoked.
1027+
type blockingCountingOAuthHandler struct {
1028+
mu sync.Mutex
1029+
callCount int
1030+
}
1031+
1032+
func (h *blockingCountingOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
1033+
return nil, nil
1034+
}
1035+
1036+
func (h *blockingCountingOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error {
1037+
h.mu.Lock()
1038+
h.callCount++
1039+
h.mu.Unlock()
1040+
// Block until the caller's context is cancelled, mirroring an
1041+
// interactive OAuth flow that the user has abandoned.
1042+
<-ctx.Done()
1043+
// Return a custom error that does not wrap context.Canceled, as a
1044+
// real-world handler might. The code under test must check ctx.Err()
1045+
// to detect the cancellation, not this error.
1046+
return fmt.Errorf("oauth flow interrupted")
1047+
}
1048+
1049+
func (h *blockingCountingOAuthHandler) Calls() int {
1050+
h.mu.Lock()
1051+
defer h.mu.Unlock()
1052+
return h.callCount
1053+
}
1054+
1055+
// TestStreamableClientOAuth_CancelledAuthorize_NoReprompt is a regression
1056+
// test for #882. When OAuthHandler.Authorize returns a context-cancelled
1057+
// error, the connection must enter a failed state so that the cancellation
1058+
// notification the call layer sends in response to ctx cancellation does
1059+
// not flow back through the same broken auth path and re-invoke Authorize.
1060+
func TestStreamableClientOAuth_CancelledAuthorize_NoReprompt(t *testing.T) {
1061+
handler := &blockingCountingOAuthHandler{}
1062+
1063+
fake := &fakeStreamableServer{
1064+
t: t,
1065+
responses: fakeResponses{
1066+
{"POST", "", methodInitialize, ""}: {
1067+
header: header{
1068+
"Content-Type": "application/json",
1069+
sessionIDHeader: "123",
1070+
},
1071+
body: jsonBody(t, initResp),
1072+
},
1073+
},
1074+
}
1075+
verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) {
1076+
return &auth.TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil
1077+
}
1078+
httpServer := httptest.NewServer(auth.RequireBearerToken(verifier, nil)(fake))
1079+
t.Cleanup(httpServer.Close)
1080+
1081+
transport := &StreamableClientTransport{
1082+
Endpoint: httpServer.URL,
1083+
OAuthHandler: handler,
1084+
}
1085+
client := NewClient(testImpl, nil)
1086+
1087+
// Use a context with a tight deadline so the cancellation path runs
1088+
// while the auth flow is in progress.
1089+
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
1090+
defer cancel()
1091+
1092+
_, err := client.Connect(ctx, transport, nil)
1093+
if err == nil {
1094+
t.Fatal("expected client.Connect to fail")
1095+
}
1096+
1097+
// Give the cancellation Notify path a moment to (try to) run.
1098+
time.Sleep(50 * time.Millisecond)
1099+
1100+
// Authorize should be invoked exactly once. The bug in #882 caused
1101+
// it to be invoked a second time when the call layer sent the
1102+
// cancellation notification through the same auth-broken connection.
1103+
if got := handler.Calls(); got != 1 {
1104+
t.Errorf("expected Authorize to be called exactly 1 time, got %d", got)
1105+
}
1106+
}
1107+
10201108
func TestTokenInfo(t *testing.T) {
10211109
ctx := context.Background()
10221110

mcp/streamable_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2839,10 +2839,10 @@ func TestStreamableOriginProtection(t *testing.T) {
28392839
wantStatusCode int
28402840
}{
28412841
{
2842-
name: "default protection with Origin header",
2842+
name: "no protection with Origin header",
28432843
protection: nil,
28442844
requestOrigin: "https://example.com",
2845-
wantStatusCode: http.StatusForbidden,
2845+
wantStatusCode: http.StatusOK,
28462846
},
28472847
{
28482848
name: "custom protection with trusted origin and same Origin",

0 commit comments

Comments
 (0)