Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 124 additions & 46 deletions go/core/internal/mcp/mcp_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,28 @@ import (
"trpc.group/trpc-go/trpc-a2a-go/protocol"
)

// MCPHandler handles MCP requests and bridges them to A2A endpoints
// MCPHandler handles MCP requests and bridges them to A2A endpoints.
//
// Agent filtering:
// Callers may restrict which agents are visible in a session by appending an
// "agents" query parameter to the MCP endpoint URL, e.g.:
//
// http://kagent-controller:8083/mcp?agents=kagent/k8s-agent,kagent/helm-agent
//
// When the parameter is present, list_agents returns only the matching agents
// and invoke_agent rejects calls targeting any agent not in the list.
// Omitting the parameter preserves the original behaviour: all agents are
// accessible.
//
// The allow-list is parsed once per session in the server factory (on the
// initialize request) and closed over in the tool handlers, making it
// immutable for the session's lifetime.
type MCPHandler struct {
kubeClient client.Client
a2aBaseURL string
a2aTimeout time.Duration
authenticator auth.AuthProvider
httpHandler *mcpsdk.StreamableHTTPHandler
server *mcpsdk.Server
a2aClients sync.Map
}

Expand Down Expand Up @@ -60,8 +74,8 @@ type InvokeAgentOutput struct {
// the configured default streaming timeout.
const defaultA2ATimeout = 10 * time.Minute

// NewMCPHandler creates a new MCP handler
// Wraps the StreamableHTTPHandler and adds A2A bridging and context management.
// NewMCPHandler creates a new MCP handler.
// It wraps the StreamableHTTPHandler and adds A2A bridging and context management.
func NewMCPHandler(kubeClient client.Client, a2aBaseURL string, authenticator auth.AuthProvider, a2aTimeout time.Duration) (*MCPHandler, error) {
if a2aTimeout <= 0 {
a2aTimeout = defaultA2ATimeout
Expand All @@ -73,47 +87,87 @@ func NewMCPHandler(kubeClient client.Client, a2aBaseURL string, authenticator au
authenticator: authenticator,
}

// Create MCP server
impl := &mcpsdk.Implementation{
// The server factory is called exactly once per MCP session (on the
// initialize request). Parsing the allow-list here and closing over it in
// the tool handlers makes the filter immutable for the session's lifetime
// — no context plumbing, no per-request re-parsing, no bypass window.
handler.httpHandler = mcpsdk.NewStreamableHTTPHandler(
func(r *http.Request) *mcpsdk.Server {
return handler.newMCPServer(parseAllowedAgents(r))
},
nil,
)

return handler, nil
}

// newMCPServer creates a new MCP server with the given agent allow-list
// closed over in the tool handlers. A nil allow-list means all agents are
// accessible.
func (h *MCPHandler) newMCPServer(allowed map[string]struct{}) *mcpsdk.Server {
server := mcpsdk.NewServer(&mcpsdk.Implementation{
Name: "kagent-agents",
Version: version.Version,
}
server := mcpsdk.NewServer(impl, nil)
handler.server = server
}, nil)

// Add list_agents tool
mcpsdk.AddTool[ListAgentsInput, ListAgentsOutput](
server,
&mcpsdk.Tool{
Name: "list_agents",
Description: "List invokable kagent agents (accepted + deploymentReady)",
},
handler.handleListAgents,
func(ctx context.Context, req *mcpsdk.CallToolRequest, input ListAgentsInput) (*mcpsdk.CallToolResult, ListAgentsOutput, error) {
return h.handleListAgents(ctx, req, input, allowed)
},
)

// Add invoke_agent tool
mcpsdk.AddTool[InvokeAgentInput, InvokeAgentOutput](
server,
&mcpsdk.Tool{
Name: "invoke_agent",
Description: "Invoke a kagent agent via A2A",
},
handler.handleInvokeAgent,
)

// Create HTTP handler
handler.httpHandler = mcpsdk.NewStreamableHTTPHandler(
func(*http.Request) *mcpsdk.Server {
return server
func(ctx context.Context, req *mcpsdk.CallToolRequest, input InvokeAgentInput) (*mcpsdk.CallToolResult, InvokeAgentOutput, error) {
return h.handleInvokeAgent(ctx, req, input, allowed)
},
nil,
)

return handler, nil
return server
}

// handleListAgents handles the list_agents MCP tool
func (h *MCPHandler) handleListAgents(ctx context.Context, req *mcpsdk.CallToolRequest, input ListAgentsInput) (*mcpsdk.CallToolResult, ListAgentsOutput, error) {
// parseAllowedAgents reads the "agents" query parameter from the request and
// returns a set of permitted agent refs (e.g. "kagent/k8s-agent").
//
// The parameter is a comma-separated list of "namespace/name" values:
//
// ?agents=kagent/k8s-agent,kagent/helm-agent
//
// Returns nil when the parameter is absent or empty, which means no filtering
// is applied and all agents are accessible.
func parseAllowedAgents(r *http.Request) map[string]struct{} {
raw := r.URL.Query().Get("agents")
if raw == "" {
return nil
}

set := make(map[string]struct{})
for ref := range strings.SplitSeq(raw, ",") {
ref = strings.TrimSpace(ref)
if ref != "" {
set[ref] = struct{}{}
}
}

if len(set) == 0 {
return nil
}
return set
}

// handleListAgents handles the list_agents MCP tool.
// When an agent allow-list is active for the session, only agents whose ref
// appears in the list are returned.
func (h *MCPHandler) handleListAgents(ctx context.Context, req *mcpsdk.CallToolRequest, input ListAgentsInput, allowed map[string]struct{}) (*mcpsdk.CallToolResult, ListAgentsOutput, error) {
log := ctrllog.FromContext(ctx).WithName("mcp-handler").WithValues("tool", "list_agents")

agentList := &v1alpha2.AgentList{}
Expand All @@ -128,7 +182,7 @@ func (h *MCPHandler) handleListAgents(ctx context.Context, req *mcpsdk.CallToolR

agents := make([]AgentSummary, 0)
for _, agent := range agentList.Items {
// Check if agent is accepted and deployment ready
// Only include agents that are both accepted and deployment-ready.
deploymentReady := false
accepted := false
for _, condition := range agent.Status.Conditions {
Expand All @@ -145,14 +199,21 @@ func (h *MCPHandler) handleListAgents(ctx context.Context, req *mcpsdk.CallToolR
}

ref := agent.Namespace + "/" + agent.Name
description := agent.Spec.Description

// When an allow-list is active, skip agents that are not in it.
if allowed != nil {
if _, ok := allowed[ref]; !ok {
continue
}
}

agents = append(agents, AgentSummary{
Ref: ref,
Description: description,
Description: agent.Spec.Description,
})
}

log.Info("Listed agents", "count", len(agents))
log.Info("Listed agents", "count", len(agents), "filtered", allowed != nil)

output := ListAgentsOutput{Agents: agents}

Expand All @@ -179,8 +240,10 @@ func (h *MCPHandler) handleListAgents(ctx context.Context, req *mcpsdk.CallToolR
}, output, nil
}

// handleInvokeAgent handles the invoke_agent MCP tool
func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallToolRequest, input InvokeAgentInput) (*mcpsdk.CallToolResult, InvokeAgentOutput, error) {
// handleInvokeAgent handles the invoke_agent MCP tool.
// When an agent allow-list is active for the session, requests targeting an
// agent not in the list are rejected before any A2A call is made.
func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallToolRequest, input InvokeAgentInput, allowed map[string]struct{}) (*mcpsdk.CallToolResult, InvokeAgentOutput, error) {
log := ctrllog.FromContext(ctx).WithName("mcp-handler").WithValues("tool", "invoke_agent")

// Parse agent reference (namespace/name or just name)
Expand All @@ -196,27 +259,43 @@ func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallTool
agentRef := agentNS + "/" + agentName
agentNns := types.NamespacedName{Namespace: agentNS, Name: agentName}

// Get context ID from client request (stateless mode)
// If not provided, contextIDPtr will be nil and a new conversation will start
// Enforce the allow-list before touching any downstream service.
// This is the hard access-control boundary: if the caller's MCP session was
// scoped to a subset of agents (via the "agents" query parameter), any attempt
// to invoke an agent outside that subset is rejected here.
if allowed != nil {
if _, ok := allowed[agentRef]; !ok {
log.Info("Rejected invoke_agent: agent not in allow-list", "agent", agentRef)
return &mcpsdk.CallToolResult{
Content: []mcpsdk.Content{
&mcpsdk.TextContent{Text: fmt.Sprintf("agent %q is not available in this session", agentRef)},
},
IsError: true,
}, InvokeAgentOutput{}, nil
}
}

// Get context ID from client request (stateless mode).
// If not provided, contextIDPtr will be nil and a new conversation will start.
var contextIDPtr *string
if input.ContextID != "" {
contextIDPtr = &input.ContextID
log.V(1).Info("Using context_id from client request", "context_id", input.ContextID)
}

// Get or create cached A2A client for this agent
// Get or create cached A2A client for this agent.
a2aURL := fmt.Sprintf("%s/%s/", h.a2aBaseURL, agentRef)
var a2aClient *a2aclient.A2AClient

if cached, ok := h.a2aClients.Load(agentRef); ok {
if client, ok := cached.(*a2aclient.A2AClient); ok {
a2aClient = client
if c, ok := cached.(*a2aclient.A2AClient); ok {
a2aClient = c
}
}

// Create new client if not cached
// Create new client if not cached.
if a2aClient == nil {
// Build A2A client options with authentication propagation
// Build A2A client options with authentication propagation.
a2aOpts := []a2aclient.Option{
a2aclient.WithTimeout(h.a2aTimeout),
a2aclient.WithHTTPReqHandler(
Expand All @@ -238,12 +317,12 @@ func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallTool
}, InvokeAgentOutput{}, nil
}

// Cache the client
// Cache the client.
h.a2aClients.Store(agentRef, newClient)
a2aClient = newClient
}

// Send message via A2A
// Send message via A2A.
result, err := a2aClient.SendMessage(ctx, protocol.SendMessageParams{
Message: protocol.Message{
Kind: protocol.KindMessage,
Expand All @@ -262,15 +341,15 @@ func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallTool
}, InvokeAgentOutput{}, nil
}

// Extract response text and context ID
// Extract response text and context ID.
var responseText, newContextID string
switch a2aResult := result.Result.(type) {
case *protocol.Message:
responseText = a2a.ExtractText(*a2aResult)
if a2aResult.ContextID != nil {
newContextID = *a2aResult.ContextID
}
// Kagent A2A only returns Task type for now
// Kagent A2A only returns Task type for now.
case *protocol.Task:
newContextID = a2aResult.ContextID
if a2aResult.Status.Message != nil {
Expand All @@ -296,7 +375,7 @@ func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallTool

log.Info("Invoked agent", "agent", agentRef, "hasContextID", newContextID != "")

// Return context_id in response so client can store it for stateless operation
// Return context_id in response so the client can store it for stateless operation.
output := InvokeAgentOutput{
Agent: agentRef,
Text: responseText,
Expand All @@ -312,15 +391,14 @@ func (h *MCPHandler) handleInvokeAgent(ctx context.Context, req *mcpsdk.CallTool
}, output, nil
}

// ServeHTTP implements http.Handler interface
// ServeHTTP implements http.Handler.
func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// The MCP HTTP handler handles all the routing internally
h.httpHandler.ServeHTTP(w, r)
}

// Shutdown gracefully shuts down the MCP handler
// Shutdown gracefully shuts down the MCP handler.
func (h *MCPHandler) Shutdown(ctx context.Context) error {
// The new SDK doesn't have an explicit Shutdown method on StreamableHTTPHandler
// The server will be shut down when the context is cancelled
// The new SDK doesn't have an explicit Shutdown method on StreamableHTTPHandler.
// The server will be shut down when the context is cancelled.
return nil
}
Loading