Skip to content

Commit 15e93a2

Browse files
mcp: add DNS rebinding and cross origin protections to SSE transport (#891)
Similar protections were already introduced before to the Streamable transport.
1 parent 1209861 commit 15e93a2

5 files changed

Lines changed: 236 additions & 14 deletions

File tree

docs/mcpgodebug.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Options listed below will be removed in the 1.6.0 version of the SDK.
2424

2525
- `disablecrossoriginprotection` added. If set to `1`, newly added cross-origin
2626
protection will be disabled. The default behavior was changed to enable
27-
cross-origin protection.
27+
cross-origin protection. **Removal of this option was postponed until 1.7.0.**
2828

2929
### 1.4.0
3030

@@ -37,5 +37,6 @@ Options listed below will be removed in the 1.6.0 version of the SDK.
3737
- `disablelocalhostprotection` added. If set to `1`, newly added DNS rebinding
3838
protection will be disabled. The default behavior was changed to enable DNS rebinding
3939
protection. The protection can also be disabled by setting the
40-
`DisableLocalhostProtection` field in the `StreamableHTTPOptions` struct to
41-
`true`, which is the recommended way to disable the protection long term.
40+
`DisableLocalhostProtection` field in the `StreamableHTTPOptions` or
41+
`SSEOptions` struct to `true`, which is the recommended way to disable
42+
the protection long term. **Removal of this option was postponed until 1.7.0.**

internal/docs/mcpgodebug.src.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Options listed below will be removed in the 1.6.0 version of the SDK.
2323

2424
- `disablecrossoriginprotection` added. If set to `1`, newly added cross-origin
2525
protection will be disabled. The default behavior was changed to enable
26-
cross-origin protection.
26+
cross-origin protection. **Removal of this option was postponed until 1.7.0.**
2727

2828
### 1.4.0
2929

@@ -36,5 +36,6 @@ Options listed below will be removed in the 1.6.0 version of the SDK.
3636
- `disablelocalhostprotection` added. If set to `1`, newly added DNS rebinding
3737
protection will be disabled. The default behavior was changed to enable DNS rebinding
3838
protection. The protection can also be disabled by setting the
39-
`DisableLocalhostProtection` field in the `StreamableHTTPOptions` struct to
40-
`true`, which is the recommended way to disable the protection long term.
39+
`DisableLocalhostProtection` field in the `StreamableHTTPOptions` or
40+
`SSEOptions` struct to `true`, which is the recommended way to disable
41+
the protection long term. **Removal of this option was postponed until 1.7.0.**

mcp/sse.go

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ import (
1010
"crypto/rand"
1111
"fmt"
1212
"io"
13+
"net"
1314
"net/http"
1415
"net/url"
1516
"sync"
1617

1718
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
19+
"github.com/modelcontextprotocol/go-sdk/internal/util"
1820
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
1921
)
2022

@@ -52,9 +54,25 @@ type SSEHandler struct {
5254
}
5355

5456
// SSEOptions specifies options for an [SSEHandler].
55-
// for now, it is empty, but may be extended in future.
56-
// https://github.com/modelcontextprotocol/go-sdk/issues/507
57-
type SSEOptions struct{}
57+
type SSEOptions struct {
58+
// DisableLocalhostProtection disables automatic DNS rebinding protection.
59+
// By default, requests arriving via a localhost address (127.0.0.1, [::1])
60+
// that have a non-localhost Host header are rejected with 403 Forbidden.
61+
// This protects against DNS rebinding attacks regardless of whether the
62+
// server is listening on localhost specifically or on 0.0.0.0.
63+
//
64+
// Only disable this if you understand the security implications.
65+
// See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise
66+
DisableLocalhostProtection bool
67+
68+
// CrossOriginProtection allows to customize cross-origin protection.
69+
// The deny handler set in the CrossOriginProtection through SetDenyHandler
70+
// is ignored.
71+
// If nil, default (zero-value) cross-origin protection will be used.
72+
// Use `disablecrossoriginprotection` MCPGODEBUG compatibility parameter
73+
// to disable the default protection until v1.7.0.
74+
CrossOriginProtection *http.CrossOriginProtection
75+
}
5876

5977
// NewSSEHandler returns a new [SSEHandler] that creates and manages MCP
6078
// sessions created via incoming HTTP requests.
@@ -79,6 +97,10 @@ func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptio
7997
s.opts = *opts
8098
}
8199

100+
if s.opts.CrossOriginProtection == nil {
101+
s.opts.CrossOriginProtection = &http.CrossOriginProtection{}
102+
}
103+
82104
return s
83105
}
84106

@@ -179,9 +201,34 @@ func (t *SSEServerTransport) Connect(context.Context) (Connection, error) {
179201
}
180202

181203
func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
182-
sessionID := req.URL.Query().Get("sessionid")
204+
// DNS rebinding protection: auto-enabled for localhost servers.
205+
// See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise
206+
if !h.opts.DisableLocalhostProtection && disablelocalhostprotection != "1" {
207+
if localAddr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok && localAddr != nil {
208+
if util.IsLoopback(localAddr.String()) && !util.IsLoopback(req.Host) {
209+
http.Error(w, fmt.Sprintf("Forbidden: invalid Host header %q", req.Host), http.StatusForbidden)
210+
return
211+
}
212+
}
213+
}
183214

184-
// TODO: consider checking Content-Type here. For now, we are lax.
215+
if disablecrossoriginprotection != "1" {
216+
// Verify the 'Origin' header to protect against CSRF attacks.
217+
if err := h.opts.CrossOriginProtection.Check(req); err != nil {
218+
http.Error(w, err.Error(), http.StatusForbidden)
219+
return
220+
}
221+
// Validate 'Content-Type' header.
222+
if req.Method == http.MethodPost {
223+
contentType := req.Header.Get("Content-Type")
224+
if contentType != "application/json" {
225+
http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType)
226+
return
227+
}
228+
}
229+
}
230+
231+
sessionID := req.URL.Query().Get("sessionid")
185232

186233
// For POST requests, the message body is a message to send to a session.
187234
if req.Method == http.MethodPost {

mcp/sse_test.go

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ import (
99
"context"
1010
"fmt"
1111
"io"
12+
"net"
1213
"net/http"
1314
"net/http/httptest"
15+
"strings"
1416
"sync/atomic"
1517
"testing"
1618

@@ -221,3 +223,174 @@ func TestSSE405AllowHeader(t *testing.T) {
221223
})
222224
}
223225
}
226+
227+
// TestSSELocalhostProtection verifies that DNS rebinding protection
228+
// is automatically enabled for localhost servers.
229+
func TestSSELocalhostProtection(t *testing.T) {
230+
server := NewServer(testImpl, nil)
231+
232+
tests := []struct {
233+
name string
234+
listenAddr string
235+
hostHeader string
236+
disableProtection bool
237+
wantStatus int
238+
}{
239+
{
240+
name: "127.0.0.1 accepts 127.0.0.1",
241+
listenAddr: "127.0.0.1:0",
242+
hostHeader: "127.0.0.1:1234",
243+
wantStatus: http.StatusOK,
244+
},
245+
{
246+
name: "127.0.0.1 accepts localhost",
247+
listenAddr: "127.0.0.1:0",
248+
hostHeader: "localhost:1234",
249+
wantStatus: http.StatusOK,
250+
},
251+
{
252+
name: "127.0.0.1 rejects evil.com",
253+
listenAddr: "127.0.0.1:0",
254+
hostHeader: "evil.com",
255+
wantStatus: http.StatusForbidden,
256+
},
257+
{
258+
name: "127.0.0.1 rejects evil.com:80",
259+
listenAddr: "127.0.0.1:0",
260+
hostHeader: "evil.com:80",
261+
wantStatus: http.StatusForbidden,
262+
},
263+
{
264+
name: "127.0.0.1 rejects localhost.evil.com",
265+
listenAddr: "127.0.0.1:0",
266+
hostHeader: "localhost.evil.com",
267+
wantStatus: http.StatusForbidden,
268+
},
269+
{
270+
name: "0.0.0.0 via localhost rejects evil.com",
271+
listenAddr: "0.0.0.0:0",
272+
hostHeader: "evil.com",
273+
wantStatus: http.StatusForbidden,
274+
},
275+
{
276+
name: "disabled accepts evil.com",
277+
listenAddr: "127.0.0.1:0",
278+
hostHeader: "evil.com",
279+
disableProtection: true,
280+
wantStatus: http.StatusOK,
281+
},
282+
}
283+
284+
for _, tt := range tests {
285+
t.Run(tt.name, func(t *testing.T) {
286+
opts := &SSEOptions{
287+
DisableLocalhostProtection: tt.disableProtection,
288+
}
289+
handler := NewSSEHandler(func(req *http.Request) *Server { return server }, opts)
290+
291+
listener, err := net.Listen("tcp", tt.listenAddr)
292+
if err != nil {
293+
t.Fatalf("Failed to listen on %s: %v", tt.listenAddr, err)
294+
}
295+
defer listener.Close()
296+
297+
srv := &http.Server{Handler: handler}
298+
go srv.Serve(listener)
299+
defer srv.Close()
300+
301+
// Use a GET request since it's the entry point for SSE sessions.
302+
// For accepted requests, the response will be a hanging SSE stream,
303+
// but we only need to check the initial status code.
304+
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s", listener.Addr().String()), nil)
305+
if err != nil {
306+
t.Fatal(err)
307+
}
308+
req.Host = tt.hostHeader
309+
req.Header.Set("Accept", "text/event-stream")
310+
311+
resp, err := http.DefaultClient.Do(req)
312+
if err != nil {
313+
t.Fatal(err)
314+
}
315+
defer resp.Body.Close()
316+
317+
if got := resp.StatusCode; got != tt.wantStatus {
318+
t.Errorf("Status code: got %d, want %d", got, tt.wantStatus)
319+
}
320+
})
321+
}
322+
}
323+
324+
func TestSSEOriginProtection(t *testing.T) {
325+
server := NewServer(testImpl, nil)
326+
327+
tests := []struct {
328+
name string
329+
protection *http.CrossOriginProtection
330+
requestOrigin string
331+
wantStatusCode int
332+
}{
333+
{
334+
name: "default protection with Origin header",
335+
protection: nil,
336+
requestOrigin: "https://example.com",
337+
wantStatusCode: http.StatusForbidden,
338+
},
339+
{
340+
name: "custom protection with trusted origin and same Origin",
341+
protection: func() *http.CrossOriginProtection {
342+
p := http.NewCrossOriginProtection()
343+
if err := p.AddTrustedOrigin("https://example.com"); err != nil {
344+
t.Fatal(err)
345+
}
346+
return p
347+
}(),
348+
requestOrigin: "https://example.com",
349+
wantStatusCode: http.StatusNotFound, // origin accepted; session not found
350+
},
351+
{
352+
name: "custom protection with trusted origin and different Origin",
353+
protection: func() *http.CrossOriginProtection {
354+
p := http.NewCrossOriginProtection()
355+
if err := p.AddTrustedOrigin("https://example.com"); err != nil {
356+
t.Fatal(err)
357+
}
358+
return p
359+
}(),
360+
requestOrigin: "https://malicious.com",
361+
wantStatusCode: http.StatusForbidden,
362+
},
363+
}
364+
365+
for _, tt := range tests {
366+
t.Run(tt.name, func(t *testing.T) {
367+
opts := &SSEOptions{
368+
CrossOriginProtection: tt.protection,
369+
}
370+
handler := NewSSEHandler(func(req *http.Request) *Server { return server }, opts)
371+
httpServer := httptest.NewServer(handler)
372+
defer httpServer.Close()
373+
374+
// Use POST with a valid session-like URL to test origin protection
375+
// without creating a hanging GET connection.
376+
reqReader := strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping"}`)
377+
req, err := http.NewRequest(http.MethodPost, httpServer.URL+"?sessionid=nonexistent", reqReader)
378+
if err != nil {
379+
t.Fatal(err)
380+
}
381+
req.Header.Set("Content-Type", "application/json")
382+
req.Header.Set("Origin", tt.requestOrigin)
383+
384+
resp, err := http.DefaultClient.Do(req)
385+
if err != nil {
386+
t.Fatal(err)
387+
}
388+
defer resp.Body.Close()
389+
390+
if got := resp.StatusCode; got != tt.wantStatusCode {
391+
body, _ := io.ReadAll(resp.Body)
392+
t.Errorf("Status code: got %d, want %d (body: %s)", got, tt.wantStatusCode, body)
393+
}
394+
})
395+
}
396+
}

mcp/streamable.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ type StreamableHTTPOptions struct {
182182
// is ignored.
183183
// If nil, default (zero-value) cross-origin protection will be used.
184184
// Use `disablecrossoriginprotection` MCPGODEBUG compatibility parameter
185-
// to disable the default protection until v1.6.0.
185+
// to disable the default protection until v1.7.0.
186186
CrossOriginProtection *http.CrossOriginProtection
187187
}
188188

@@ -235,14 +235,14 @@ func (h *StreamableHTTPHandler) closeAll() {
235235
// disablelocalhostprotection is a compatibility parameter that allows to disable
236236
// DNS rebinding protection, which was added in the 1.4.0 version of the SDK.
237237
// See the documentation for the mcpgodebug package for instructions how to enable it.
238-
// The option will be removed in the 1.6.0 version of the SDK.
238+
// The option will be removed in the 1.7.0 version of the SDK.
239239
var disablelocalhostprotection = mcpgodebug.Value("disablelocalhostprotection")
240240

241241
// disablecrossoriginprotection is a compatibility parameter that allows to disable
242242
// the verification of the 'Origin' and 'Content-Type' headers, which was added in
243243
// the 1.4.1 version of the SDK. See the documentation for the mcpgodebug package
244244
// for instructions how to enable it.
245-
// The option will be removed in the 1.6.0 version of the SDK.
245+
// The option will be removed in the 1.7.0 version of the SDK.
246246
var disablecrossoriginprotection = mcpgodebug.Value("disablecrossoriginprotection")
247247

248248
func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {

0 commit comments

Comments
 (0)