Skip to content

Commit 4af1f8e

Browse files
committed
refactor: decouple MCP header utilities from http.Request by accepting http.Header instead
1 parent 7df5ab6 commit 4af1f8e

3 files changed

Lines changed: 20 additions & 26 deletions

File tree

mcp/mcp_http_headers.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,32 +44,32 @@ func extractName(method string, params json.RawMessage) (string, bool) {
4444
return "", false
4545
}
4646

47-
func setStandardHeaders(httpReq *http.Request, msg jsonrpc.Message) {
47+
func setStandardHeaders(header http.Header, msg jsonrpc.Message) {
4848
if msg == nil {
4949
return
5050
}
51-
if httpReq.Header.Get(ProtocolVersionHeader) == "" || httpReq.Header.Get(ProtocolVersionHeader) < MinVersionForStandardHeaders {
51+
if header.Get(ProtocolVersionHeader) == "" || header.Get(ProtocolVersionHeader) < MinVersionForStandardHeaders {
5252
return
5353
}
5454

5555
switch msg := msg.(type) {
5656
case *jsonrpc.Request:
57-
httpReq.Header.Set(MethodHeader, msg.Method)
57+
header.Set(MethodHeader, msg.Method)
5858
if name, ok := extractName(msg.Method, msg.Params); ok {
59-
httpReq.Header.Set(NameHeader, name)
59+
header.Set(NameHeader, name)
6060
}
6161
}
6262
}
6363

64-
func validateMcpHeaders(req *http.Request, msg jsonrpc.Message) error {
65-
protocolVersion := req.Header.Get(ProtocolVersionHeader)
64+
func validateMcpHeaders(header http.Header, msg jsonrpc.Message) error {
65+
protocolVersion := header.Get(ProtocolVersionHeader)
6666
if protocolVersion == "" || protocolVersion < MinVersionForStandardHeaders {
6767
return nil
6868
}
6969

7070
switch msg := msg.(type) {
7171
case *jsonrpc.Request:
72-
methodInHeader := req.Header.Get(MethodHeader)
72+
methodInHeader := header.Get(MethodHeader)
7373
if methodInHeader == "" {
7474
return errors.New("missing required Mcp-Method header")
7575
}
@@ -78,7 +78,7 @@ func validateMcpHeaders(req *http.Request, msg jsonrpc.Message) error {
7878
}
7979

8080
if msg.Method == "tools/call" || msg.Method == "resources/read" || msg.Method == "prompts/get" {
81-
nameInHeader := req.Header.Get(NameHeader)
81+
nameInHeader := header.Get(NameHeader)
8282
if nameInHeader == "" {
8383
return fmt.Errorf("missing required Mcp-Name header for method %q", msg.Method)
8484
}

mcp/mcp_http_headers_test.go

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -193,20 +193,17 @@ func TestSetStandardHeaders(t *testing.T) {
193193

194194
for _, tt := range tests {
195195
t.Run(tt.name, func(t *testing.T) {
196-
httpReq, err := http.NewRequest("POST", "http://localhost/mcp", nil)
197-
if err != nil {
198-
t.Fatal(err)
199-
}
196+
header := http.Header{}
200197
if tt.protocolVersion != "" {
201-
httpReq.Header.Set(ProtocolVersionHeader, tt.protocolVersion)
198+
header.Set(ProtocolVersionHeader, tt.protocolVersion)
202199
}
203200

204-
setStandardHeaders(httpReq, tt.msg)
201+
setStandardHeaders(header, tt.msg)
205202

206-
if got := httpReq.Header.Get(MethodHeader); got != tt.wantMethodHeader {
203+
if got := header.Get(MethodHeader); got != tt.wantMethodHeader {
207204
t.Errorf("MethodHeader = %q, want %q", got, tt.wantMethodHeader)
208205
}
209-
if got := httpReq.Header.Get(NameHeader); got != tt.wantNameHeader {
206+
if got := header.Get(NameHeader); got != tt.wantNameHeader {
210207
t.Errorf("NameHeader = %q, want %q", got, tt.wantNameHeader)
211208
}
212209
})
@@ -406,21 +403,18 @@ func TestValidateMcpHeaders(t *testing.T) {
406403

407404
for _, tt := range tests {
408405
t.Run(tt.name, func(t *testing.T) {
409-
httpReq, err := http.NewRequest("POST", "http://localhost/mcp", nil)
410-
if err != nil {
411-
t.Fatal(err)
412-
}
406+
header := http.Header{}
413407
if tt.version != "" {
414-
httpReq.Header.Set(ProtocolVersionHeader, tt.version)
408+
header.Set(ProtocolVersionHeader, tt.version)
415409
}
416410
if tt.methodHeader != "" {
417-
httpReq.Header.Set(MethodHeader, tt.methodHeader)
411+
header.Set(MethodHeader, tt.methodHeader)
418412
}
419413
if tt.nameHeader != "" {
420-
httpReq.Header.Set(NameHeader, tt.nameHeader)
414+
header.Set(NameHeader, tt.nameHeader)
421415
}
422416

423-
err = validateMcpHeaders(httpReq, tt.msg)
417+
err := validateMcpHeaders(header, tt.msg)
424418
if tt.wantErr {
425419
if err == nil {
426420
t.Fatalf("validateMcpHeaders() = nil, want error containing %q", tt.wantErrContain)

mcp/streamable.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
11871187

11881188
// Validate MCP standard headers (Mcp-Method, Mcp-Name)
11891189
if !isBatch && len(incoming) == 1 {
1190-
if err := validateMcpHeaders(req, incoming[0]); err != nil {
1190+
if err := validateMcpHeaders(req.Header, incoming[0]); err != nil {
11911191
resp := &jsonrpc.Response{
11921192
Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()),
11931193
}
@@ -1809,7 +1809,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
18091809
// and permanently break the connection.
18101810
return nil, nil, fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err)
18111811
}
1812-
setStandardHeaders(req, msg)
1812+
setStandardHeaders(req.Header, msg)
18131813
resp, err := c.client.Do(req)
18141814
if err != nil {
18151815
// Any error from client.Do means the request didn't reach the server.

0 commit comments

Comments
 (0)