diff --git a/go/core/internal/mcp/mcp_handler.go b/go/core/internal/mcp/mcp_handler.go index 8182df6fb..aed69d71c 100644 --- a/go/core/internal/mcp/mcp_handler.go +++ b/go/core/internal/mcp/mcp_handler.go @@ -21,14 +21,28 @@ import ( "trpc.group/trpc-go/trpc-a2a-go/protocol" ) -// MCPHandler handles MCP requests and bridges them to A2A endpoints +// MCPHandler handles MCP requests and bridges them to A2A endpoints. +// +// Agent filtering: +// Callers may restrict which agents are visible in a session by appending an +// "agents" query parameter to the MCP endpoint URL, e.g.: +// +// http://kagent-controller:8083/mcp?agents=kagent/k8s-agent,kagent/helm-agent +// +// When the parameter is present, list_agents returns only the matching agents +// and invoke_agent rejects calls targeting any agent not in the list. +// Omitting the parameter preserves the original behaviour: all agents are +// accessible. +// +// The allow-list is parsed once per session in the server factory (on the +// initialize request) and closed over in the tool handlers, making it +// immutable for the session's lifetime. type MCPHandler struct { kubeClient client.Client a2aBaseURL string a2aTimeout time.Duration authenticator auth.AuthProvider httpHandler *mcpsdk.StreamableHTTPHandler - server *mcpsdk.Server a2aClients sync.Map } @@ -60,8 +74,8 @@ type InvokeAgentOutput struct { // the configured default streaming timeout. const defaultA2ATimeout = 10 * time.Minute -// NewMCPHandler creates a new MCP handler -// Wraps the StreamableHTTPHandler and adds A2A bridging and context management. +// NewMCPHandler creates a new MCP handler. +// It wraps the StreamableHTTPHandler and adds A2A bridging and context management. func NewMCPHandler(kubeClient client.Client, a2aBaseURL string, authenticator auth.AuthProvider, a2aTimeout time.Duration) (*MCPHandler, error) { if a2aTimeout <= 0 { a2aTimeout = defaultA2ATimeout @@ -73,47 +87,87 @@ func NewMCPHandler(kubeClient client.Client, a2aBaseURL string, authenticator au authenticator: authenticator, } - // Create MCP server - impl := &mcpsdk.Implementation{ + // The server factory is called exactly once per MCP session (on the + // initialize request). Parsing the allow-list here and closing over it in + // the tool handlers makes the filter immutable for the session's lifetime + // — no context plumbing, no per-request re-parsing, no bypass window. + handler.httpHandler = mcpsdk.NewStreamableHTTPHandler( + func(r *http.Request) *mcpsdk.Server { + return handler.newMCPServer(parseAllowedAgents(r)) + }, + nil, + ) + + return handler, nil +} + +// newMCPServer creates a new MCP server with the given agent allow-list +// closed over in the tool handlers. A nil allow-list means all agents are +// accessible. +func (h *MCPHandler) newMCPServer(allowed map[string]struct{}) *mcpsdk.Server { + server := mcpsdk.NewServer(&mcpsdk.Implementation{ Name: "kagent-agents", Version: version.Version, - } - server := mcpsdk.NewServer(impl, nil) - handler.server = server + }, nil) - // Add list_agents tool mcpsdk.AddTool[ListAgentsInput, ListAgentsOutput]( server, &mcpsdk.Tool{ Name: "list_agents", Description: "List invokable kagent agents (accepted + deploymentReady)", }, - handler.handleListAgents, + func(ctx context.Context, req *mcpsdk.CallToolRequest, input ListAgentsInput) (*mcpsdk.CallToolResult, ListAgentsOutput, error) { + return h.handleListAgents(ctx, req, input, allowed) + }, ) - // Add invoke_agent tool mcpsdk.AddTool[InvokeAgentInput, InvokeAgentOutput]( server, &mcpsdk.Tool{ Name: "invoke_agent", Description: "Invoke a kagent agent via A2A", }, - handler.handleInvokeAgent, - ) - - // Create HTTP handler - handler.httpHandler = mcpsdk.NewStreamableHTTPHandler( - func(*http.Request) *mcpsdk.Server { - return server + func(ctx context.Context, req *mcpsdk.CallToolRequest, input InvokeAgentInput) (*mcpsdk.CallToolResult, InvokeAgentOutput, error) { + return h.handleInvokeAgent(ctx, req, input, allowed) }, - nil, ) - return handler, nil + return server } -// handleListAgents handles the list_agents MCP tool -func (h *MCPHandler) handleListAgents(ctx context.Context, req *mcpsdk.CallToolRequest, input ListAgentsInput) (*mcpsdk.CallToolResult, ListAgentsOutput, error) { +// parseAllowedAgents reads the "agents" query parameter from the request and +// returns a set of permitted agent refs (e.g. "kagent/k8s-agent"). +// +// The parameter is a comma-separated list of "namespace/name" values: +// +// ?agents=kagent/k8s-agent,kagent/helm-agent +// +// Returns nil when the parameter is absent or empty, which means no filtering +// is applied and all agents are accessible. +func parseAllowedAgents(r *http.Request) map[string]struct{} { + raw := r.URL.Query().Get("agents") + if raw == "" { + return nil + } + + set := make(map[string]struct{}) + for ref := range strings.SplitSeq(raw, ",") { + ref = strings.TrimSpace(ref) + if ref != "" { + set[ref] = struct{}{} + } + } + + if len(set) == 0 { + return nil + } + return set +} + +// handleListAgents handles the list_agents MCP tool. +// When an agent allow-list is active for the session, only agents whose ref +// appears in the list are returned. +func (h *MCPHandler) handleListAgents(ctx context.Context, req *mcpsdk.CallToolRequest, input ListAgentsInput, allowed map[string]struct{}) (*mcpsdk.CallToolResult, ListAgentsOutput, error) { log := ctrllog.FromContext(ctx).WithName("mcp-handler").WithValues("tool", "list_agents") agentList := &v1alpha2.AgentList{} @@ -128,7 +182,7 @@ func (h *MCPHandler) handleListAgents(ctx context.Context, req *mcpsdk.CallToolR agents := make([]AgentSummary, 0) for _, agent := range agentList.Items { - // Check if agent is accepted and deployment ready + // Only include agents that are both accepted and deployment-ready. deploymentReady := false accepted := false for _, condition := range agent.Status.Conditions { @@ -145,14 +199,21 @@ func (h *MCPHandler) handleListAgents(ctx context.Context, req *mcpsdk.CallToolR } ref := agent.Namespace + "/" + agent.Name - description := agent.Spec.Description + + // When an allow-list is active, skip agents that are not in it. + if allowed != nil { + if _, ok := allowed[ref]; !ok { + continue + } + } + agents = append(agents, AgentSummary{ Ref: ref, - Description: description, + Description: agent.Spec.Description, }) } - log.Info("Listed agents", "count", len(agents)) + log.Info("Listed agents", "count", len(agents), "filtered", allowed != nil) output := ListAgentsOutput{Agents: agents} @@ -179,8 +240,10 @@ func (h *MCPHandler) handleListAgents(ctx context.Context, req *mcpsdk.CallToolR }, output, nil } -// handleInvokeAgent handles the invoke_agent MCP tool -func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallToolRequest, input InvokeAgentInput) (*mcpsdk.CallToolResult, InvokeAgentOutput, error) { +// handleInvokeAgent handles the invoke_agent MCP tool. +// When an agent allow-list is active for the session, requests targeting an +// agent not in the list are rejected before any A2A call is made. +func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallToolRequest, input InvokeAgentInput, allowed map[string]struct{}) (*mcpsdk.CallToolResult, InvokeAgentOutput, error) { log := ctrllog.FromContext(ctx).WithName("mcp-handler").WithValues("tool", "invoke_agent") // Parse agent reference (namespace/name or just name) @@ -196,27 +259,43 @@ func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallTool agentRef := agentNS + "/" + agentName agentNns := types.NamespacedName{Namespace: agentNS, Name: agentName} - // Get context ID from client request (stateless mode) - // If not provided, contextIDPtr will be nil and a new conversation will start + // Enforce the allow-list before touching any downstream service. + // This is the hard access-control boundary: if the caller's MCP session was + // scoped to a subset of agents (via the "agents" query parameter), any attempt + // to invoke an agent outside that subset is rejected here. + if allowed != nil { + if _, ok := allowed[agentRef]; !ok { + log.Info("Rejected invoke_agent: agent not in allow-list", "agent", agentRef) + return &mcpsdk.CallToolResult{ + Content: []mcpsdk.Content{ + &mcpsdk.TextContent{Text: fmt.Sprintf("agent %q is not available in this session", agentRef)}, + }, + IsError: true, + }, InvokeAgentOutput{}, nil + } + } + + // Get context ID from client request (stateless mode). + // If not provided, contextIDPtr will be nil and a new conversation will start. var contextIDPtr *string if input.ContextID != "" { contextIDPtr = &input.ContextID log.V(1).Info("Using context_id from client request", "context_id", input.ContextID) } - // Get or create cached A2A client for this agent + // Get or create cached A2A client for this agent. a2aURL := fmt.Sprintf("%s/%s/", h.a2aBaseURL, agentRef) var a2aClient *a2aclient.A2AClient if cached, ok := h.a2aClients.Load(agentRef); ok { - if client, ok := cached.(*a2aclient.A2AClient); ok { - a2aClient = client + if c, ok := cached.(*a2aclient.A2AClient); ok { + a2aClient = c } } - // Create new client if not cached + // Create new client if not cached. if a2aClient == nil { - // Build A2A client options with authentication propagation + // Build A2A client options with authentication propagation. a2aOpts := []a2aclient.Option{ a2aclient.WithTimeout(h.a2aTimeout), a2aclient.WithHTTPReqHandler( @@ -238,12 +317,12 @@ func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallTool }, InvokeAgentOutput{}, nil } - // Cache the client + // Cache the client. h.a2aClients.Store(agentRef, newClient) a2aClient = newClient } - // Send message via A2A + // Send message via A2A. result, err := a2aClient.SendMessage(ctx, protocol.SendMessageParams{ Message: protocol.Message{ Kind: protocol.KindMessage, @@ -262,7 +341,7 @@ func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallTool }, InvokeAgentOutput{}, nil } - // Extract response text and context ID + // Extract response text and context ID. var responseText, newContextID string switch a2aResult := result.Result.(type) { case *protocol.Message: @@ -270,7 +349,7 @@ func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallTool if a2aResult.ContextID != nil { newContextID = *a2aResult.ContextID } - // Kagent A2A only returns Task type for now + // Kagent A2A only returns Task type for now. case *protocol.Task: newContextID = a2aResult.ContextID if a2aResult.Status.Message != nil { @@ -296,7 +375,7 @@ func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallTool log.Info("Invoked agent", "agent", agentRef, "hasContextID", newContextID != "") - // Return context_id in response so client can store it for stateless operation + // Return context_id in response so the client can store it for stateless operation. output := InvokeAgentOutput{ Agent: agentRef, Text: responseText, @@ -312,15 +391,14 @@ func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallTool }, output, nil } -// ServeHTTP implements http.Handler interface +// ServeHTTP implements http.Handler. func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // The MCP HTTP handler handles all the routing internally h.httpHandler.ServeHTTP(w, r) } -// Shutdown gracefully shuts down the MCP handler +// Shutdown gracefully shuts down the MCP handler. func (h *MCPHandler) Shutdown(ctx context.Context) error { - // The new SDK doesn't have an explicit Shutdown method on StreamableHTTPHandler - // The server will be shut down when the context is cancelled + // The new SDK doesn't have an explicit Shutdown method on StreamableHTTPHandler. + // The server will be shut down when the context is cancelled. return nil } diff --git a/go/core/internal/mcp/mcp_handler_test.go b/go/core/internal/mcp/mcp_handler_test.go new file mode 100644 index 000000000..1e4626354 --- /dev/null +++ b/go/core/internal/mcp/mcp_handler_test.go @@ -0,0 +1,342 @@ +package mcp + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/kagent-dev/kagent/go/api/v1alpha2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +// --- parseAllowedAgents unit tests --- + +func TestParseAllowedAgents(t *testing.T) { + tests := []struct { + name string + agentsVal string // raw value of the "agents" query parameter (empty = omit the param) + wantNil bool + wantRefs []string // expected keys in the returned set + }{ + { + name: "no agents parameter returns nil (all agents allowed)", + wantNil: true, + }, + { + name: "empty agents parameter returns nil", + agentsVal: "", + wantNil: true, + }, + { + name: "single agent ref", + agentsVal: "kagent/k8s-agent", + wantRefs: []string{"kagent/k8s-agent"}, + }, + { + name: "multiple agent refs", + agentsVal: "kagent/k8s-agent,kagent/helm-agent,kagent/observability-agent", + wantRefs: []string{"kagent/k8s-agent", "kagent/helm-agent", "kagent/observability-agent"}, + }, + { + name: "whitespace around refs is trimmed", + agentsVal: " kagent/k8s-agent , kagent/helm-agent ", + wantRefs: []string{"kagent/k8s-agent", "kagent/helm-agent"}, + }, + { + name: "comma-only value returns nil", + agentsVal: ",,,", + wantNil: true, + }, + { + name: "duplicate refs are deduplicated", + agentsVal: "kagent/k8s-agent,kagent/k8s-agent", + wantRefs: []string{"kagent/k8s-agent"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "http://example.com/mcp", nil) + if tt.agentsVal != "" { + q := url.Values{} + q.Set("agents", tt.agentsVal) + r.URL.RawQuery = q.Encode() + } + + got := parseAllowedAgents(r) + + if tt.wantNil { + assert.Nil(t, got) + return + } + + require.NotNil(t, got) + assert.Len(t, got, len(tt.wantRefs)) + for _, ref := range tt.wantRefs { + assert.Contains(t, got, ref, "expected ref %q in allow-list", ref) + } + }) + } +} + +// --- MCP handler integration tests --- + +// scheme holds the CRD types used in fake client construction. +var testScheme = func() *runtime.Scheme { + s := runtime.NewScheme() + if err := v1alpha2.AddToScheme(s); err != nil { + panic(fmt.Sprintf("failed to add v1alpha2 to scheme: %v", err)) + } + return s +}() + +// readyAgent returns an Agent object with both Accepted and DeploymentReady conditions. +func readyAgent(namespace, name, description string) *v1alpha2.Agent { + return &v1alpha2.Agent{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + Spec: v1alpha2.AgentSpec{ + Description: description, + }, + Status: v1alpha2.AgentStatus{ + Conditions: []metav1.Condition{ + { + Type: "Accepted", + Status: metav1.ConditionTrue, + }, + { + Type: "Ready", + Reason: "DeploymentReady", + Status: metav1.ConditionTrue, + }, + }, + }, + } +} + +// mcpSession holds an established MCP session for use across multiple calls. +type mcpSession struct { + handler http.Handler + sessionID string + targetURL string // base URL including any ?agents= query param +} + +// newMCPSession performs the MCP initialize handshake with the given handler at +// the given URL (which may include query parameters such as ?agents=…) and +// returns a session ready for tool calls. +func newMCPSession(t *testing.T, handler http.Handler, targetURL string) *mcpSession { + t.Helper() + + body, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{"name": "test", "version": "1.0"}, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, targetURL, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code, "initialize must succeed") + + sessionID := rr.Header().Get("Mcp-Session-Id") + require.NotEmpty(t, sessionID, "server must return Mcp-Session-Id after initialize") + + return &mcpSession{handler: handler, sessionID: sessionID, targetURL: targetURL} +} + +// call sends a tools/call request within this session and returns the first +// "data:" event payload. +func (s *mcpSession) call(t *testing.T, toolName string, args any) map[string]any { + t.Helper() + + body, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "name": toolName, + "arguments": args, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, s.targetURL, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("Mcp-Session-Id", s.sessionID) + + rr := httptest.NewRecorder() + s.handler.ServeHTTP(rr, req) + + // Parse the SSE stream and return the first data event. + scanner := bufio.NewScanner(rr.Body) + for scanner.Scan() { + line := scanner.Text() + if after, ok := strings.CutPrefix(line, "data: "); ok { + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(after), &result)) + return result + } + } + t.Fatalf("no data event found in MCP response (status %d, session %s): %s", + rr.Code, s.sessionID, rr.Body.String()) + return nil +} + +// extractToolResult returns the "result.content[0].text" field from an MCP +// tools/call response, and whether the result carries IsError=true. +func extractToolResult(t *testing.T, resp map[string]any) (text string, isError bool) { + t.Helper() + + result, ok := resp["result"].(map[string]any) + require.True(t, ok, "expected result field in response") + + isError, _ = result["isError"].(bool) + + content, ok := result["content"].([]any) + require.True(t, ok, "expected content array in result") + require.NotEmpty(t, content) + + first, ok := content[0].(map[string]any) + require.True(t, ok) + + text, _ = first["text"].(string) + return text, isError +} + +// newTestHandler creates an MCPHandler backed by a fake Kubernetes client +// pre-populated with the given agents. The a2aBaseURL is left intentionally +// empty: tests that exercise invoke_agent for blocked agents never reach the +// A2A layer, so no real backend is needed. +func newTestHandler(t *testing.T, agents ...*v1alpha2.Agent) *MCPHandler { + t.Helper() + + objs := make([]runtime.Object, len(agents)) + for i, a := range agents { + objs[i] = a + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(testScheme). + WithRuntimeObjects(objs...). + WithStatusSubresource(&v1alpha2.Agent{}). + Build() + + handler, err := NewMCPHandler(fakeClient, "http://unused-a2a-base", nil, 0) + require.NoError(t, err) + return handler +} + +// TestListAgents_NoFilter verifies that without a filter all ready agents +// are returned. +func TestListAgents_NoFilter(t *testing.T) { + handler := newTestHandler(t, + readyAgent("kagent", "k8s-agent", "Kubernetes expert"), + readyAgent("kagent", "helm-agent", "Helm expert"), + ) + + sess := newMCPSession(t, handler, "/mcp") + resp := sess.call(t, "list_agents", map[string]any{}) + + text, isError := extractToolResult(t, resp) + assert.False(t, isError) + assert.Contains(t, text, "kagent/k8s-agent") + assert.Contains(t, text, "kagent/helm-agent") +} + +// TestListAgents_WithFilter verifies that the allow-list restricts the +// list_agents response to only the permitted agents. +func TestListAgents_WithFilter(t *testing.T) { + handler := newTestHandler(t, + readyAgent("kagent", "k8s-agent", "Kubernetes expert"), + readyAgent("kagent", "helm-agent", "Helm expert"), + readyAgent("kagent", "observability-agent", "Observability expert"), + ) + + // Establish a session scoped to k8s-agent only. + sess := newMCPSession(t, handler, "/mcp?agents=kagent%2Fk8s-agent") + resp := sess.call(t, "list_agents", map[string]any{}) + + text, isError := extractToolResult(t, resp) + assert.False(t, isError) + assert.Contains(t, text, "kagent/k8s-agent", "allowed agent must appear in result") + assert.NotContains(t, text, "kagent/helm-agent", "non-allowed agent must not appear") + assert.NotContains(t, text, "kagent/observability-agent", "non-allowed agent must not appear") +} + +// TestListAgents_MultipleFilter verifies that a multi-agent allow-list permits +// exactly the specified agents and no others. +func TestListAgents_MultipleFilter(t *testing.T) { + handler := newTestHandler(t, + readyAgent("kagent", "k8s-agent", "Kubernetes expert"), + readyAgent("kagent", "helm-agent", "Helm expert"), + readyAgent("kagent", "observability-agent", "Observability expert"), + ) + + sess := newMCPSession(t, handler, "/mcp?agents=kagent%2Fk8s-agent,kagent%2Fhelm-agent") + resp := sess.call(t, "list_agents", map[string]any{}) + + text, isError := extractToolResult(t, resp) + assert.False(t, isError) + assert.Contains(t, text, "kagent/k8s-agent") + assert.Contains(t, text, "kagent/helm-agent") + assert.NotContains(t, text, "kagent/observability-agent") +} + +// TestInvokeAgent_BlockedByFilter verifies that invoke_agent is rejected +// for agents outside the session allow-list, without touching the A2A layer. +func TestInvokeAgent_BlockedByFilter(t *testing.T) { + handler := newTestHandler(t, + readyAgent("kagent", "k8s-agent", "Kubernetes expert"), + readyAgent("kagent", "helm-agent", "Helm expert"), + ) + + // Session is scoped to k8s-agent only — helm-agent must be rejected. + sess := newMCPSession(t, handler, "/mcp?agents=kagent%2Fk8s-agent") + resp := sess.call(t, "invoke_agent", map[string]any{ + "agent": "kagent/helm-agent", + "task": "List all releases", + }) + + text, isError := extractToolResult(t, resp) + assert.True(t, isError, "response must carry IsError=true for a blocked agent") + assert.Contains(t, text, "not available in this session") + assert.Contains(t, text, "kagent/helm-agent") +} + +// TestInvokeAgent_InvalidRef verifies that invoke_agent rejects refs that +// do not follow the namespace/name format. +func TestInvokeAgent_InvalidRef(t *testing.T) { + handler := newTestHandler(t) + + sess := newMCPSession(t, handler, "/mcp") + resp := sess.call(t, "invoke_agent", map[string]any{ + "agent": "no-slash-here", + "task": "do something", + }) + + text, isError := extractToolResult(t, resp) + assert.True(t, isError) + assert.Contains(t, text, "namespace/name") +}