Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,15 @@ func paginateAll[T any](
logConn.Printf("list%s: received page of %d %s from serverID=%s", itemKind, len(first.Items), itemKind, serverID)

cursor := first.NextCursor
seenCursors := make(map[string]struct{})
for pageCount := 1; cursor != ""; pageCount++ {
if pageCount >= paginateAllMaxPages {
return nil, fmt.Errorf("list%s: backend serverID=%s returned more than %d pages; aborting to prevent unbounded memory growth", itemKind, serverID, paginateAllMaxPages)
}
if _, seen := seenCursors[cursor]; seen {
return nil, fmt.Errorf("list%s: backend serverID=%s returned cyclical cursor %q", itemKind, serverID, cursor)
}
seenCursors[cursor] = struct{}{}
page, err := fetch(cursor)
if err != nil {
return nil, err
Expand Down
31 changes: 29 additions & 2 deletions internal/mcp/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -978,16 +978,43 @@ func TestPaginateAll(t *testing.T) {
})

t.Run("exceeding max pages returns error", func(t *testing.T) {
// Each call returns a cursor so the loop never ends naturally.
// Each call returns a unique cursor so the loop never ends naturally.
callCount := 0
_, err := paginateAll("server1", "tools", func(cursor string) (paginatedPage[string], error) {
callCount++
return paginatedPage[string]{Items: []string{"x"}, NextCursor: "next"}, nil
nextCursor := "next"
if cursor != "" {
nextCursor = cursor + "next"
}
return paginatedPage[string]{Items: []string{"x"}, NextCursor: nextCursor}, nil
})
require.Error(t, err)
assert.Contains(t, err.Error(), "more than")
assert.Contains(t, err.Error(), "pages")
// Must stop at the page limit, not run forever.
assert.Equal(t, paginateAllMaxPages, callCount)
})

t.Run("cyclical cursor returns error", func(t *testing.T) {
callCount := 0
_, err := paginateAll("server1", "tools", func(cursor string) (paginatedPage[string], error) {
callCount++
switch cursor {
case "":
return paginatedPage[string]{Items: []string{"a"}, NextCursor: "page2"}, nil
case "page2":
return paginatedPage[string]{Items: []string{"b"}, NextCursor: "page3"}, nil
case "page3":
return paginatedPage[string]{Items: []string{"c"}, NextCursor: "page2"}, nil
default:
return paginatedPage[string]{Items: nil, NextCursor: ""}, nil
}
})

require.Error(t, err)
assert.Contains(t, err.Error(), "cyclical cursor")
assert.Contains(t, err.Error(), "page2")
// Initial page + 2 unique cursor fetches, then cycle detected before another fetch.
assert.Equal(t, 3, callCount)
})
}
6 changes: 5 additions & 1 deletion internal/mcp/http_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ const (
HTTPTransportPlainJSON HTTPTransportType = "plain-json"
)

// MCPProtocolVersion is the MCP protocol version used in initialization requests.
// MCPProtocolVersion is the MCP protocol version used only by the plain JSON-RPC
// fallback path in this package. Streamable and SSE transports are SDK-managed
// and negotiate protocol versions internally.
const MCPProtocolVersion = "2025-11-25"

// requestIDCounter is used to generate unique request IDs for HTTP requests
Expand Down Expand Up @@ -78,6 +80,8 @@ func isSessionNotFoundError(err error) bool {
if errors.Is(err, sdk.ErrSessionMissing) {
return true
}
// Plain JSON-RPC fallback requests bypass SDK session types, so they cannot
// return sdk.ErrSessionMissing and are matched by backend error text instead.
return strings.Contains(strings.ToLower(err.Error()), "session not found")
}

Expand Down
28 changes: 28 additions & 0 deletions internal/server/routed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,24 @@ func TestRegisterToolWithoutValidation(t *testing.T) {
},
}, handler)

// This canary verifies the key behavior relied on by registerToolWithoutValidation:
// tool calls are not rejected by SDK argument-value validation.
var strictHandlerCalled bool
registerToolWithoutValidation(server, &sdk.Tool{
Name: "strict_tool",
Description: "A strict-schema tool",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"count": map[string]interface{}{"type": "integer"},
},
"required": []interface{}{"count"},
},
}, func(ctx context.Context, req *sdk.CallToolRequest, state interface{}) (*sdk.CallToolResult, interface{}, error) {
strictHandlerCalled = true
return &sdk.CallToolResult{IsError: false}, state, nil
})

// Use in-memory transports to connect a client to the server and invoke the tool
serverTransport, clientTransport := sdk.NewInMemoryTransports()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
Expand All @@ -721,6 +739,16 @@ func TestRegisterToolWithoutValidation(t *testing.T) {
require.NoError(err)
assert.False(result.IsError)
assert.True(handlerCalled, "Handler should have been called")

// Provide an intentionally invalid value for the strict schema ("count" must be integer).
// If SDK starts validating argument values on this registration path, this call will fail.
strictResult, err := clientSession.CallTool(ctx, &sdk.CallToolParams{
Name: "strict_tool",
Arguments: map[string]interface{}{"count": "not-an-integer"},
})
require.NoError(err)
assert.False(strictResult.IsError)
assert.True(strictHandlerCalled, "Strict handler should be called even with schema-invalid arguments")
}

// TestCreateHTTPServerForRoutedMode_OAuth tests OAuth discovery endpoint in routed mode
Expand Down
Loading