From 2af98838bcaf40e1b33655522103e335e03edf2d Mon Sep 17 00:00:00 2001 From: Nick Misasi Date: Fri, 22 May 2026 20:46:00 -0400 Subject: [PATCH 1/7] dynamic mcp: harden client catalog discovery Co-authored-by: Cursor --- llm/context.go | 91 ++++ llm/context_test.go | 62 +++ llm/tools.go | 205 ++++++++- llm/tools_test.go | 318 ++++++++++++- mcp/client.go | 320 +++++++++++-- mcp/client_embedded_oauth_test.go | 367 +++++++++++++++ mcp/client_manager.go | 96 ++-- mcp/client_manager_filter_test.go | 50 ++ mcp/client_manager_test.go | 142 +++++- mcp/client_test.go | 736 ++++++++++++++++++++++-------- mcp/retrieval_overrides.go | 14 + mcp/tools_cache_test.go | 10 + mcp/user_clients.go | 290 ++++++++---- mcp/user_clients_test.go | 205 ++++++++- mcpserver/eval_helpers_test.go | 2 +- search/search_test.go | 13 +- telemetry/integration_test.go | 6 +- 17 files changed, 2555 insertions(+), 372 deletions(-) create mode 100644 mcp/client_embedded_oauth_test.go create mode 100644 mcp/retrieval_overrides.go diff --git a/llm/context.go b/llm/context.go index 339923845..9981c8c8c 100644 --- a/llm/context.go +++ b/llm/context.go @@ -4,6 +4,7 @@ package llm import ( + stdcontext "context" "fmt" "strings" "time" @@ -36,6 +37,13 @@ type Context struct { // User that is making the request RequestingUser *model.User + // RequestContext carries the caller's request-scoped context for downstream + // work such as MCP tool discovery. May be nil in tests. + RequestContext stdcontext.Context + + // ConversationID identifies the conversation whose context is being built. + ConversationID string + // Bot Specific BotName string BotUsername string @@ -47,6 +55,42 @@ type Context struct { Tools *ToolStore DisabledToolsInfo []ToolInfo // Info about tools that are unavailable in the current context (e.g., DM-only tools in a channel) Parameters map[string]interface{} + + // MCPDynamicToolLoading indicates this context uses strict MCP dynamic loading. + MCPDynamicToolLoading bool + // MCPDynamicToolTelemetry receives low-cardinality dynamic MCP tool events. + MCPDynamicToolTelemetry MCPDynamicToolTelemetry + MCPDynamicToolSearchUsed bool + MCPDynamicLoadedToolNames map[string]bool + MCPDynamicSearchLoadCallSuccessRecorded map[string]bool + + // DisabledMCPServerOrigins contains per-user disabled MCP server origins that + // must be removed before strict registry construction. + DisabledMCPServerOrigins []string + + // KeepMCPTool, when non-nil, is applied to MCP tools before strict registry + // construction and before flag-off visible MCP insertion. + KeepMCPTool func(Tool) bool + + // PreloadedMCPTools contains exact-or-bare MCP tool selectors for internal + // predefined flows. They are selected only from the already-authorized MCP + // catalog and are request scoped. + PreloadedMCPTools []EnabledMCPTool + + // MCPToolRegistry holds the strict MCP tool registry that was built + // alongside Tools, when MCP dynamic tool loading is enabled. It is stashed + // here so callers can replay loaded-tool restoration after the conversation + // row exists without rebuilding the entire tool store. + // + // Stored as `any` to avoid an llm -> mcp import cycle: the mcp package + // already imports llm, and the only consumer that needs the concrete type + // is the llmcontext package, which can import both. Type-assert to + // *mcp.ToolRegistry there. + MCPToolRegistry any +} + +type MCPDynamicToolTelemetry interface { + ObserveMCPDynamicToolEvent(botName, event, result string) } // ContextOption defines a function that configures a Context @@ -99,6 +143,53 @@ func (c *Context) CustomPromptVars() map[string]string { return vars } +func (c *Context) ObserveMCPDynamicToolEvent(event, result string) { + if c == nil || c.MCPDynamicToolTelemetry == nil { + return + } + + botName := c.BotUsername + if botName == "" { + botName = c.BotName + } + if botName == "" { + botName = "unknown" + } + + c.MCPDynamicToolTelemetry.ObserveMCPDynamicToolEvent(botName, event, result) +} + +func (c *Context) MarkMCPDynamicToolSearch() { + if c == nil { + return + } + c.MCPDynamicToolSearchUsed = true +} + +func (c *Context) MarkMCPDynamicToolLoaded(name string) { + if c == nil || name == "" { + return + } + if c.MCPDynamicLoadedToolNames == nil { + c.MCPDynamicLoadedToolNames = make(map[string]bool) + } + c.MCPDynamicLoadedToolNames[name] = true +} + +func (c *Context) ShouldRecordMCPDynamicSearchLoadCallSuccess(name string) bool { + if c == nil || name == "" || !c.MCPDynamicToolSearchUsed || !c.MCPDynamicLoadedToolNames[name] { + return false + } + if c.MCPDynamicSearchLoadCallSuccessRecorded == nil { + c.MCPDynamicSearchLoadCallSuccessRecorded = make(map[string]bool) + } + if c.MCPDynamicSearchLoadCallSuccessRecorded[name] { + return false + } + c.MCPDynamicSearchLoadCallSuccessRecorded[name] = true + return true +} + func (c Context) String() string { var result strings.Builder result.WriteString(fmt.Sprintf("Time: %v\nServerName: %v\nCompanyName: %v", c.Time, c.ServerName, c.CompanyName)) diff --git a/llm/context_test.go b/llm/context_test.go index ae6ec4eb1..1e99e4830 100644 --- a/llm/context_test.go +++ b/llm/context_test.go @@ -10,6 +10,20 @@ import ( "github.com/stretchr/testify/assert" ) +type contextTelemetryEvent struct { + botName string + event string + result string +} + +type fakeMCPDynamicTelemetry struct { + events []contextTelemetryEvent +} + +func (t *fakeMCPDynamicTelemetry) ObserveMCPDynamicToolEvent(botName, event, result string) { + t.events = append(t.events, contextTelemetryEvent{botName: botName, event: event, result: result}) +} + func TestContext_SetBotFields(t *testing.T) { c := NewContext() c.SetBotFields("BotDisplay", "botuser", "user-id-123", "gpt-4", "openai", "Be helpful and concise") @@ -107,3 +121,51 @@ func TestContext_CustomPromptVars(t *testing.T) { }) } } + +func TestContextObserveMCPDynamicToolEventBotLabelFallbacks(t *testing.T) { + tests := []struct { + name string + context *Context + wantBotName string + }{ + { + name: "username", + context: &Context{BotUsername: "matty", BotName: "Matty"}, + wantBotName: "matty", + }, + { + name: "display name", + context: &Context{BotName: "Matty"}, + wantBotName: "Matty", + }, + { + name: "unknown", + context: &Context{}, + wantBotName: "unknown", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + telemetry := &fakeMCPDynamicTelemetry{} + tt.context.MCPDynamicToolTelemetry = telemetry + + tt.context.ObserveMCPDynamicToolEvent("search", "success") + + assert.Equal(t, []contextTelemetryEvent{{botName: tt.wantBotName, event: "search", result: "success"}}, telemetry.events) + }) + } +} + +func TestContextMCPDynamicSearchLoadCallSuccessState(t *testing.T) { + c := &Context{} + + assert.False(t, c.ShouldRecordMCPDynamicSearchLoadCallSuccess("jira__get_issue")) + + c.MarkMCPDynamicToolSearch() + assert.False(t, c.ShouldRecordMCPDynamicSearchLoadCallSuccess("jira__get_issue")) + + c.MarkMCPDynamicToolLoaded("jira__get_issue") + assert.True(t, c.ShouldRecordMCPDynamicSearchLoadCallSuccess("jira__get_issue")) + assert.False(t, c.ShouldRecordMCPDynamicSearchLoadCallSuccess("jira__get_issue")) +} diff --git a/llm/tools.go b/llm/tools.go index 965516495..1a80f3d68 100644 --- a/llm/tools.go +++ b/llm/tools.go @@ -235,8 +235,10 @@ type ToolCall struct { Name string `json:"name"` Description string `json:"description"` Arguments json.RawMessage `json:"arguments"` + Schema any `json:"schema,omitempty"` Result string `json:"result"` Status ToolCallStatus `json:"status"` + MCPBareName string `json:"mcp_bare_name,omitempty"` // ServerOrigin identifies the MCP server this tool came from (the BaseURL). // Empty for built-in tools. Used for auto-approval decisions. @@ -318,8 +320,19 @@ type ToolAuthError struct { } type ToolStore struct { - tools map[string]Tool - authErrors []ToolAuthError + tools map[string]Tool + unloadedMCPTools map[string]ToolInfo + log TraceLog + doTrace bool + authErrors []ToolAuthError +} + +type TraceLog interface { + Info(message string, keyValuePairs ...any) +} + +type warnTraceLog interface { + Warn(message string, keyValuePairs ...any) } // NewJSONSchemaFromStruct creates a JSONSchema from a Go struct using generics @@ -336,13 +349,26 @@ func NewJSONSchemaFromStruct[T any]() *jsonschema.Schema { func NewNoTools() *ToolStore { return &ToolStore{ tools: make(map[string]Tool), + log: nil, + doTrace: false, authErrors: []ToolAuthError{}, } } -func NewToolStore() *ToolStore { +func NewToolStore(options ...any) *ToolStore { + var log TraceLog + var doTrace bool + if len(options) > 0 { + log, _ = options[0].(TraceLog) + } + if len(options) > 1 { + doTrace, _ = options[1].(bool) + } + return &ToolStore{ tools: make(map[string]Tool), + log: log, + doTrace: doTrace, authErrors: []ToolAuthError{}, } } @@ -350,6 +376,9 @@ func NewToolStore() *ToolStore { func (s *ToolStore) AddTools(tools []Tool) { for _, tool := range tools { s.tools[tool.Name] = tool + if s.unloadedMCPTools != nil { + delete(s.unloadedMCPTools, tool.Name) + } } } @@ -361,12 +390,15 @@ func (s *ToolStore) ResolveTool(ctx context.Context, name string, argsGetter Too tool, ok := s.tools[name] if !ok { + s.LogUnknownToolWarning(name, argsGetter) + s.TraceUnknown(name, argsGetter) err := errors.New("unknown tool " + name) span.RecordError(err) span.SetStatus(otelcodes.Error, err.Error()) return "", err } result, err := tool.Resolver(llmCtx, argsGetter) + s.TraceResolved(name, argsGetter, result, err) if err != nil { span.RecordError(err) span.SetStatus(otelcodes.Error, err.Error()) @@ -384,12 +416,55 @@ func (s *ToolStore) GetTools() []Tool { // GetTool returns a pointer to a tool by name, or nil if not found func (s *ToolStore) GetTool(name string) *Tool { + if s == nil { + return nil + } if tool, ok := s.tools[name]; ok { return &tool } return nil } +func (s *ToolStore) SetUnloadedMCPTools(tools []Tool) { + if s == nil { + return + } + if len(tools) == 0 { + s.unloadedMCPTools = nil + return + } + + s.unloadedMCPTools = make(map[string]ToolInfo, len(tools)) + for _, tool := range tools { + if tool.Name == "" || s.GetTool(tool.Name) != nil { + continue + } + s.unloadedMCPTools[tool.Name] = ToolInfo{ + Name: tool.Name, + Description: tool.Description, + } + } + if len(s.unloadedMCPTools) == 0 { + s.unloadedMCPTools = nil + } +} + +func (s *ToolStore) IsUnloadedMCPTool(name string) bool { + if s == nil || s.GetTool(name) != nil { + return false + } + _, ok := s.unloadedMCPTools[name] + return ok +} + +func (s *ToolStore) GetUnloadedMCPToolInfo(name string) (ToolInfo, bool) { + if s == nil || s.GetTool(name) != nil { + return ToolInfo{}, false + } + info, ok := s.unloadedMCPTools[name] + return info, ok +} + // GetServerOrigin returns the ServerOrigin for a tool by name. // Returns empty string if the tool is not found or has no server origin (built-in tools). func (s *ToolStore) GetServerOrigin(toolName string) string { @@ -414,16 +489,25 @@ func (s *ToolStore) KeepToolsIf(keep func(Tool) bool) { // RemoveToolsByServerOrigin removes all tools whose ServerOrigin matches // any of the provided origins. This is used for user-disabled provider // filtering in Copilot DM contexts. +func normalizeToolServerOrigin(origin string) string { + return strings.TrimRight(strings.TrimSpace(origin), "/") +} + func (s *ToolStore) RemoveToolsByServerOrigin(disabledOrigins []string) { if s == nil || len(disabledOrigins) == 0 { return } + disabledSet := make(map[string]bool, len(disabledOrigins)) for _, origin := range disabledOrigins { + origin = normalizeToolServerOrigin(origin) + if origin == "" { + continue + } disabledSet[origin] = true } for name, tool := range s.tools { - if disabledSet[tool.ServerOrigin] { + if disabledSet[normalizeToolServerOrigin(tool.ServerOrigin)] { delete(s.tools, name) } } @@ -432,6 +516,72 @@ func (s *ToolStore) RemoveToolsByServerOrigin(disabledOrigins []string) { // MCPServerToolWildcard in EnabledMCPTool.ToolName means every tool from that ServerOrigin is allowed. const MCPServerToolWildcard = "*" +const MCPToolNameSeparator = "__" + +func NamespaceMCPToolName(serverSlug, bareToolName string) string { + if serverSlug == "" || bareToolName == "" { + return bareToolName + } + return serverSlug + MCPToolNameSeparator + bareToolName +} + +func BareMCPToolName(toolName string) string { + _, bareName, ok := strings.Cut(toolName, MCPToolNameSeparator) + if !ok { + return toolName + } + return bareName +} + +// IsBareMCPToolName reports whether name is non-empty and has no MCP server +// namespace prefix (e.g. "get_issue" rather than "jira__get_issue"). +func IsBareMCPToolName(name string) bool { + return name != "" && BareMCPToolName(name) == name +} + +func MCPToolNameMatches(runtimeName, configuredName string) bool { + return runtimeName == configuredName || BareMCPToolName(runtimeName) == configuredName +} + +// mcpToolAllowed reports whether a tool passes the allowlist filter. Built-in +// tools (empty ServerOrigin) always pass. MCP tools pass when the allowlist +// map contains the key for either the namespaced runtime name or the bare +// name (see BareMCPToolName). Allowlist keys use the format +// "serverOrigin\x00toolName". +func mcpToolAllowed(tool Tool, allowlist map[string]bool) bool { + if tool.ServerOrigin == "" { + return true + } + if allowlist[tool.ServerOrigin+"\x00"+tool.Name] { + return true + } + return allowlist[tool.ServerOrigin+"\x00"+BareMCPToolName(tool.Name)] +} + +// FilterMCPToolsByAllowlist returns a new slice containing every built-in tool +// (empty ServerOrigin) plus every MCP tool whose (ServerOrigin, Name) pair is +// present in the allowlist map. Allowlist keys use the format +// "serverOrigin\x00toolName"; both the namespaced runtime name and the bare +// name (see BareMCPToolName) are checked, so persisted allowlists with legacy +// bare names continue to match. +// +// An empty or nil allowlist drops every MCP tool while still keeping built-in +// tools. The input slice is never mutated. This helper does not interpret +// MCPServerToolWildcard entries; callers that need wildcard semantics should +// pre-expand wildcards into the allowlist map before calling. +func FilterMCPToolsByAllowlist(tools []Tool, allowlist map[string]bool) []Tool { + if len(tools) == 0 { + return tools + } + filtered := make([]Tool, 0, len(tools)) + for _, tool := range tools { + if mcpToolAllowed(tool, allowlist) { + filtered = append(filtered, tool) + } + } + return filtered +} + // RetainOnlyMCPTools filters the tool store to only retain MCP tools whose // (ServerOrigin, Name) pair appears in the allowlist. Built-in tools (those // with empty ServerOrigin) are never removed by this method. @@ -456,17 +606,13 @@ func (s *ToolStore) RetainOnlyMCPTools(allowlist []EnabledMCPTool) { } for name, tool := range s.tools { - // Never filter built-in tools (empty ServerOrigin) - if tool.ServerOrigin == "" { + if mcpToolAllowed(tool, allowed) { continue } if wildcardOrigins[tool.ServerOrigin] { continue } - // Remove MCP tools not in the allowlist - if !allowed[tool.ServerOrigin+"\x00"+tool.Name] { - delete(s.tools, name) - } + delete(s.tools, name) } } @@ -488,6 +634,45 @@ func (s *ToolStore) GetToolsInfo() []ToolInfo { return result } +func (s *ToolStore) TraceUnknown(name string, argsGetter ToolArgumentGetter) { + if s.log != nil && s.doTrace { + s.log.Info("unknown tool called", "name", name, "args", toolArgsForLog(argsGetter)) + } +} + +func (s *ToolStore) TraceResolved(name string, argsGetter ToolArgumentGetter, result string, err error) { + if s.log != nil && s.doTrace { + s.log.Info("tool resolved", "name", name, "args", toolArgsForLog(argsGetter), "result", result, "error", err) + } +} + +// maxToolArgsLogBytes caps the size of the JSON arg snippet we emit to logs. +// Tool calls (especially failures) can carry large payloads; truncating keeps +// log output bounded without losing the diagnostic head of the args. +const maxToolArgsLogBytes = 512 + +func (s *ToolStore) LogUnknownToolWarning(name string, argsGetter ToolArgumentGetter) { + if s == nil || s.log == nil { + return + } + warnLog, ok := s.log.(warnTraceLog) + if !ok { + return + } + warnLog.Warn("unknown tool called", "name", name, "args", toolArgsForLog(argsGetter), "available_tool_count", len(s.tools)) +} + +func toolArgsForLog(argsGetter ToolArgumentGetter) string { + var raw json.RawMessage + if err := argsGetter(&raw); err != nil { + return fmt.Sprintf("failed to get tool args: %v", err) + } + if len(raw) > maxToolArgsLogBytes { + return string(raw[:maxToolArgsLogBytes]) + "...(truncated)" + } + return string(raw) +} + // AddAuthError adds an authentication error to the tool store func (s *ToolStore) AddAuthError(authError ToolAuthError) { s.authErrors = append(s.authErrors, authError) diff --git a/llm/tools_test.go b/llm/tools_test.go index f64517eea..7df48d9a7 100644 --- a/llm/tools_test.go +++ b/llm/tools_test.go @@ -4,7 +4,9 @@ package llm import ( + "context" "encoding/json" + "errors" "sort" "testing" @@ -108,6 +110,117 @@ func TestSanitizeNonPrintableChars(t *testing.T) { } } +type logEntry struct { + message string + fields []any +} + +type captureToolLog struct { + infos []logEntry + warns []logEntry +} + +func (l *captureToolLog) Info(message string, keyValuePairs ...any) { + l.infos = append(l.infos, logEntry{message: message, fields: keyValuePairs}) +} + +func (l *captureToolLog) Warn(message string, keyValuePairs ...any) { + l.warns = append(l.warns, logEntry{message: message, fields: keyValuePairs}) +} + +type infoOnlyToolLog struct { + infos []logEntry +} + +func (l *infoOnlyToolLog) Info(message string, keyValuePairs ...any) { + l.infos = append(l.infos, logEntry{message: message, fields: keyValuePairs}) +} + +func logFields(entry logEntry) map[string]any { + fields := make(map[string]any, len(entry.fields)/2) + for i := 0; i+1 < len(entry.fields); i += 2 { + key, ok := entry.fields[i].(string) + if ok { + fields[key] = entry.fields[i+1] + } + } + return fields +} + +func rawArgsGetter(raw string) ToolArgumentGetter { + return func(args any) error { + return json.Unmarshal([]byte(raw), args) + } +} + +func TestResolveToolUnknownWarnsWithoutTrace(t *testing.T) { + log := &captureToolLog{} + store := NewToolStore(log, false) + + _, err := store.ResolveTool(context.Background(), "ghost_tool", rawArgsGetter(`{"query":"hello"}`), &Context{}) + + require.EqualError(t, err, "unknown tool ghost_tool") + require.Len(t, log.warns, 1) + assert.Empty(t, log.infos) + assert.Equal(t, "unknown tool called", log.warns[0].message) + fields := logFields(log.warns[0]) + assert.Equal(t, "ghost_tool", fields["name"]) + assert.Equal(t, `{"query":"hello"}`, fields["args"]) + assert.Equal(t, 0, fields["available_tool_count"]) +} + +func TestResolveToolUnknownPreservesTrace(t *testing.T) { + log := &captureToolLog{} + store := NewToolStore(log, true) + + _, err := store.ResolveTool(context.Background(), "ghost_tool", rawArgsGetter(`{"query":"hello"}`), &Context{}) + + require.EqualError(t, err, "unknown tool ghost_tool") + require.Len(t, log.warns, 1) + require.Len(t, log.infos, 1) + assert.Equal(t, "unknown tool called", log.warns[0].message) + assert.Equal(t, "unknown tool called", log.infos[0].message) + assert.Equal(t, `{"query":"hello"}`, logFields(log.infos[0])["args"]) +} + +func TestResolveToolUnknownWithInfoOnlyLoggerStillTracesWhenEnabled(t *testing.T) { + log := &infoOnlyToolLog{} + store := NewToolStore(log, true) + + _, err := store.ResolveTool(context.Background(), "ghost_tool", rawArgsGetter(`{"query":"hello"}`), &Context{}) + + require.EqualError(t, err, "unknown tool ghost_tool") + require.Len(t, log.infos, 1) + assert.Equal(t, "unknown tool called", log.infos[0].message) +} + +func TestResolveToolUnknownLogsArgumentGetterError(t *testing.T) { + log := &captureToolLog{} + store := NewToolStore(log, true) + argsErr := errors.New("bad arguments") + + _, err := store.ResolveTool(context.Background(), "ghost_tool", func(any) error { return argsErr }, &Context{}) + + require.EqualError(t, err, "unknown tool ghost_tool") + require.Len(t, log.warns, 1) + require.Len(t, log.infos, 1) + assert.Equal(t, "failed to get tool args: bad arguments", logFields(log.warns[0])["args"]) + assert.Equal(t, "failed to get tool args: bad arguments", logFields(log.infos[0])["args"]) +} + +func TestGetToolKnownAndUnknown(t *testing.T) { + store := NewToolStore(nil, false) + store.AddTools([]Tool{{ + Name: "known", + Resolver: func(_ *Context, _ ToolArgumentGetter) (string, error) { + return "ok", nil + }, + }}) + + require.NotNil(t, store.GetTool("known")) + assert.Nil(t, store.GetTool("ghost")) +} + func TestToolCall_SanitizeArguments(t *testing.T) { tests := []struct { name string @@ -190,7 +303,7 @@ func TestGetServerOrigin(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - store := NewToolStore() + store := NewToolStore(nil, false) store.AddTools(tc.tools) result := store.GetServerOrigin(tc.lookupName) assert.Equal(t, tc.expectedURL, result) @@ -358,6 +471,15 @@ func TestRemoveToolsByServerOrigin(t *testing.T) { disabledOrigins: []string{"https://server-a.com"}, expectedTools: []string{"builtin_tool"}, }, + { + name: "normalizes disabled origins before removal", + tools: []Tool{ + {Name: "tool_a", ServerOrigin: "https://server-a.com/"}, + {Name: "tool_b", ServerOrigin: "https://server-b.com"}, + }, + disabledOrigins: []string{" https://server-a.com "}, + expectedTools: []string{"tool_b"}, + }, { name: "all tools removed when all origins are disabled", tools: []Tool{ @@ -373,7 +495,7 @@ func TestRemoveToolsByServerOrigin(t *testing.T) { t.Run(tc.name, func(t *testing.T) { var store *ToolStore if tc.tools != nil { - store = NewToolStore() + store = NewToolStore(nil, false) store.AddTools(tc.tools) } @@ -395,6 +517,29 @@ func TestRemoveToolsByServerOrigin(t *testing.T) { } } +func TestMCPToolNameHelpers(t *testing.T) { + assert.Equal(t, "jira__get_issue", NamespaceMCPToolName("jira", "get_issue")) + assert.Equal(t, "get_issue", NamespaceMCPToolName("", "get_issue")) + assert.Equal(t, "", NamespaceMCPToolName("jira", "")) + + assert.Equal(t, "get_issue", BareMCPToolName("jira__get_issue")) + assert.Equal(t, "search", BareMCPToolName("search")) + assert.Equal(t, "foo__bar", BareMCPToolName("server__foo__bar")) + + assert.True(t, MCPToolNameMatches("jira__get_issue", "jira__get_issue")) + assert.True(t, MCPToolNameMatches("jira__get_issue", "get_issue")) + assert.True(t, MCPToolNameMatches("server__foo__bar", "foo__bar")) + assert.False(t, MCPToolNameMatches("jira__get_issue", "create_issue")) +} + +func TestIsBareMCPToolName(t *testing.T) { + assert.True(t, IsBareMCPToolName("get_issue")) + assert.True(t, IsBareMCPToolName("search")) + assert.False(t, IsBareMCPToolName("jira__get_issue")) + assert.False(t, IsBareMCPToolName("server__foo__bar")) + assert.False(t, IsBareMCPToolName("")) +} + func TestRetainOnlyMCPTools(t *testing.T) { tests := []struct { name string @@ -445,18 +590,37 @@ func TestRetainOnlyMCPTools(t *testing.T) { wantToolNames: []string{}, }, { - name: "same tool name different server origins — last write wins", + name: "namespaced tools with same bare name are retained independently per origin", tools: []Tool{ - {Name: "search", ServerOrigin: "https://server-a.com"}, - {Name: "search", ServerOrigin: "https://server-b.com"}, + {Name: "jira__search", ServerOrigin: "https://server-a.com"}, + {Name: "github__search", ServerOrigin: "https://server-b.com"}, }, allowlist: []EnabledMCPTool{ {ServerOrigin: "https://server-a.com", ToolName: "search"}, }, - // ToolStore uses tool.Name as map key, so server-b overwrites - // server-a. The allowlist references server-a, which no longer - // exists in the store, so the result is empty. - wantToolNames: []string{}, + wantToolNames: []string{"jira__search"}, + }, + { + name: "bare allowlist retains namespaced runtime tool", + tools: []Tool{ + {Name: "jira__get_issue", ServerOrigin: "https://mcp.atlassian.com"}, + {Name: "jira__create_issue", ServerOrigin: "https://mcp.atlassian.com"}, + }, + allowlist: []EnabledMCPTool{ + {ServerOrigin: "https://mcp.atlassian.com", ToolName: "get_issue"}, + }, + wantToolNames: []string{"jira__get_issue"}, + }, + { + name: "namespaced allowlist retains namespaced runtime tool", + tools: []Tool{ + {Name: "jira__get_issue", ServerOrigin: "https://mcp.atlassian.com"}, + {Name: "jira__create_issue", ServerOrigin: "https://mcp.atlassian.com"}, + }, + allowlist: []EnabledMCPTool{ + {ServerOrigin: "https://mcp.atlassian.com", ToolName: "jira__get_issue"}, + }, + wantToolNames: []string{"jira__get_issue"}, }, { name: "server wildcard entry retains every tool from that origin", @@ -485,6 +649,19 @@ func TestRetainOnlyMCPTools(t *testing.T) { }, wantToolNames: []string{"jira_get", "jira_create", "slack_post"}, }, + { + name: "server wildcard retains namespaced runtime tools", + tools: []Tool{ + {Name: "builtin_search", ServerOrigin: ""}, + {Name: "jira__get_issue", ServerOrigin: "https://mcp.atlassian.com"}, + {Name: "jira__create_issue", ServerOrigin: "https://mcp.atlassian.com"}, + {Name: "github__search", ServerOrigin: "https://api.githubcopilot.com"}, + }, + allowlist: []EnabledMCPTool{ + {ServerOrigin: "https://mcp.atlassian.com", ToolName: MCPServerToolWildcard}, + }, + wantToolNames: []string{"builtin_search", "jira__get_issue", "jira__create_issue"}, + }, { name: "nil ToolStore is safe", tools: nil, // will test on nil *ToolStore @@ -504,7 +681,7 @@ func TestRetainOnlyMCPTools(t *testing.T) { return } - s := NewToolStore() + s := NewToolStore(nil, false) s.AddTools(tt.tools) s.RetainOnlyMCPTools(tt.allowlist) @@ -520,3 +697,124 @@ func TestRetainOnlyMCPTools(t *testing.T) { }) } } + +func TestFilterMCPToolsByAllowlist(t *testing.T) { + builtin := Tool{Name: "builtin_search", ServerOrigin: ""} + atlassianGet := Tool{Name: "jira_get", ServerOrigin: "https://mcp.atlassian.com"} + atlassianCreate := Tool{Name: "jira_create", ServerOrigin: "https://mcp.atlassian.com"} + atlassianNamespacedGet := Tool{Name: "jira__get_issue", ServerOrigin: "https://mcp.atlassian.com"} + slackPost := Tool{Name: "slack_post", ServerOrigin: "https://mcp.slack.com"} + + tests := []struct { + name string + tools []Tool + allowlist map[string]bool + want []Tool + }{ + { + name: "built-in tool always kept", + tools: []Tool{builtin}, + allowlist: map[string]bool{}, + want: []Tool{builtin}, + }, + { + name: "MCP tool with full namespaced name match is kept", + tools: []Tool{atlassianNamespacedGet}, + allowlist: map[string]bool{ + "https://mcp.atlassian.com\x00jira__get_issue": true, + }, + want: []Tool{atlassianNamespacedGet}, + }, + { + name: "MCP tool with bare name match is kept", + tools: []Tool{atlassianNamespacedGet}, + allowlist: map[string]bool{ + "https://mcp.atlassian.com\x00get_issue": true, + }, + want: []Tool{atlassianNamespacedGet}, + }, + { + name: "MCP tool with no match is dropped", + tools: []Tool{atlassianGet}, + allowlist: map[string]bool{ + "https://mcp.slack.com\x00slack_post": true, + }, + want: []Tool{}, + }, + { + name: "mixed slice keeps built-ins and matching MCP tools only", + tools: []Tool{builtin, atlassianGet, atlassianCreate, slackPost}, + allowlist: map[string]bool{ + "https://mcp.atlassian.com\x00jira_get": true, + "https://mcp.slack.com\x00slack_post": true, + }, + want: []Tool{builtin, atlassianGet, slackPost}, + }, + { + name: "empty allowlist drops all MCP tools but keeps built-in", + tools: []Tool{builtin, atlassianGet, slackPost}, + allowlist: map[string]bool{}, + want: []Tool{builtin}, + }, + { + name: "nil allowlist drops all MCP tools but keeps built-in", + tools: []Tool{builtin, atlassianGet, slackPost}, + allowlist: nil, + want: []Tool{builtin}, + }, + { + name: "same bare name across origins is matched per-origin", + tools: []Tool{ + {Name: "jira__search", ServerOrigin: "https://server-a.com"}, + {Name: "github__search", ServerOrigin: "https://server-b.com"}, + }, + allowlist: map[string]bool{ + "https://server-a.com\x00search": true, + }, + want: []Tool{ + {Name: "jira__search", ServerOrigin: "https://server-a.com"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input := append([]Tool(nil), tt.tools...) + got := FilterMCPToolsByAllowlist(tt.tools, tt.allowlist) + assert.Equal(t, tt.want, got) + // Input slice must not be mutated. + assert.Equal(t, input, tt.tools) + }) + } +} + +func TestToolStoreUnloadedMCPTools(t *testing.T) { + var nilStore *ToolStore + nilStore.SetUnloadedMCPTools([]Tool{{Name: "jira__get_issue"}}) + assert.False(t, nilStore.IsUnloadedMCPTool("jira__get_issue")) + _, ok := nilStore.GetUnloadedMCPToolInfo("jira__get_issue") + assert.False(t, ok) + + store := NewNoTools() + store.SetUnloadedMCPTools([]Tool{ + {Name: "jira__get_issue", Description: "Get a Jira issue", ServerOrigin: "https://jira.example.com", Schema: map[string]any{"type": "object"}}, + {Name: "", Description: "ignored"}, + }) + + assert.True(t, store.IsUnloadedMCPTool("jira__get_issue")) + info, ok := store.GetUnloadedMCPToolInfo("jira__get_issue") + require.True(t, ok) + assert.Equal(t, ToolInfo{Name: "jira__get_issue", Description: "Get a Jira issue"}, info) + + store.AddTools([]Tool{{Name: "jira__get_issue", Description: "loaded", ServerOrigin: "https://jira.example.com"}}) + assert.False(t, store.IsUnloadedMCPTool("jira__get_issue")) + _, ok = store.GetUnloadedMCPToolInfo("jira__get_issue") + assert.False(t, ok) + + store.SetUnloadedMCPTools([]Tool{{Name: "github__search", Description: "Search GitHub"}}) + assert.True(t, store.IsUnloadedMCPTool("github__search")) + assert.False(t, store.IsUnloadedMCPTool("jira__get_issue")) + + store.SetUnloadedMCPTools(nil) + assert.False(t, store.IsUnloadedMCPTool("github__search")) +} diff --git a/mcp/client.go b/mcp/client.go index e6c0a826f..89b5c9e63 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -11,9 +11,12 @@ import ( "net/http" "net/url" "strings" + "sync" + "sync/atomic" "time" "github.com/mattermost/mattermost-plugin-agents/config" + "github.com/mattermost/mattermost-plugin-agents/mmapi" "github.com/mattermost/mattermost-plugin-agents/telemetry" "github.com/mattermost/mattermost/server/public/pluginapi" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -46,22 +49,30 @@ type EmbeddedMCPServer interface { // EmbeddedServerClient handles connections to the embedded MCP server type EmbeddedServerClient struct { - server EmbeddedMCPServer - log pluginapi.LogService - pluginAPI *pluginapi.Client + server EmbeddedMCPServer + log pluginapi.LogService + pluginAPI *pluginapi.Client + toolsCache *ToolsCache } // Client represents the connection to a single MCP server type Client struct { - session *mcp.ClientSession - config ServerConfig - tools map[string]*mcp.Tool - userID string - log pluginapi.LogService - oauthManager *OAuthManager - httpClient *http.Client - embeddedClient *EmbeddedServerClient // for reconnection (nil for remote servers) - sessionID string // session ID for embedded server reconnection + session *mcp.ClientSession + config ServerConfig + toolsMu sync.RWMutex + discoveryMu sync.Mutex + tools map[string]*mcp.Tool + toolsDirty bool + toolsGeneration uint64 + notifyOwnerMu sync.RWMutex + notifyOwner *Client + userID string + log pluginapi.LogService + oauthManager *OAuthManager + httpClient *http.Client + toolsCache *ToolsCache + embeddedClient *EmbeddedServerClient // for reconnection (nil for remote servers) + sessionID string // session ID for embedded server reconnection } // staticOAuthCreds returns static OAuth credentials from a server config, or nil if not configured. @@ -122,6 +133,29 @@ func NewEmbeddedServerClient(server EmbeddedMCPServer, log pluginapi.LogService, } } +// NewEmbeddedServerClientWithCache is the same as NewEmbeddedServerClient but +// also wires up a shared tools cache. Pass a non-nil cache when callers want +// per-user tool listings to be cached across requests. +func NewEmbeddedServerClientWithCache(server EmbeddedMCPServer, log pluginapi.LogService, pluginAPI *pluginapi.Client, toolsCache *ToolsCache) *EmbeddedServerClient { + client := NewEmbeddedServerClient(server, log, pluginAPI) + client.toolsCache = toolsCache + return client +} + +func listAllTools(ctx context.Context, session *mcp.ClientSession) (map[string]*mcp.Tool, error) { + tools := make(map[string]*mcp.Tool) + for tool, err := range session.Tools(ctx, &mcp.ListToolsParams{}) { + if err != nil { + return nil, err + } + if tool == nil { + continue + } + tools[tool.Name] = tool + } + return tools, nil +} + // CreateClient creates an embedded MCP client using session ID for authentication. // If sessionID is empty, creates an unauthenticated client (used for tool discovery). func (c *EmbeddedServerClient) CreateClient(ctx context.Context, userID, sessionID string) (*Client, error) { @@ -145,13 +179,21 @@ func (c *EmbeddedServerClient) CreateClient(ctx context.Context, userID, session return nil, fmt.Errorf("failed to create in-memory transport: %w", err) } + var clientPtr atomic.Pointer[Client] + // Create MCP client mcpClient := mcp.NewClient( &mcp.Implementation{ Name: "mattermost-agents-embedded", Version: "1.0", }, - &mcp.ClientOptions{}, + &mcp.ClientOptions{ + ToolListChangedHandler: func(ctx context.Context, _ *mcp.ToolListChangedRequest) { + if cl := clientPtr.Load(); cl != nil { + cl.notificationOwner().invalidateDiscoveredTools(ctx, c.toolsCache, EmbeddedClientKey, c.toolsCache != nil) + } + }, + }, ) // Connect to the embedded server using in-memory transport @@ -167,26 +209,30 @@ func (c *EmbeddedServerClient) CreateClient(ctx context.Context, userID, session tools: make(map[string]*mcp.Tool), userID: userID, log: c.log, - oauthManager: nil, // Embedded servers don't use OAuth + oauthManager: nil, // Embedded servers don't use OAuth + toolsCache: c.toolsCache, embeddedClient: c, // Store client helper for reconnection sessionID: sessionID, // Store session ID for reconnection } + clientPtr.Store(client) // Initialize tools - initResult, err := mcpSession.ListTools(ctx, &mcp.ListToolsParams{}) + discoveredTools, err := listAllTools(ctx, mcpSession) if err != nil { mcpSession.Close() return nil, fmt.Errorf("failed to list tools: %w", err) } - if len(initResult.Tools) == 0 { + if len(discoveredTools) == 0 { mcpSession.Close() return nil, fmt.Errorf("no tools found on MCP server %s for user %s", EmbeddedClientKey, userID) } // Store the tools for this server - for _, tool := range initResult.Tools { - client.tools[tool.Name] = tool + client.toolsMu.Lock() + client.tools = discoveredTools + client.toolsMu.Unlock() + for _, tool := range discoveredTools { c.log.Debug("Registered MCP tool", "userID", userID, "name", tool.Name, @@ -211,6 +257,7 @@ func NewClient(ctx context.Context, userID string, serverConfig ServerConfig, lo log: log, oauthManager: oauthManager, httpClient: httpClient, + toolsCache: toolsCache, } session, err := c.createSession(ctx, serverConfig) @@ -227,7 +274,9 @@ func NewClient(ctx context.Context, userID string, serverConfig ServerConfig, lo cachedTools := toolsCache.GetTools(serverID) if len(cachedTools) > 0 { // Cache hit - use cached tools + c.toolsMu.Lock() c.tools = cachedTools + c.toolsMu.Unlock() log.Debug("Using cached tools for MCP server", "userID", userID, "server", serverConfig.Name, @@ -238,7 +287,7 @@ func NewClient(ctx context.Context, userID string, serverConfig ServerConfig, lo } // Cache miss - fetch tools from server - initResult, err := session.ListTools(ctx, &mcp.ListToolsParams{}) + discoveredTools, err := listAllTools(ctx, session) if err != nil { session.Close() if oauthErr := c.oauthNeededError(err); oauthErr != nil { @@ -247,14 +296,16 @@ func NewClient(ctx context.Context, userID string, serverConfig ServerConfig, lo return nil, fmt.Errorf("failed to list tools: %w", err) } - if len(initResult.Tools) == 0 { + if len(discoveredTools) == 0 { session.Close() return nil, fmt.Errorf("no tools found on MCP server %s for user %s", serverConfig.Name, userID) } // Store the tools for this server - for _, tool := range initResult.Tools { - c.tools[tool.Name] = tool + c.toolsMu.Lock() + c.tools = discoveredTools + c.toolsMu.Unlock() + for _, tool := range discoveredTools { log.Debug("Registered MCP tool", "userID", userID, "name", tool.Name, @@ -264,7 +315,7 @@ func NewClient(ctx context.Context, userID string, serverConfig ServerConfig, lo // Update the global cache with fetched tools. if toolsCache != nil && useSharedToolsCache { - if err := toolsCache.SetTools(serverID, serverConfig.Name, serverConfig.BaseURL, c.tools, time.Now()); err != nil { + if err := toolsCache.SetTools(serverID, serverConfig.Name, serverConfig.BaseURL, discoveredTools, time.Now()); err != nil { log.Warn("Failed to update tools cache", "server", serverConfig.Name, "error", err) } } @@ -273,6 +324,88 @@ func NewClient(ctx context.Context, userID string, serverConfig ServerConfig, lo return c, nil } +// NewPluginClient creates a per-user MCP client for a plugin-registered server. +// Plugin clients use listAllTools and ToolListChangedHandler like other clients, +// but do not use the shared tools cache. +func NewPluginClient(ctx context.Context, userID string, cfg PluginServerConfig, sourcePluginAPI mmapi.Client, log pluginapi.LogService) (*Client, error) { + if sourcePluginAPI == nil { + return nil, fmt.Errorf("sourcePluginAPI is nil; plugin MCP server %s cannot be reached", cfg.PluginID) + } + + originKey := pluginServerOriginKey(cfg.PluginID) + roundTripper := NewPluginHTTPRoundTripper(cfg.PluginID, cfg.Path, sourcePluginAPI) + httpClient := &http.Client{ + Transport: &headerTransport{ + base: roundTripper, + headers: map[string]string{MMUserIDHeader: userID}, + }, + } + + pluginCfg := ServerConfig{ + Name: cfg.Name, + Enabled: true, + BaseURL: originKey, + } + + client := &Client{ + config: pluginCfg, + tools: make(map[string]*mcp.Tool), + userID: userID, + log: log, + httpClient: httpClient, + } + + var clientPtr atomic.Pointer[Client] + clientPtr.Store(client) + + mcpClient := mcp.NewClient( + &mcp.Implementation{ + Name: "mattermost-agents-plugin-bridge", + Version: "1.0", + }, + &mcp.ClientOptions{ + ToolListChangedHandler: func(ctx context.Context, _ *mcp.ToolListChangedRequest) { + if cl := clientPtr.Load(); cl != nil { + cl.invalidateDiscoveredTools(ctx, nil, pluginCfg.Name, false) + } + }, + }, + ) + + session, err := mcpClient.Connect(ctx, &mcp.StreamableClientTransport{ + Endpoint: "http://plugin" + cfg.Path, + HTTPClient: httpClient, + }, nil) + if err != nil { + return nil, fmt.Errorf("failed to connect to plugin MCP server %s: %w", cfg.PluginID, err) + } + + discoveredTools, err := listAllTools(ctx, session) + if err != nil { + session.Close() + return nil, fmt.Errorf("failed to list tools on plugin MCP server %s: %w", cfg.PluginID, err) + } + if len(discoveredTools) == 0 { + session.Close() + return nil, fmt.Errorf("no tools found on plugin MCP server %s for user %s", cfg.PluginID, userID) + } + + client.session = session + client.toolsMu.Lock() + client.tools = discoveredTools + client.toolsMu.Unlock() + + for _, tool := range discoveredTools { + log.Debug("Registered MCP tool", + "userID", userID, + "name", tool.Name, + "description", tool.Description, + "server", originKey) + } + + return client, nil +} + // extractOAuthMetadataURL attempts to extract the OAuth metadata URL from an error message. // This is part of a temporary workaround // Returns the metadata URL and true if found, empty string and false otherwise. @@ -349,7 +482,11 @@ func (c *Client) createSession(ctx context.Context, serverConfig ServerConfig) ( Name: "mattermost-agents", Version: "1.0", }, - &mcp.ClientOptions{}, + &mcp.ClientOptions{ + ToolListChangedHandler: func(ctx context.Context, _ *mcp.ToolListChangedRequest) { + c.invalidateDiscoveredTools(ctx, c.toolsCache, serverConfig.Name, shouldUseSharedToolsCache(serverConfig)) + }, + }, ) httpClient := c.httpClientForMCP(headers) @@ -389,6 +526,117 @@ func (c *Client) createSession(ctx context.Context, serverConfig ServerConfig) ( return nil, fmt.Errorf("failed to connect to MCP server %s, Streamable HTTP: %w, SSE: %w", c.config.Name, errStreamable, errSSE) } +func (c *Client) invalidateDiscoveredTools(ctx context.Context, toolsCache *ToolsCache, serverID string, useSharedToolsCache bool) { + c.toolsMu.Lock() + c.tools = make(map[string]*mcp.Tool) + c.toolsDirty = true + c.toolsGeneration++ + c.toolsMu.Unlock() + + if toolsCache != nil && useSharedToolsCache { + if err := toolsCache.InvalidateServer(serverID); err != nil { + c.log.Warn("Failed to invalidate MCP tools after list_changed notification", + "serverID", serverID, + "server", c.config.Name, + "userID", c.userID, + "error", err) + return + } + } + + c.log.Debug("Invalidated MCP tools after list_changed notification", + "serverID", serverID, + "server", c.config.Name, + "userID", c.userID) +} + +func (c *Client) notificationOwner() *Client { + c.notifyOwnerMu.RLock() + defer c.notifyOwnerMu.RUnlock() + if c.notifyOwner == nil { + return c + } + return c.notifyOwner +} + +func (c *Client) setNotificationOwner(owner *Client) { + c.notifyOwnerMu.Lock() + c.notifyOwner = owner + c.notifyOwnerMu.Unlock() +} + +func (c *Client) ensureDiscoveredTools(ctx context.Context) error { + c.toolsMu.RLock() + dirty := c.toolsDirty + c.toolsMu.RUnlock() + if !dirty { + return nil + } + + c.discoveryMu.Lock() + defer c.discoveryMu.Unlock() + + c.toolsMu.RLock() + dirty = c.toolsDirty + session := c.session + generation := c.toolsGeneration + c.toolsMu.RUnlock() + if !dirty { + return nil + } + if session == nil { + return fmt.Errorf("MCP client not connected") + } + + discoveredTools, err := listAllTools(ctx, session) + if err != nil { + return fmt.Errorf("failed to list tools: %w", err) + } + + c.toolsMu.Lock() + if c.toolsGeneration != generation { + c.toolsMu.Unlock() + c.log.Debug("MCP tools changed during rediscovery; leaving catalog dirty", + "server", c.config.Name, + "userID", c.userID) + return nil + } + c.tools = discoveredTools + c.toolsDirty = false + c.toolsMu.Unlock() + + if c.toolsCache != nil && shouldUseSharedToolsCache(c.config) { + if err := c.toolsCache.SetTools(c.config.Name, c.config.Name, c.config.BaseURL, discoveredTools, time.Now()); err != nil { + c.log.Warn("Failed to update tools cache after list_changed rediscovery", + "server", c.config.Name, + "userID", c.userID, + "error", err) + } + c.toolsMu.RLock() + cacheGenerationChanged := c.toolsGeneration != generation + c.toolsMu.RUnlock() + if cacheGenerationChanged { + if err := c.toolsCache.InvalidateServer(c.config.Name); err != nil { + c.log.Warn("Failed to invalidate MCP tools cache after concurrent list_changed notification", + "server", c.config.Name, + "userID", c.userID, + "error", err) + } + c.log.Debug("MCP tools changed during cache refresh; leaving catalog dirty", + "server", c.config.Name, + "userID", c.userID) + return nil + } + } + + c.log.Debug("Rediscovered MCP tools after list_changed notification", + "server", c.config.Name, + "userID", c.userID, + "toolCount", len(discoveredTools)) + + return nil +} + func (c *Client) oauthStartURL() string { if c.oauthManager == nil { return "" @@ -425,7 +673,9 @@ func (c *Client) Close() error { // Tools returns the tools available from this client func (c *Client) Tools() map[string]*mcp.Tool { - return c.tools + c.toolsMu.RLock() + defer c.toolsMu.RUnlock() + return maps.Clone(c.tools) } // CallTool calls a tool on this MCP server @@ -476,8 +726,12 @@ func (c *Client) CallToolWithMetadata(ctx context.Context, toolName string, args } // Update session and tools from the new client + newClient.setNotificationOwner(c) + c.toolsMu.Lock() c.session = newClient.session - c.tools = newClient.tools + c.tools = newClient.Tools() + c.toolsDirty = false + c.toolsMu.Unlock() c.log.Debug("Successfully reconnected to embedded MCP server", "userID", c.userID) } else { // Reconnect to remote server @@ -496,16 +750,14 @@ func (c *Client) CallToolWithMetadata(ctx context.Context, toolName string, args return "", fmt.Errorf("failed to call tool %s on server %s: %w", toolName, c.config.Name, err) } } - // Extract text content from the result - text := "" - if len(result.Content) > 0 { - for _, content := range result.Content { - // Use type assertion to extract text content - if textContent, ok := content.(*mcp.TextContent); ok { - text += textContent.Text + "\n" - } + var textBuilder strings.Builder + for _, content := range result.Content { + if textContent, ok := content.(*mcp.TextContent); ok { + textBuilder.WriteString(textContent.Text) + textBuilder.WriteByte('\n') } } + text := textBuilder.String() // MCP tools can return IsError=true without transport-level errors. // Surface this as a resolver error so tool-call status is set correctly. diff --git a/mcp/client_embedded_oauth_test.go b/mcp/client_embedded_oauth_test.go new file mode 100644 index 000000000..69dc60696 --- /dev/null +++ b/mcp/client_embedded_oauth_test.go @@ -0,0 +1,367 @@ +// Copyright (c) 2023-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package mcp + +import ( + "context" + "fmt" + "net/url" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/require" +) + +func TestEmbeddedCreateClientDiscoversPaginatedTools(t *testing.T) { + server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3", "tool_4", "tool_5") + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + embeddedClient := NewEmbeddedServerClient(&fakeEmbeddedMCPServer{ctx: ctx, server: server}, newTestLogService(), nil) + client, err := embeddedClient.CreateClient(context.Background(), "user-id", "") + require.NoError(t, err) + t.Cleanup(func() { _ = client.Close() }) + + require.Len(t, client.Tools(), 5) + for _, toolName := range []string{"tool_1", "tool_2", "tool_3", "tool_4", "tool_5"} { + require.Contains(t, client.Tools(), toolName) + } +} + +func TestEmbeddedToolListChangedInvalidatesCacheAndClientTools(t *testing.T) { + server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3") + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + cache := newTestToolsCache() + require.NoError(t, cache.SetTools(EmbeddedClientKey, EmbeddedServerName, EmbeddedClientKey, map[string]*mcp.Tool{ + "cached_tool": { + Name: "cached_tool", + Description: "Cached tool", + InputSchema: map[string]any{"type": "object"}, + }, + }, time.Now())) + + embeddedClient := NewEmbeddedServerClientWithCache(&fakeEmbeddedMCPServer{ctx: ctx, server: server}, newTestLogService(), nil, cache) + client, err := embeddedClient.CreateClient(context.Background(), "user-id", "") + require.NoError(t, err) + t.Cleanup(func() { _ = client.Close() }) + require.NotEmpty(t, client.Tools()) + require.NotNil(t, cache.GetTools(EmbeddedClientKey)) + + addTestMCPTool(server, "new_tool") + + require.Eventually(t, func() bool { + return len(client.Tools()) == 0 && cache.GetTools(EmbeddedClientKey) == nil + }, 5*time.Second, 10*time.Millisecond) +} + +func TestEmbeddedToolListChangedNextGetToolsForUserRediscoversTools(t *testing.T) { + server := newTestMCPServer(2, "tool_1", "tool_2") + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + cache := newTestToolsCache() + pluginAPI := newTestPluginAPIForEmbeddedManager("user-id", "session-id") + embeddedClient := NewEmbeddedServerClientWithCache(&fakeEmbeddedMCPServer{ctx: ctx, server: server}, pluginAPI.Log, pluginAPI, cache) + manager := &ClientManager{ + config: Config{ + EmbeddedServer: EmbeddedServerConfig{Enabled: true}, + }, + log: pluginAPI.Log, + pluginAPI: pluginAPI, + clients: make(map[string]*UserClients), + activity: make(map[string]time.Time), + embeddedClient: embeddedClient, + toolsCache: cache, + } + t.Cleanup(func() { cleanupTestClientManager(manager) }) + + tools, mcpErrors := manager.GetToolsForUser("user-id") + require.Nil(t, mcpErrors) + requireToolNames(t, tools, "mattermost__tool_1", "mattermost__tool_2") + + addTestMCPTool(server, "new_tool") + + require.Eventually(t, func() bool { + manager.clientsMu.RLock() + userClient := manager.clients["user-id"] + manager.clientsMu.RUnlock() + if userClient == nil { + return false + } + client := userClient.clients[EmbeddedClientKey] + return client != nil && len(client.Tools()) == 0 + }, 5*time.Second, 10*time.Millisecond) + + tools, mcpErrors = manager.GetToolsForUser("user-id") + require.Nil(t, mcpErrors) + requireToolNames(t, tools, "mattermost__new_tool", "mattermost__tool_1", "mattermost__tool_2") + require.Len(t, cache.GetTools(EmbeddedClientKey), 3) +} + +func TestEmbeddedReconnectKeepsPaginatedDiscovery(t *testing.T) { + server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3") + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + pluginAPI := newTestPluginAPIWithSession("session-id") + + embeddedClient := NewEmbeddedServerClient(&fakeEmbeddedMCPServer{ctx: ctx, server: server}, pluginAPI.Log, pluginAPI) + client, err := embeddedClient.CreateClient(context.Background(), "test-user", "session-id") + require.NoError(t, err) + require.Len(t, client.Tools(), 3) + t.Cleanup(func() { _ = client.Close() }) + + require.NoError(t, client.session.Close()) + result, err := client.CallTool(context.Background(), "tool_1", map[string]any{}) + require.NoError(t, err) + require.Contains(t, result, "tool_1 ok") + require.Len(t, client.Tools(), 3) + + addTestMCPTool(server, "new_tool") + require.Eventually(t, func() bool { + return len(client.Tools()) == 0 + }, 5*time.Second, 10*time.Millisecond) +} + +func TestClientToolsReturnsCopyAndSurvivesConcurrentInvalidation(t *testing.T) { + client := &Client{ + config: ServerConfig{Name: "server", BaseURL: "https://example.com"}, + tools: map[string]*mcp.Tool{ + "tool_1": {Name: "tool_1"}, + }, + userID: "user-id", + log: newTestLogService(), + } + + tools := client.Tools() + delete(tools, "tool_1") + require.Contains(t, client.Tools(), "tool_1") + + done := make(chan struct{}) + go func() { + defer close(done) + for i := 0; i < 100; i++ { + _ = client.Tools() + } + }() + for i := 0; i < 100; i++ { + client.invalidateDiscoveredTools(context.Background(), nil, "server", false) + } + <-done + require.Empty(t, client.Tools()) +} + +func TestExtractOAuthMetadataURL(t *testing.T) { + tests := []struct { + name string + errMsg string + wantURL string + wantFound bool + }{ + { + name: "nil error", + errMsg: "", + wantURL: "", + wantFound: false, + }, + { + name: "unrelated error", + errMsg: "connection refused", + wantURL: "", + wantFound: false, + }, + { + name: "metadata URL without wrapped error", + errMsg: "OAuth authentication needed for resource at https://api.githubcopilot.com/.well-known/oauth-protected-resource/mcp/", + wantURL: "https://api.githubcopilot.com/.well-known/oauth-protected-resource/mcp/", + wantFound: true, + }, + { + name: "metadata URL with wrapped error", + errMsg: "OAuth authentication needed for resource at https://example.com/.well-known/oauth-protected-resource: Got error: token refresh failed", + wantURL: "https://example.com/.well-known/oauth-protected-resource", + wantFound: true, + }, + { + name: "metadata URL embedded in longer error chain", + errMsg: "failed to connect: OAuth authentication needed for resource at https://api.githubcopilot.com/.well-known/oauth-protected-resource/mcp/", + wantURL: "https://api.githubcopilot.com/.well-known/oauth-protected-resource/mcp/", + wantFound: true, + }, + { + name: "empty metadata URL", + errMsg: "OAuth authentication needed for resource at ", + wantURL: "", + wantFound: false, + }, + { + name: "URL with port", + errMsg: "OAuth authentication needed for resource at https://example.com:8443/.well-known/oauth-protected-resource", + wantURL: "https://example.com:8443/.well-known/oauth-protected-resource", + wantFound: true, + }, + { + name: "URL with port and wrapped error", + errMsg: "OAuth authentication needed for resource at https://example.com:8443/.well-known/oauth-protected-resource: Got error: something failed", + wantURL: "https://example.com:8443/.well-known/oauth-protected-resource", + wantFound: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var err error + if tt.errMsg != "" { + err = fmt.Errorf("%s", tt.errMsg) + } + gotURL, gotFound := extractOAuthMetadataURL(err) + require.Equal(t, tt.wantFound, gotFound) + require.Equal(t, tt.wantURL, gotURL) + }) + } +} + +func TestClientOAuthNeededError(t *testing.T) { + client := &Client{ + config: ServerConfig{ + Name: "OAuth Server", + }, + oauthManager: &OAuthManager{ + callbackURL: "https://mattermost.example.com/plugins/mattermost-ai/oauth/callback", + }, + } + + tests := []struct { + name string + err error + }{ + { + name: "mcp unauthorized error", + err: &mcpUnauthorized{ + metadataURL: "https://oauth.example.com/.well-known/oauth-protected-resource", + }, + }, + { + name: "string matched oauth error", + err: fmt.Errorf("OAuth authentication needed for resource at https://oauth.example.com/.well-known/oauth-protected-resource"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := client.oauthNeededError(tt.err) + require.Error(t, err) + + var oauthErr *OAuthNeededError + require.ErrorAs(t, err, &oauthErr) + authURL, parseErr := url.Parse(oauthErr.AuthURL()) + require.NoError(t, parseErr) + require.Equal(t, "https://mattermost.example.com", authURL.Scheme+"://"+authURL.Host) + require.Equal(t, "/plugins/mattermost-ai/mcp/oauth/OAuth%20Server/start", authURL.EscapedPath()) + require.Equal(t, "https://oauth.example.com/.well-known/oauth-protected-resource", authURL.Query().Get("resource_metadata")) + }) + } +} + +// TestNilCacheHandling verifies that nil cache is handled gracefully in the cache code +func TestNilCacheHandling(t *testing.T) { + // This test documents that the cache code handles nil properly + // The actual NewClient function checks if toolsCache is nil before using it + kvAPI := newMockKVService() + log := &mockLogService{} + cache := NewToolsCache(kvAPI, log) + + // Verify cache can be created and used + require.NotNil(t, cache) + + // Test that GetTools returns nil for non-existent server (not a panic) + tools := cache.GetTools("nonexistent") + require.Nil(t, tools) +} + +func TestShouldUseSharedToolsCache(t *testing.T) { + tests := []struct { + name string + serverConfig ServerConfig + expected bool + }{ + { + name: "server without static oauth creds uses shared cache", + serverConfig: ServerConfig{ + Name: "no-oauth", + BaseURL: "https://example.com", + }, + expected: true, + }, + { + name: "server with static oauth creds skips shared cache", + serverConfig: ServerConfig{ + Name: "static-oauth", + BaseURL: "https://example.com", + ClientID: "client-id", + ClientSecret: "client-secret", + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, shouldUseSharedToolsCache(tt.serverConfig)) + }) + } +} + +func TestInvalidateSharedToolsCacheForOAuthDiscovery(t *testing.T) { + kvAPI := newMockKVService() + log := &mockLogService{} + cache := NewToolsCache(kvAPI, log) + + serverID := "oauth-server" + tools := map[string]*mcp.Tool{ + "search": { + Name: "search", + Description: "Searches data", + }, + } + + err := cache.SetTools(serverID, "OAuth Server", "https://example.com", tools, time.Now()) + require.NoError(t, err) + require.NotNil(t, cache.GetTools(serverID)) + + invalidateSharedToolsCacheForOAuthDiscovery(cache, log, "user-id", serverID, ServerConfig{ + Name: serverID, + BaseURL: "https://example.com", + ClientID: "client-id", + ClientSecret: "client-secret", + }, false) + + require.Nil(t, cache.GetTools(serverID)) +} + +func TestInvalidateSharedToolsCacheForOAuthDiscoveryKeepsCacheWithStoredToken(t *testing.T) { + kvAPI := newMockKVService() + log := &mockLogService{} + cache := NewToolsCache(kvAPI, log) + + serverID := "oauth-server" + tools := map[string]*mcp.Tool{ + "search": { + Name: "search", + Description: "Searches data", + }, + } + + err := cache.SetTools(serverID, "OAuth Server", "https://example.com", tools, time.Now()) + require.NoError(t, err) + + invalidateSharedToolsCacheForOAuthDiscovery(cache, log, "user-id", serverID, ServerConfig{ + Name: serverID, + BaseURL: "https://example.com", + ClientID: "client-id", + ClientSecret: "client-secret", + }, true) + + require.NotNil(t, cache.GetTools(serverID)) +} diff --git a/mcp/client_manager.go b/mcp/client_manager.go index 6f43177a0..ad992d59d 100644 --- a/mcp/client_manager.go +++ b/mcp/client_manager.go @@ -8,6 +8,7 @@ import ( "errors" "net/http" "sort" + "strings" "sync" "time" @@ -97,7 +98,7 @@ func (m *ClientManager) ReInit(config Config, embeddedServer EmbeddedMCPServer) // Update embedded server client if embeddedServer != nil { - m.embeddedClient = NewEmbeddedServerClient(embeddedServer, m.log, m.pluginAPI) + m.embeddedClient = NewEmbeddedServerClientWithCache(embeddedServer, m.log, m.pluginAPI, m.toolsCache) } else { m.embeddedClient = nil } @@ -140,23 +141,23 @@ func (m *ClientManager) Close() { } // createAndStoreUserClient creates a new UserClients instance and stores it in the manager -func (m *ClientManager) createAndStoreUserClient(userID string) (*UserClients, *Errors) { +func (m *ClientManager) createAndStoreUserClient(ctx context.Context, userID string) (*UserClients, *Errors) { + userClients := NewUserClients(userID, m.log, m.oauthManager, m.httpClient, m.toolsCache) + + // Connect outside the manager lock so remote MCP handshakes do not block other users. + mcpErrors := userClients.ConnectToRemoteServers(ctx, m.config.Servers) + userClients.setInitialRemoteConnectErrors(mcpErrors) + m.clientsMu.Lock() defer m.clientsMu.Unlock() - // Check again in case another goroutine created the client while we were waiting for the lock - client, exists := m.clients[userID] - if exists { + // Check again in case another goroutine created the client while we were connecting. + if client, exists := m.clients[userID]; exists { + userClients.Close() m.activity[userID] = time.Now() - return client, client.initialRemoteConnectErrors + return client, client.InitialRemoteConnectErrors() } - userClients := NewUserClients(userID, m.log, m.oauthManager, m.httpClient, m.toolsCache) - - // Let user client connect to remote servers only - mcpErrors := userClients.ConnectToRemoteServers(m.config.Servers) - userClients.initialRemoteConnectErrors = mcpErrors - // Store the client even if some servers failed to connect // This allows partial success - user gets tools from working servers m.clients[userID] = userClients @@ -166,23 +167,25 @@ func (m *ClientManager) createAndStoreUserClient(userID string) (*UserClients, * } // getClientForUser gets or creates an MCP client for a specific user -func (m *ClientManager) getClientForUser(userID string) (*UserClients, *Errors) { +func (m *ClientManager) getClientForUser(ctx context.Context, userID string) (*UserClients, *Errors) { m.clientsMu.Lock() client, exists := m.clients[userID] if exists { m.activity[userID] = time.Now() m.clientsMu.Unlock() - return client, client.initialRemoteConnectErrors + return client, client.InitialRemoteConnectErrors() } m.clientsMu.Unlock() - return m.createAndStoreUserClient(userID) + return m.createAndStoreUserClient(ctx, userID) } // GetToolsForUser returns the tools available for a specific user, connecting to embedded server if session ID provided. func (m *ClientManager) GetToolsForUser(userID string) ([]llm.Tool, *Errors) { + ctx := context.Background() + // Get or create client for this user (connects to remote servers only) - userClient, mcpErrors := m.getClientForUser(userID) + userClient, _ := m.getClientForUser(ctx, userID) // Connect to embedded server using a dedicated per-user session (stored/created in KV). if m.embeddedClient != nil && m.config.EmbeddedServer.Enabled { @@ -190,7 +193,7 @@ func (m *ClientManager) GetToolsForUser(userID string) ([]llm.Tool, *Errors) { if ensureErr != nil { m.log.Debug("Failed to ensure embedded session for user - embedded MCP tools will not be available", "userID", userID, "error", ensureErr) } else if ensuredSessionID != "" { - if embeddedErr := userClient.ConnectToEmbeddedServerIfAvailable(ensuredSessionID, m.embeddedClient, m.config.EmbeddedServer); embeddedErr != nil { + if embeddedErr := userClient.ConnectToEmbeddedServerIfAvailable(ctx, ensuredSessionID, m.embeddedClient, m.config.EmbeddedServer); embeddedErr != nil { m.log.Debug("Failed to connect to embedded server for user - embedded MCP tools will not be available", "userID", userID, "sessionID", ensuredSessionID, "error", embeddedErr) } } @@ -199,20 +202,59 @@ func (m *ClientManager) GetToolsForUser(userID string) ([]llm.Tool, *Errors) { // Snapshot under RLock, then release before PluginHTTP work. pluginSnap := m.snapshotEnabledPluginServers() for _, cfg := range pluginSnap { - if connectErr := userClient.ConnectToPluginServer(context.TODO(), cfg, m.sourcePluginAPI); connectErr != nil { + if connectErr := userClient.ConnectToPluginServer(ctx, cfg, m.sourcePluginAPI); connectErr != nil { m.log.Error("Failed to connect to plugin MCP server", "userID", userID, "pluginID", cfg.PluginID, "error", connectErr) - if mcpErrors == nil { - mcpErrors = &Errors{} - } - mcpErrors.Errors = append(mcpErrors.Errors, connectErr) - // Surface plugin connect failures on subsequent cached lookups. - userClient.initialRemoteConnectErrors = mcpErrors + userClient.appendInitialRemoteConnectError(connectErr) } } - rawTools := userClient.GetTools() + rawTools := userClient.GetTools(ctx) filtered := filterToolsByConfig(rawTools, m.config, m.embeddedClient, pluginSnap) - return filtered, mcpErrors + return filtered, userClient.InitialRemoteConnectErrors() +} + +func (m *ClientManager) GetToolRetrievalOverrides() map[string]ToolRetrievalOverride { + if m == nil { + return nil + } + + var overrides map[string]ToolRetrievalOverride + addOverride := func(serverOrigin string, toolConfig ToolConfig) { + summary := strings.TrimSpace(toolConfig.RetrievalDescriptionOverride) + if summary == "" { + return + } + if overrides == nil { + overrides = make(map[string]ToolRetrievalOverride) + } + overrides[ToolRetrievalOverrideKey(serverOrigin, toolConfig.Name)] = ToolRetrievalOverride{ + Summary: summary, + } + } + + for _, server := range m.config.Servers { + if !server.Enabled { + continue + } + for _, toolConfig := range server.ToolConfigs { + addOverride(server.BaseURL, toolConfig) + } + } + + for _, toolConfig := range m.config.EmbeddedServer.ToolConfigs { + addOverride(EmbeddedClientKey, toolConfig) + } + + for _, server := range m.config.PluginServers { + if !server.Enabled || server.PluginID == "" { + continue + } + for _, toolConfig := range server.ToolConfigs { + addOverride(pluginServerOriginKey(server.PluginID), toolConfig) + } + } + + return overrides } // snapshotEnabledPluginServers returns a copy of enabled plugin configs so @@ -446,7 +488,7 @@ func filterToolsByConfig(rawTools []llm.Tool, cfg Config, embeddedClient *Embedd var filtered []llm.Tool for _, t := range tools { - _, enabled := sc.GetToolPolicy(t.Name) + _, enabled := sc.GetToolPolicy(llm.BareMCPToolName(t.Name)) if enabled { filtered = append(filtered, t) } diff --git a/mcp/client_manager_filter_test.go b/mcp/client_manager_filter_test.go index 4e4cfc0f9..199785dfb 100644 --- a/mcp/client_manager_filter_test.go +++ b/mcp/client_manager_filter_test.go @@ -153,6 +153,56 @@ func TestFilterToolsByConfig(t *testing.T) { }, wantToolNames: []string{"create_post", "search_users"}, }, + { + name: "namespaced tool is denormalized before disabled admin policy lookup", + config: Config{ + Servers: []ServerConfig{ + { + Name: "Jira", + Enabled: true, + BaseURL: "https://mcp.atlassian.com", + ToolConfigs: []ToolConfig{ + {Name: "get_issue", Policy: ToolPolicyAsk, Enabled: false}, + }, + }, + }, + }, + rawTools: []llm.Tool{ + {Name: "jira__get_issue", Description: "Get issue", ServerOrigin: "https://mcp.atlassian.com"}, + }, + }, + { + name: "unconfigured namespaced tool defaults enabled by bare name", + config: Config{ + Servers: []ServerConfig{ + { + Name: "Jira", + Enabled: true, + BaseURL: "https://mcp.atlassian.com", + ToolConfigs: []ToolConfig{ + {Name: "get_issue", Policy: ToolPolicyAsk, Enabled: true}, + }, + }, + }, + }, + rawTools: []llm.Tool{ + {Name: "jira__new_tool", Description: "New tool", ServerOrigin: "https://mcp.atlassian.com"}, + }, + wantToolNames: []string{"jira__new_tool"}, + }, + { + name: "embedded namespaced tool is denormalized before admin policy lookup", + config: Config{ + EmbeddedServer: EmbeddedServerConfig{ + ToolConfigs: []ToolConfig{ + {Name: "search_users", Policy: ToolPolicyAsk, Enabled: false}, + }, + }, + }, + rawTools: []llm.Tool{ + {Name: "mattermost__search_users", Description: "Search users", ServerOrigin: EmbeddedClientKey}, + }, + }, { name: "plugin server enabled, tools flow through default-allow", config: Config{}, diff --git a/mcp/client_manager_test.go b/mcp/client_manager_test.go index 533db0516..00434d3e6 100644 --- a/mcp/client_manager_test.go +++ b/mcp/client_manager_test.go @@ -4,6 +4,7 @@ package mcp import ( + "context" "net/http" "net/http/httptest" "strings" @@ -565,6 +566,143 @@ func TestClientManager_PluginServerRegistry_RaceSafe(t *testing.T) { } } +func TestClientManagerGetToolRetrievalOverridesRemote(t *testing.T) { + manager := &ClientManager{ + config: Config{ + Servers: []ServerConfig{ + { + Name: "Jira", + Enabled: true, + BaseURL: "https://jira.example.com", + ToolConfigs: []ToolConfig{ + {Name: "get_issue", Policy: ToolPolicyAsk, Enabled: true, RetrievalDescriptionOverride: "Find Jira issues by key"}, + {Name: "create_issue", Policy: ToolPolicyAsk, Enabled: true}, + }, + }, + }, + }, + } + + overrides := manager.GetToolRetrievalOverrides() + + require.Equal(t, map[string]ToolRetrievalOverride{ + ToolRetrievalOverrideKey("https://jira.example.com", "get_issue"): { + Summary: "Find Jira issues by key", + }, + }, overrides) +} + +func TestClientManagerGetToolRetrievalOverridesEmbedded(t *testing.T) { + manager := &ClientManager{ + config: Config{ + EmbeddedServer: EmbeddedServerConfig{ + ToolConfigs: []ToolConfig{ + {Name: "search_users", Policy: ToolPolicyAsk, Enabled: true, RetrievalDescriptionOverride: "Find Mattermost people"}, + }, + }, + }, + } + + overrides := manager.GetToolRetrievalOverrides() + + require.Equal(t, map[string]ToolRetrievalOverride{ + ToolRetrievalOverrideKey(EmbeddedClientKey, "search_users"): { + Summary: "Find Mattermost people", + }, + }, overrides) +} + +func TestClientManagerGetToolRetrievalOverridesPlugin(t *testing.T) { + manager := &ClientManager{ + config: Config{ + PluginServers: []PluginServerConfig{ + { + PluginID: "com.example.mcp", + Enabled: true, + ToolConfigs: []ToolConfig{ + {Name: "lookup", Policy: ToolPolicyAsk, Enabled: true, RetrievalDescriptionOverride: "Find plugin records"}, + }, + }, + }, + }, + } + + overrides := manager.GetToolRetrievalOverrides() + + require.Equal(t, map[string]ToolRetrievalOverride{ + ToolRetrievalOverrideKey("plugin://com.example.mcp", "lookup"): { + Summary: "Find plugin records", + }, + }, overrides) +} + +func TestClientManagerGetToolRetrievalOverridesTrimsAndSkipsEmpty(t *testing.T) { + manager := &ClientManager{ + config: Config{ + Servers: []ServerConfig{ + { + Name: "Jira", + Enabled: true, + BaseURL: "https://jira.example.com", + ToolConfigs: []ToolConfig{ + {Name: "get_issue", RetrievalDescriptionOverride: " Find Jira issues "}, + {Name: "create_issue", RetrievalDescriptionOverride: " "}, + }, + }, + }, + }, + } + + overrides := manager.GetToolRetrievalOverrides() + + require.Equal(t, map[string]ToolRetrievalOverride{ + ToolRetrievalOverrideKey("https://jira.example.com", "get_issue"): { + Summary: "Find Jira issues", + }, + }, overrides) +} + +func TestClientManagerGetToolRetrievalOverridesLastDuplicateWins(t *testing.T) { + manager := &ClientManager{ + config: Config{ + Servers: []ServerConfig{ + { + Name: "Jira", + Enabled: true, + BaseURL: "https://jira.example.com", + ToolConfigs: []ToolConfig{ + {Name: "get_issue", RetrievalDescriptionOverride: "old summary"}, + {Name: "get_issue", RetrievalDescriptionOverride: "new summary"}, + }, + }, + }, + }, + } + + overrides := manager.GetToolRetrievalOverrides() + + require.Equal(t, "new summary", overrides[ToolRetrievalOverrideKey("https://jira.example.com", "get_issue")].Summary) +} + +func TestClientManagerGetToolRetrievalOverridesDisabledServer(t *testing.T) { + manager := &ClientManager{ + config: Config{ + Servers: []ServerConfig{ + { + Name: "Jira", + Enabled: false, + BaseURL: "https://jira.example.com", + ToolConfigs: []ToolConfig{ + {Name: "get_issue", RetrievalDescriptionOverride: "Find Jira issues"}, + }, + }, + }, + }, + } + + require.Empty(t, manager.GetToolRetrievalOverrides()) +} + func TestClientManagerInvalidateUserClients(t *testing.T) { now := time.Now() testCases := []struct { @@ -633,7 +771,7 @@ func TestClientManagerCreateAndStoreUserClientSetsInitialActivity(t *testing.T) } before := time.Now() - userClients, mcpErrors := manager.createAndStoreUserClient("user-1") + userClients, mcpErrors := manager.createAndStoreUserClient(context.Background(), "user-1") after := time.Now() require.NotNil(t, userClients) @@ -669,7 +807,7 @@ func TestClientManagerGetClientForUserExistingClientConcurrent(t *testing.T) { defer wg.Done() <-start for range iterations { - got, errs := manager.getClientForUser("user-1") + got, errs := manager.getClientForUser(context.Background(), "user-1") if got != userClients || errs != nil { t.Errorf("getClientForUser returned unexpected result: got=%p errs=%v", got, errs) return diff --git a/mcp/client_test.go b/mcp/client_test.go index 5d23b5f26..2271b59be 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -4,15 +4,240 @@ package mcp import ( + "context" "fmt" - "net/url" + "net/http" + "net/http/httptest" + "sync/atomic" "testing" "time" + "github.com/mattermost/mattermost-plugin-agents/llm" + "github.com/mattermost/mattermost-plugin-agents/mmapi" + "github.com/mattermost/mattermost/server/public/model" + plugintest "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/require" ) +const testListToolsMethod = "tools/list" + +type fixedPluginAPI struct { + plugintest.API + kvGet func(string) ([]byte, *model.AppError) + sessionByID map[string]*model.Session + userByID map[string]*model.User +} + +func (f *fixedPluginAPI) LogDebug(string, ...interface{}) {} + +func (f *fixedPluginAPI) LogInfo(string, ...interface{}) {} + +func (f *fixedPluginAPI) LogWarn(string, ...interface{}) {} + +func (f *fixedPluginAPI) LogError(string, ...interface{}) {} + +func (f *fixedPluginAPI) KVGet(key string) ([]byte, *model.AppError) { + if f.kvGet != nil { + return f.kvGet(key) + } + return nil, nil +} + +func (f *fixedPluginAPI) KVSet(string, []byte) *model.AppError { + return nil +} + +func (f *fixedPluginAPI) KVSetWithOptions(string, []byte, model.PluginKVSetOptions) (bool, *model.AppError) { + return true, nil +} + +func (f *fixedPluginAPI) KVDelete(string) *model.AppError { + return nil +} + +func (f *fixedPluginAPI) GetSession(sessionID string) (*model.Session, *model.AppError) { + if f.sessionByID == nil { + return nil, nil + } + return f.sessionByID[sessionID], nil +} + +func (f *fixedPluginAPI) GetUser(userID string) (*model.User, *model.AppError) { + if f.userByID == nil { + return nil, nil + } + return f.userByID[userID], nil +} + +type fakeEmbeddedMCPServer struct { + ctx context.Context + server *mcp.Server +} + +func (f *fakeEmbeddedMCPServer) CreateClientTransport(_ string, _ string, _ *pluginapi.Client) (*mcp.InMemoryTransport, error) { + serverTransport, clientTransport := mcp.NewInMemoryTransports() + go func() { + _ = f.server.Run(f.ctx, serverTransport) + }() + return clientTransport, nil +} + +func newTestMCPServer(pageSize int, toolNames ...string) *mcp.Server { + return newTestMCPServerWithCapabilities(pageSize, nil, toolNames...) +} + +func newTestMCPServerWithCapabilities(pageSize int, capabilities *mcp.ServerCapabilities, toolNames ...string) *mcp.Server { + var opts *mcp.ServerOptions + if pageSize > 0 || capabilities != nil { + opts = &mcp.ServerOptions{ + PageSize: pageSize, + Capabilities: capabilities, + } + } + server := mcp.NewServer(&mcp.Implementation{ + Name: "test-mcp-server", + Version: "1.0.0", + }, opts) + for _, toolName := range toolNames { + addTestMCPTool(server, toolName) + } + return server +} + +func newStaticToolListMCPServer(pageSize int, toolNames ...string) *mcp.Server { + return newTestMCPServerWithCapabilities(pageSize, &mcp.ServerCapabilities{ + Tools: &mcp.ToolCapabilities{ListChanged: false}, + }, toolNames...) +} + +func newEmptyToolsMCPServer() *mcp.Server { + return mcp.NewServer(&mcp.Implementation{ + Name: "test-empty-mcp-server", + Version: "1.0.0", + }, &mcp.ServerOptions{ + Capabilities: &mcp.ServerCapabilities{ + Tools: &mcp.ToolCapabilities{ListChanged: true}, + }, + }) +} + +func addTestMCPTool(server *mcp.Server, toolName string) { + server.AddTool(&mcp.Tool{ + Name: toolName, + Description: fmt.Sprintf("Test tool %s", toolName), + InputSchema: map[string]any{"type": "object"}, + }, func(context.Context, *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: fmt.Sprintf("%s ok", toolName)}}, + }, nil + }) +} + +func connectInMemoryTestSession(t *testing.T, server *mcp.Server) *mcp.ClientSession { + t.Helper() + + serverTransport, clientTransport := mcp.NewInMemoryTransports() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + go func() { + _ = server.Run(ctx, serverTransport) + }() + + client := mcp.NewClient(&mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, nil) + + session, err := client.Connect(ctx, clientTransport, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = session.Close() }) + return session +} + +func startStreamableMCPServer(t *testing.T, server *mcp.Server) *httptest.Server { + t.Helper() + + httpServer := httptest.NewServer(mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil)) + t.Cleanup(httpServer.Close) + return httpServer +} + +func newTestToolsCache() *ToolsCache { + return NewToolsCache(newMockKVService(), &mockLogService{}) +} + +func newTestLogService() pluginapi.LogService { + return newTestPluginAPIWithSession("").Log +} + +func newTestOAuthManager() *OAuthManager { + pluginAPI := newTestPluginAPIWithSession("") + return NewOAuthManager(mmapi.NewClient(pluginAPI), "https://mattermost.example.com/plugins/mattermost-ai/oauth/callback", http.DefaultClient, nil) +} + +func newTestPluginAPIWithSession(sessionID string) *pluginapi.Client { + fakeAPI := &fixedPluginAPI{ + sessionByID: map[string]*model.Session{ + sessionID: { + Id: sessionID, + UserId: "test-user", + Token: "test-token", + }, + }, + } + return pluginapi.NewClient(fakeAPI, nil) +} + +func newTestPluginAPIForEmbeddedManager(userID, sessionID string) *pluginapi.Client { + fakeAPI := &fixedPluginAPI{ + kvGet: func(key string) ([]byte, *model.AppError) { + if key == buildEmbeddedSessionKey(userID) { + return []byte(sessionID), nil + } + return nil, nil + }, + sessionByID: map[string]*model.Session{ + sessionID: { + Id: sessionID, + UserId: userID, + Token: "test-token", + ExpiresAt: time.Now().Add(time.Hour).UnixMilli(), + }, + }, + userByID: map[string]*model.User{ + userID: { + Id: userID, + Roles: "system_user", + }, + }, + } + return pluginapi.NewClient(fakeAPI, nil) +} + +func requireToolNames(t *testing.T, tools []llm.Tool, expectedNames ...string) { + t.Helper() + + names := make([]string, 0, len(tools)) + for _, tool := range tools { + names = append(names, tool.Name) + } + require.ElementsMatch(t, expectedNames, names) +} + +func cleanupTestClientManager(manager *ClientManager) { + manager.clientsMu.Lock() + defer manager.clientsMu.Unlock() + for _, userClient := range manager.clients { + userClient.Close() + } + manager.clients = make(map[string]*UserClients) +} + // TestCacheHitBehavior verifies that when tools are in cache, // they can be retrieved and reused correctly func TestCacheHitBehavior(t *testing.T) { @@ -99,216 +324,365 @@ func TestCacheUpdateOnNewTools(t *testing.T) { require.Contains(t, cachedTools, "file_write") } -func TestExtractOAuthMetadataURL(t *testing.T) { - tests := []struct { - name string - errMsg string - wantURL string - wantFound bool - }{ - { - name: "nil error", - errMsg: "", - wantURL: "", - wantFound: false, - }, - { - name: "unrelated error", - errMsg: "connection refused", - wantURL: "", - wantFound: false, - }, - { - name: "metadata URL without wrapped error", - errMsg: "OAuth authentication needed for resource at https://api.githubcopilot.com/.well-known/oauth-protected-resource/mcp/", - wantURL: "https://api.githubcopilot.com/.well-known/oauth-protected-resource/mcp/", - wantFound: true, - }, - { - name: "metadata URL with wrapped error", - errMsg: "OAuth authentication needed for resource at https://example.com/.well-known/oauth-protected-resource: Got error: token refresh failed", - wantURL: "https://example.com/.well-known/oauth-protected-resource", - wantFound: true, - }, - { - name: "metadata URL embedded in longer error chain", - errMsg: "failed to connect: OAuth authentication needed for resource at https://api.githubcopilot.com/.well-known/oauth-protected-resource/mcp/", - wantURL: "https://api.githubcopilot.com/.well-known/oauth-protected-resource/mcp/", - wantFound: true, - }, - { - name: "empty metadata URL", - errMsg: "OAuth authentication needed for resource at ", - wantURL: "", - wantFound: false, - }, - { - name: "URL with port", - errMsg: "OAuth authentication needed for resource at https://example.com:8443/.well-known/oauth-protected-resource", - wantURL: "https://example.com:8443/.well-known/oauth-protected-resource", - wantFound: true, - }, - { - name: "URL with port and wrapped error", - errMsg: "OAuth authentication needed for resource at https://example.com:8443/.well-known/oauth-protected-resource: Got error: something failed", - wantURL: "https://example.com:8443/.well-known/oauth-protected-resource", - wantFound: true, - }, +func TestListAllToolsCollectsPaginatedTools(t *testing.T) { + server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3", "tool_4", "tool_5") + session := connectInMemoryTestSession(t, server) + + tools, err := listAllTools(context.Background(), session) + require.NoError(t, err) + require.Len(t, tools, 5) + for _, toolName := range []string{"tool_1", "tool_2", "tool_3", "tool_4", "tool_5"} { + require.Contains(t, tools, toolName) } +} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var err error - if tt.errMsg != "" { - err = fmt.Errorf("%s", tt.errMsg) +func TestListAllToolsSkipsNilTools(t *testing.T) { + server := newTestMCPServer(0, "tool_1") + server.AddReceivingMiddleware(func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + result, err := next(ctx, method, req) + if err != nil || method != testListToolsMethod { + return result, err } - gotURL, gotFound := extractOAuthMetadataURL(err) - require.Equal(t, tt.wantFound, gotFound) - require.Equal(t, tt.wantURL, gotURL) - }) - } + listResult, ok := result.(*mcp.ListToolsResult) + require.True(t, ok) + listResult.Tools = append(listResult.Tools, nil) + return listResult, nil + } + }) + session := connectInMemoryTestSession(t, server) + + tools, err := listAllTools(context.Background(), session) + require.NoError(t, err) + require.Len(t, tools, 1) + require.Contains(t, tools, "tool_1") } -func TestClientOAuthNeededError(t *testing.T) { - client := &Client{ - config: ServerConfig{ - Name: "OAuth Server", - }, - oauthManager: &OAuthManager{ - callbackURL: "https://mattermost.example.com/plugins/mattermost-ai/oauth/callback", - }, +func TestNewClientDiscoversPaginatedRemoteTools(t *testing.T) { + server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3", "tool_4", "tool_5") + httpServer := startStreamableMCPServer(t, server) + cache := newTestToolsCache() + + client, err := NewClient(context.Background(), "user-id", ServerConfig{ + Name: "paged", + BaseURL: httpServer.URL, + Enabled: true, + }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), cache) + require.NoError(t, err) + t.Cleanup(func() { _ = client.Close() }) + + require.Len(t, client.Tools(), 5) + cachedTools := cache.GetTools("paged") + require.Len(t, cachedTools, 5) + for _, toolName := range []string{"tool_1", "tool_2", "tool_3", "tool_4", "tool_5"} { + require.Contains(t, client.Tools(), toolName) + require.Contains(t, cachedTools, toolName) } +} - tests := []struct { - name string - err error - }{ - { - name: "mcp unauthorized error", - err: &mcpUnauthorized{ - metadataURL: "https://oauth.example.com/.well-known/oauth-protected-resource", - }, - }, - { - name: "string matched oauth error", - err: fmt.Errorf("OAuth authentication needed for resource at https://oauth.example.com/.well-known/oauth-protected-resource"), +func TestNewClientUsesCacheWithoutPaginationCall(t *testing.T) { + var listCalls atomic.Int32 + server := newStaticToolListMCPServer(2, "server_tool") + server.AddReceivingMiddleware(func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + if method == testListToolsMethod { + listCalls.Add(1) + return nil, fmt.Errorf("unexpected tools/list call on cache hit") + } + return next(ctx, method, req) + } + }) + httpServer := startStreamableMCPServer(t, server) + cache := newTestToolsCache() + cachedTools := map[string]*mcp.Tool{ + "cached_tool": { + Name: "cached_tool", + Description: "Cached tool", + InputSchema: map[string]any{"type": "object"}, }, } + require.NoError(t, cache.SetTools("paged", "Paged", httpServer.URL, cachedTools, time.Now())) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := client.oauthNeededError(tt.err) - require.Error(t, err) - - var oauthErr *OAuthNeededError - require.ErrorAs(t, err, &oauthErr) - authURL, parseErr := url.Parse(oauthErr.AuthURL()) - require.NoError(t, parseErr) - require.Equal(t, "https://mattermost.example.com", authURL.Scheme+"://"+authURL.Host) - require.Equal(t, "/plugins/mattermost-ai/mcp/oauth/OAuth%20Server/start", authURL.EscapedPath()) - require.Equal(t, "https://oauth.example.com/.well-known/oauth-protected-resource", authURL.Query().Get("resource_metadata")) - }) - } + client, err := NewClient(context.Background(), "user-id", ServerConfig{ + Name: "paged", + BaseURL: httpServer.URL, + Enabled: true, + }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), cache) + require.NoError(t, err) + t.Cleanup(func() { _ = client.Close() }) + + require.Zero(t, listCalls.Load()) + require.Equal(t, cachedTools, client.Tools()) } -// TestNilCacheHandling verifies that nil cache is handled gracefully in the cache code -func TestNilCacheHandling(t *testing.T) { - // This test documents that the cache code handles nil properly - // The actual NewClient function checks if toolsCache is nil before using it - kvAPI := newMockKVService() - log := &mockLogService{} - cache := NewToolsCache(kvAPI, log) +func TestNewClientDoesNotCachePartialPaginationOnError(t *testing.T) { + server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3") + server.AddReceivingMiddleware(func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + if method == testListToolsMethod { + if params, ok := req.GetParams().(*mcp.ListToolsParams); ok && params.Cursor != "" { + return nil, fmt.Errorf("page 2 failed") + } + } + return next(ctx, method, req) + } + }) + httpServer := startStreamableMCPServer(t, server) + cache := newTestToolsCache() + + client, err := NewClient(context.Background(), "user-id", ServerConfig{ + Name: "paged", + BaseURL: httpServer.URL, + Enabled: true, + }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), cache) + require.Error(t, err) + require.Nil(t, client) + require.Nil(t, cache.GetTools("paged")) +} - // Verify cache can be created and used - require.NotNil(t, cache) - - // Test that GetTools returns nil for non-existent server (not a panic) - tools := cache.GetTools("nonexistent") - require.Nil(t, tools) -} - -func TestShouldUseSharedToolsCache(t *testing.T) { - tests := []struct { - name string - serverConfig ServerConfig - expected bool - }{ - { - name: "server without static oauth creds uses shared cache", - serverConfig: ServerConfig{ - Name: "no-oauth", - BaseURL: "https://example.com", - }, - expected: true, - }, - { - name: "server with static oauth creds skips shared cache", - serverConfig: ServerConfig{ - Name: "static-oauth", - BaseURL: "https://example.com", - ClientID: "client-id", - ClientSecret: "client-secret", - }, - expected: false, +func TestNewClientErrorsOnEmptyRemoteToolCatalog(t *testing.T) { + server := newEmptyToolsMCPServer() + httpServer := startStreamableMCPServer(t, server) + cache := newTestToolsCache() + + client, err := NewClient(context.Background(), "user-id", ServerConfig{ + Name: "empty", + BaseURL: httpServer.URL, + Enabled: true, + }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), cache) + require.Error(t, err) + require.Nil(t, client) + require.Contains(t, err.Error(), "no tools found") + require.Nil(t, cache.GetTools("empty")) +} + +func TestRemoteToolListChangedInvalidatesCacheAndClientTools(t *testing.T) { + server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3") + httpServer := startStreamableMCPServer(t, server) + cache := newTestToolsCache() + + client, err := NewClient(context.Background(), "user-id", ServerConfig{ + Name: "paged", + BaseURL: httpServer.URL, + Enabled: true, + }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), cache) + require.NoError(t, err) + t.Cleanup(func() { _ = client.Close() }) + require.NotEmpty(t, client.Tools()) + require.NotNil(t, cache.GetTools("paged")) + + addTestMCPTool(server, "new_tool") + + require.Eventually(t, func() bool { + return len(client.Tools()) == 0 && cache.GetTools("paged") == nil + }, 5*time.Second, 10*time.Millisecond) +} + +func TestRemoteToolListChangedNextGetToolsForUserRediscoversTools(t *testing.T) { + server := newTestMCPServer(2, "tool_1", "tool_2") + httpServer := startStreamableMCPServer(t, server) + cache := newTestToolsCache() + manager := &ClientManager{ + config: Config{ + Servers: []ServerConfig{{ + Name: "paged", + BaseURL: httpServer.URL, + Enabled: true, + }}, }, + log: newTestLogService(), + clients: make(map[string]*UserClients), + activity: make(map[string]time.Time), + oauthManager: newTestOAuthManager(), + httpClient: httpServer.Client(), + toolsCache: cache, } + t.Cleanup(func() { cleanupTestClientManager(manager) }) + + var tools []llm.Tool + var mcpErrors *Errors + require.Eventually(t, func() bool { + tools, mcpErrors = manager.GetToolsForUser("user-id") + if mcpErrors != nil || len(cache.GetTools("paged")) != 2 { + return false + } + toolNames := make(map[string]bool, len(tools)) + for _, tool := range tools { + toolNames[tool.Name] = true + } + return len(tools) == 2 && toolNames["paged__tool_1"] && toolNames["paged__tool_2"] + }, 5*time.Second, 10*time.Millisecond) + require.Nil(t, mcpErrors) + requireToolNames(t, tools, "paged__tool_1", "paged__tool_2") + require.Len(t, cache.GetTools("paged"), 2) + + addTestMCPTool(server, "new_tool") + + require.Eventually(t, func() bool { + manager.clientsMu.RLock() + userClient := manager.clients["user-id"] + manager.clientsMu.RUnlock() + if userClient == nil { + return false + } + client := userClient.clients["paged"] + return client != nil && len(client.Tools()) == 0 && cache.GetTools("paged") == nil + }, 5*time.Second, 10*time.Millisecond) + + tools, mcpErrors = manager.GetToolsForUser("user-id") + require.Nil(t, mcpErrors) + requireToolNames(t, tools, "paged__new_tool", "paged__tool_1", "paged__tool_2") + require.Len(t, cache.GetTools("paged"), 3) +} + +func TestToolListChangedDuringRediscoveryKeepsClientDirty(t *testing.T) { + listBlocked := make(chan struct{}) + releaseList := make(chan struct{}) + var blocked atomic.Bool + server := newTestMCPServer(0, "tool_1") + server.AddReceivingMiddleware(func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + result, err := next(ctx, method, req) + if err != nil || method != testListToolsMethod || !blocked.CompareAndSwap(false, true) { + return result, err + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - require.Equal(t, tt.expected, shouldUseSharedToolsCache(tt.serverConfig)) - }) + close(listBlocked) + select { + case <-releaseList: + case <-ctx.Done(): + return nil, ctx.Err() + } + return result, nil + } + }) + session := connectInMemoryTestSession(t, server) + cache := newTestToolsCache() + client := &Client{ + session: session, + config: ServerConfig{Name: "server", BaseURL: "https://example.com"}, + tools: make(map[string]*mcp.Tool), + toolsDirty: true, + userID: "user-id", + log: newTestLogService(), + toolsCache: cache, } -} -func TestInvalidateSharedToolsCacheForOAuthDiscovery(t *testing.T) { - kvAPI := newMockKVService() - log := &mockLogService{} - cache := NewToolsCache(kvAPI, log) + errCh := make(chan error, 1) + go func() { + errCh <- client.ensureDiscoveredTools(context.Background()) + }() - serverID := "oauth-server" - tools := map[string]*mcp.Tool{ - "search": { - Name: "search", - Description: "Searches data", - }, + select { + case <-listBlocked: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for rediscovery to enter tools/list") } - err := cache.SetTools(serverID, "OAuth Server", "https://example.com", tools, time.Now()) + client.invalidateDiscoveredTools(context.Background(), cache, "server", true) + close(releaseList) + + select { + case err := <-errCh: + require.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for rediscovery to finish") + } + + client.toolsMu.RLock() + require.True(t, client.toolsDirty) + client.toolsMu.RUnlock() + require.Empty(t, client.Tools()) + require.Nil(t, cache.GetTools("server")) + + addTestMCPTool(server, "tool_2") + require.NoError(t, client.ensureDiscoveredTools(context.Background())) + + client.toolsMu.RLock() + require.False(t, client.toolsDirty) + client.toolsMu.RUnlock() + require.Contains(t, client.Tools(), "tool_1") + require.Contains(t, client.Tools(), "tool_2") + require.Len(t, cache.GetTools("server"), 2) +} + +func TestRemoteToolListChangedWithNilCacheClearsClientTools(t *testing.T) { + server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3") + httpServer := startStreamableMCPServer(t, server) + + client, err := NewClient(context.Background(), "user-id", ServerConfig{ + Name: "paged", + BaseURL: httpServer.URL, + Enabled: true, + }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), nil) require.NoError(t, err) - require.NotNil(t, cache.GetTools(serverID)) + t.Cleanup(func() { _ = client.Close() }) + require.NotEmpty(t, client.Tools()) - invalidateSharedToolsCacheForOAuthDiscovery(cache, log, "user-id", serverID, ServerConfig{ - Name: serverID, - BaseURL: "https://example.com", - ClientID: "client-id", - ClientSecret: "client-secret", - }, false) + addTestMCPTool(server, "new_tool") - require.Nil(t, cache.GetTools(serverID)) + require.Eventually(t, func() bool { + return len(client.Tools()) == 0 + }, 5*time.Second, 10*time.Millisecond) } -func TestInvalidateSharedToolsCacheForOAuthDiscoveryKeepsCacheWithStoredToken(t *testing.T) { - kvAPI := newMockKVService() - log := &mockLogService{} - cache := NewToolsCache(kvAPI, log) +func TestRemoteToolListChangedForStaticOAuthSkipsSharedCacheInvalidation(t *testing.T) { + server := newTestMCPServer(2, "server_tool") + httpServer := startStreamableMCPServer(t, server) + cache := newTestToolsCache() + require.NoError(t, cache.SetTools("oauth-server", "OAuth Server", httpServer.URL, map[string]*mcp.Tool{ + "cached_tool": { + Name: "cached_tool", + Description: "Cached tool", + InputSchema: map[string]any{"type": "object"}, + }, + }, time.Now())) - serverID := "oauth-server" - tools := map[string]*mcp.Tool{ - "search": { - Name: "search", - Description: "Searches data", + client, err := NewClient(context.Background(), "user-id", ServerConfig{ + Name: "oauth-server", + BaseURL: httpServer.URL, + Enabled: true, + ClientID: "client-id", + ClientSecret: "client-secret", + }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), cache) + require.NoError(t, err) + t.Cleanup(func() { _ = client.Close() }) + require.Contains(t, client.Tools(), "server_tool") + require.NoError(t, cache.SetTools("oauth-server", "OAuth Server", httpServer.URL, map[string]*mcp.Tool{ + "cached_after_connect": { + Name: "cached_after_connect", + Description: "Cached after connect", + InputSchema: map[string]any{"type": "object"}, }, - } + }, time.Now())) + require.NotNil(t, cache.GetTools("oauth-server")) - err := cache.SetTools(serverID, "OAuth Server", "https://example.com", tools, time.Now()) + addTestMCPTool(server, "new_tool") + + require.Eventually(t, func() bool { + return len(client.Tools()) == 0 + }, 5*time.Second, 10*time.Millisecond) + require.NotNil(t, cache.GetTools("oauth-server")) +} + +func TestRemoteToolListChangedNotificationStormIsIdempotent(t *testing.T) { + server := newTestMCPServer(2, "tool_1") + httpServer := startStreamableMCPServer(t, server) + cache := newTestToolsCache() + + client, err := NewClient(context.Background(), "user-id", ServerConfig{ + Name: "storm", + BaseURL: httpServer.URL, + Enabled: true, + }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), cache) require.NoError(t, err) + t.Cleanup(func() { _ = client.Close() }) - invalidateSharedToolsCacheForOAuthDiscovery(cache, log, "user-id", serverID, ServerConfig{ - Name: serverID, - BaseURL: "https://example.com", - ClientID: "client-id", - ClientSecret: "client-secret", - }, true) + addTestMCPTool(server, "storm_tool") + server.RemoveTools("storm_tool") + addTestMCPTool(server, "storm_tool") - require.NotNil(t, cache.GetTools(serverID)) + require.Eventually(t, func() bool { + return len(client.Tools()) == 0 && cache.GetTools("storm") == nil + }, 5*time.Second, 10*time.Millisecond) } diff --git a/mcp/retrieval_overrides.go b/mcp/retrieval_overrides.go new file mode 100644 index 000000000..cf325f303 --- /dev/null +++ b/mcp/retrieval_overrides.go @@ -0,0 +1,14 @@ +// Copyright (c) 2023-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package mcp + +import "github.com/mattermost/mattermost-plugin-agents/llm" + +type ToolRetrievalOverride struct { + Summary string +} + +func ToolRetrievalOverrideKey(serverOrigin, toolName string) string { + return serverOrigin + "\x00" + llm.BareMCPToolName(toolName) +} diff --git a/mcp/tools_cache_test.go b/mcp/tools_cache_test.go index 2fdaf28db..a99a67574 100644 --- a/mcp/tools_cache_test.go +++ b/mcp/tools_cache_test.go @@ -138,6 +138,7 @@ func (m *mockLogService) Debug(msg string, keyValuePairs ...interface{}) {} func (m *mockLogService) Info(msg string, keyValuePairs ...interface{}) {} func (m *mockLogService) Warn(msg string, keyValuePairs ...interface{}) {} func (m *mockLogService) Error(msg string, keyValuePairs ...interface{}) {} +func (m *mockLogService) Flush() error { return nil } func createTestTools() map[string]*mcp.Tool { return map[string]*mcp.Tool{ @@ -244,6 +245,15 @@ func TestInvalidateServer(t *testing.T) { require.Error(t, err) } +func TestInvalidateServerMissingKeyIsNoop(t *testing.T) { + kvAPI := newMockKVService() + log := &mockLogService{} + cache := NewToolsCache(kvAPI, log) + + require.NoError(t, cache.InvalidateServer("missing_server")) + require.Nil(t, cache.GetTools("missing_server")) +} + func TestBuildCacheKey(t *testing.T) { kvAPI := newMockKVService() log := &mockLogService{} diff --git a/mcp/user_clients.go b/mcp/user_clients.go index 9d433e4a4..fd20bb4aa 100644 --- a/mcp/user_clients.go +++ b/mcp/user_clients.go @@ -5,15 +5,20 @@ package mcp import ( "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "net/http" + "net/url" + "sort" + "strings" + "sync" "time" "github.com/mattermost/mattermost-plugin-agents/llm" "github.com/mattermost/mattermost-plugin-agents/mmapi" "github.com/mattermost/mattermost/server/public/pluginapi" - gosdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" ) // ToolInfo represents a tool's metadata for discovery purposes @@ -25,6 +30,7 @@ type ToolInfo struct { // UserClients represents a per-user MCP client with multiple server connections type UserClients struct { + clientsMu sync.RWMutex clients map[string]*Client // serverID -> client (both remote and embedded) userID string log pluginapi.LogService @@ -38,6 +44,11 @@ type UserClients struct { initialRemoteConnectErrors *Errors } +type userClientSnapshot struct { + serverID string + client *Client +} + func NewUserClients(userID string, log pluginapi.LogService, oauthManager *OAuthManager, httpClient *http.Client, toolsCache *ToolsCache) *UserClients { return &UserClients{ log: log, @@ -50,7 +61,7 @@ func NewUserClients(userID string, log pluginapi.LogService, oauthManager *OAuth } // ConnectToRemoteServers initializes connections to remote MCP servers -func (c *UserClients) ConnectToRemoteServers(servers []ServerConfig) *Errors { +func (c *UserClients) ConnectToRemoteServers(ctx context.Context, servers []ServerConfig) *Errors { if len(servers) == 0 { c.log.Debug("No remote MCP servers provided for user", "userID", c.userID) return nil @@ -65,7 +76,7 @@ func (c *UserClients) ConnectToRemoteServers(servers []ServerConfig) *Errors { continue } - if err := c.connectToServer(context.TODO(), serverConfig.Name, serverConfig); err != nil { + if err := c.connectToServer(ctx, serverConfig.Name, serverConfig); err != nil { // Initialize errors struct if needed if mcpErrors == nil { mcpErrors = &Errors{} @@ -93,12 +104,12 @@ func (c *UserClients) ConnectToRemoteServers(servers []ServerConfig) *Errors { // ConnectToEmbeddedServerIfAvailable connects to the embedded server if session ID is provided. // If a connection already exists, it is reused. -func (c *UserClients) ConnectToEmbeddedServerIfAvailable(sessionID string, embeddedClient *EmbeddedServerClient, embeddedConfig EmbeddedServerConfig) error { +func (c *UserClients) ConnectToEmbeddedServerIfAvailable(ctx context.Context, sessionID string, embeddedClient *EmbeddedServerClient, embeddedConfig EmbeddedServerConfig) error { if !embeddedConfig.Enabled || embeddedClient == nil { return nil } - if _, exists := c.clients[EmbeddedClientKey]; exists { + if c.hasClient(EmbeddedClientKey) { return nil } @@ -106,12 +117,23 @@ func (c *UserClients) ConnectToEmbeddedServerIfAvailable(sessionID string, embed return nil } - ctxWithTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctxWithTimeout, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - if err := c.connectToEmbeddedServerWithClient(ctxWithTimeout, c.userID, sessionID, embeddedClient); err != nil { + + serverClient, err := embeddedClient.CreateClient(ctxWithTimeout, c.userID, sessionID) + if err != nil { c.log.Error("Failed to connect to embedded MCP server", "userID", c.userID, "error", err) return fmt.Errorf("failed to connect to embedded server: %w", err) } + + c.clientsMu.Lock() + defer c.clientsMu.Unlock() + if _, exists := c.clients[EmbeddedClientKey]; exists { + _ = serverClient.Close() + return nil + } + + c.clients[EmbeddedClientKey] = serverClient c.log.Debug("Successfully connected to embedded MCP server", "userID", c.userID) return nil @@ -123,22 +145,71 @@ func (c *UserClients) connectToServer(ctx context.Context, serverID string, serv if err != nil { return err } + c.clientsMu.Lock() + defer c.clientsMu.Unlock() c.clients[serverID] = serverClient return nil } -// connectToEmbeddedServerWithClient establishes a connection to the embedded server using the embedded client helper -func (c *UserClients) connectToEmbeddedServerWithClient(ctx context.Context, userID, sessionID string, embeddedClient *EmbeddedServerClient) error { - serverClient, err := embeddedClient.CreateClient(ctx, userID, sessionID) - if err != nil { - return err +func (c *UserClients) hasClient(serverID string) bool { + c.clientsMu.RLock() + defer c.clientsMu.RUnlock() + _, exists := c.clients[serverID] + return exists +} + +func (c *UserClients) snapshotClients() []userClientSnapshot { + c.clientsMu.RLock() + defer c.clientsMu.RUnlock() + if len(c.clients) == 0 { + return nil } - c.clients[EmbeddedClientKey] = serverClient - return nil + + serverIDs := make([]string, 0, len(c.clients)) + for serverID := range c.clients { + serverIDs = append(serverIDs, serverID) + } + sort.Strings(serverIDs) + + snapshot := make([]userClientSnapshot, 0, len(serverIDs)) + for _, serverID := range serverIDs { + snapshot = append(snapshot, userClientSnapshot{ + serverID: serverID, + client: c.clients[serverID], + }) + } + return snapshot +} + +func (c *UserClients) InitialRemoteConnectErrors() *Errors { + c.clientsMu.RLock() + defer c.clientsMu.RUnlock() + return c.initialRemoteConnectErrors +} + +func (c *UserClients) setInitialRemoteConnectErrors(mcpErrors *Errors) { + c.clientsMu.Lock() + defer c.clientsMu.Unlock() + c.initialRemoteConnectErrors = mcpErrors +} + +func (c *UserClients) appendInitialRemoteConnectError(err error) { + if err == nil { + return + } + c.clientsMu.Lock() + defer c.clientsMu.Unlock() + if c.initialRemoteConnectErrors == nil { + c.initialRemoteConnectErrors = &Errors{} + } + c.initialRemoteConnectErrors.Errors = append(c.initialRemoteConnectErrors.Errors, err) } // Close closes all server connections for a user client func (c *UserClients) Close() { + c.clientsMu.Lock() + defer c.clientsMu.Unlock() + // Close all MCP server clients (both remote and embedded) for serverID, client := range c.clients { if err := client.Close(); err != nil { @@ -151,32 +222,53 @@ func (c *UserClients) Close() { } // GetTools returns the tools available from the clients -func (c *UserClients) GetTools() []llm.Tool { - if len(c.clients) == 0 { +func (c *UserClients) GetTools(ctx context.Context) []llm.Tool { + clientSnapshot := c.snapshotClients() + if len(clientSnapshot) == 0 { return nil } var tools []llm.Tool - seenTools := make(map[string]string) // toolName -> serverID for conflict detection + seenTools := make(map[string]string) // runtime toolName -> serverID for conflict detection + usedSlugs := make(map[string]string) // slug -> server origin for collision suffixing + + // Iterate over a snapshot so callers do not hold clientsMu during network work. + for _, entry := range clientSnapshot { + serverID := entry.serverID + client := entry.client + if err := client.ensureDiscoveredTools(ctx); err != nil { + c.log.Warn("Failed to rediscover MCP tools after list_changed notification", + "userID", c.userID, + "serverID", serverID, + "server", client.config.Name, + "error", err) + continue + } - // Iterate over all clients and collect their tools - for serverID, client := range c.clients { clientTools := client.Tools() - for toolName, tool := range clientTools { - // Check for tool name conflicts across servers - if existingServerID, exists := seenTools[toolName]; exists { - c.log.Warn("Tool name conflict detected", + serverSlug := dedupeMCPServerSlug(mcpServerSlug(serverID, client), client.config.BaseURL, serverID, usedSlugs) + toolNames := make([]string, 0, len(clientTools)) + for toolName := range clientTools { + toolNames = append(toolNames, toolName) + } + sort.Strings(toolNames) + for _, toolName := range toolNames { + tool := clientTools[toolName] + runtimeToolName := llm.NamespaceMCPToolName(serverSlug, toolName) + // Namespacing should make cross-server duplicate bare names safe. A + // final collision means the slug de-dupe or upstream catalog is broken. + if existingServerID, exists := seenTools[runtimeToolName]; exists { + c.log.Warn("Namespaced MCP tool name conflict detected", "userID", c.userID, - "tool", toolName, + "tool", runtimeToolName, "server1", existingServerID, "server2", serverID) - // Skip duplicate tool (first server wins) continue } - seenTools[toolName] = serverID + seenTools[runtimeToolName] = serverID tools = append(tools, llm.Tool{ - Name: toolName, + Name: runtimeToolName, Description: tool.Description, Schema: tool.InputSchema, Resolver: c.createToolResolver(client, toolName), @@ -276,7 +368,12 @@ func (c *UserClients) createToolResolver(client *Client, toolName string) func(l metadata := c.prepareToolCallMetadata(client, toolName, llmContext) - result, err := client.CallToolWithMetadata(context.Background(), toolName, args, metadata) + callCtx := context.Background() + if llmContext != nil && llmContext.RequestContext != nil { + callCtx = llmContext.RequestContext + } + + result, err := client.CallToolWithMetadata(callCtx, toolName, args, metadata) if err != nil { c.rememberOAuthNeededForToolCall(client, err) return result, err @@ -287,6 +384,77 @@ func (c *UserClients) createToolResolver(client *Client, toolName string) func(l } } +func mcpServerSlug(serverID string, client *Client) string { + if client != nil && (client.config.BaseURL == EmbeddedClientKey || client.config.Name == EmbeddedClientKey || serverID == EmbeddedClientKey) { + return "mattermost" + } + + candidates := []string{} + if client != nil { + candidates = append(candidates, client.config.Name) + } + candidates = append(candidates, serverID) + if client != nil && client.config.BaseURL != "" { + if parsed, err := url.Parse(client.config.BaseURL); err == nil { + baseURLName := strings.Trim(strings.Trim(parsed.Host+parsed.Path, "/"), "_") + candidates = append(candidates, baseURLName) + } + } + candidates = append(candidates, "mcp") + + for _, candidate := range candidates { + if slug := sanitizeMCPServerSlug(candidate); slug != "" { + return slug + } + } + return "mcp" +} + +func dedupeMCPServerSlug(slug, serverOrigin, serverID string, usedSlugs map[string]string) string { + if slug == "" { + slug = "mcp" + } + if existingOrigin, exists := usedSlugs[slug]; !exists || existingOrigin == serverOrigin { + usedSlugs[slug] = serverOrigin + return slug + } + + hashInput := serverOrigin + if hashInput == "" { + hashInput = serverID + } + if hashInput == "" { + hashInput = slug + } + dedupedSlug := slug + "_" + shortSlugHash(hashInput) + usedSlugs[dedupedSlug] = serverOrigin + return dedupedSlug +} + +func sanitizeMCPServerSlug(value string) string { + value = strings.ToLower(value) + var b strings.Builder + lastWasSeparator := false + for _, r := range value { + isAllowed := (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') + if isAllowed { + b.WriteRune(r) + lastWasSeparator = false + continue + } + if b.Len() > 0 && !lastWasSeparator { + b.WriteByte('_') + lastWasSeparator = true + } + } + return strings.Trim(b.String(), "_") +} + +func shortSlugHash(value string) string { + sum := sha256.Sum256([]byte(value)) + return hex.EncodeToString(sum[:])[:8] +} + // pluginServerOriginKey returns the synthetic origin string for plugin-server // tools. Must match the key used by filterToolsByConfig. func pluginServerOriginKey(pluginID string) string { @@ -297,72 +465,24 @@ func pluginServerOriginKey(pluginID string) string { // over PluginHTTP, injecting X-Mattermost-UserID. Plugin servers use // inter-plugin auth, not user OAuth. func (c *UserClients) ConnectToPluginServer(ctx context.Context, cfg PluginServerConfig, sourcePluginAPI mmapi.Client) error { - if sourcePluginAPI == nil { - return fmt.Errorf("sourcePluginAPI is nil; plugin MCP server %s cannot be reached", cfg.PluginID) - } - originKey := pluginServerOriginKey(cfg.PluginID) - if _, exists := c.clients[originKey]; exists { + if c.hasClient(originKey) { return nil } - roundTripper := NewPluginHTTPRoundTripper(cfg.PluginID, cfg.Path, sourcePluginAPI) - httpClient := &http.Client{ - Transport: &headerTransport{ - base: roundTripper, - headers: map[string]string{MMUserIDHeader: c.userID}, - }, - } - - // Endpoint URL is a placeholder — PluginHTTPRoundTripper rewrites - // req.URL.Path on each round trip. go-sdk requires a parseable URL. - mcpClient := gosdkmcp.NewClient( - &gosdkmcp.Implementation{ - Name: "mattermost-agents-plugin-bridge", - Version: "1.0", - }, - &gosdkmcp.ClientOptions{}, - ) - session, err := mcpClient.Connect(ctx, &gosdkmcp.StreamableClientTransport{ - Endpoint: "http://plugin" + cfg.Path, - HTTPClient: httpClient, - }, nil) + client, err := NewPluginClient(ctx, c.userID, cfg, sourcePluginAPI, c.log) if err != nil { - return fmt.Errorf("failed to connect to plugin MCP server %s: %w", cfg.PluginID, err) - } - - initResult, err := session.ListTools(ctx, &gosdkmcp.ListToolsParams{}) - if err != nil { - _ = session.Close() - return fmt.Errorf("failed to list tools on plugin MCP server %s: %w", cfg.PluginID, err) - } - if len(initResult.Tools) == 0 { - _ = session.Close() - return fmt.Errorf("no tools found on plugin MCP server %s for user %s", cfg.PluginID, c.userID) - } - - // Synthetic ServerConfig: BaseURL == originKey ties the client into - // filterToolsByConfig via llm.Tool.ServerOrigin in GetTools. - pluginCfg := ServerConfig{ - Name: cfg.Name, - Enabled: true, - BaseURL: originKey, + return err } - client := &Client{ - session: session, - config: pluginCfg, - tools: make(map[string]*gosdkmcp.Tool, len(initResult.Tools)), - userID: c.userID, - log: c.log, - httpClient: httpClient, - // oauthManager/embeddedClient stay nil; reconnect reuses httpClient. - } - for _, tool := range initResult.Tools { - client.tools[tool.Name] = tool + c.clientsMu.Lock() + defer c.clientsMu.Unlock() + if _, exists := c.clients[originKey]; exists { + _ = client.Close() + return nil } c.clients[originKey] = client - c.log.Debug("Connected to plugin MCP server", "userID", c.userID, "pluginID", cfg.PluginID, "toolCount", len(client.tools)) + c.log.Debug("Connected to plugin MCP server", "userID", c.userID, "pluginID", cfg.PluginID, "toolCount", len(client.Tools())) return nil } diff --git a/mcp/user_clients_test.go b/mcp/user_clients_test.go index c2c32670e..f079ec6ce 100644 --- a/mcp/user_clients_test.go +++ b/mcp/user_clients_test.go @@ -13,7 +13,7 @@ import ( "github.com/mattermost/mattermost-plugin-agents/llm" plugintest "github.com/mattermost/mattermost/server/public/plugin/plugintest" "github.com/mattermost/mattermost/server/public/pluginapi" - gosdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + gomcp "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -42,7 +42,7 @@ func newFakePluginMCPServer(t *testing.T, toolCount int) *httptest.Server { // prefix; UserClients.GetTools dedupes by tool name across servers. func newFakePluginMCPServerWithPrefix(t *testing.T, prefix string, toolCount int) *httptest.Server { t.Helper() - srv := gosdkmcp.NewServer(&gosdkmcp.Implementation{Name: "fake", Version: "1.0"}, nil) + srv := gomcp.NewServer(&gomcp.Implementation{Name: "fake", Version: "1.0"}, nil) type echoIn struct { Message string `json:"message"` } @@ -51,13 +51,13 @@ func newFakePluginMCPServerWithPrefix(t *testing.T, prefix string, toolCount int } for i := 0; i < toolCount; i++ { name := fmt.Sprintf("%s_%d", prefix, i) - gosdkmcp.AddTool(srv, &gosdkmcp.Tool{Name: name, Description: "test"}, func(_ context.Context, _ *gosdkmcp.CallToolRequest, in echoIn) (*gosdkmcp.CallToolResult, echoOut, error) { + gomcp.AddTool(srv, &gomcp.Tool{Name: name, Description: "test"}, func(_ context.Context, _ *gomcp.CallToolRequest, in echoIn) (*gomcp.CallToolResult, echoOut, error) { return nil, echoOut{Echo: in.Message}, nil }) } - h := gosdkmcp.NewStreamableHTTPHandler( - func(*http.Request) *gosdkmcp.Server { return srv }, - &gosdkmcp.StreamableHTTPOptions{Stateless: true, JSONResponse: true}, + h := gomcp.NewStreamableHTTPHandler( + func(*http.Request) *gomcp.Server { return srv }, + &gomcp.StreamableHTTPOptions{Stateless: true, JSONResponse: true}, ) return httptest.NewServer(h) } @@ -76,6 +76,112 @@ func newPluginHTTPForwarder(t *testing.T, target *httptest.Server) *fakePluginHT } } +func TestUserClientsGetToolsNamespacesDuplicateBareNames(t *testing.T) { + userClients := &UserClients{ + userID: "user-id", + clients: map[string]*Client{ + "github": testClientWithTools("GitHub", "https://api.githubcopilot.com", "search"), + "jira": testClientWithTools("Jira", "https://mcp.atlassian.com", "search"), + }, + } + + tools := userClients.GetTools(context.Background()) + + requireToolNames(t, tools, "github__search", "jira__search") +} + +func TestUserClientsGetToolsResolverUsesBareToolName(t *testing.T) { + server := newTestMCPServer(0, "search") + session := connectInMemoryTestSession(t, server) + userClients := &UserClients{ + userID: "user-id", + clients: map[string]*Client{ + "jira": { + session: session, + config: ServerConfig{Name: "Jira", BaseURL: "https://mcp.atlassian.com", Enabled: true}, + tools: map[string]*gomcp.Tool{ + "search": { + Name: "search", + Description: "Search Jira", + }, + }, + }, + }, + } + + tools := userClients.GetTools(context.Background()) + requireToolNames(t, tools, "jira__search") + + result, err := tools[0].Resolver(&llm.Context{}, func(args any) error { + *(args.(*map[string]any)) = map[string]any{} + return nil + }) + + require.NoError(t, err) + require.Equal(t, "search ok\n", result) +} + +func TestUserClientsGetToolsEmbeddedToolNamesUseMattermostSlug(t *testing.T) { + userClients := &UserClients{ + userID: "user-id", + clients: map[string]*Client{ + EmbeddedClientKey: testClientWithTools(EmbeddedClientKey, EmbeddedClientKey, "search_users"), + }, + } + + tools := userClients.GetTools(context.Background()) + + requireToolNames(t, tools, "mattermost__search_users") +} + +func TestUserClientsGetToolsDeterministicSlugCollision(t *testing.T) { + userClients := &UserClients{ + userID: "user-id", + clients: map[string]*Client{ + "server-a": testClientWithTools("Jira!", "https://a.example.com", "search"), + "server-b": testClientWithTools("Jira", "https://b.example.com", "search"), + }, + } + expectedDedupedName := "jira_" + shortSlugHash("https://b.example.com") + "__search" + + first := userClients.GetTools(context.Background()) + second := userClients.GetTools(context.Background()) + + requireToolNames(t, first, "jira__search", expectedDedupedName) + requireToolNames(t, second, "jira__search", expectedDedupedName) +} + +func TestUserClientsGetToolsPreservesRediscoveryBeforeRead(t *testing.T) { + server := newTestMCPServer(0, "old_tool") + session := connectInMemoryTestSession(t, server) + client := &Client{ + session: session, + config: ServerConfig{Name: "Jira", BaseURL: "https://mcp.atlassian.com", Enabled: true}, + tools: make(map[string]*gomcp.Tool), + toolsDirty: true, + userID: "user-id", + log: newTestLogService(), + } + userClients := &UserClients{ + userID: "user-id", + log: newTestLogService(), + clients: map[string]*Client{ + "jira": client, + }, + } + + addTestMCPTool(server, "new_tool") + require.NoError(t, client.ensureDiscoveredTools(context.Background())) + client.toolsMu.Lock() + client.toolsDirty = true + client.tools = make(map[string]*gomcp.Tool) + client.toolsMu.Unlock() + + tools := userClients.GetTools(context.Background()) + + requireToolNames(t, tools, "jira__new_tool", "jira__old_tool") +} + func TestConnectToPluginServer_HappyPath(t *testing.T) { target := newFakePluginMCPServer(t, 2) t.Cleanup(target.Close) @@ -98,11 +204,11 @@ func TestConnectToPluginServer_HappyPath(t *testing.T) { require.NoError(t, err) originKey := "plugin://" + cfg.PluginID - c, ok := uc.clients[originKey] - require.True(t, ok, "expected client under origin key %s", originKey) - require.NotNil(t, c) - require.Equal(t, originKey, c.config.BaseURL) - require.Len(t, c.tools, 2) + require.True(t, uc.hasClient(originKey)) + snapshot := uc.snapshotClients() + require.Len(t, snapshot, 1) + require.Equal(t, originKey, snapshot[0].client.config.BaseURL) + require.Len(t, snapshot[0].client.Tools(), 2) } func TestConnectToPluginServer_Idempotent(t *testing.T) { @@ -132,10 +238,67 @@ func TestConnectToPluginServer_NilAPI(t *testing.T) { require.Error(t, err) } +func TestConnectToEmbeddedServerIfAvailable_Idempotent(t *testing.T) { + server := newTestMCPServer(0, "tool_1") + runCtx, cancelRun := context.WithCancel(context.Background()) + t.Cleanup(cancelRun) + + pluginAPI := newTestPluginAPIForEmbeddedManager("alice", "session-id") + embeddedClient := NewEmbeddedServerClient(&fakeEmbeddedMCPServer{ctx: runCtx, server: server}, pluginAPI.Log, pluginAPI) + uc := NewUserClients("alice", pluginAPI.Log, nil, nil, nil) + cfg := EmbeddedServerConfig{Enabled: true} + + require.NoError(t, uc.ConnectToEmbeddedServerIfAvailable(context.Background(), "session-id", embeddedClient, cfg)) + firstSnapshot := uc.snapshotClients() + require.Len(t, firstSnapshot, 1) + firstClient := firstSnapshot[0].client + + // Stop the embedded server so a second dial would fail if Connect re-created a client. + cancelRun() + require.NoError(t, uc.ConnectToEmbeddedServerIfAvailable(context.Background(), "session-id", embeddedClient, cfg)) + + secondSnapshot := uc.snapshotClients() + require.Len(t, secondSnapshot, 1) + require.Same(t, firstClient, secondSnapshot[0].client) +} + +func TestUserClientsGetToolsResolverUsesRequestContext(t *testing.T) { + callCtx, cancel := context.WithCancel(context.Background()) + cancel() + + server := newTestMCPServer(0, "search") + session := connectInMemoryTestSession(t, server) + userClients := &UserClients{ + userID: "user-id", + clients: map[string]*Client{ + "jira": { + session: session, + config: ServerConfig{Name: "Jira", BaseURL: "https://mcp.atlassian.com", Enabled: true}, + tools: map[string]*gomcp.Tool{ + "search": { + Name: "search", + Description: "Search Jira", + }, + }, + }, + }, + } + + tools := userClients.GetTools(context.Background()) + require.Len(t, tools, 1) + + _, err := tools[0].Resolver(&llm.Context{RequestContext: callCtx}, func(args any) error { + *(args.(*map[string]any)) = map[string]any{} + return nil + }) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) +} + func TestPrepareToolCallMetadata_EmbeddedMergesCallMetadataAndBotUserID(t *testing.T) { llmContext := llm.NewContext() llmContext.BotUserID = "bot-user-id" - llmContext.Tools = llm.NewToolStore() + llmContext.Tools = llm.NewToolStore(nil, false) llmContext.Tools.AddTools([]llm.Tool{ llm.Tool{Name: "search_posts"}.WithCallMetadata(map[string]any{ "tool_hooks": map[string]any{ @@ -164,3 +327,21 @@ func TestPrepareToolCallMetadata_EmbeddedMergesCallMetadataAndBotUserID(t *testi remoteMeta := clients.prepareToolCallMetadata(remoteClient, "search_posts", llmContext) require.Nil(t, remoteMeta) } + +func testClientWithTools(name, baseURL string, toolNames ...string) *Client { + tools := make(map[string]*gomcp.Tool, len(toolNames)) + for _, toolName := range toolNames { + tools[toolName] = &gomcp.Tool{ + Name: toolName, + Description: "Test tool " + toolName, + } + } + return &Client{ + config: ServerConfig{ + Name: name, + BaseURL: baseURL, + Enabled: true, + }, + tools: tools, + } +} diff --git a/mcpserver/eval_helpers_test.go b/mcpserver/eval_helpers_test.go index fef0ad6d6..deb0d24ad 100644 --- a/mcpserver/eval_helpers_test.go +++ b/mcpserver/eval_helpers_test.go @@ -563,7 +563,7 @@ func setupAgenticEval(t *testing.T, e *evals.EvalT, suite *TestSuite, requesting allToolNames[i] = tool.Name } - toolStore := llm.NewToolStore() + toolStore := llm.NewToolStore(nil, false) toolStore.AddTools(mcpTools) llmContext := llm.NewContext() diff --git a/search/search_test.go b/search/search_test.go index 2873ef20a..56e3e9678 100644 --- a/search/search_test.go +++ b/search/search_test.go @@ -688,15 +688,14 @@ func TestRunSearch(t *testing.T) { }). Return(nil).Once() - // Second DM is for response post (async in goroutine) - use Maybe since test may finish before goroutine - mockClient.On("DM", "bot1", "user1", mock.Anything).Return(nil).Maybe() + // Second DM is for response post (async in goroutine). + mockClient.On("DM", "bot1", "user1", mock.Anything).Return(nil).Once() - // The goroutine may call LogError if the search fails - use Maybe to handle both cases - mockClient.On("LogError", mock.Anything, mock.Anything).Maybe() - - // The goroutine may call Search - set up to return empty results to avoid further processing + // Return empty results to exercise the async UpdatePost path, then wait + // for that update so this test does not leak background work into later + // tracing assertions. mockEmbedding.On("Search", mock.Anything, mock.Anything, mock.Anything). - Return([]embeddings.SearchResult{}, nil).Maybe() + Return([]embeddings.SearchResult{}, nil).Once() // If zero results, UpdatePost is called. Wait for it so the async // search goroutine cannot leak into following tests. diff --git a/telemetry/integration_test.go b/telemetry/integration_test.go index 2d78b56f0..4f0f26fde 100644 --- a/telemetry/integration_test.go +++ b/telemetry/integration_test.go @@ -252,7 +252,7 @@ func TestToolResolveSpan(t *testing.T) { exporter, cleanup := setupTestTracing(t) defer cleanup() - store := llm.NewToolStore() + store := llm.NewToolStore(nil, false) store.AddTools([]llm.Tool{ { Name: "test_tool", @@ -291,7 +291,7 @@ func TestToolResolveUnknownSpan(t *testing.T) { exporter, cleanup := setupTestTracing(t) defer cleanup() - store := llm.NewToolStore() + store := llm.NewToolStore(nil, false) _, err := store.ResolveTool(context.Background(), "nonexistent", func(args any) error { return nil @@ -380,7 +380,7 @@ func TestFullRequestTrace(t *testing.T) { llmSpan.End() // Tool resolution - store := llm.NewToolStore() + store := llm.NewToolStore(nil, false) store.AddTools([]llm.Tool{ { Name: "web_search", From e7d4c2ac9261fb24ed74b7f6151095276f717e47 Mon Sep 17 00:00:00 2001 From: Nick Misasi Date: Fri, 22 May 2026 23:04:51 -0400 Subject: [PATCH 2/7] dynamic mcp: address client catalog review feedback Co-authored-by: Cursor --- api/api.go | 2 +- api/api_channel.go | 1 + api/api_llm_bridge.go | 12 +++-- api/api_llm_bridge_test.go | 5 ++ api/api_mcp.go | 4 +- api/api_no_tools_test.go | 2 +- api/api_test.go | 2 +- conversations/bot_channel_tool_filter_test.go | 27 ++++++++++ conversations/conversations_test.go | 2 +- conversations/dm_conversation_test.go | 2 +- conversations/handle_messages.go | 2 + conversations/regeneration.go | 1 + conversations/tool_approval.go | 2 + conversations/tool_policy_test.go | 36 +++++++++++++ llm/tools.go | 3 ++ llm/tools_test.go | 4 ++ llmcontext/llm_context.go | 18 ++++++- llmcontext/llm_context_test.go | 7 ++- mcp/client.go | 3 ++ mcp/client_embedded_oauth_test.go | 17 ++++++- mcp/client_integration_test.go | 2 +- mcp/client_manager.go | 34 ++++++++++--- mcp/client_manager_test.go | 50 +++++++++++++++++-- mcp/client_test.go | 4 +- mcp/tool_policy.go | 27 ++++++++-- mcp/tool_policy_lookup_test.go | 20 ++++++++ mcp/user_clients.go | 18 ++----- mcp/user_clients_test.go | 31 +++++++++++- 28 files changed, 287 insertions(+), 51 deletions(-) diff --git a/api/api.go b/api/api.go index be56da48b..0a7d09149 100644 --- a/api/api.go +++ b/api/api.go @@ -63,7 +63,7 @@ type MCPClientManager interface { MarkOAuthNeeded(userID, serverName, authURL string) error GetEmbeddedServer() mcp.EmbeddedMCPServer EnsureMCPSessionID(userID string) (string, error) - GetToolsForUser(userID string) ([]llm.Tool, *mcp.Errors) + GetToolsForUser(ctx context.Context, userID string) ([]llm.Tool, *mcp.Errors) GetConfig() mcp.Config RegisterPluginServer(cfg mcp.PluginServerConfig) diff --git a/api/api_channel.go b/api/api_channel.go index 025af641d..c76f4de9d 100644 --- a/api/api_channel.go +++ b/api/api_channel.go @@ -87,6 +87,7 @@ func (a *API) handleChannelAnalysis(c *gin.Context) { } opts := []llm.ContextOption{ + a.contextBuilder.WithLLMContextRequestContext(c.Request.Context()), a.contextBuilder.WithLLMContextDefaultTools(bot), } diff --git a/api/api_llm_bridge.go b/api/api_llm_bridge.go index 8e4b2db47..37fbe8216 100644 --- a/api/api_llm_bridge.go +++ b/api/api_llm_bridge.go @@ -5,6 +5,7 @@ package api import ( "bytes" + stdcontext "context" "encoding/json" "errors" "fmt" @@ -163,13 +164,14 @@ func (a *API) buildLLMBridgeContext(bot *bots.Bot, req bridgeclient.CompletionRe return context, nil } -func (a *API) convertAgentBridgeRequestToInternal(bot *bots.Bot, req bridgeclient.CompletionRequest, includeTools bool, operation, operationSubType string) (llm.CompletionRequest, error) { +func (a *API) convertAgentBridgeRequestToInternal(ctx stdcontext.Context, bot *bots.Bot, req bridgeclient.CompletionRequest, includeTools bool, operation, operationSubType string) (llm.CompletionRequest, error) { posts, err := a.convertBridgePostsToInternal(req) if err != nil { return llm.CompletionRequest{}, err } bridgeContext := llm.NewContext() + bridgeContext.RequestContext = ctx bridgeContext.RequestingUser = &model.User{Id: req.UserID} if includeTools && a.contextBuilder != nil { a.contextBuilder.WithLLMContextTools(bot)(bridgeContext) @@ -270,6 +272,7 @@ func validateCompletionRequestIDs(req bridgeclient.CompletionRequest) (int, erro } func (a *API) prepareAgentBridgeCompletion( + ctx stdcontext.Context, agent string, req bridgeclient.CompletionRequest, pluginID string, @@ -314,7 +317,7 @@ func (a *API) prepareAgentBridgeCompletion( } toolsRequested := allowedToolNames != nil - llmRequest, err := a.convertAgentBridgeRequestToInternal(bot, req, toolsRequested, operation, operationSubType) + llmRequest, err := a.convertAgentBridgeRequestToInternal(ctx, bot, req, toolsRequested, operation, operationSubType) if err != nil { return nil, llm.CompletionRequest{}, nil, nil, nil, http.StatusBadRequest, fmt.Errorf("invalid request: %v", err) } @@ -687,6 +690,7 @@ func (a *API) handleGetAgentTools(c *gin.Context) { // Build a minimal context just to resolve the bot's available tools. toolContext := llm.NewContext() + toolContext.RequestContext = c.Request.Context() toolContext.RequestingUser = &model.User{Id: userID} if a.contextBuilder != nil { a.contextBuilder.WithLLMContextTools(bot)(toolContext) @@ -780,7 +784,7 @@ func (a *API) handleAgentCompletionStreaming(c *gin.Context) { return } - bot, llmRequest, opts, shouldExecute, beforeHookKeys, statusCode, err := a.prepareAgentBridgeCompletion(agent, req, c.GetHeader("Mattermost-Plugin-ID"), llm.OperationBridgeAgent, llm.SubTypeStreaming) + bot, llmRequest, opts, shouldExecute, beforeHookKeys, statusCode, err := a.prepareAgentBridgeCompletion(c.Request.Context(), agent, req, c.GetHeader("Mattermost-Plugin-ID"), llm.OperationBridgeAgent, llm.SubTypeStreaming) if err != nil { c.JSON(statusCode, bridgeclient.ErrorResponse{ Error: err.Error(), @@ -811,7 +815,7 @@ func (a *API) handleAgentCompletionNoStream(c *gin.Context) { return } - bot, llmRequest, opts, shouldExecute, beforeHookKeys, statusCode, err := a.prepareAgentBridgeCompletion(agent, req, c.GetHeader("Mattermost-Plugin-ID"), llm.OperationBridgeAgent, llm.SubTypeNoStream) + bot, llmRequest, opts, shouldExecute, beforeHookKeys, statusCode, err := a.prepareAgentBridgeCompletion(c.Request.Context(), agent, req, c.GetHeader("Mattermost-Plugin-ID"), llm.OperationBridgeAgent, llm.SubTypeNoStream) if err != nil { c.JSON(statusCode, bridgeclient.ErrorResponse{ Error: err.Error(), diff --git a/api/api_llm_bridge_test.go b/api/api_llm_bridge_test.go index efc8c69aa..0a375d24a 100644 --- a/api/api_llm_bridge_test.go +++ b/api/api_llm_bridge_test.go @@ -1690,6 +1690,7 @@ func TestPrepareAgentBridgeCompletionAllowedToolsRequiresUserID(t *testing.T) { defer e.Cleanup(t) _, _, _, _, _, statusCode, err := e.api.prepareAgentBridgeCompletion( + context.Background(), testBotUserID, bridgeclient.CompletionRequest{ Posts: []bridgeclient.Post{ @@ -1724,6 +1725,7 @@ func TestPrepareAgentBridgeCompletionToolHooksRequiresPluginID(t *testing.T) { e.setupTestBot(botConfig) _, _, _, _, _, statusCode, err := e.api.prepareAgentBridgeCompletion( + context.Background(), testBotUserID, bridgeclient.CompletionRequest{ Posts: []bridgeclient.Post{ @@ -1783,6 +1785,7 @@ func TestPrepareAgentBridgeCompletionStoresToolHookKeysInMCPMetadata(t *testing. ).Return(true, (*model.AppError)(nil)).Once() _, llmRequest, _, _, beforeHookKeys, statusCode, err := e.api.prepareAgentBridgeCompletion( + context.Background(), testBotUserID, bridgeclient.CompletionRequest{ Posts: []bridgeclient.Post{ @@ -1849,6 +1852,7 @@ func TestPrepareAgentBridgeCompletionToolHooksRequiresUserID(t *testing.T) { e.setupTestBot(botConfig) _, _, _, _, _, statusCode, err := e.api.prepareAgentBridgeCompletion( + context.Background(), testBotUserID, bridgeclient.CompletionRequest{ Posts: []bridgeclient.Post{ @@ -1886,6 +1890,7 @@ func TestPrepareAgentBridgeCompletionToolHooksRequiresAllowedTools(t *testing.T) e.setupTestBot(botConfig) _, _, _, _, _, statusCode, err := e.api.prepareAgentBridgeCompletion( + context.Background(), testBotUserID, bridgeclient.CompletionRequest{ Posts: []bridgeclient.Post{ diff --git a/api/api_mcp.go b/api/api_mcp.go index 793578c8b..dc21904aa 100644 --- a/api/api_mcp.go +++ b/api/api_mcp.go @@ -45,7 +45,7 @@ func (a *API) handleGetUserMCPTools(c *gin.Context) { mcpCfg := a.config.MCP() - tools, mcpErrors := a.mcpClientManager.GetToolsForUser(userID) + tools, mcpErrors := a.mcpClientManager.GetToolsForUser(c.Request.Context(), userID) // Group tools by ServerOrigin toolsByOrigin := make(map[string][]llm.Tool, len(tools)) @@ -139,7 +139,7 @@ func buildUserMCPServerInfo( ) UserMCPServerInfo { toolInfos := make([]UserMCPToolInfo, 0, len(originTools)) for _, t := range originTools { - policy, enabled := serverConfig.GetToolPolicy(t.Name) + policy, enabled := serverConfig.GetToolPolicy(mcp.ToolPolicyLookupName(serverConfig, t.Name)) toolInfos = append(toolInfos, UserMCPToolInfo{ Name: t.Name, Description: t.Description, diff --git a/api/api_no_tools_test.go b/api/api_no_tools_test.go index c15816653..e85e44d53 100644 --- a/api/api_no_tools_test.go +++ b/api/api_no_tools_test.go @@ -38,7 +38,7 @@ type noToolsTestMCPProvider struct { calls int } -func (p *noToolsTestMCPProvider) GetToolsForUser(string) ([]llm.Tool, *mcp.Errors) { +func (p *noToolsTestMCPProvider) GetToolsForUser(context.Context, string) ([]llm.Tool, *mcp.Errors) { p.calls++ return nil, nil } diff --git a/api/api_test.go b/api/api_test.go index 1d17e2a64..94d2ab30b 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -182,7 +182,7 @@ func (m *mockMCPClientManager) GetHTTPClient() *http.Client { return nil } -func (m *mockMCPClientManager) GetToolsForUser(userID string) ([]llm.Tool, *mcp.Errors) { +func (m *mockMCPClientManager) GetToolsForUser(context.Context, string) ([]llm.Tool, *mcp.Errors) { return m.tools, m.mcpErrors } diff --git a/conversations/bot_channel_tool_filter_test.go b/conversations/bot_channel_tool_filter_test.go index 075ed7995..9098eaf3f 100644 --- a/conversations/bot_channel_tool_filter_test.go +++ b/conversations/bot_channel_tool_filter_test.go @@ -22,6 +22,9 @@ func (m mapPolicyChecker) GetToolPolicy(serverOrigin, toolName string) (string, return mcp.ToolPolicyAsk, false } cfg, ok := byServer[toolName] + if !ok { + cfg, ok = byServer[llm.BareMCPToolName(toolName)] + } if !ok { return mcp.ToolPolicyAsk, true } @@ -58,6 +61,30 @@ func TestApplyBotChannelAutoEverywhereToolFilter(t *testing.T) { require.Len(t, llmContext.DisabledToolsInfo, 3) } +func TestApplyBotChannelAutoEverywhereToolFilter_NamespacedToolUsesBarePolicy(t *testing.T) { + origin := "https://mcp.example.com/mcp" + c := &Conversations{ + toolPolicyChecker: mapPolicyChecker{ + origin: { + "everywhere_tool": {policy: mcp.ToolPolicyAutoRunEverywhere, enabled: true}, + }, + }, + } + + llmContext := &llm.Context{ + Tools: llm.NewToolStore(), + } + llmContext.Tools.AddTools([]llm.Tool{ + {Name: "server__everywhere_tool", ServerOrigin: origin, Resolver: func(*llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, + }) + + c.applyBotChannelAutoEverywhereToolFilter(llmContext) + + tools := llmContext.Tools.GetTools() + require.Len(t, tools, 1) + require.Equal(t, "server__everywhere_tool", tools[0].Name) +} + func TestApplyToolAvailabilityBeforeBotChannelFilterPreservesDisabledToolsInfo(t *testing.T) { origin := "https://mcp.example.com/mcp" c := &Conversations{ diff --git a/conversations/conversations_test.go b/conversations/conversations_test.go index 137da555e..3405e1bcb 100644 --- a/conversations/conversations_test.go +++ b/conversations/conversations_test.go @@ -46,7 +46,7 @@ func (m *mockToolProvider) GetTools(bot *bots.Bot) []llm.Tool { type mockMCPClientManager struct{} -func (m *mockMCPClientManager) GetToolsForUser(userID string) ([]llm.Tool, *mcp.Errors) { +func (m *mockMCPClientManager) GetToolsForUser(context.Context, string) ([]llm.Tool, *mcp.Errors) { return []llm.Tool{}, nil } diff --git a/conversations/dm_conversation_test.go b/conversations/dm_conversation_test.go index 2d8c89ff6..0b7742f47 100644 --- a/conversations/dm_conversation_test.go +++ b/conversations/dm_conversation_test.go @@ -457,7 +457,7 @@ func setupDMTestEnv(t *testing.T, llmResponses ...*llm.TextStreamResult) *dmTest // testMCPClientManager implements llmcontext.MCPClientManager for testing. type testMCPClientManager struct{} -func (m *testMCPClientManager) GetToolsForUser(string) ([]llm.Tool, *mcp.Errors) { +func (m *testMCPClientManager) GetToolsForUser(context.Context, string) ([]llm.Tool, *mcp.Errors) { return nil, nil } diff --git a/conversations/handle_messages.go b/conversations/handle_messages.go index a575fe2a2..43c811e5d 100644 --- a/conversations/handle_messages.go +++ b/conversations/handle_messages.go @@ -194,6 +194,7 @@ func (c *Conversations) handleMentionViaConversation( responseRootID string, ) error { contextOpts := []llm.ContextOption{ + c.contextBuilder.WithLLMContextRequestContext(ctx), c.contextBuilder.WithLLMContextTools(bot), } llmContext := c.contextBuilder.BuildLLMContextUserRequest(bot, postingUser, channel, contextOpts...) @@ -339,6 +340,7 @@ func (c *Conversations) handleDMs(ctx context.Context, bot *bots.Bot, channel *m // handleDMViaConversation processes a DM message using the conversation entity model. func (c *Conversations) handleDMViaConversation(ctx context.Context, bot *bots.Bot, channel *model.Channel, postingUser *model.User, post *model.Post) error { contextOpts := []llm.ContextOption{ + c.contextBuilder.WithLLMContextRequestContext(ctx), c.contextBuilder.WithLLMContextTools(bot), } webSearchParams := c.extractWebSearchContext(post) diff --git a/conversations/regeneration.go b/conversations/regeneration.go index b46b37cc7..830e83092 100644 --- a/conversations/regeneration.go +++ b/conversations/regeneration.go @@ -251,6 +251,7 @@ func (c *Conversations) regenerateViaConversation( } contextOpts := []llm.ContextOption{ + c.contextBuilder.WithLLMContextRequestContext(ctx), c.contextBuilder.WithLLMContextDefaultTools(bot), } llmContext := c.contextBuilder.BuildLLMContextUserRequest(bot, user, channel, contextOpts...) diff --git a/conversations/tool_approval.go b/conversations/tool_approval.go index fbeffc899..9de24e94c 100644 --- a/conversations/tool_approval.go +++ b/conversations/tool_approval.go @@ -98,6 +98,7 @@ func (c *Conversations) HandleToolCall(ctx context.Context, userID string, post // Build LLM context with tools for execution. contextOpts := []llm.ContextOption{ + c.contextBuilder.WithLLMContextRequestContext(ctx), c.contextBuilder.WithLLMContextDefaultTools(bot), } llmContext := c.contextBuilder.BuildLLMContextUserRequest(bot, user, channel, contextOpts...) @@ -422,6 +423,7 @@ func (c *Conversations) streamToolFollowUp( defer span.End() contextOpts := []llm.ContextOption{ + c.contextBuilder.WithLLMContextRequestContext(ctx), c.contextBuilder.WithLLMContextDefaultTools(bot), } llmContext := c.contextBuilder.BuildLLMContextUserRequest(bot, user, channel, contextOpts...) diff --git a/conversations/tool_policy_test.go b/conversations/tool_policy_test.go index 502fe15e9..5999b6c67 100644 --- a/conversations/tool_policy_test.go +++ b/conversations/tool_policy_test.go @@ -67,6 +67,23 @@ func TestShouldAutoExecuteTool_NilChecker(t *testing.T) { } } +func TestShouldAutoExecuteTool_NamespacedToolUsesBarePolicy(t *testing.T) { + const origin = "https://mcp.example.com/mcp" + + c := &Conversations{ + toolPolicyChecker: mapPolicyChecker{ + origin: { + "example_tool": {policy: mcp.ToolPolicyAutoRunEverywhere, enabled: true}, + }, + }, + } + llmCtx := &llm.Context{Tools: llm.NewToolStore()} + + got := c.shouldAutoExecuteTool(llmCtx, false)(llm.ToolCall{Name: "example__example_tool", ServerOrigin: origin}) + + assert.True(t, got) +} + // TestAllToolsAutoRunEverywhere_RespectsEnabledFlag pins the result-sharing // contract: a disabled tool must never drive results to shared=true, even if // its policy is auto_run_everywhere. The enabled flag is authoritative — @@ -92,3 +109,22 @@ func TestAllToolsAutoRunEverywhere_RespectsEnabledFlag(t *testing.T) { assert.False(t, c.allToolsAutoRunEverywhere(turns, llmCtx), "a disabled tool must not auto-share results even when the policy is auto_run_everywhere") } + +func TestAllToolsAutoRunEverywhere_NamespacedToolUsesBarePolicy(t *testing.T) { + const origin = "https://mcp.example.com/mcp" + + c := &Conversations{ + toolPolicyChecker: mapPolicyChecker{ + origin: { + "example_tool": {policy: mcp.ToolPolicyAutoRunEverywhere, enabled: true}, + }, + }, + } + llmCtx := &llm.Context{Tools: llm.NewToolStore()} + + turns := []toolrunner.ToolTurn{{ + AssistantToolCalls: []llm.ToolCall{{Name: "example__example_tool", ServerOrigin: origin}}, + }} + + assert.True(t, c.allToolsAutoRunEverywhere(turns, llmCtx)) +} diff --git a/llm/tools.go b/llm/tools.go index 1a80f3d68..0422222ef 100644 --- a/llm/tools.go +++ b/llm/tools.go @@ -663,6 +663,9 @@ func (s *ToolStore) LogUnknownToolWarning(name string, argsGetter ToolArgumentGe } func toolArgsForLog(argsGetter ToolArgumentGetter) string { + if argsGetter == nil { + return "{}" + } var raw json.RawMessage if err := argsGetter(&raw); err != nil { return fmt.Sprintf("failed to get tool args: %v", err) diff --git a/llm/tools_test.go b/llm/tools_test.go index 7df48d9a7..01e5eb451 100644 --- a/llm/tools_test.go +++ b/llm/tools_test.go @@ -110,6 +110,10 @@ func TestSanitizeNonPrintableChars(t *testing.T) { } } +func TestToolArgsForLogNilGetter(t *testing.T) { + assert.Equal(t, "{}", toolArgsForLog(nil)) +} + type logEntry struct { message string fields []any diff --git a/llmcontext/llm_context.go b/llmcontext/llm_context.go index 8417f62ec..7bddead63 100644 --- a/llmcontext/llm_context.go +++ b/llmcontext/llm_context.go @@ -4,6 +4,7 @@ package llmcontext import ( + stdcontext "context" "slices" "strings" "time" @@ -23,7 +24,7 @@ type ToolProvider interface { // MCPToolProvider provides MCP tools for a user type MCPToolProvider interface { - GetToolsForUser(userID string) ([]llm.Tool, *mcp.Errors) + GetToolsForUser(ctx stdcontext.Context, userID string) ([]llm.Tool, *mcp.Errors) } // ConfigProvider provides configuration access @@ -197,8 +198,13 @@ func (b *Builder) getToolsStoreForUser(c *llm.Context, bot *bots.Bot, userID str // so that GetToolsInfo() can inform the LLM about their availability. // Actual execution is controlled via WithToolsDisabled() based on channel type. if b.mcpToolProvider != nil { + if c.RequestContext == nil { + b.pluginAPI.Log.Error("Cannot add MCP tools to context: RequestContext is nil", "userID", userID) + return store + } + // Get tools from all connected servers - mcpTools, mcpErrors := b.mcpToolProvider.GetToolsForUser(userID) + mcpTools, mcpErrors := b.mcpToolProvider.GetToolsForUser(c.RequestContext, userID) // Add tools from successfully connected servers even if some had errors // These will be disabled in non-DM channels via WithToolsDisabled() @@ -249,6 +255,14 @@ func (b *Builder) WithLLMContextDefaultTools(bot *bots.Bot) llm.ContextOption { return b.WithLLMContextTools(bot) } +// WithLLMContextRequestContext threads request-scoped cancellation/deadlines into +// MCP discovery and tool execution. +func (b *Builder) WithLLMContextRequestContext(ctx stdcontext.Context) llm.ContextOption { + return func(c *llm.Context) { + c.RequestContext = ctx + } +} + // WithLLMContextNoTools explicitly disables tools for this context session only, // overriding the bot's DisableTools configuration. This allows inter-plugin requests // to work with tool-enabled bots by bypassing tools for non-streaming calls. diff --git a/llmcontext/llm_context_test.go b/llmcontext/llm_context_test.go index 0c3cd1138..3c587378a 100644 --- a/llmcontext/llm_context_test.go +++ b/llmcontext/llm_context_test.go @@ -4,6 +4,7 @@ package llmcontext import ( + stdcontext "context" "testing" "github.com/mattermost/mattermost-plugin-agents/bots" @@ -26,7 +27,7 @@ type countingMCPToolProvider struct { calls int } -func (p *countingMCPToolProvider) GetToolsForUser(string) ([]llm.Tool, *mcp.Errors) { +func (p *countingMCPToolProvider) GetToolsForUser(stdcontext.Context, string) ([]llm.Tool, *mcp.Errors) { p.calls++ return []llm.Tool{ { @@ -42,7 +43,7 @@ type staticMCPToolProvider struct { errors *mcp.Errors } -func (p *staticMCPToolProvider) GetToolsForUser(string) ([]llm.Tool, *mcp.Errors) { +func (p *staticMCPToolProvider) GetToolsForUser(stdcontext.Context, string) ([]llm.Tool, *mcp.Errors) { return p.tools, p.errors } @@ -86,6 +87,7 @@ func TestWithLLMContextDefaultToolsCallsMCPProvider(t *testing.T) { newTestBot(), user, channel, + builder.WithLLMContextRequestContext(stdcontext.Background()), builder.WithLLMContextDefaultTools(newTestBot()), ) @@ -161,6 +163,7 @@ func TestWithLLMContextDefaultToolsRetainsAuthErrorsForWildcardAllowlist(t *test bot, user, channel, + builder.WithLLMContextRequestContext(stdcontext.Background()), builder.WithLLMContextDefaultTools(bot), ) diff --git a/mcp/client.go b/mcp/client.go index 89b5c9e63..faa99d8d7 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -161,6 +161,9 @@ func listAllTools(ctx context.Context, session *mcp.ClientSession) (map[string]* func (c *EmbeddedServerClient) CreateClient(ctx context.Context, userID, sessionID string) (*Client, error) { // Validate session exists before creating transport (unless empty for tool discovery) if sessionID != "" { + if c.pluginAPI == nil { + return nil, fmt.Errorf("plugin API is required when sessionID is provided") + } mmSession, err := c.pluginAPI.Session.Get(sessionID) if err != nil { return nil, fmt.Errorf("failed to get session: %w", err) diff --git a/mcp/client_embedded_oauth_test.go b/mcp/client_embedded_oauth_test.go index 69dc60696..8faacc45d 100644 --- a/mcp/client_embedded_oauth_test.go +++ b/mcp/client_embedded_oauth_test.go @@ -30,6 +30,19 @@ func TestEmbeddedCreateClientDiscoversPaginatedTools(t *testing.T) { } } +func TestEmbeddedCreateClientRequiresPluginAPIForSessionValidation(t *testing.T) { + server := newTestMCPServer(0, "tool_1") + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + embeddedClient := NewEmbeddedServerClient(&fakeEmbeddedMCPServer{ctx: ctx, server: server}, newTestLogService(), nil) + + client, err := embeddedClient.CreateClient(context.Background(), "user-id", "session-id") + + require.Nil(t, client) + require.EqualError(t, err, "plugin API is required when sessionID is provided") +} + func TestEmbeddedToolListChangedInvalidatesCacheAndClientTools(t *testing.T) { server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3") ctx, cancel := context.WithCancel(context.Background()) @@ -77,7 +90,7 @@ func TestEmbeddedToolListChangedNextGetToolsForUserRediscoversTools(t *testing.T } t.Cleanup(func() { cleanupTestClientManager(manager) }) - tools, mcpErrors := manager.GetToolsForUser("user-id") + tools, mcpErrors := manager.GetToolsForUser(context.Background(), "user-id") require.Nil(t, mcpErrors) requireToolNames(t, tools, "mattermost__tool_1", "mattermost__tool_2") @@ -94,7 +107,7 @@ func TestEmbeddedToolListChangedNextGetToolsForUserRediscoversTools(t *testing.T return client != nil && len(client.Tools()) == 0 }, 5*time.Second, 10*time.Millisecond) - tools, mcpErrors = manager.GetToolsForUser("user-id") + tools, mcpErrors = manager.GetToolsForUser(context.Background(), "user-id") require.Nil(t, mcpErrors) requireToolNames(t, tools, "mattermost__new_tool", "mattermost__tool_1", "mattermost__tool_2") require.Len(t, cache.GetTools(EmbeddedClientKey), 3) diff --git a/mcp/client_integration_test.go b/mcp/client_integration_test.go index 8670b2377..319b2751e 100644 --- a/mcp/client_integration_test.go +++ b/mcp/client_integration_test.go @@ -289,7 +289,7 @@ func TestClientManager_GetToolsForUser(t *testing.T) { defer manager.Close() // Call GetToolsForUser - tools, errors := manager.GetToolsForUser(user.Id) + tools, errors := manager.GetToolsForUser(context.Background(), user.Id) // Should succeed with no errors assert.Nil(t, errors, "Should have no errors") diff --git a/mcp/client_manager.go b/mcp/client_manager.go index ad992d59d..65e340bed 100644 --- a/mcp/client_manager.go +++ b/mcp/client_manager.go @@ -181,11 +181,10 @@ func (m *ClientManager) getClientForUser(ctx context.Context, userID string) (*U } // GetToolsForUser returns the tools available for a specific user, connecting to embedded server if session ID provided. -func (m *ClientManager) GetToolsForUser(userID string) ([]llm.Tool, *Errors) { - ctx := context.Background() - +func (m *ClientManager) GetToolsForUser(ctx context.Context, userID string) ([]llm.Tool, *Errors) { // Get or create client for this user (connects to remote servers only) - userClient, _ := m.getClientForUser(ctx, userID) + userClient, initialErrors := m.getClientForUser(ctx, userID) + mcpErrors := cloneMCPErrors(initialErrors) // Connect to embedded server using a dedicated per-user session (stored/created in KV). if m.embeddedClient != nil && m.config.EmbeddedServer.Enabled { @@ -204,13 +203,34 @@ func (m *ClientManager) GetToolsForUser(userID string) ([]llm.Tool, *Errors) { for _, cfg := range pluginSnap { if connectErr := userClient.ConnectToPluginServer(ctx, cfg, m.sourcePluginAPI); connectErr != nil { m.log.Error("Failed to connect to plugin MCP server", "userID", userID, "pluginID", cfg.PluginID, "error", connectErr) - userClient.appendInitialRemoteConnectError(connectErr) + mcpErrors = appendMCPError(mcpErrors, connectErr) } } rawTools := userClient.GetTools(ctx) filtered := filterToolsByConfig(rawTools, m.config, m.embeddedClient, pluginSnap) - return filtered, userClient.InitialRemoteConnectErrors() + return filtered, mcpErrors +} + +func cloneMCPErrors(src *Errors) *Errors { + if src == nil || (len(src.ToolAuthErrors) == 0 && len(src.Errors) == 0) { + return nil + } + return &Errors{ + ToolAuthErrors: append([]llm.ToolAuthError(nil), src.ToolAuthErrors...), + Errors: append([]error(nil), src.Errors...), + } +} + +func appendMCPError(mcpErrors *Errors, err error) *Errors { + if err == nil { + return mcpErrors + } + if mcpErrors == nil { + mcpErrors = &Errors{} + } + mcpErrors.Errors = append(mcpErrors.Errors, err) + return mcpErrors } func (m *ClientManager) GetToolRetrievalOverrides() map[string]ToolRetrievalOverride { @@ -488,7 +508,7 @@ func filterToolsByConfig(rawTools []llm.Tool, cfg Config, embeddedClient *Embedd var filtered []llm.Tool for _, t := range tools { - _, enabled := sc.GetToolPolicy(llm.BareMCPToolName(t.Name)) + _, enabled := sc.GetToolPolicy(ToolPolicyLookupName(sc, t.Name)) if enabled { filtered = append(filtered, t) } diff --git a/mcp/client_manager_test.go b/mcp/client_manager_test.go index 00434d3e6..5f95928ea 100644 --- a/mcp/client_manager_test.go +++ b/mcp/client_manager_test.go @@ -374,7 +374,7 @@ func TestClientManager_GetToolsForUser_PluginEnabled(t *testing.T) { } m.RegisterPluginServer(cfg) - tools, mcpErrors := m.GetToolsForUser("alice") + tools, mcpErrors := m.GetToolsForUser(context.Background(), "alice") require.Nil(t, mcpErrors, "no errors expected on happy path") require.Len(t, tools, 2, "expected 2 tools from plugin server") for _, tool := range tools { @@ -404,7 +404,7 @@ func TestClientManager_GetToolsForUser_PluginDisabled_ZeroTools(t *testing.T) { } m.RegisterPluginServer(cfg) - tools, mcpErrors := m.GetToolsForUser("alice") + tools, mcpErrors := m.GetToolsForUser(context.Background(), "alice") require.Nil(t, mcpErrors, "no errors expected when plugin is simply disabled") require.Empty(t, tools, "disabled plugin must contribute zero tools") @@ -454,7 +454,7 @@ func TestClientManager_GetToolsForUser_PluginEnabled_HTTPFailure(t *testing.T) { Enabled: true, }) - tools, mcpErrors := m.GetToolsForUser("alice") + tools, mcpErrors := m.GetToolsForUser(context.Background(), "alice") require.NotNil(t, mcpErrors, "plugin connection failure must be surfaced") require.NotEmpty(t, mcpErrors.Errors, "plugin connection failure must populate generic MCP errors") require.Empty(t, mcpErrors.ToolAuthErrors, "plugin HTTP failures should not be treated as OAuth errors") @@ -466,6 +466,48 @@ func TestClientManager_GetToolsForUser_PluginEnabled_HTTPFailure(t *testing.T) { } } +func TestClientManager_GetToolsForUser_PluginConnectErrorsAreRequestScoped(t *testing.T) { + target := newFakePluginMCPServer(t, 1) + t.Cleanup(target.Close) + + var calls atomic.Int32 + mockAPI := &fakePluginHTTPClient{ + pluginHTTP: func(req *http.Request) *http.Response { + if calls.Add(1) == 1 { + rec := httptest.NewRecorder() + rec.WriteHeader(http.StatusInternalServerError) + return rec.Result() + } + + rec := httptest.NewRecorder() + target.Config.Handler.ServeHTTP(rec, req) + return rec.Result() + }, + } + + pluginTestAPI := &plugintest.API{} + setupTestLogger(pluginTestAPI) + client := pluginapi.NewClient(pluginTestAPI, nil) + + m := NewClientManager(Config{IdleTimeoutMinutes: 30}, client.Log, client, nil, nil, nil, mockAPI) + t.Cleanup(m.Close) + m.RegisterPluginServer(PluginServerConfig{ + PluginID: "com.example.mcp", + Name: "Example", + Path: "/mcp", + Enabled: true, + }) + + tools, mcpErrors := m.GetToolsForUser(context.Background(), "alice") + require.Empty(t, tools) + require.NotNil(t, mcpErrors) + require.NotEmpty(t, mcpErrors.Errors) + + tools, mcpErrors = m.GetToolsForUser(context.Background(), "alice") + require.Nil(t, mcpErrors, "successful plugin reconnect must not return the prior transient error") + require.Len(t, tools, 1) +} + func TestClientManager_GetToolsForUser_MultiplePluginServers(t *testing.T) { targetA := newFakePluginMCPServerWithPrefix(t, "tool_a", 2) t.Cleanup(targetA.Close) @@ -498,7 +540,7 @@ func TestClientManager_GetToolsForUser_MultiplePluginServers(t *testing.T) { m.RegisterPluginServer(PluginServerConfig{PluginID: "com.example.a", Name: "A", Path: "/mcp", Enabled: true}) m.RegisterPluginServer(PluginServerConfig{PluginID: "com.example.b", Name: "B", Path: "/mcp", Enabled: true}) - tools, mcpErrors := m.GetToolsForUser("alice") + tools, mcpErrors := m.GetToolsForUser(context.Background(), "alice") require.Nil(t, mcpErrors) require.Len(t, tools, 3, "expected 2 tools from A + 1 tool from B") diff --git a/mcp/client_test.go b/mcp/client_test.go index 2271b59be..db07ca35d 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -502,7 +502,7 @@ func TestRemoteToolListChangedNextGetToolsForUserRediscoversTools(t *testing.T) var tools []llm.Tool var mcpErrors *Errors require.Eventually(t, func() bool { - tools, mcpErrors = manager.GetToolsForUser("user-id") + tools, mcpErrors = manager.GetToolsForUser(context.Background(), "user-id") if mcpErrors != nil || len(cache.GetTools("paged")) != 2 { return false } @@ -529,7 +529,7 @@ func TestRemoteToolListChangedNextGetToolsForUserRediscoversTools(t *testing.T) return client != nil && len(client.Tools()) == 0 && cache.GetTools("paged") == nil }, 5*time.Second, 10*time.Millisecond) - tools, mcpErrors = manager.GetToolsForUser("user-id") + tools, mcpErrors = manager.GetToolsForUser(context.Background(), "user-id") require.Nil(t, mcpErrors) requireToolNames(t, tools, "paged__new_tool", "paged__tool_1", "paged__tool_2") require.Len(t, cache.GetTools("paged"), 3) diff --git a/mcp/tool_policy.go b/mcp/tool_policy.go index a5a340dd3..bb24a708c 100644 --- a/mcp/tool_policy.go +++ b/mcp/tool_policy.go @@ -3,7 +3,11 @@ package mcp -import "strings" +import ( + "strings" + + "github.com/mattermost/mattermost-plugin-agents/llm" +) // ToolPolicyChecker looks up the per-tool policy for a given MCP server/tool. type ToolPolicyChecker interface { @@ -18,6 +22,21 @@ func (f ToolPolicyFunc) GetToolPolicy(serverBaseURL string, toolName string) (st return f(serverBaseURL, toolName) } +// ToolPolicyLookupName returns the configured name to use for a runtime tool name. +// Runtime MCP tools may be namespaced while persisted policy config is usually +// stored by the server's bare tool name. An exact configured name still wins. +func ToolPolicyLookupName(sc *ServerConfig, toolName string) string { + if sc == nil || toolName == "" || llm.IsBareMCPToolName(toolName) { + return toolName + } + for _, toolConfig := range sc.ToolConfigs { + if toolConfig.Name == toolName { + return toolName + } + } + return llm.BareMCPToolName(toolName) +} + // LookupToolPolicy resolves a tool's policy for embedded, remote, and plugin // origins. Unknown or disabled origins never auto-execute. func LookupToolPolicy(cfg Config, serverBaseURL, toolName string) (string, bool) { @@ -32,12 +51,12 @@ func LookupToolPolicy(cfg Config, serverBaseURL, toolName string) (string, bool) BaseURL: EmbeddedClientKey, ToolConfigs: toolConfigs, } - return embeddedCfg.GetToolPolicy(toolName) + return embeddedCfg.GetToolPolicy(ToolPolicyLookupName(embeddedCfg, toolName)) } for i := range cfg.Servers { if cfg.Servers[i].BaseURL == serverBaseURL { - return cfg.Servers[i].GetToolPolicy(toolName) + return cfg.Servers[i].GetToolPolicy(ToolPolicyLookupName(&cfg.Servers[i], toolName)) } } @@ -60,7 +79,7 @@ func LookupToolPolicy(cfg Config, serverBaseURL, toolName string) (string, bool) BaseURL: serverBaseURL, ToolConfigs: ps.ToolConfigs, } - return synthetic.GetToolPolicy(toolName) + return synthetic.GetToolPolicy(ToolPolicyLookupName(synthetic, toolName)) } return ToolPolicyAsk, false diff --git a/mcp/tool_policy_lookup_test.go b/mcp/tool_policy_lookup_test.go index 4f45000f3..62c296c61 100644 --- a/mcp/tool_policy_lookup_test.go +++ b/mcp/tool_policy_lookup_test.go @@ -136,6 +136,26 @@ func TestLookupToolPolicy(t *testing.T) { require.True(t, enabled) }) + t.Run("remote namespaced tool matches bare configured policy", func(t *testing.T) { + cfg := Config{ + Servers: []ServerConfig{{ + Name: "Remote", + Enabled: true, + BaseURL: remoteURL, + ToolConfigs: []ToolConfig{{ + Name: remoteToolName, + Policy: ToolPolicyAutoRunEverywhere, + Enabled: true, + }}, + }}, + } + + policy, enabled := LookupToolPolicy(cfg, remoteURL, "remote__"+remoteToolName) + + require.Equal(t, ToolPolicyAutoRunEverywhere, policy) + require.True(t, enabled) + }) + t.Run("embedded server with empty tool configs falls back to vetted seed", func(t *testing.T) { cfg := Config{ EmbeddedServer: EmbeddedServerConfig{ diff --git a/mcp/user_clients.go b/mcp/user_clients.go index fd20bb4aa..d2ff35490 100644 --- a/mcp/user_clients.go +++ b/mcp/user_clients.go @@ -193,18 +193,6 @@ func (c *UserClients) setInitialRemoteConnectErrors(mcpErrors *Errors) { c.initialRemoteConnectErrors = mcpErrors } -func (c *UserClients) appendInitialRemoteConnectError(err error) { - if err == nil { - return - } - c.clientsMu.Lock() - defer c.clientsMu.Unlock() - if c.initialRemoteConnectErrors == nil { - c.initialRemoteConnectErrors = &Errors{} - } - c.initialRemoteConnectErrors.Errors = append(c.initialRemoteConnectErrors.Errors, err) -} - // Close closes all server connections for a user client func (c *UserClients) Close() { c.clientsMu.Lock() @@ -368,10 +356,10 @@ func (c *UserClients) createToolResolver(client *Client, toolName string) func(l metadata := c.prepareToolCallMetadata(client, toolName, llmContext) - callCtx := context.Background() - if llmContext != nil && llmContext.RequestContext != nil { - callCtx = llmContext.RequestContext + if llmContext == nil || llmContext.RequestContext == nil { + return "", errors.New("missing request context for MCP tool call") } + callCtx := llmContext.RequestContext result, err := client.CallToolWithMetadata(callCtx, toolName, args, metadata) if err != nil { diff --git a/mcp/user_clients_test.go b/mcp/user_clients_test.go index f079ec6ce..d6ceb80a5 100644 --- a/mcp/user_clients_test.go +++ b/mcp/user_clients_test.go @@ -112,7 +112,7 @@ func TestUserClientsGetToolsResolverUsesBareToolName(t *testing.T) { tools := userClients.GetTools(context.Background()) requireToolNames(t, tools, "jira__search") - result, err := tools[0].Resolver(&llm.Context{}, func(args any) error { + result, err := tools[0].Resolver(&llm.Context{RequestContext: context.Background()}, func(args any) error { *(args.(*map[string]any)) = map[string]any{} return nil }) @@ -295,6 +295,35 @@ func TestUserClientsGetToolsResolverUsesRequestContext(t *testing.T) { require.ErrorIs(t, err, context.Canceled) } +func TestUserClientsGetToolsResolverRequiresRequestContext(t *testing.T) { + server := newTestMCPServer(0, "search") + session := connectInMemoryTestSession(t, server) + userClients := &UserClients{ + userID: "user-id", + clients: map[string]*Client{ + "jira": { + session: session, + config: ServerConfig{Name: "Jira", BaseURL: "https://mcp.atlassian.com", Enabled: true}, + tools: map[string]*gomcp.Tool{ + "search": { + Name: "search", + Description: "Search Jira", + }, + }, + }, + }, + } + + tools := userClients.GetTools(context.Background()) + require.Len(t, tools, 1) + + _, err := tools[0].Resolver(&llm.Context{}, func(args any) error { + *(args.(*map[string]any)) = map[string]any{} + return nil + }) + require.EqualError(t, err, "missing request context for MCP tool call") +} + func TestPrepareToolCallMetadata_EmbeddedMergesCallMetadataAndBotUserID(t *testing.T) { llmContext := llm.NewContext() llmContext.BotUserID = "bot-user-id" From 3593c37052ac547ec6da74b3c2ee1772325f9051 Mon Sep 17 00:00:00 2001 From: Nick Misasi Date: Sat, 23 May 2026 11:20:48 -0400 Subject: [PATCH 3/7] dynamic mcp: tolerate bare tool calls Co-authored-by: Cursor --- .../real-api/disabled-tool.spec.ts | 5 +- e2e/tests/tool-config/tool-toggle.spec.ts | 27 ++++--- e2e/tests/tool-config/vetted-seed.spec.ts | 8 +- llm/tools.go | 65 ++++++++++++++-- llm/tools_test.go | 74 +++++++++++++++++++ 5 files changed, 153 insertions(+), 26 deletions(-) diff --git a/e2e/tests/tool-config/real-api/disabled-tool.spec.ts b/e2e/tests/tool-config/real-api/disabled-tool.spec.ts index 77e235fc3..23a04d116 100644 --- a/e2e/tests/tool-config/real-api/disabled-tool.spec.ts +++ b/e2e/tests/tool-config/real-api/disabled-tool.spec.ts @@ -38,6 +38,7 @@ const VETTED_EMBEDDED_TOOLS = [ ]; const TARGET_TOOL_NAME = 'read_post'; +const TARGET_RUNTIME_TOOL_NAME = `mattermost__${TARGET_TOOL_NAME}`; const TARGET_TOOL_LABEL = 'Read Post'; const SEEDED_POST_MESSAGE = 'Disabled tool e2e seed message: cobalt narwhal orchard 4821.'; @@ -107,8 +108,8 @@ for (const provider of providers) { expect(embeddedServer).toBeDefined(); const names = embeddedServer.tools.map((t: any) => t.name); - expect(names).not.toContain(TARGET_TOOL_NAME); - expect(names).toContain('get_channel_info'); + expect(names).not.toContain(TARGET_RUNTIME_TOOL_NAME); + expect(names).toContain('mattermost__get_channel_info'); await mmPage.login(mattermost.url(), 'regularuser', 'regularuser'); await aiPlugin.openRHS(); diff --git a/e2e/tests/tool-config/tool-toggle.spec.ts b/e2e/tests/tool-config/tool-toggle.spec.ts index 4594119cf..03738c79a 100644 --- a/e2e/tests/tool-config/tool-toggle.spec.ts +++ b/e2e/tests/tool-config/tool-toggle.spec.ts @@ -17,6 +17,9 @@ import { adminUsername, adminPassword } from 'helpers/system-console-container'; let mattermost: MattermostContainer; let openAIMock: OpenAIMockContainer; +const READ_POST_TOOL_NAME = 'read_post'; +const READ_POST_RUNTIME_TOOL_NAME = 'mattermost__read_post'; + test.describe('Per-Tool Enable/Disable', () => { test.beforeAll(async () => { mattermost = await RunToolConfigContainer(); @@ -43,12 +46,12 @@ test.describe('Per-Tool Enable/Disable', () => { await page.waitForTimeout(500); // Find read_post tool - should be enabled - await expect(page.getByText('read_post')).toBeVisible({ timeout: 5000 }); - const toggle = toolConfig.getToolToggle('read_post'); + await expect(page.getByText(READ_POST_TOOL_NAME)).toBeVisible({ timeout: 5000 }); + const toggle = toolConfig.getToolToggle(READ_POST_TOOL_NAME); await expect(toggle).toBeChecked(); // Disable the tool - await toolConfig.toggleTool('read_post', false); + await toolConfig.toggleTool(READ_POST_TOOL_NAME, false); await expect(toggle).not.toBeChecked(); // Save @@ -63,12 +66,12 @@ test.describe('Per-Tool Enable/Disable', () => { await page.waitForTimeout(500); // Verify tool shows as disabled - await expect(page.getByText('read_post')).toBeVisible({ timeout: 5000 }); - const toggleAfter = toolConfig.getToolToggle('read_post'); + await expect(page.getByText(READ_POST_TOOL_NAME)).toBeVisible({ timeout: 5000 }); + const toggleAfter = toolConfig.getToolToggle(READ_POST_TOOL_NAME); await expect(toggleAfter).not.toBeChecked(); // Re-enable the tool for subsequent tests - await toolConfig.toggleTool('read_post', true); + await toolConfig.toggleTool(READ_POST_TOOL_NAME, true); await toolConfig.clickSave(); }); @@ -91,7 +94,7 @@ test.describe('Per-Tool Enable/Disable', () => { // Verify tools are returned expect(toolsBefore.servers).toBeDefined(); const serverBefore = toolsBefore.servers?.find((s: any) => - s.tools?.some((t: any) => t.name === 'read_post'), + s.tools?.some((t: any) => t.name === READ_POST_RUNTIME_TOOL_NAME), ); expect(serverBefore).toBeDefined(); @@ -100,14 +103,14 @@ test.describe('Per-Tool Enable/Disable', () => { const serverHeader = page.getByText(/\d+\/\d+ tools? enabled/).first(); await serverHeader.click(); await page.waitForTimeout(500); - await expect(page.getByText('read_post')).toBeVisible({ timeout: 5000 }); - await toolConfig.toggleTool('read_post', false); + await expect(page.getByText(READ_POST_TOOL_NAME)).toBeVisible({ timeout: 5000 }); + await toolConfig.toggleTool(READ_POST_TOOL_NAME, false); await toolConfig.clickSave(); // Verify the API no longer returns read_post const toolsAfter = await apiHelper.getUserMCPTools(token); const serverAfter = toolsAfter.servers?.find((s: any) => - s.tools?.some((t: any) => t.name === 'read_post'), + s.tools?.some((t: any) => t.name === READ_POST_RUNTIME_TOOL_NAME), ); expect(serverAfter).toBeUndefined(); @@ -116,8 +119,8 @@ test.describe('Per-Tool Enable/Disable', () => { const serverHeader2 = page.getByText(/\d+\/\d+ tools? enabled/).first(); await serverHeader2.click(); await page.waitForTimeout(500); - await expect(page.getByText('read_post')).toBeVisible({ timeout: 5000 }); - await toolConfig.toggleTool('read_post', true); + await expect(page.getByText(READ_POST_TOOL_NAME)).toBeVisible({ timeout: 5000 }); + await toolConfig.toggleTool(READ_POST_TOOL_NAME, true); await toolConfig.clickSave(); }); }); diff --git a/e2e/tests/tool-config/vetted-seed.spec.ts b/e2e/tests/tool-config/vetted-seed.spec.ts index a4c57707f..d108d7875 100644 --- a/e2e/tests/tool-config/vetted-seed.spec.ts +++ b/e2e/tests/tool-config/vetted-seed.spec.ts @@ -28,6 +28,7 @@ const VETTED_READ_TOOLS = [ 'search_users', 'get_user_channels', ]; +const VETTED_READ_RUNTIME_TOOLS = VETTED_READ_TOOLS.map((name) => `mattermost__${name}`); test.describe('Vetted Server Seed', () => { test.beforeAll(async () => { @@ -85,16 +86,15 @@ test.describe('Vetted Server Seed', () => { expect(toolsResponse.servers).toBeDefined(); expect(toolsResponse.servers.length).toBeGreaterThan(0); - // Find the embedded server (identified by having Mattermost tools) - const embeddedServer = toolsResponse.servers.find((s: any) => - s.tools?.some((t: any) => VETTED_READ_TOOLS.includes(t.name)), + const embeddedServer = toolsResponse.servers.find( + (s: any) => s.serverOrigin === 'embedded://mattermost', ); // If embedded server tools are in the response, verify they are present if (embeddedServer) { const toolNames = embeddedServer.tools.map((t: any) => t.name); // At least some vetted tools should be present - const foundVettedTools = VETTED_READ_TOOLS.filter((name) => + const foundVettedTools = VETTED_READ_RUNTIME_TOOLS.filter((name) => toolNames.includes(name), ); expect(foundVettedTools.length).toBeGreaterThan(0); diff --git a/llm/tools.go b/llm/tools.go index 0422222ef..cec9422a1 100644 --- a/llm/tools.go +++ b/llm/tools.go @@ -382,13 +382,39 @@ func (s *ToolStore) AddTools(tools []Tool) { } } +func (s *ToolStore) lookupTool(name string) (Tool, bool) { + if s == nil || name == "" { + return Tool{}, false + } + if tool, ok := s.tools[name]; ok { + return tool, true + } + if !IsBareMCPToolName(name) { + return Tool{}, false + } + + var matched Tool + found := false + for toolName, tool := range s.tools { + if tool.ServerOrigin == "" || BareMCPToolName(toolName) != name { + continue + } + if found { + return Tool{}, false + } + matched = tool + found = true + } + return matched, found +} + func (s *ToolStore) ResolveTool(ctx context.Context, name string, argsGetter ToolArgumentGetter, llmCtx *Context) (string, error) { _, span := telemetry.Tracer().Start(ctx, "resolve tool", trace.WithAttributes(telemetry.ToolName.String(name)), ) defer span.End() - tool, ok := s.tools[name] + tool, ok := s.lookupTool(name) if !ok { s.LogUnknownToolWarning(name, argsGetter) s.TraceUnknown(name, argsGetter) @@ -416,10 +442,7 @@ func (s *ToolStore) GetTools() []Tool { // GetTool returns a pointer to a tool by name, or nil if not found func (s *ToolStore) GetTool(name string) *Tool { - if s == nil { - return nil - } - if tool, ok := s.tools[name]; ok { + if tool, ok := s.lookupTool(name); ok { return &tool } return nil @@ -453,7 +476,7 @@ func (s *ToolStore) IsUnloadedMCPTool(name string) bool { if s == nil || s.GetTool(name) != nil { return false } - _, ok := s.unloadedMCPTools[name] + _, ok := s.lookupUnloadedMCPTool(name) return ok } @@ -461,14 +484,40 @@ func (s *ToolStore) GetUnloadedMCPToolInfo(name string) (ToolInfo, bool) { if s == nil || s.GetTool(name) != nil { return ToolInfo{}, false } + return s.lookupUnloadedMCPTool(name) +} + +func (s *ToolStore) lookupUnloadedMCPTool(name string) (ToolInfo, bool) { info, ok := s.unloadedMCPTools[name] - return info, ok + if ok { + return info, true + } + if !IsBareMCPToolName(name) { + return ToolInfo{}, false + } + + var matched ToolInfo + found := false + for toolName, info := range s.unloadedMCPTools { + if BareMCPToolName(toolName) != name { + continue + } + if found { + return ToolInfo{}, false + } + matched = info + found = true + } + if !found { + return ToolInfo{}, false + } + return matched, true } // GetServerOrigin returns the ServerOrigin for a tool by name. // Returns empty string if the tool is not found or has no server origin (built-in tools). func (s *ToolStore) GetServerOrigin(toolName string) string { - if tool, ok := s.tools[toolName]; ok { + if tool, ok := s.lookupTool(toolName); ok { return tool.ServerOrigin } return "" diff --git a/llm/tools_test.go b/llm/tools_test.go index 01e5eb451..8df33b633 100644 --- a/llm/tools_test.go +++ b/llm/tools_test.go @@ -212,6 +212,46 @@ func TestResolveToolUnknownLogsArgumentGetterError(t *testing.T) { assert.Equal(t, "failed to get tool args: bad arguments", logFields(log.infos[0])["args"]) } +func TestResolveToolUsesUniqueBareMCPToolName(t *testing.T) { + store := NewToolStore(nil, false) + store.AddTools([]Tool{{ + Name: "jira__get_issue", + ServerOrigin: "https://mcp.atlassian.com", + Resolver: func(_ *Context, _ ToolArgumentGetter) (string, error) { + return "issue result", nil + }, + }}) + + result, err := store.ResolveTool(context.Background(), "get_issue", rawArgsGetter(`{}`), &Context{}) + + require.NoError(t, err) + assert.Equal(t, "issue result", result) +} + +func TestResolveToolBareMCPToolNameAmbiguous(t *testing.T) { + store := NewToolStore(nil, false) + store.AddTools([]Tool{ + { + Name: "jira__search", + ServerOrigin: "https://mcp.atlassian.com", + Resolver: func(_ *Context, _ ToolArgumentGetter) (string, error) { + return "jira", nil + }, + }, + { + Name: "github__search", + ServerOrigin: "https://api.githubcopilot.com", + Resolver: func(_ *Context, _ ToolArgumentGetter) (string, error) { + return "github", nil + }, + }, + }) + + _, err := store.ResolveTool(context.Background(), "search", rawArgsGetter(`{}`), &Context{}) + + require.EqualError(t, err, "unknown tool search") +} + func TestGetToolKnownAndUnknown(t *testing.T) { store := NewToolStore(nil, false) store.AddTools([]Tool{{ @@ -225,6 +265,19 @@ func TestGetToolKnownAndUnknown(t *testing.T) { assert.Nil(t, store.GetTool("ghost")) } +func TestGetToolUsesUniqueBareMCPToolName(t *testing.T) { + store := NewToolStore(nil, false) + store.AddTools([]Tool{{ + Name: "jira__get_issue", + ServerOrigin: "https://mcp.atlassian.com", + }}) + + tool := store.GetTool("get_issue") + + require.NotNil(t, tool) + assert.Equal(t, "jira__get_issue", tool.Name) +} + func TestToolCall_SanitizeArguments(t *testing.T) { tests := []struct { name string @@ -297,6 +350,23 @@ func TestGetServerOrigin(t *testing.T) { lookupName: "unknown_tool", expectedURL: "", }, + { + name: "unique bare MCP tool name returns server origin", + tools: []Tool{ + {Name: "jira__get_issue", ServerOrigin: "https://mcp.atlassian.com/v2"}, + }, + lookupName: "get_issue", + expectedURL: "https://mcp.atlassian.com/v2", + }, + { + name: "ambiguous bare MCP tool name returns empty", + tools: []Tool{ + {Name: "jira__search", ServerOrigin: "https://mcp.atlassian.com/v2"}, + {Name: "github__search", ServerOrigin: "https://api.githubcopilot.com"}, + }, + lookupName: "search", + expectedURL: "", + }, { name: "empty store returns empty", tools: []Tool{}, @@ -809,6 +879,10 @@ func TestToolStoreUnloadedMCPTools(t *testing.T) { info, ok := store.GetUnloadedMCPToolInfo("jira__get_issue") require.True(t, ok) assert.Equal(t, ToolInfo{Name: "jira__get_issue", Description: "Get a Jira issue"}, info) + assert.True(t, store.IsUnloadedMCPTool("get_issue")) + info, ok = store.GetUnloadedMCPToolInfo("get_issue") + require.True(t, ok) + assert.Equal(t, ToolInfo{Name: "jira__get_issue", Description: "Get a Jira issue"}, info) store.AddTools([]Tool{{Name: "jira__get_issue", Description: "loaded", ServerOrigin: "https://jira.example.com"}}) assert.False(t, store.IsUnloadedMCPTool("jira__get_issue")) From 5225170dd0bea4160201c570c2cad9ce9f0bb91d Mon Sep 17 00:00:00 2001 From: Nick Misasi Date: Sat, 23 May 2026 12:08:30 -0400 Subject: [PATCH 4/7] dynamic mcp: keep channel analysis tools discoverable Co-authored-by: Cursor --- api/api_channel.go | 22 ++++++++++++++++------ api/api_no_tools_test.go | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/api/api_channel.go b/api/api_channel.go index c76f4de9d..28b0b8e3c 100644 --- a/api/api_channel.go +++ b/api/api_channel.go @@ -59,6 +59,7 @@ func (a *API) handleChannelAnalysis(c *gin.Context) { userID := c.GetHeader("Mattermost-User-Id") channel := c.MustGet(ContextChannelKey).(*model.Channel) bot := c.MustGet(ContextBotKey).(*bots.Bot) + toolBot := channelAnalysisToolBot(bot) var data struct { AnalysisType string `json:"analysis_type" binding:"required"` @@ -88,7 +89,7 @@ func (a *API) handleChannelAnalysis(c *gin.Context) { opts := []llm.ContextOption{ a.contextBuilder.WithLLMContextRequestContext(c.Request.Context()), - a.contextBuilder.WithLLMContextDefaultTools(bot), + a.contextBuilder.WithLLMContextDefaultTools(toolBot), } // If the channel is a DM/GM and we have a team ID from the client, use it for context @@ -121,16 +122,12 @@ func (a *API) handleChannelAnalysis(c *gin.Context) { // Check if read_channel tool is available availableTools := llmContext.Tools.GetTools() - hasReadChannel := false var toolNames []string for _, tool := range availableTools { toolNames = append(toolNames, tool.Name) - if tool.Name == "read_channel" { - hasReadChannel = true - } } - if !hasReadChannel { + if llmContext.Tools.GetTool("read_channel") == nil { a.pluginAPI.Log.Error("Channel analysis failed: read_channel tool not available", "userID", userID, "channelID", channel.Id, @@ -179,6 +176,19 @@ func (a *API) handleChannelAnalysis(c *gin.Context) { }) } +func channelAnalysisToolBot(bot *bots.Bot) *bots.Bot { + cfg := bot.GetConfig() + if cfg.AutoEnableNewMCPTools { + return bot + } + + // Channel analysis immediately scopes the catalog to bound read-only tools, + // so load the MCP catalog even when the agent uses a narrower allowlist. + cfg.AutoEnableNewMCPTools = true + cfg.EnabledMCPTools = nil + return bots.NewBot(cfg, bot.GetService(), bot.GetMMBot(), bot.LLM()) +} + func (a *API) handleInterval(c *gin.Context) { userID := c.GetHeader("Mattermost-User-Id") channel := c.MustGet(ContextChannelKey).(*model.Channel) diff --git a/api/api_no_tools_test.go b/api/api_no_tools_test.go index e85e44d53..822c35ed6 100644 --- a/api/api_no_tools_test.go +++ b/api/api_no_tools_test.go @@ -36,11 +36,12 @@ func (p *noToolsTestToolProvider) GetTools(*bots.Bot) []llm.Tool { type noToolsTestMCPProvider struct { calls int + tools []llm.Tool } func (p *noToolsTestMCPProvider) GetToolsForUser(context.Context, string) ([]llm.Tool, *mcp.Errors) { p.calls++ - return nil, nil + return p.tools, nil } type noToolsTestContextConfigProvider struct{} @@ -283,6 +284,39 @@ func TestHandleIntervalDoesNotLoadToolsWhenToolsAreDisabled(t *testing.T) { require.Equal(t, 0, mcpProvider.calls, "channel interval should not build MCP tools when the LLM call disables tools") } +func TestHandleChannelAnalysisAcceptsNamespacedEmbeddedTools(t *testing.T) { + gin.SetMode(gin.ReleaseMode) + gin.DefaultWriter = io.Discard + + mcpProvider := &noToolsTestMCPProvider{ + tools: []llm.Tool{ + {Name: "mattermost__read_channel", ServerOrigin: mcp.EmbeddedClientKey}, + {Name: "mattermost__get_channel_info", ServerOrigin: mcp.EmbeddedClientKey}, + }, + } + mmClient := mmapimocks.NewMockClient(t) + e, streamingService, _ := setupNoToolsAPI(t, mcpProvider, mmClient) + defer e.Cleanup(t) + + requestingUser := &model.User{Id: testUserID, Username: "requester", Locale: "en"} + channel := &model.Channel{Id: testChannelID, Type: model.ChannelTypeOpen, TeamId: "teamid"} + + e.mockAPI.On("GetChannel", testChannelID).Return(channel, nil) + e.mockAPI.On("GetTeam", "teamid").Return(&model.Team{Id: "teamid", Name: "team"}, nil) + e.mockAPI.On("HasPermissionToChannel", testUserID, testChannelID, model.PermissionReadChannel).Return(true) + e.mockAPI.On("GetUser", testUserID).Return(requestingUser, nil) + + request := httptest.NewRequest(http.MethodPost, "/channel/"+testChannelID+"/analyze", strings.NewReader(`{"analysis_type":"custom","prompt":"summarize this channel"}`)) + request.Header.Add("Mattermost-User-ID", testUserID) + + recorder := httptest.NewRecorder() + e.api.ServeHTTP(&plugin.Context{}, recorder, request) + + require.Equal(t, http.StatusOK, recorder.Result().StatusCode, recorder.Body.String()) + require.Equal(t, 1, streamingService.newDMCalls) + require.Equal(t, 1, mcpProvider.calls) +} + // TestHandleIntervalSetsConversationRootPostID verifies that after an // interval summary completes, the conversation's RootPostID is set to the // newly-created response post. Without this, interval summaries appear as From e9e6f798f0bd1b5af3fd818e4d603223de6138fa Mon Sep 17 00:00:00 2001 From: Nick Misasi Date: Wed, 27 May 2026 09:25:41 -0400 Subject: [PATCH 5/7] dynamic mcp: drop legacy tool-trace plumbing Removed upstream in PR #542 (5ab61bf3) when OpenTelemetry replaced log-based tracing. Restore master's no-arg NewToolStore signature and drop TraceLog, TraceUnknown, TraceResolved, and LogUnknownToolWarning while keeping the dynamic-MCP additions. --- llm/tools.go | 70 +-------------------- llm/tools_test.go | 111 +++------------------------------ mcp/user_clients_test.go | 2 +- mcpserver/eval_helpers_test.go | 2 +- telemetry/integration_test.go | 6 +- 5 files changed, 13 insertions(+), 178 deletions(-) diff --git a/llm/tools.go b/llm/tools.go index cec9422a1..73fc5e354 100644 --- a/llm/tools.go +++ b/llm/tools.go @@ -322,19 +322,9 @@ type ToolAuthError struct { type ToolStore struct { tools map[string]Tool unloadedMCPTools map[string]ToolInfo - log TraceLog - doTrace bool authErrors []ToolAuthError } -type TraceLog interface { - Info(message string, keyValuePairs ...any) -} - -type warnTraceLog interface { - Warn(message string, keyValuePairs ...any) -} - // NewJSONSchemaFromStruct creates a JSONSchema from a Go struct using generics // It's a helper function for tool providers that currently define schemas as structs func NewJSONSchemaFromStruct[T any]() *jsonschema.Schema { @@ -349,26 +339,13 @@ func NewJSONSchemaFromStruct[T any]() *jsonschema.Schema { func NewNoTools() *ToolStore { return &ToolStore{ tools: make(map[string]Tool), - log: nil, - doTrace: false, authErrors: []ToolAuthError{}, } } -func NewToolStore(options ...any) *ToolStore { - var log TraceLog - var doTrace bool - if len(options) > 0 { - log, _ = options[0].(TraceLog) - } - if len(options) > 1 { - doTrace, _ = options[1].(bool) - } - +func NewToolStore() *ToolStore { return &ToolStore{ tools: make(map[string]Tool), - log: log, - doTrace: doTrace, authErrors: []ToolAuthError{}, } } @@ -416,15 +393,12 @@ func (s *ToolStore) ResolveTool(ctx context.Context, name string, argsGetter Too tool, ok := s.lookupTool(name) if !ok { - s.LogUnknownToolWarning(name, argsGetter) - s.TraceUnknown(name, argsGetter) err := errors.New("unknown tool " + name) span.RecordError(err) span.SetStatus(otelcodes.Error, err.Error()) return "", err } result, err := tool.Resolver(llmCtx, argsGetter) - s.TraceResolved(name, argsGetter, result, err) if err != nil { span.RecordError(err) span.SetStatus(otelcodes.Error, err.Error()) @@ -683,48 +657,6 @@ func (s *ToolStore) GetToolsInfo() []ToolInfo { return result } -func (s *ToolStore) TraceUnknown(name string, argsGetter ToolArgumentGetter) { - if s.log != nil && s.doTrace { - s.log.Info("unknown tool called", "name", name, "args", toolArgsForLog(argsGetter)) - } -} - -func (s *ToolStore) TraceResolved(name string, argsGetter ToolArgumentGetter, result string, err error) { - if s.log != nil && s.doTrace { - s.log.Info("tool resolved", "name", name, "args", toolArgsForLog(argsGetter), "result", result, "error", err) - } -} - -// maxToolArgsLogBytes caps the size of the JSON arg snippet we emit to logs. -// Tool calls (especially failures) can carry large payloads; truncating keeps -// log output bounded without losing the diagnostic head of the args. -const maxToolArgsLogBytes = 512 - -func (s *ToolStore) LogUnknownToolWarning(name string, argsGetter ToolArgumentGetter) { - if s == nil || s.log == nil { - return - } - warnLog, ok := s.log.(warnTraceLog) - if !ok { - return - } - warnLog.Warn("unknown tool called", "name", name, "args", toolArgsForLog(argsGetter), "available_tool_count", len(s.tools)) -} - -func toolArgsForLog(argsGetter ToolArgumentGetter) string { - if argsGetter == nil { - return "{}" - } - var raw json.RawMessage - if err := argsGetter(&raw); err != nil { - return fmt.Sprintf("failed to get tool args: %v", err) - } - if len(raw) > maxToolArgsLogBytes { - return string(raw[:maxToolArgsLogBytes]) + "...(truncated)" - } - return string(raw) -} - // AddAuthError adds an authentication error to the tool store func (s *ToolStore) AddAuthError(authError ToolAuthError) { s.authErrors = append(s.authErrors, authError) diff --git a/llm/tools_test.go b/llm/tools_test.go index 8df33b633..4d1614458 100644 --- a/llm/tools_test.go +++ b/llm/tools_test.go @@ -6,7 +6,6 @@ package llm import ( "context" "encoding/json" - "errors" "sort" "testing" @@ -110,110 +109,14 @@ func TestSanitizeNonPrintableChars(t *testing.T) { } } -func TestToolArgsForLogNilGetter(t *testing.T) { - assert.Equal(t, "{}", toolArgsForLog(nil)) -} - -type logEntry struct { - message string - fields []any -} - -type captureToolLog struct { - infos []logEntry - warns []logEntry -} - -func (l *captureToolLog) Info(message string, keyValuePairs ...any) { - l.infos = append(l.infos, logEntry{message: message, fields: keyValuePairs}) -} - -func (l *captureToolLog) Warn(message string, keyValuePairs ...any) { - l.warns = append(l.warns, logEntry{message: message, fields: keyValuePairs}) -} - -type infoOnlyToolLog struct { - infos []logEntry -} - -func (l *infoOnlyToolLog) Info(message string, keyValuePairs ...any) { - l.infos = append(l.infos, logEntry{message: message, fields: keyValuePairs}) -} - -func logFields(entry logEntry) map[string]any { - fields := make(map[string]any, len(entry.fields)/2) - for i := 0; i+1 < len(entry.fields); i += 2 { - key, ok := entry.fields[i].(string) - if ok { - fields[key] = entry.fields[i+1] - } - } - return fields -} - func rawArgsGetter(raw string) ToolArgumentGetter { return func(args any) error { return json.Unmarshal([]byte(raw), args) } } -func TestResolveToolUnknownWarnsWithoutTrace(t *testing.T) { - log := &captureToolLog{} - store := NewToolStore(log, false) - - _, err := store.ResolveTool(context.Background(), "ghost_tool", rawArgsGetter(`{"query":"hello"}`), &Context{}) - - require.EqualError(t, err, "unknown tool ghost_tool") - require.Len(t, log.warns, 1) - assert.Empty(t, log.infos) - assert.Equal(t, "unknown tool called", log.warns[0].message) - fields := logFields(log.warns[0]) - assert.Equal(t, "ghost_tool", fields["name"]) - assert.Equal(t, `{"query":"hello"}`, fields["args"]) - assert.Equal(t, 0, fields["available_tool_count"]) -} - -func TestResolveToolUnknownPreservesTrace(t *testing.T) { - log := &captureToolLog{} - store := NewToolStore(log, true) - - _, err := store.ResolveTool(context.Background(), "ghost_tool", rawArgsGetter(`{"query":"hello"}`), &Context{}) - - require.EqualError(t, err, "unknown tool ghost_tool") - require.Len(t, log.warns, 1) - require.Len(t, log.infos, 1) - assert.Equal(t, "unknown tool called", log.warns[0].message) - assert.Equal(t, "unknown tool called", log.infos[0].message) - assert.Equal(t, `{"query":"hello"}`, logFields(log.infos[0])["args"]) -} - -func TestResolveToolUnknownWithInfoOnlyLoggerStillTracesWhenEnabled(t *testing.T) { - log := &infoOnlyToolLog{} - store := NewToolStore(log, true) - - _, err := store.ResolveTool(context.Background(), "ghost_tool", rawArgsGetter(`{"query":"hello"}`), &Context{}) - - require.EqualError(t, err, "unknown tool ghost_tool") - require.Len(t, log.infos, 1) - assert.Equal(t, "unknown tool called", log.infos[0].message) -} - -func TestResolveToolUnknownLogsArgumentGetterError(t *testing.T) { - log := &captureToolLog{} - store := NewToolStore(log, true) - argsErr := errors.New("bad arguments") - - _, err := store.ResolveTool(context.Background(), "ghost_tool", func(any) error { return argsErr }, &Context{}) - - require.EqualError(t, err, "unknown tool ghost_tool") - require.Len(t, log.warns, 1) - require.Len(t, log.infos, 1) - assert.Equal(t, "failed to get tool args: bad arguments", logFields(log.warns[0])["args"]) - assert.Equal(t, "failed to get tool args: bad arguments", logFields(log.infos[0])["args"]) -} - func TestResolveToolUsesUniqueBareMCPToolName(t *testing.T) { - store := NewToolStore(nil, false) + store := NewToolStore() store.AddTools([]Tool{{ Name: "jira__get_issue", ServerOrigin: "https://mcp.atlassian.com", @@ -229,7 +132,7 @@ func TestResolveToolUsesUniqueBareMCPToolName(t *testing.T) { } func TestResolveToolBareMCPToolNameAmbiguous(t *testing.T) { - store := NewToolStore(nil, false) + store := NewToolStore() store.AddTools([]Tool{ { Name: "jira__search", @@ -253,7 +156,7 @@ func TestResolveToolBareMCPToolNameAmbiguous(t *testing.T) { } func TestGetToolKnownAndUnknown(t *testing.T) { - store := NewToolStore(nil, false) + store := NewToolStore() store.AddTools([]Tool{{ Name: "known", Resolver: func(_ *Context, _ ToolArgumentGetter) (string, error) { @@ -266,7 +169,7 @@ func TestGetToolKnownAndUnknown(t *testing.T) { } func TestGetToolUsesUniqueBareMCPToolName(t *testing.T) { - store := NewToolStore(nil, false) + store := NewToolStore() store.AddTools([]Tool{{ Name: "jira__get_issue", ServerOrigin: "https://mcp.atlassian.com", @@ -377,7 +280,7 @@ func TestGetServerOrigin(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - store := NewToolStore(nil, false) + store := NewToolStore() store.AddTools(tc.tools) result := store.GetServerOrigin(tc.lookupName) assert.Equal(t, tc.expectedURL, result) @@ -569,7 +472,7 @@ func TestRemoveToolsByServerOrigin(t *testing.T) { t.Run(tc.name, func(t *testing.T) { var store *ToolStore if tc.tools != nil { - store = NewToolStore(nil, false) + store = NewToolStore() store.AddTools(tc.tools) } @@ -755,7 +658,7 @@ func TestRetainOnlyMCPTools(t *testing.T) { return } - s := NewToolStore(nil, false) + s := NewToolStore() s.AddTools(tt.tools) s.RetainOnlyMCPTools(tt.allowlist) diff --git a/mcp/user_clients_test.go b/mcp/user_clients_test.go index d6ceb80a5..c8104a4bb 100644 --- a/mcp/user_clients_test.go +++ b/mcp/user_clients_test.go @@ -327,7 +327,7 @@ func TestUserClientsGetToolsResolverRequiresRequestContext(t *testing.T) { func TestPrepareToolCallMetadata_EmbeddedMergesCallMetadataAndBotUserID(t *testing.T) { llmContext := llm.NewContext() llmContext.BotUserID = "bot-user-id" - llmContext.Tools = llm.NewToolStore(nil, false) + llmContext.Tools = llm.NewToolStore() llmContext.Tools.AddTools([]llm.Tool{ llm.Tool{Name: "search_posts"}.WithCallMetadata(map[string]any{ "tool_hooks": map[string]any{ diff --git a/mcpserver/eval_helpers_test.go b/mcpserver/eval_helpers_test.go index deb0d24ad..fef0ad6d6 100644 --- a/mcpserver/eval_helpers_test.go +++ b/mcpserver/eval_helpers_test.go @@ -563,7 +563,7 @@ func setupAgenticEval(t *testing.T, e *evals.EvalT, suite *TestSuite, requesting allToolNames[i] = tool.Name } - toolStore := llm.NewToolStore(nil, false) + toolStore := llm.NewToolStore() toolStore.AddTools(mcpTools) llmContext := llm.NewContext() diff --git a/telemetry/integration_test.go b/telemetry/integration_test.go index 4f0f26fde..2d78b56f0 100644 --- a/telemetry/integration_test.go +++ b/telemetry/integration_test.go @@ -252,7 +252,7 @@ func TestToolResolveSpan(t *testing.T) { exporter, cleanup := setupTestTracing(t) defer cleanup() - store := llm.NewToolStore(nil, false) + store := llm.NewToolStore() store.AddTools([]llm.Tool{ { Name: "test_tool", @@ -291,7 +291,7 @@ func TestToolResolveUnknownSpan(t *testing.T) { exporter, cleanup := setupTestTracing(t) defer cleanup() - store := llm.NewToolStore(nil, false) + store := llm.NewToolStore() _, err := store.ResolveTool(context.Background(), "nonexistent", func(args any) error { return nil @@ -380,7 +380,7 @@ func TestFullRequestTrace(t *testing.T) { llmSpan.End() // Tool resolution - store := llm.NewToolStore(nil, false) + store := llm.NewToolStore() store.AddTools([]llm.Tool{ { Name: "web_search", From d317501612d9cee6737214d390359f27175a67de Mon Sep 17 00:00:00 2001 From: Nick Misasi Date: Wed, 27 May 2026 14:42:13 -0400 Subject: [PATCH 6/7] dynamic mcp: address client catalog follow-up feedback Co-authored-by: Cursor --- api/api_llm_bridge_test.go | 16 +- channels/analysis_conversation_test.go | 4 +- conversations/bot_channel_tool_filter_test.go | 21 +- conversations/conversations_test.go | 2 +- conversations/dm_conversation_test.go | 4 +- llm/context.go | 105 +++++-- llm/context_test.go | 24 +- llm/tools.go | 14 +- llm/tools_test.go | 10 +- llmcontext/llm_context.go | 2 +- mcp/client.go | 204 +++----------- mcp/client_embedded_oauth_test.go | 85 +----- mcp/client_test.go | 265 ++---------------- mcp/user_clients.go | 20 +- mcp/user_clients_test.go | 42 +-- mcpserver/eval_helpers_test.go | 6 +- mmtools/web_search.go | 18 +- mmtools/web_search_test.go | 6 +- ...tandard_personality_without_locale_test.go | 7 +- telemetry/integration_test.go | 4 +- toolrunner/toolrunner_test.go | 2 +- 21 files changed, 263 insertions(+), 598 deletions(-) diff --git a/api/api_llm_bridge_test.go b/api/api_llm_bridge_test.go index 0a375d24a..7deb07d2c 100644 --- a/api/api_llm_bridge_test.go +++ b/api/api_llm_bridge_test.go @@ -1332,7 +1332,7 @@ func (e *TestEnvironment) setupMCPWithEligibleTools(t *testing.T, toolNames []st ServerOrigin: server.URL, Description: name, Schema: llm.NewJSONSchemaFromStruct[struct{}](), - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return "ok", nil }, } @@ -1415,7 +1415,7 @@ func TestBridgeGetAgentToolsReturnsEligibleOnly(t *testing.T) { ServerOrigin: server.URL, Description: "eligible from context", Schema: llm.NewJSONSchemaFromStruct[struct{}](), - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return "ok", nil }, }, @@ -1424,7 +1424,7 @@ func TestBridgeGetAgentToolsReturnsEligibleOnly(t *testing.T) { ServerOrigin: server.URL, Description: "should be filtered out", Schema: llm.NewJSONSchemaFromStruct[struct{}](), - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return "ok", nil }, }, @@ -1478,7 +1478,7 @@ func TestBridgeGetAgentToolsReturnsEmbeddedServerTools(t *testing.T) { ServerOrigin: mcp.EmbeddedClientKey, Description: "tool from embedded server", Schema: llm.NewJSONSchemaFromStruct[struct{}](), - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return "ok", nil }, }, @@ -1542,7 +1542,7 @@ func TestBridgeGetAgentToolsSkipsUnreachableEligibleServer(t *testing.T) { ServerOrigin: server.URL, Description: "eligible from context", Schema: llm.NewJSONSchemaFromStruct[struct{}](), - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return "ok", nil }, }, @@ -2015,7 +2015,7 @@ func TestBridgeClientAgentCompletionRejectsBuiltinToolInAllowedTools(t *testing. ServerOrigin: server.URL, Description: "eligible_tool", Schema: llm.NewJSONSchemaFromStruct[struct{}](), - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return "ok", nil }, }, @@ -2023,7 +2023,7 @@ func TestBridgeClientAgentCompletionRejectsBuiltinToolInAllowedTools(t *testing. Name: "builtin_only", Description: "built-in tool with no MCP origin", Schema: llm.NewJSONSchemaFromStruct[struct{}](), - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return "ok", nil }, }, @@ -2191,7 +2191,7 @@ func TestBridgeGetAgentToolsReturnsEmptyWhenMCPDisabled(t *testing.T) { Name: "context_only_tool", Description: "should not be bridge-eligible without MCP", Schema: llm.NewJSONSchemaFromStruct[struct{}](), - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return "ok", nil }, }, diff --git a/channels/analysis_conversation_test.go b/channels/analysis_conversation_test.go index 746a38212..20e0d26b3 100644 --- a/channels/analysis_conversation_test.go +++ b/channels/analysis_conversation_test.go @@ -229,7 +229,7 @@ func makeTool(name, result string) llm.Tool { return llm.Tool{ Name: name, Description: "test tool", - Resolver: func(_ *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { return result, nil }, } @@ -240,7 +240,7 @@ func makeToolWithError(name, errMsg string) llm.Tool { return llm.Tool{ Name: name, Description: "test tool that errors", - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return "", fmt.Errorf("%s", errMsg) }, } diff --git a/conversations/bot_channel_tool_filter_test.go b/conversations/bot_channel_tool_filter_test.go index 9098eaf3f..50fadb308 100644 --- a/conversations/bot_channel_tool_filter_test.go +++ b/conversations/bot_channel_tool_filter_test.go @@ -4,6 +4,7 @@ package conversations import ( + "context" "testing" "github.com/mattermost/mattermost-plugin-agents/llm" @@ -47,10 +48,10 @@ func TestApplyBotChannelAutoEverywhereToolFilter(t *testing.T) { Tools: llm.NewToolStore(), } llmContext.Tools.AddTools([]llm.Tool{ - {Name: "builtin", ServerOrigin: "", Resolver: func(*llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, - {Name: "everywhere_tool", ServerOrigin: origin, Resolver: func(*llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, - {Name: "auto_run_tool", ServerOrigin: origin, Resolver: func(*llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, - {Name: "ask_tool", ServerOrigin: origin, Resolver: func(*llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, + {Name: "builtin", ServerOrigin: "", Resolver: func(context.Context, *llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, + {Name: "everywhere_tool", ServerOrigin: origin, Resolver: func(context.Context, *llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, + {Name: "auto_run_tool", ServerOrigin: origin, Resolver: func(context.Context, *llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, + {Name: "ask_tool", ServerOrigin: origin, Resolver: func(context.Context, *llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, }) c.applyBotChannelAutoEverywhereToolFilter(llmContext) @@ -75,7 +76,7 @@ func TestApplyBotChannelAutoEverywhereToolFilter_NamespacedToolUsesBarePolicy(t Tools: llm.NewToolStore(), } llmContext.Tools.AddTools([]llm.Tool{ - {Name: "server__everywhere_tool", ServerOrigin: origin, Resolver: func(*llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, + {Name: "server__everywhere_tool", ServerOrigin: origin, Resolver: func(context.Context, *llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, }) c.applyBotChannelAutoEverywhereToolFilter(llmContext) @@ -100,9 +101,9 @@ func TestApplyToolAvailabilityBeforeBotChannelFilterPreservesDisabledToolsInfo(t Tools: llm.NewToolStore(), } llmContext.Tools.AddTools([]llm.Tool{ - {Name: "builtin", Description: "builtin tool", ServerOrigin: "", Resolver: func(*llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, - {Name: "everywhere_tool", Description: "auto everywhere", ServerOrigin: origin, Resolver: func(*llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, - {Name: "ask_tool", Description: "needs approval", ServerOrigin: origin, Resolver: func(*llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, + {Name: "builtin", Description: "builtin tool", ServerOrigin: "", Resolver: func(context.Context, *llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, + {Name: "everywhere_tool", Description: "auto everywhere", ServerOrigin: origin, Resolver: func(context.Context, *llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, + {Name: "ask_tool", Description: "needs approval", ServerOrigin: origin, Resolver: func(context.Context, *llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, }) toolsDisabled := applyToolAvailability(llmContext, false, true) @@ -129,8 +130,8 @@ func TestApplyBotChannelAutoEverywhereToolFilter_nilCheckerFailClosed(t *testing Tools: llm.NewToolStore(), } llmContext.Tools.AddTools([]llm.Tool{ - {Name: "builtin", ServerOrigin: "", Resolver: func(*llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, - {Name: "mcp_tool", ServerOrigin: origin, Resolver: func(*llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, + {Name: "builtin", ServerOrigin: "", Resolver: func(context.Context, *llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, + {Name: "mcp_tool", ServerOrigin: origin, Resolver: func(context.Context, *llm.Context, llm.ToolArgumentGetter) (string, error) { return "", nil }}, }) c.applyBotChannelAutoEverywhereToolFilter(llmContext) diff --git a/conversations/conversations_test.go b/conversations/conversations_test.go index 3405e1bcb..2f894ab03 100644 --- a/conversations/conversations_test.go +++ b/conversations/conversations_test.go @@ -37,7 +37,7 @@ func (m *mockToolProvider) GetTools(bot *bots.Bot) []llm.Tool { Name: "WebSearch", Description: "Search the web for information.", Schema: llm.NewJSONSchemaFromStruct[struct{ Term string }](), - Resolver: func(context *llm.Context, args llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, args llm.ToolArgumentGetter) (string, error) { return "No results found.", nil }, }, diff --git a/conversations/dm_conversation_test.go b/conversations/dm_conversation_test.go index 0b7742f47..021c0f975 100644 --- a/conversations/dm_conversation_test.go +++ b/conversations/dm_conversation_test.go @@ -602,7 +602,7 @@ func TestDMAutoRunTools_ToolRunnerExecutesAndWritesTurns(t *testing.T) { Name: "get_weather", Description: "Gets the weather", ServerOrigin: "https://mcp.example.com", - Resolver: func(ctx *llm.Context, args llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, args llm.ToolArgumentGetter) (string, error) { return "72F and sunny", nil }, }, @@ -893,7 +893,7 @@ func TestDMToolSharedFlag_AlwaysTrue(t *testing.T) { Name: "tool_a", Description: "A tool", ServerOrigin: "https://example.com", - Resolver: func(ctx *llm.Context, args llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, args llm.ToolArgumentGetter) (string, error) { return "result", nil }, }, diff --git a/llm/context.go b/llm/context.go index 9981c8c8c..1f3b2dfb9 100644 --- a/llm/context.go +++ b/llm/context.go @@ -20,7 +20,7 @@ type ToolInfo struct { ServerOrigin string } -// Context represents the data necessary to build the context of the LLM. +// Context represents the per-turn data necessary to build the context of the LLM. // For consumers none of the fields can be assumed to be present. type Context struct { // Server @@ -56,6 +56,13 @@ type Context struct { DisabledToolsInfo []ToolInfo // Info about tools that are unavailable in the current context (e.g., DM-only tools in a channel) Parameters map[string]interface{} + // ToolRuntime holds non-prompt tool execution state for this turn. + ToolRuntime ToolRuntimeContext +} + +// ToolRuntimeContext holds request-scoped tool runtime state that should not be +// rendered into the prompt. +type ToolRuntimeContext struct { // MCPDynamicToolLoading indicates this context uses strict MCP dynamic loading. MCPDynamicToolLoading bool // MCPDynamicToolTelemetry receives low-cardinality dynamic MCP tool events. @@ -77,16 +84,7 @@ type Context struct { // catalog and are request scoped. PreloadedMCPTools []EnabledMCPTool - // MCPToolRegistry holds the strict MCP tool registry that was built - // alongside Tools, when MCP dynamic tool loading is enabled. It is stashed - // here so callers can replay loaded-tool restoration after the conversation - // row exists without rebuilding the entire tool store. - // - // Stored as `any` to avoid an llm -> mcp import cycle: the mcp package - // already imports llm, and the only consumer that needs the concrete type - // is the llmcontext package, which can import both. Type-assert to - // *mcp.ToolRegistry there. - MCPToolRegistry any + restoreMCPDynamicTools func(names []string) } type MCPDynamicToolTelemetry interface { @@ -144,7 +142,7 @@ func (c *Context) CustomPromptVars() map[string]string { } func (c *Context) ObserveMCPDynamicToolEvent(event, result string) { - if c == nil || c.MCPDynamicToolTelemetry == nil { + if c == nil { return } @@ -156,37 +154,98 @@ func (c *Context) ObserveMCPDynamicToolEvent(event, result string) { botName = "unknown" } - c.MCPDynamicToolTelemetry.ObserveMCPDynamicToolEvent(botName, event, result) + c.ToolRuntime.ObserveMCPDynamicToolEvent(botName, event, result) +} + +func (t *ToolRuntimeContext) ObserveMCPDynamicToolEvent(botName, event, result string) { + if t == nil || t.MCPDynamicToolTelemetry == nil { + return + } + + t.MCPDynamicToolTelemetry.ObserveMCPDynamicToolEvent(botName, event, result) } func (c *Context) MarkMCPDynamicToolSearch() { if c == nil { return } - c.MCPDynamicToolSearchUsed = true + c.ToolRuntime.MarkMCPDynamicToolSearch() +} + +func (t *ToolRuntimeContext) MarkMCPDynamicToolSearch() { + if t == nil { + return + } + t.MCPDynamicToolSearchUsed = true } func (c *Context) MarkMCPDynamicToolLoaded(name string) { - if c == nil || name == "" { + if c == nil { + return + } + c.ToolRuntime.MarkMCPDynamicToolLoaded(name) +} + +func (t *ToolRuntimeContext) MarkMCPDynamicToolLoaded(name string) { + if t == nil || name == "" { + return + } + if t.MCPDynamicLoadedToolNames == nil { + t.MCPDynamicLoadedToolNames = make(map[string]bool) + } + t.MCPDynamicLoadedToolNames[name] = true +} + +// RestoreMCPDynamicTools materializes the named MCP tools into c.Tools. +func (c *Context) RestoreMCPDynamicTools(names []string) { + if c == nil { + return + } + c.ToolRuntime.RestoreMCPDynamicTools(names) +} + +// RestoreMCPDynamicTools materializes the named MCP tools into the active tool store. +func (t *ToolRuntimeContext) RestoreMCPDynamicTools(names []string) { + if t == nil || t.restoreMCPDynamicTools == nil || len(names) == 0 { + return + } + t.restoreMCPDynamicTools(names) +} + +// SetMCPDynamicToolRestorer installs the strict MCP tool restorer. +func (c *Context) SetMCPDynamicToolRestorer(fn func(names []string)) { + if c == nil { return } - if c.MCPDynamicLoadedToolNames == nil { - c.MCPDynamicLoadedToolNames = make(map[string]bool) + c.ToolRuntime.SetMCPDynamicToolRestorer(fn) +} + +// SetMCPDynamicToolRestorer installs the strict MCP tool restorer. +func (t *ToolRuntimeContext) SetMCPDynamicToolRestorer(fn func(names []string)) { + if t == nil { + return } - c.MCPDynamicLoadedToolNames[name] = true + t.restoreMCPDynamicTools = fn } func (c *Context) ShouldRecordMCPDynamicSearchLoadCallSuccess(name string) bool { - if c == nil || name == "" || !c.MCPDynamicToolSearchUsed || !c.MCPDynamicLoadedToolNames[name] { + if c == nil { + return false + } + return c.ToolRuntime.ShouldRecordMCPDynamicSearchLoadCallSuccess(name) +} + +func (t *ToolRuntimeContext) ShouldRecordMCPDynamicSearchLoadCallSuccess(name string) bool { + if t == nil || name == "" || !t.MCPDynamicToolSearchUsed || !t.MCPDynamicLoadedToolNames[name] { return false } - if c.MCPDynamicSearchLoadCallSuccessRecorded == nil { - c.MCPDynamicSearchLoadCallSuccessRecorded = make(map[string]bool) + if t.MCPDynamicSearchLoadCallSuccessRecorded == nil { + t.MCPDynamicSearchLoadCallSuccessRecorded = make(map[string]bool) } - if c.MCPDynamicSearchLoadCallSuccessRecorded[name] { + if t.MCPDynamicSearchLoadCallSuccessRecorded[name] { return false } - c.MCPDynamicSearchLoadCallSuccessRecorded[name] = true + t.MCPDynamicSearchLoadCallSuccessRecorded[name] = true return true } diff --git a/llm/context_test.go b/llm/context_test.go index 1e99e4830..540597a9f 100644 --- a/llm/context_test.go +++ b/llm/context_test.go @@ -148,7 +148,7 @@ func TestContextObserveMCPDynamicToolEventBotLabelFallbacks(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { telemetry := &fakeMCPDynamicTelemetry{} - tt.context.MCPDynamicToolTelemetry = telemetry + tt.context.ToolRuntime.MCPDynamicToolTelemetry = telemetry tt.context.ObserveMCPDynamicToolEvent("search", "success") @@ -169,3 +169,25 @@ func TestContextMCPDynamicSearchLoadCallSuccessState(t *testing.T) { assert.True(t, c.ShouldRecordMCPDynamicSearchLoadCallSuccess("jira__get_issue")) assert.False(t, c.ShouldRecordMCPDynamicSearchLoadCallSuccess("jira__get_issue")) } + +func TestContextRestoreMCPDynamicTools(t *testing.T) { + var nilContext *Context + nilContext.RestoreMCPDynamicTools([]string{"jira__get_issue"}) + nilContext.SetMCPDynamicToolRestorer(func([]string) { + t.Fatal("nil context should not install a restorer") + }) + + c := &Context{} + c.RestoreMCPDynamicTools([]string{"jira__get_issue"}) + + var restored []string + c.SetMCPDynamicToolRestorer(func(names []string) { + restored = append(restored, names...) + }) + + c.RestoreMCPDynamicTools(nil) + assert.Empty(t, restored) + + c.RestoreMCPDynamicTools([]string{"jira__get_issue"}) + assert.Equal(t, []string{"jira__get_issue"}, restored) +} diff --git a/llm/tools.go b/llm/tools.go index 73fc5e354..be84b1b24 100644 --- a/llm/tools.go +++ b/llm/tools.go @@ -25,7 +25,7 @@ import ( // It is the Resolver function that implements the actual functionality. // // The Schema field should contain a JSONSchema that defines the expected structure of the tool's arguments. -// The Resolver function receives the conversation context and a way to access the parsed arguments, +// The Resolver function receives the request context, conversation context, and parsed arguments, // and returns either a result that will be passed to the LLM or an error. type Tool struct { Name string @@ -45,7 +45,7 @@ type Tool struct { CallMetadata map[string]any } -type ToolResolver func(context *Context, argsGetter ToolArgumentGetter) (string, error) +type ToolResolver func(ctx context.Context, llmCtx *Context, argsGetter ToolArgumentGetter) (string, error) // WithBoundParams creates a new Tool with parameters bound to fixed values. // Bound parameters are: @@ -121,7 +121,7 @@ func wrapResolverWithBoundParams(original ToolResolver, params map[string]interf return original } - return func(context *Context, argsGetter ToolArgumentGetter) (string, error) { + return func(ctx context.Context, llmCtx *Context, argsGetter ToolArgumentGetter) (string, error) { wrappedGetter := func(args any) error { // First unmarshal the original args if err := argsGetter(args); err != nil { @@ -130,7 +130,7 @@ func wrapResolverWithBoundParams(original ToolResolver, params map[string]interf // Then inject bound params return injectBoundParams(args, params) } - return original(context, wrappedGetter) + return original(ctx, llmCtx, wrappedGetter) } } @@ -353,9 +353,7 @@ func NewToolStore() *ToolStore { func (s *ToolStore) AddTools(tools []Tool) { for _, tool := range tools { s.tools[tool.Name] = tool - if s.unloadedMCPTools != nil { - delete(s.unloadedMCPTools, tool.Name) - } + delete(s.unloadedMCPTools, tool.Name) } } @@ -398,7 +396,7 @@ func (s *ToolStore) ResolveTool(ctx context.Context, name string, argsGetter Too span.SetStatus(otelcodes.Error, err.Error()) return "", err } - result, err := tool.Resolver(llmCtx, argsGetter) + result, err := tool.Resolver(ctx, llmCtx, argsGetter) if err != nil { span.RecordError(err) span.SetStatus(otelcodes.Error, err.Error()) diff --git a/llm/tools_test.go b/llm/tools_test.go index 4d1614458..abf2ced78 100644 --- a/llm/tools_test.go +++ b/llm/tools_test.go @@ -120,7 +120,7 @@ func TestResolveToolUsesUniqueBareMCPToolName(t *testing.T) { store.AddTools([]Tool{{ Name: "jira__get_issue", ServerOrigin: "https://mcp.atlassian.com", - Resolver: func(_ *Context, _ ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *Context, _ ToolArgumentGetter) (string, error) { return "issue result", nil }, }}) @@ -137,14 +137,14 @@ func TestResolveToolBareMCPToolNameAmbiguous(t *testing.T) { { Name: "jira__search", ServerOrigin: "https://mcp.atlassian.com", - Resolver: func(_ *Context, _ ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *Context, _ ToolArgumentGetter) (string, error) { return "jira", nil }, }, { Name: "github__search", ServerOrigin: "https://api.githubcopilot.com", - Resolver: func(_ *Context, _ ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *Context, _ ToolArgumentGetter) (string, error) { return "github", nil }, }, @@ -159,7 +159,7 @@ func TestGetToolKnownAndUnknown(t *testing.T) { store := NewToolStore() store.AddTools([]Tool{{ Name: "known", - Resolver: func(_ *Context, _ ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *Context, _ ToolArgumentGetter) (string, error) { return "ok", nil }, }}) @@ -354,7 +354,7 @@ func TestWithBoundParamsPreservesServerOrigin(t *testing.T) { Name: "test_tool", Description: "A test tool", ServerOrigin: "https://mcp.example.com", - Resolver: func(_ *Context, _ ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *Context, _ ToolArgumentGetter) (string, error) { return "result", nil }, } diff --git a/llmcontext/llm_context.go b/llmcontext/llm_context.go index 7bddead63..f5e9e0c37 100644 --- a/llmcontext/llm_context.go +++ b/llmcontext/llm_context.go @@ -256,7 +256,7 @@ func (b *Builder) WithLLMContextDefaultTools(bot *bots.Bot) llm.ContextOption { } // WithLLMContextRequestContext threads request-scoped cancellation/deadlines into -// MCP discovery and tool execution. +// MCP tool discovery. func (b *Builder) WithLLMContextRequestContext(ctx stdcontext.Context) llm.ContextOption { return func(c *llm.Context) { c.RequestContext = ctx diff --git a/mcp/client.go b/mcp/client.go index faa99d8d7..b411962ab 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -12,7 +12,6 @@ import ( "net/url" "strings" "sync" - "sync/atomic" "time" "github.com/mattermost/mattermost-plugin-agents/config" @@ -57,22 +56,17 @@ type EmbeddedServerClient struct { // Client represents the connection to a single MCP server type Client struct { - session *mcp.ClientSession - config ServerConfig - toolsMu sync.RWMutex - discoveryMu sync.Mutex - tools map[string]*mcp.Tool - toolsDirty bool - toolsGeneration uint64 - notifyOwnerMu sync.RWMutex - notifyOwner *Client - userID string - log pluginapi.LogService - oauthManager *OAuthManager - httpClient *http.Client - toolsCache *ToolsCache - embeddedClient *EmbeddedServerClient // for reconnection (nil for remote servers) - sessionID string // session ID for embedded server reconnection + session *mcp.ClientSession + config ServerConfig + toolsMu sync.RWMutex + tools map[string]*mcp.Tool + userID string + log pluginapi.LogService + oauthManager *OAuthManager + httpClient *http.Client + toolsCache *ToolsCache + embeddedClient *EmbeddedServerClient // for reconnection (nil for remote servers) + sessionID string // session ID for embedded server reconnection } // staticOAuthCreds returns static OAuth credentials from a server config, or nil if not configured. @@ -182,21 +176,13 @@ func (c *EmbeddedServerClient) CreateClient(ctx context.Context, userID, session return nil, fmt.Errorf("failed to create in-memory transport: %w", err) } - var clientPtr atomic.Pointer[Client] - // Create MCP client mcpClient := mcp.NewClient( &mcp.Implementation{ Name: "mattermost-agents-embedded", Version: "1.0", }, - &mcp.ClientOptions{ - ToolListChangedHandler: func(ctx context.Context, _ *mcp.ToolListChangedRequest) { - if cl := clientPtr.Load(); cl != nil { - cl.notificationOwner().invalidateDiscoveredTools(ctx, c.toolsCache, EmbeddedClientKey, c.toolsCache != nil) - } - }, - }, + nil, ) // Connect to the embedded server using in-memory transport @@ -217,8 +203,6 @@ func (c *EmbeddedServerClient) CreateClient(ctx context.Context, userID, session embeddedClient: c, // Store client helper for reconnection sessionID: sessionID, // Store session ID for reconnection } - clientPtr.Store(client) - // Initialize tools discoveredTools, err := listAllTools(ctx, mcpSession) if err != nil { @@ -328,8 +312,7 @@ func NewClient(ctx context.Context, userID string, serverConfig ServerConfig, lo } // NewPluginClient creates a per-user MCP client for a plugin-registered server. -// Plugin clients use listAllTools and ToolListChangedHandler like other clients, -// but do not use the shared tools cache. +// Plugin clients list tools at connect time and do not use the shared tools cache. func NewPluginClient(ctx context.Context, userID string, cfg PluginServerConfig, sourcePluginAPI mmapi.Client, log pluginapi.LogService) (*Client, error) { if sourcePluginAPI == nil { return nil, fmt.Errorf("sourcePluginAPI is nil; plugin MCP server %s cannot be reached", cfg.PluginID) @@ -358,21 +341,12 @@ func NewPluginClient(ctx context.Context, userID string, cfg PluginServerConfig, httpClient: httpClient, } - var clientPtr atomic.Pointer[Client] - clientPtr.Store(client) - mcpClient := mcp.NewClient( &mcp.Implementation{ Name: "mattermost-agents-plugin-bridge", Version: "1.0", }, - &mcp.ClientOptions{ - ToolListChangedHandler: func(ctx context.Context, _ *mcp.ToolListChangedRequest) { - if cl := clientPtr.Load(); cl != nil { - cl.invalidateDiscoveredTools(ctx, nil, pluginCfg.Name, false) - } - }, - }, + nil, ) session, err := mcpClient.Connect(ctx, &mcp.StreamableClientTransport{ @@ -485,11 +459,7 @@ func (c *Client) createSession(ctx context.Context, serverConfig ServerConfig) ( Name: "mattermost-agents", Version: "1.0", }, - &mcp.ClientOptions{ - ToolListChangedHandler: func(ctx context.Context, _ *mcp.ToolListChangedRequest) { - c.invalidateDiscoveredTools(ctx, c.toolsCache, serverConfig.Name, shouldUseSharedToolsCache(serverConfig)) - }, - }, + nil, ) httpClient := c.httpClientForMCP(headers) @@ -529,117 +499,6 @@ func (c *Client) createSession(ctx context.Context, serverConfig ServerConfig) ( return nil, fmt.Errorf("failed to connect to MCP server %s, Streamable HTTP: %w, SSE: %w", c.config.Name, errStreamable, errSSE) } -func (c *Client) invalidateDiscoveredTools(ctx context.Context, toolsCache *ToolsCache, serverID string, useSharedToolsCache bool) { - c.toolsMu.Lock() - c.tools = make(map[string]*mcp.Tool) - c.toolsDirty = true - c.toolsGeneration++ - c.toolsMu.Unlock() - - if toolsCache != nil && useSharedToolsCache { - if err := toolsCache.InvalidateServer(serverID); err != nil { - c.log.Warn("Failed to invalidate MCP tools after list_changed notification", - "serverID", serverID, - "server", c.config.Name, - "userID", c.userID, - "error", err) - return - } - } - - c.log.Debug("Invalidated MCP tools after list_changed notification", - "serverID", serverID, - "server", c.config.Name, - "userID", c.userID) -} - -func (c *Client) notificationOwner() *Client { - c.notifyOwnerMu.RLock() - defer c.notifyOwnerMu.RUnlock() - if c.notifyOwner == nil { - return c - } - return c.notifyOwner -} - -func (c *Client) setNotificationOwner(owner *Client) { - c.notifyOwnerMu.Lock() - c.notifyOwner = owner - c.notifyOwnerMu.Unlock() -} - -func (c *Client) ensureDiscoveredTools(ctx context.Context) error { - c.toolsMu.RLock() - dirty := c.toolsDirty - c.toolsMu.RUnlock() - if !dirty { - return nil - } - - c.discoveryMu.Lock() - defer c.discoveryMu.Unlock() - - c.toolsMu.RLock() - dirty = c.toolsDirty - session := c.session - generation := c.toolsGeneration - c.toolsMu.RUnlock() - if !dirty { - return nil - } - if session == nil { - return fmt.Errorf("MCP client not connected") - } - - discoveredTools, err := listAllTools(ctx, session) - if err != nil { - return fmt.Errorf("failed to list tools: %w", err) - } - - c.toolsMu.Lock() - if c.toolsGeneration != generation { - c.toolsMu.Unlock() - c.log.Debug("MCP tools changed during rediscovery; leaving catalog dirty", - "server", c.config.Name, - "userID", c.userID) - return nil - } - c.tools = discoveredTools - c.toolsDirty = false - c.toolsMu.Unlock() - - if c.toolsCache != nil && shouldUseSharedToolsCache(c.config) { - if err := c.toolsCache.SetTools(c.config.Name, c.config.Name, c.config.BaseURL, discoveredTools, time.Now()); err != nil { - c.log.Warn("Failed to update tools cache after list_changed rediscovery", - "server", c.config.Name, - "userID", c.userID, - "error", err) - } - c.toolsMu.RLock() - cacheGenerationChanged := c.toolsGeneration != generation - c.toolsMu.RUnlock() - if cacheGenerationChanged { - if err := c.toolsCache.InvalidateServer(c.config.Name); err != nil { - c.log.Warn("Failed to invalidate MCP tools cache after concurrent list_changed notification", - "server", c.config.Name, - "userID", c.userID, - "error", err) - } - c.log.Debug("MCP tools changed during cache refresh; leaving catalog dirty", - "server", c.config.Name, - "userID", c.userID) - return nil - } - } - - c.log.Debug("Rediscovered MCP tools after list_changed notification", - "server", c.config.Name, - "userID", c.userID, - "toolCount", len(discoveredTools)) - - return nil -} - func (c *Client) oauthStartURL() string { if c.oauthManager == nil { return "" @@ -728,20 +587,41 @@ func (c *Client) CallToolWithMetadata(ctx context.Context, toolName string, args return "", fmt.Errorf("failed to reconnect to embedded MCP server: %w", reconnectErr) } - // Update session and tools from the new client - newClient.setNotificationOwner(c) c.toolsMu.Lock() c.session = newClient.session c.tools = newClient.Tools() - c.toolsDirty = false c.toolsMu.Unlock() c.log.Debug("Successfully reconnected to embedded MCP server", "userID", c.userID) } else { // Reconnect to remote server - c.session, err = c.createSession(ctx, c.config) - if err != nil { - return "", fmt.Errorf("failed to reconnect to MCP server %s: %w", c.config.Name, err) + newSession, reconnectErr := c.createSession(ctx, c.config) + if reconnectErr != nil { + return "", fmt.Errorf("failed to reconnect to MCP server %s: %w", c.config.Name, reconnectErr) + } + discoveredTools, listErr := listAllTools(ctx, newSession) + if listErr != nil { + newSession.Close() + return "", fmt.Errorf("failed to list tools after reconnecting to MCP server %s: %w", c.config.Name, listErr) + } + if len(discoveredTools) == 0 { + newSession.Close() + return "", fmt.Errorf("no tools found after reconnecting to MCP server %s for user %s", c.config.Name, c.userID) + } + + c.toolsMu.Lock() + c.session = newSession + c.tools = discoveredTools + c.toolsMu.Unlock() + + if c.toolsCache != nil && shouldUseSharedToolsCache(c.config) { + if cacheErr := c.toolsCache.SetTools(c.config.Name, c.config.Name, c.config.BaseURL, discoveredTools, time.Now()); cacheErr != nil { + c.log.Warn("Failed to update tools cache after MCP reconnect", + "server", c.config.Name, + "userID", c.userID, + "error", cacheErr) + } } + c.log.Debug("Successfully reconnected to MCP server", "userID", c.userID, "server", c.config.Name) } // Retry the tool call after reconnecting diff --git a/mcp/client_embedded_oauth_test.go b/mcp/client_embedded_oauth_test.go index 8faacc45d..7a4f38a92 100644 --- a/mcp/client_embedded_oauth_test.go +++ b/mcp/client_embedded_oauth_test.go @@ -43,76 +43,6 @@ func TestEmbeddedCreateClientRequiresPluginAPIForSessionValidation(t *testing.T) require.EqualError(t, err, "plugin API is required when sessionID is provided") } -func TestEmbeddedToolListChangedInvalidatesCacheAndClientTools(t *testing.T) { - server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3") - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - cache := newTestToolsCache() - require.NoError(t, cache.SetTools(EmbeddedClientKey, EmbeddedServerName, EmbeddedClientKey, map[string]*mcp.Tool{ - "cached_tool": { - Name: "cached_tool", - Description: "Cached tool", - InputSchema: map[string]any{"type": "object"}, - }, - }, time.Now())) - - embeddedClient := NewEmbeddedServerClientWithCache(&fakeEmbeddedMCPServer{ctx: ctx, server: server}, newTestLogService(), nil, cache) - client, err := embeddedClient.CreateClient(context.Background(), "user-id", "") - require.NoError(t, err) - t.Cleanup(func() { _ = client.Close() }) - require.NotEmpty(t, client.Tools()) - require.NotNil(t, cache.GetTools(EmbeddedClientKey)) - - addTestMCPTool(server, "new_tool") - - require.Eventually(t, func() bool { - return len(client.Tools()) == 0 && cache.GetTools(EmbeddedClientKey) == nil - }, 5*time.Second, 10*time.Millisecond) -} - -func TestEmbeddedToolListChangedNextGetToolsForUserRediscoversTools(t *testing.T) { - server := newTestMCPServer(2, "tool_1", "tool_2") - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - cache := newTestToolsCache() - pluginAPI := newTestPluginAPIForEmbeddedManager("user-id", "session-id") - embeddedClient := NewEmbeddedServerClientWithCache(&fakeEmbeddedMCPServer{ctx: ctx, server: server}, pluginAPI.Log, pluginAPI, cache) - manager := &ClientManager{ - config: Config{ - EmbeddedServer: EmbeddedServerConfig{Enabled: true}, - }, - log: pluginAPI.Log, - pluginAPI: pluginAPI, - clients: make(map[string]*UserClients), - activity: make(map[string]time.Time), - embeddedClient: embeddedClient, - toolsCache: cache, - } - t.Cleanup(func() { cleanupTestClientManager(manager) }) - - tools, mcpErrors := manager.GetToolsForUser(context.Background(), "user-id") - require.Nil(t, mcpErrors) - requireToolNames(t, tools, "mattermost__tool_1", "mattermost__tool_2") - - addTestMCPTool(server, "new_tool") - - require.Eventually(t, func() bool { - manager.clientsMu.RLock() - userClient := manager.clients["user-id"] - manager.clientsMu.RUnlock() - if userClient == nil { - return false - } - client := userClient.clients[EmbeddedClientKey] - return client != nil && len(client.Tools()) == 0 - }, 5*time.Second, 10*time.Millisecond) - - tools, mcpErrors = manager.GetToolsForUser(context.Background(), "user-id") - require.Nil(t, mcpErrors) - requireToolNames(t, tools, "mattermost__new_tool", "mattermost__tool_1", "mattermost__tool_2") - require.Len(t, cache.GetTools(EmbeddedClientKey), 3) -} - func TestEmbeddedReconnectKeepsPaginatedDiscovery(t *testing.T) { server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3") ctx, cancel := context.WithCancel(context.Background()) @@ -132,12 +62,15 @@ func TestEmbeddedReconnectKeepsPaginatedDiscovery(t *testing.T) { require.Len(t, client.Tools(), 3) addTestMCPTool(server, "new_tool") - require.Eventually(t, func() bool { - return len(client.Tools()) == 0 - }, 5*time.Second, 10*time.Millisecond) + require.NoError(t, client.session.Close()) + result, err = client.CallTool(context.Background(), "tool_1", map[string]any{}) + require.NoError(t, err) + require.Contains(t, result, "tool_1 ok") + require.Contains(t, client.Tools(), "new_tool") + require.Len(t, client.Tools(), 4) } -func TestClientToolsReturnsCopyAndSurvivesConcurrentInvalidation(t *testing.T) { +func TestClientToolsReturnsCopyAndSurvivesConcurrentUpdate(t *testing.T) { client := &Client{ config: ServerConfig{Name: "server", BaseURL: "https://example.com"}, tools: map[string]*mcp.Tool{ @@ -159,7 +92,9 @@ func TestClientToolsReturnsCopyAndSurvivesConcurrentInvalidation(t *testing.T) { } }() for i := 0; i < 100; i++ { - client.invalidateDiscoveredTools(context.Background(), nil, "server", false) + client.toolsMu.Lock() + client.tools = make(map[string]*mcp.Tool) + client.toolsMu.Unlock() } <-done require.Empty(t, client.Tools()) diff --git a/mcp/client_test.go b/mcp/client_test.go index db07ca35d..a690c720b 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -229,15 +229,6 @@ func requireToolNames(t *testing.T, tools []llm.Tool, expectedNames ...string) { require.ElementsMatch(t, expectedNames, names) } -func cleanupTestClientManager(manager *ClientManager) { - manager.clientsMu.Lock() - defer manager.clientsMu.Unlock() - for _, userClient := range manager.clients { - userClient.Close() - } - manager.clients = make(map[string]*UserClients) -} - // TestCacheHitBehavior verifies that when tools are in cache, // they can be retrieved and reused correctly func TestCacheHitBehavior(t *testing.T) { @@ -380,6 +371,31 @@ func TestNewClientDiscoversPaginatedRemoteTools(t *testing.T) { } } +func TestRemoteReconnectRefreshesToolCatalog(t *testing.T) { + server := newTestMCPServer(0, "tool_1") + httpServer := startStreamableMCPServer(t, server) + cache := newTestToolsCache() + + client, err := NewClient(context.Background(), "user-id", ServerConfig{ + Name: "remote", + BaseURL: httpServer.URL, + Enabled: true, + }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), cache) + require.NoError(t, err) + t.Cleanup(func() { _ = client.Close() }) + require.Len(t, client.Tools(), 1) + + addTestMCPTool(server, "new_tool") + require.NoError(t, client.session.Close()) + + result, err := client.CallTool(context.Background(), "tool_1", map[string]any{}) + require.NoError(t, err) + require.Contains(t, result, "tool_1 ok") + require.Contains(t, client.Tools(), "new_tool") + require.Len(t, client.Tools(), 2) + require.Len(t, cache.GetTools("remote"), 2) +} + func TestNewClientUsesCacheWithoutPaginationCall(t *testing.T) { var listCalls atomic.Int32 server := newStaticToolListMCPServer(2, "server_tool") @@ -455,234 +471,3 @@ func TestNewClientErrorsOnEmptyRemoteToolCatalog(t *testing.T) { require.Contains(t, err.Error(), "no tools found") require.Nil(t, cache.GetTools("empty")) } - -func TestRemoteToolListChangedInvalidatesCacheAndClientTools(t *testing.T) { - server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3") - httpServer := startStreamableMCPServer(t, server) - cache := newTestToolsCache() - - client, err := NewClient(context.Background(), "user-id", ServerConfig{ - Name: "paged", - BaseURL: httpServer.URL, - Enabled: true, - }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), cache) - require.NoError(t, err) - t.Cleanup(func() { _ = client.Close() }) - require.NotEmpty(t, client.Tools()) - require.NotNil(t, cache.GetTools("paged")) - - addTestMCPTool(server, "new_tool") - - require.Eventually(t, func() bool { - return len(client.Tools()) == 0 && cache.GetTools("paged") == nil - }, 5*time.Second, 10*time.Millisecond) -} - -func TestRemoteToolListChangedNextGetToolsForUserRediscoversTools(t *testing.T) { - server := newTestMCPServer(2, "tool_1", "tool_2") - httpServer := startStreamableMCPServer(t, server) - cache := newTestToolsCache() - manager := &ClientManager{ - config: Config{ - Servers: []ServerConfig{{ - Name: "paged", - BaseURL: httpServer.URL, - Enabled: true, - }}, - }, - log: newTestLogService(), - clients: make(map[string]*UserClients), - activity: make(map[string]time.Time), - oauthManager: newTestOAuthManager(), - httpClient: httpServer.Client(), - toolsCache: cache, - } - t.Cleanup(func() { cleanupTestClientManager(manager) }) - - var tools []llm.Tool - var mcpErrors *Errors - require.Eventually(t, func() bool { - tools, mcpErrors = manager.GetToolsForUser(context.Background(), "user-id") - if mcpErrors != nil || len(cache.GetTools("paged")) != 2 { - return false - } - toolNames := make(map[string]bool, len(tools)) - for _, tool := range tools { - toolNames[tool.Name] = true - } - return len(tools) == 2 && toolNames["paged__tool_1"] && toolNames["paged__tool_2"] - }, 5*time.Second, 10*time.Millisecond) - require.Nil(t, mcpErrors) - requireToolNames(t, tools, "paged__tool_1", "paged__tool_2") - require.Len(t, cache.GetTools("paged"), 2) - - addTestMCPTool(server, "new_tool") - - require.Eventually(t, func() bool { - manager.clientsMu.RLock() - userClient := manager.clients["user-id"] - manager.clientsMu.RUnlock() - if userClient == nil { - return false - } - client := userClient.clients["paged"] - return client != nil && len(client.Tools()) == 0 && cache.GetTools("paged") == nil - }, 5*time.Second, 10*time.Millisecond) - - tools, mcpErrors = manager.GetToolsForUser(context.Background(), "user-id") - require.Nil(t, mcpErrors) - requireToolNames(t, tools, "paged__new_tool", "paged__tool_1", "paged__tool_2") - require.Len(t, cache.GetTools("paged"), 3) -} - -func TestToolListChangedDuringRediscoveryKeepsClientDirty(t *testing.T) { - listBlocked := make(chan struct{}) - releaseList := make(chan struct{}) - var blocked atomic.Bool - server := newTestMCPServer(0, "tool_1") - server.AddReceivingMiddleware(func(next mcp.MethodHandler) mcp.MethodHandler { - return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - result, err := next(ctx, method, req) - if err != nil || method != testListToolsMethod || !blocked.CompareAndSwap(false, true) { - return result, err - } - - close(listBlocked) - select { - case <-releaseList: - case <-ctx.Done(): - return nil, ctx.Err() - } - return result, nil - } - }) - session := connectInMemoryTestSession(t, server) - cache := newTestToolsCache() - client := &Client{ - session: session, - config: ServerConfig{Name: "server", BaseURL: "https://example.com"}, - tools: make(map[string]*mcp.Tool), - toolsDirty: true, - userID: "user-id", - log: newTestLogService(), - toolsCache: cache, - } - - errCh := make(chan error, 1) - go func() { - errCh <- client.ensureDiscoveredTools(context.Background()) - }() - - select { - case <-listBlocked: - case <-time.After(5 * time.Second): - t.Fatal("timed out waiting for rediscovery to enter tools/list") - } - - client.invalidateDiscoveredTools(context.Background(), cache, "server", true) - close(releaseList) - - select { - case err := <-errCh: - require.NoError(t, err) - case <-time.After(5 * time.Second): - t.Fatal("timed out waiting for rediscovery to finish") - } - - client.toolsMu.RLock() - require.True(t, client.toolsDirty) - client.toolsMu.RUnlock() - require.Empty(t, client.Tools()) - require.Nil(t, cache.GetTools("server")) - - addTestMCPTool(server, "tool_2") - require.NoError(t, client.ensureDiscoveredTools(context.Background())) - - client.toolsMu.RLock() - require.False(t, client.toolsDirty) - client.toolsMu.RUnlock() - require.Contains(t, client.Tools(), "tool_1") - require.Contains(t, client.Tools(), "tool_2") - require.Len(t, cache.GetTools("server"), 2) -} - -func TestRemoteToolListChangedWithNilCacheClearsClientTools(t *testing.T) { - server := newTestMCPServer(2, "tool_1", "tool_2", "tool_3") - httpServer := startStreamableMCPServer(t, server) - - client, err := NewClient(context.Background(), "user-id", ServerConfig{ - Name: "paged", - BaseURL: httpServer.URL, - Enabled: true, - }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), nil) - require.NoError(t, err) - t.Cleanup(func() { _ = client.Close() }) - require.NotEmpty(t, client.Tools()) - - addTestMCPTool(server, "new_tool") - - require.Eventually(t, func() bool { - return len(client.Tools()) == 0 - }, 5*time.Second, 10*time.Millisecond) -} - -func TestRemoteToolListChangedForStaticOAuthSkipsSharedCacheInvalidation(t *testing.T) { - server := newTestMCPServer(2, "server_tool") - httpServer := startStreamableMCPServer(t, server) - cache := newTestToolsCache() - require.NoError(t, cache.SetTools("oauth-server", "OAuth Server", httpServer.URL, map[string]*mcp.Tool{ - "cached_tool": { - Name: "cached_tool", - Description: "Cached tool", - InputSchema: map[string]any{"type": "object"}, - }, - }, time.Now())) - - client, err := NewClient(context.Background(), "user-id", ServerConfig{ - Name: "oauth-server", - BaseURL: httpServer.URL, - Enabled: true, - ClientID: "client-id", - ClientSecret: "client-secret", - }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), cache) - require.NoError(t, err) - t.Cleanup(func() { _ = client.Close() }) - require.Contains(t, client.Tools(), "server_tool") - require.NoError(t, cache.SetTools("oauth-server", "OAuth Server", httpServer.URL, map[string]*mcp.Tool{ - "cached_after_connect": { - Name: "cached_after_connect", - Description: "Cached after connect", - InputSchema: map[string]any{"type": "object"}, - }, - }, time.Now())) - require.NotNil(t, cache.GetTools("oauth-server")) - - addTestMCPTool(server, "new_tool") - - require.Eventually(t, func() bool { - return len(client.Tools()) == 0 - }, 5*time.Second, 10*time.Millisecond) - require.NotNil(t, cache.GetTools("oauth-server")) -} - -func TestRemoteToolListChangedNotificationStormIsIdempotent(t *testing.T) { - server := newTestMCPServer(2, "tool_1") - httpServer := startStreamableMCPServer(t, server) - cache := newTestToolsCache() - - client, err := NewClient(context.Background(), "user-id", ServerConfig{ - Name: "storm", - BaseURL: httpServer.URL, - Enabled: true, - }, newTestLogService(), newTestOAuthManager(), httpServer.Client(), cache) - require.NoError(t, err) - t.Cleanup(func() { _ = client.Close() }) - - addTestMCPTool(server, "storm_tool") - server.RemoveTools("storm_tool") - addTestMCPTool(server, "storm_tool") - - require.Eventually(t, func() bool { - return len(client.Tools()) == 0 && cache.GetTools("storm") == nil - }, 5*time.Second, 10*time.Millisecond) -} diff --git a/mcp/user_clients.go b/mcp/user_clients.go index d2ff35490..13a6f092e 100644 --- a/mcp/user_clients.go +++ b/mcp/user_clients.go @@ -224,15 +224,6 @@ func (c *UserClients) GetTools(ctx context.Context) []llm.Tool { for _, entry := range clientSnapshot { serverID := entry.serverID client := entry.client - if err := client.ensureDiscoveredTools(ctx); err != nil { - c.log.Warn("Failed to rediscover MCP tools after list_changed notification", - "userID", c.userID, - "serverID", serverID, - "server", client.config.Name, - "error", err) - continue - } - clientTools := client.Tools() serverSlug := dedupeMCPServerSlug(mcpServerSlug(serverID, client), client.config.BaseURL, serverID, usedSlugs) toolNames := make([]string, 0, len(clientTools)) @@ -347,8 +338,8 @@ func (c *UserClients) rememberOAuthNeededForToolCall(client *Client, err error) } // createToolResolver creates a resolver function for the given tool -func (c *UserClients) createToolResolver(client *Client, toolName string) func(llmContext *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { - return func(llmContext *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { +func (c *UserClients) createToolResolver(client *Client, toolName string) llm.ToolResolver { + return func(ctx context.Context, llmContext *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { var args map[string]any if err := argsGetter(&args); err != nil { return "", fmt.Errorf("failed to get arguments for tool %s: %w", toolName, err) @@ -356,12 +347,7 @@ func (c *UserClients) createToolResolver(client *Client, toolName string) func(l metadata := c.prepareToolCallMetadata(client, toolName, llmContext) - if llmContext == nil || llmContext.RequestContext == nil { - return "", errors.New("missing request context for MCP tool call") - } - callCtx := llmContext.RequestContext - - result, err := client.CallToolWithMetadata(callCtx, toolName, args, metadata) + result, err := client.CallToolWithMetadata(ctx, toolName, args, metadata) if err != nil { c.rememberOAuthNeededForToolCall(client, err) return result, err diff --git a/mcp/user_clients_test.go b/mcp/user_clients_test.go index c8104a4bb..81534afb9 100644 --- a/mcp/user_clients_test.go +++ b/mcp/user_clients_test.go @@ -112,7 +112,7 @@ func TestUserClientsGetToolsResolverUsesBareToolName(t *testing.T) { tools := userClients.GetTools(context.Background()) requireToolNames(t, tools, "jira__search") - result, err := tools[0].Resolver(&llm.Context{RequestContext: context.Background()}, func(args any) error { + result, err := tools[0].Resolver(context.Background(), &llm.Context{}, func(args any) error { *(args.(*map[string]any)) = map[string]any{} return nil }) @@ -151,16 +151,21 @@ func TestUserClientsGetToolsDeterministicSlugCollision(t *testing.T) { requireToolNames(t, second, "jira__search", expectedDedupedName) } -func TestUserClientsGetToolsPreservesRediscoveryBeforeRead(t *testing.T) { +func TestUserClientsGetToolsUsesCachedCatalog(t *testing.T) { server := newTestMCPServer(0, "old_tool") session := connectInMemoryTestSession(t, server) + addTestMCPTool(server, "new_tool") client := &Client{ - session: session, - config: ServerConfig{Name: "Jira", BaseURL: "https://mcp.atlassian.com", Enabled: true}, - tools: make(map[string]*gomcp.Tool), - toolsDirty: true, - userID: "user-id", - log: newTestLogService(), + session: session, + config: ServerConfig{Name: "Jira", BaseURL: "https://mcp.atlassian.com", Enabled: true}, + tools: map[string]*gomcp.Tool{ + "old_tool": { + Name: "old_tool", + Description: "Old tool", + }, + }, + userID: "user-id", + log: newTestLogService(), } userClients := &UserClients{ userID: "user-id", @@ -170,16 +175,9 @@ func TestUserClientsGetToolsPreservesRediscoveryBeforeRead(t *testing.T) { }, } - addTestMCPTool(server, "new_tool") - require.NoError(t, client.ensureDiscoveredTools(context.Background())) - client.toolsMu.Lock() - client.toolsDirty = true - client.tools = make(map[string]*gomcp.Tool) - client.toolsMu.Unlock() - tools := userClients.GetTools(context.Background()) - requireToolNames(t, tools, "jira__new_tool", "jira__old_tool") + requireToolNames(t, tools, "jira__old_tool") } func TestConnectToPluginServer_HappyPath(t *testing.T) { @@ -262,7 +260,7 @@ func TestConnectToEmbeddedServerIfAvailable_Idempotent(t *testing.T) { require.Same(t, firstClient, secondSnapshot[0].client) } -func TestUserClientsGetToolsResolverUsesRequestContext(t *testing.T) { +func TestUserClientsGetToolsResolverUsesResolverContext(t *testing.T) { callCtx, cancel := context.WithCancel(context.Background()) cancel() @@ -287,7 +285,7 @@ func TestUserClientsGetToolsResolverUsesRequestContext(t *testing.T) { tools := userClients.GetTools(context.Background()) require.Len(t, tools, 1) - _, err := tools[0].Resolver(&llm.Context{RequestContext: callCtx}, func(args any) error { + _, err := tools[0].Resolver(callCtx, &llm.Context{}, func(args any) error { *(args.(*map[string]any)) = map[string]any{} return nil }) @@ -295,7 +293,7 @@ func TestUserClientsGetToolsResolverUsesRequestContext(t *testing.T) { require.ErrorIs(t, err, context.Canceled) } -func TestUserClientsGetToolsResolverRequiresRequestContext(t *testing.T) { +func TestUserClientsGetToolsResolverDoesNotRequireRequestContext(t *testing.T) { server := newTestMCPServer(0, "search") session := connectInMemoryTestSession(t, server) userClients := &UserClients{ @@ -317,11 +315,13 @@ func TestUserClientsGetToolsResolverRequiresRequestContext(t *testing.T) { tools := userClients.GetTools(context.Background()) require.Len(t, tools, 1) - _, err := tools[0].Resolver(&llm.Context{}, func(args any) error { + result, err := tools[0].Resolver(context.Background(), &llm.Context{}, func(args any) error { *(args.(*map[string]any)) = map[string]any{} return nil }) - require.EqualError(t, err, "missing request context for MCP tool call") + + require.NoError(t, err) + require.Equal(t, "search ok\n", result) } func TestPrepareToolCallMetadata_EmbeddedMergesCallMetadataAndBotUserID(t *testing.T) { diff --git a/mcpserver/eval_helpers_test.go b/mcpserver/eval_helpers_test.go index fef0ad6d6..30d112cb2 100644 --- a/mcpserver/eval_helpers_test.go +++ b/mcpserver/eval_helpers_test.go @@ -588,8 +588,7 @@ func setupAgenticEval(t *testing.T, e *evals.EvalT, suite *TestSuite, requesting } // mcpToolsToLLMTools converts MCP server tools into llm.Tool instances with resolvers -// that call through the MCP protocol. This mirrors the production pattern in -// mcp/user_clients.go:201 (createToolResolver). +// that call through the MCP protocol. func mcpToolsToLLMTools(t *testing.T, mcpServer *gomcp.Server) []llm.Tool { t.Helper() @@ -606,8 +605,7 @@ func mcpToolsToLLMTools(t *testing.T, mcpServer *gomcp.Server) []llm.Tool { Name: tool.Name, Description: tool.Description, Schema: tool.InputSchema, - Resolver: func(_ *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { - // Same pattern as production createToolResolver (mcp/user_clients.go:201) + Resolver: func(ctx context.Context, _ *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { var args map[string]any if err := argsGetter(&args); err != nil { return "", err diff --git a/mmtools/web_search.go b/mmtools/web_search.go index e752e10ca..8a18a195f 100644 --- a/mmtools/web_search.go +++ b/mmtools/web_search.go @@ -209,14 +209,14 @@ func (s *webSearchService) SourceTool(bot *bots.Bot) *llm.Tool { } t := *s.sourceTool - t.Resolver = func(ctx *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { - return s.resolveSource(bot, ctx, argsGetter) + t.Resolver = func(ctx context.Context, llmCtx *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { + return s.resolveSource(ctx, bot, llmCtx, argsGetter) } return &t } -func (s *webSearchService) resolve(llmContext *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { +func (s *webSearchService) resolve(ctx context.Context, llmContext *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { var args WebSearchToolArgs if err := argsGetter(&args); err != nil { return "invalid parameters to function", fmt.Errorf("failed to get arguments for WebSearch tool: %w", err) @@ -301,7 +301,7 @@ func (s *webSearchService) resolve(llmContext *llm.Context, argsGetter llm.ToolA } // Perform the search - searchResp, err := provider.Search(context.Background(), query, resultLimit) + searchResp, err := provider.Search(ctx, query, resultLimit) if err != nil { return "unable to perform web search", err } @@ -432,7 +432,7 @@ func (s *webSearchService) resolve(llmContext *llm.Context, argsGetter llm.ToolA return builder.String(), nil } -func (s *webSearchService) resolveSource(bot *bots.Bot, llmContext *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { +func (s *webSearchService) resolveSource(ctx context.Context, bot *bots.Bot, llmContext *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { var args WebSearchSourceArgs if err := argsGetter(&args); err != nil { return "invalid parameters to function", fmt.Errorf("failed to get arguments for WebSearchFetchSource tool: %w", err) @@ -503,7 +503,7 @@ func (s *webSearchService) resolveSource(bot *bots.Bot, llmContext *llm.Context, return "web search is not properly configured", errors.New("web search http client is not configured") } - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, pageURL, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, pageURL, nil) if err != nil { s.logError("failed to create source fetch request", "error", err) return "unable to create request", err @@ -557,7 +557,7 @@ func (s *webSearchService) resolveSource(bot *bots.Bot, llmContext *llm.Context, } // Perform recursive summarization - summary, err := s.summarizeContent(bot, textContent) + summary, err := s.summarizeContent(ctx, bot, textContent) if err != nil { s.logWarn("recursive summarization failed, falling back to raw content with warnings", "error", err) return s.wrapSourceContentWithContext(textContent, matchedResult, llmContext), nil @@ -566,7 +566,7 @@ func (s *webSearchService) resolveSource(bot *bots.Bot, llmContext *llm.Context, return s.formatSummarizedContent(summary, matchedResult), nil } -func (s *webSearchService) summarizeContent(bot *bots.Bot, content string) (string, error) { +func (s *webSearchService) summarizeContent(ctx context.Context, bot *bots.Bot, content string) (string, error) { if bot == nil { return "", errors.New("bot instance is nil") } @@ -606,7 +606,7 @@ func (s *webSearchService) summarizeContent(bot *bots.Bot, content string) (stri } // Use a reasonable token limit for the summary (e.g. 4000 tokens) - return languageModel.ChatCompletionNoStream(context.Background(), req, llm.WithMaxGeneratedTokens(4000)) + return languageModel.ChatCompletionNoStream(ctx, req, llm.WithMaxGeneratedTokens(4000)) } func (s *webSearchService) formatSummarizedContent(summary string, matchedResult *WebSearchResult) string { diff --git a/mmtools/web_search_test.go b/mmtools/web_search_test.go index 8b39041e7..ca8030c2e 100644 --- a/mmtools/web_search_test.go +++ b/mmtools/web_search_test.go @@ -594,7 +594,7 @@ func TestWebSearchSourceWhitelist(t *testing.T) { } // Should succeed (return content) - resp, err := service.resolveSource(mockBot, ctx, argsGetter) + resp, err := service.resolveSource(context.Background(), mockBot, ctx, argsGetter) require.NoError(t, err) require.Contains(t, resp, "Summarized content") }) @@ -614,7 +614,7 @@ func TestWebSearchSourceWhitelist(t *testing.T) { return nil } - resp, err := service.resolveSource(mockBot, ctx, argsGetter) + resp, err := service.resolveSource(context.Background(), mockBot, ctx, argsGetter) require.Error(t, err) require.Equal(t, "url not in whitelist", err.Error()) require.Contains(t, resp, "you can only fetch URLs that were returned from web search results") @@ -633,7 +633,7 @@ func TestWebSearchSourceWhitelist(t *testing.T) { return nil } - resp, err := service.resolveSource(mockBot, ctx, argsGetter) + resp, err := service.resolveSource(context.Background(), mockBot, ctx, argsGetter) require.Error(t, err) require.Equal(t, "no whitelist in context", err.Error()) require.Contains(t, resp, "you can only fetch URLs that were returned from web search results") diff --git a/prompts/standard_personality_without_locale_test.go b/prompts/standard_personality_without_locale_test.go index a3b16ce90..eed6e391a 100644 --- a/prompts/standard_personality_without_locale_test.go +++ b/prompts/standard_personality_without_locale_test.go @@ -4,6 +4,7 @@ package prompts_test import ( + "context" "fmt" "regexp" "testing" @@ -31,7 +32,7 @@ func TestStandardPersonalityWithoutLocaleWhitespaceGating(t *testing.T) { store.AddTools([]llm.Tool{{ Name: name, Description: "test tool", - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return "", nil }, }}) @@ -127,14 +128,14 @@ func TestStandardPersonalityWithoutLocaleListsAvailableToolsForGeminiAndVertexOn { Name: "search_users", Description: "Look up users by name", - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return "", nil }, }, { Name: "read_channel", Description: "Read channel history", - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return "", nil }, }, diff --git a/telemetry/integration_test.go b/telemetry/integration_test.go index 2d78b56f0..5ed43559d 100644 --- a/telemetry/integration_test.go +++ b/telemetry/integration_test.go @@ -257,7 +257,7 @@ func TestToolResolveSpan(t *testing.T) { { Name: "test_tool", Description: "A test tool", - Resolver: func(_ *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { return "tool result", nil }, }, @@ -384,7 +384,7 @@ func TestFullRequestTrace(t *testing.T) { store.AddTools([]llm.Tool{ { Name: "web_search", - Resolver: func(_ *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) { var args struct { Query string `json:"query"` } diff --git a/toolrunner/toolrunner_test.go b/toolrunner/toolrunner_test.go index 7e9c90ce0..2babe3b8c 100644 --- a/toolrunner/toolrunner_test.go +++ b/toolrunner/toolrunner_test.go @@ -83,7 +83,7 @@ func newTestToolStore(tools ...testToolDef) *llm.ToolStore { Name: t.name, Description: "test tool", ServerOrigin: t.serverOrigin, - Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { + Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) { return result, toolErr }, } From c60dec5c0168c36304956cc6b5f8ba6889726dab Mon Sep 17 00:00:00 2001 From: Nick Misasi Date: Fri, 29 May 2026 11:31:24 -0400 Subject: [PATCH 7/7] dynamic mcp: thread context into tool discovery Co-authored-by: Cursor --- api/api_channel.go | 3 +-- api/api_llm_bridge.go | 6 ++---- conversations/handle_messages.go | 6 ++---- conversations/regeneration.go | 3 +-- conversations/tool_approval.go | 6 ++---- llm/context.go | 5 ----- llmcontext/llm_context.go | 28 +++++++++------------------- llmcontext/llm_context_test.go | 6 ++---- mcp/user_clients_test.go | 2 +- 9 files changed, 20 insertions(+), 45 deletions(-) diff --git a/api/api_channel.go b/api/api_channel.go index 28b0b8e3c..15c85ec7f 100644 --- a/api/api_channel.go +++ b/api/api_channel.go @@ -88,8 +88,7 @@ func (a *API) handleChannelAnalysis(c *gin.Context) { } opts := []llm.ContextOption{ - a.contextBuilder.WithLLMContextRequestContext(c.Request.Context()), - a.contextBuilder.WithLLMContextDefaultTools(toolBot), + a.contextBuilder.WithLLMContextDefaultTools(c.Request.Context(), toolBot), } // If the channel is a DM/GM and we have a team ID from the client, use it for context diff --git a/api/api_llm_bridge.go b/api/api_llm_bridge.go index 37fbe8216..341b13773 100644 --- a/api/api_llm_bridge.go +++ b/api/api_llm_bridge.go @@ -171,10 +171,9 @@ func (a *API) convertAgentBridgeRequestToInternal(ctx stdcontext.Context, bot *b } bridgeContext := llm.NewContext() - bridgeContext.RequestContext = ctx bridgeContext.RequestingUser = &model.User{Id: req.UserID} if includeTools && a.contextBuilder != nil { - a.contextBuilder.WithLLMContextTools(bot)(bridgeContext) + a.contextBuilder.WithLLMContextTools(ctx, bot)(bridgeContext) } resolvedOperation := operation @@ -690,10 +689,9 @@ func (a *API) handleGetAgentTools(c *gin.Context) { // Build a minimal context just to resolve the bot's available tools. toolContext := llm.NewContext() - toolContext.RequestContext = c.Request.Context() toolContext.RequestingUser = &model.User{Id: userID} if a.contextBuilder != nil { - a.contextBuilder.WithLLMContextTools(bot)(toolContext) + a.contextBuilder.WithLLMContextTools(c.Request.Context(), bot)(toolContext) } var tools []bridgeclient.BridgeToolInfo diff --git a/conversations/handle_messages.go b/conversations/handle_messages.go index 43c811e5d..6dab77941 100644 --- a/conversations/handle_messages.go +++ b/conversations/handle_messages.go @@ -194,8 +194,7 @@ func (c *Conversations) handleMentionViaConversation( responseRootID string, ) error { contextOpts := []llm.ContextOption{ - c.contextBuilder.WithLLMContextRequestContext(ctx), - c.contextBuilder.WithLLMContextTools(bot), + c.contextBuilder.WithLLMContextTools(ctx, bot), } llmContext := c.contextBuilder.BuildLLMContextUserRequest(bot, postingUser, channel, contextOpts...) @@ -340,8 +339,7 @@ func (c *Conversations) handleDMs(ctx context.Context, bot *bots.Bot, channel *m // handleDMViaConversation processes a DM message using the conversation entity model. func (c *Conversations) handleDMViaConversation(ctx context.Context, bot *bots.Bot, channel *model.Channel, postingUser *model.User, post *model.Post) error { contextOpts := []llm.ContextOption{ - c.contextBuilder.WithLLMContextRequestContext(ctx), - c.contextBuilder.WithLLMContextTools(bot), + c.contextBuilder.WithLLMContextTools(ctx, bot), } webSearchParams := c.extractWebSearchContext(post) if len(webSearchParams) > 0 { diff --git a/conversations/regeneration.go b/conversations/regeneration.go index 830e83092..fa91e0c5c 100644 --- a/conversations/regeneration.go +++ b/conversations/regeneration.go @@ -251,8 +251,7 @@ func (c *Conversations) regenerateViaConversation( } contextOpts := []llm.ContextOption{ - c.contextBuilder.WithLLMContextRequestContext(ctx), - c.contextBuilder.WithLLMContextDefaultTools(bot), + c.contextBuilder.WithLLMContextDefaultTools(ctx, bot), } llmContext := c.contextBuilder.BuildLLMContextUserRequest(bot, user, channel, contextOpts...) diff --git a/conversations/tool_approval.go b/conversations/tool_approval.go index 9de24e94c..9fca3283d 100644 --- a/conversations/tool_approval.go +++ b/conversations/tool_approval.go @@ -98,8 +98,7 @@ func (c *Conversations) HandleToolCall(ctx context.Context, userID string, post // Build LLM context with tools for execution. contextOpts := []llm.ContextOption{ - c.contextBuilder.WithLLMContextRequestContext(ctx), - c.contextBuilder.WithLLMContextDefaultTools(bot), + c.contextBuilder.WithLLMContextDefaultTools(ctx, bot), } llmContext := c.contextBuilder.BuildLLMContextUserRequest(bot, user, channel, contextOpts...) @@ -423,8 +422,7 @@ func (c *Conversations) streamToolFollowUp( defer span.End() contextOpts := []llm.ContextOption{ - c.contextBuilder.WithLLMContextRequestContext(ctx), - c.contextBuilder.WithLLMContextDefaultTools(bot), + c.contextBuilder.WithLLMContextDefaultTools(ctx, bot), } llmContext := c.contextBuilder.BuildLLMContextUserRequest(bot, user, channel, contextOpts...) diff --git a/llm/context.go b/llm/context.go index 1f3b2dfb9..96495973c 100644 --- a/llm/context.go +++ b/llm/context.go @@ -4,7 +4,6 @@ package llm import ( - stdcontext "context" "fmt" "strings" "time" @@ -37,10 +36,6 @@ type Context struct { // User that is making the request RequestingUser *model.User - // RequestContext carries the caller's request-scoped context for downstream - // work such as MCP tool discovery. May be nil in tests. - RequestContext stdcontext.Context - // ConversationID identifies the conversation whose context is being built. ConversationID string diff --git a/llmcontext/llm_context.go b/llmcontext/llm_context.go index f5e9e0c37..1bd960bfa 100644 --- a/llmcontext/llm_context.go +++ b/llmcontext/llm_context.go @@ -167,9 +167,8 @@ func sanitizeUserProfileField(s string) string { // WithLLMContextSessionID removed: embedded MCP manages its own session lifecycle -// getToolsStoreForUser returns a tool store for a specific user, including MCP tools -// Session information is extracted from the llm.Context -func (b *Builder) getToolsStoreForUser(c *llm.Context, bot *bots.Bot, userID string) *llm.ToolStore { +// getToolsStoreForUser returns a tool store for a specific user, including MCP tools. +func (b *Builder) getToolsStoreForUser(ctx stdcontext.Context, bot *bots.Bot, userID string) *llm.ToolStore { // Check for nil bot, which is unexpected if bot == nil { b.pluginAPI.Log.Error("Unexpected nil bot when getting tool store for user", "userID", userID) @@ -198,13 +197,13 @@ func (b *Builder) getToolsStoreForUser(c *llm.Context, bot *bots.Bot, userID str // so that GetToolsInfo() can inform the LLM about their availability. // Actual execution is controlled via WithToolsDisabled() based on channel type. if b.mcpToolProvider != nil { - if c.RequestContext == nil { - b.pluginAPI.Log.Error("Cannot add MCP tools to context: RequestContext is nil", "userID", userID) + if ctx == nil { + b.pluginAPI.Log.Error("Cannot add MCP tools to context: request context is nil", "userID", userID) return store } // Get tools from all connected servers - mcpTools, mcpErrors := b.mcpToolProvider.GetToolsForUser(c.RequestContext, userID) + mcpTools, mcpErrors := b.mcpToolProvider.GetToolsForUser(ctx, userID) // Add tools from successfully connected servers even if some had errors // These will be disabled in non-DM channels via WithToolsDisabled() @@ -238,29 +237,20 @@ func (b *Builder) getToolsStoreForUser(c *llm.Context, bot *bots.Bot, userID str // WithLLMContextTools adds tools to the LLM context the requester can access. // Tools are always added for LLM awareness; execution is controlled via WithToolsDisabled() // based on the context (e.g., DM vs channel). -func (b *Builder) WithLLMContextTools(bot *bots.Bot) llm.ContextOption { +func (b *Builder) WithLLMContextTools(ctx stdcontext.Context, bot *bots.Bot) llm.ContextOption { return func(c *llm.Context) { if c.RequestingUser == nil { b.pluginAPI.Log.Error("Cannot add tools to context: RequestingUser is nil") return } - // Get tools using session info from llm.Context - c.Tools = b.getToolsStoreForUser(c, bot, c.RequestingUser.Id) + c.Tools = b.getToolsStoreForUser(ctx, bot, c.RequestingUser.Id) } } // WithLLMContextDefaultTools adds default tools to the LLM context for the requesting user -func (b *Builder) WithLLMContextDefaultTools(bot *bots.Bot) llm.ContextOption { - return b.WithLLMContextTools(bot) -} - -// WithLLMContextRequestContext threads request-scoped cancellation/deadlines into -// MCP tool discovery. -func (b *Builder) WithLLMContextRequestContext(ctx stdcontext.Context) llm.ContextOption { - return func(c *llm.Context) { - c.RequestContext = ctx - } +func (b *Builder) WithLLMContextDefaultTools(ctx stdcontext.Context, bot *bots.Bot) llm.ContextOption { + return b.WithLLMContextTools(ctx, bot) } // WithLLMContextNoTools explicitly disables tools for this context session only, diff --git a/llmcontext/llm_context_test.go b/llmcontext/llm_context_test.go index 3c587378a..98da5d839 100644 --- a/llmcontext/llm_context_test.go +++ b/llmcontext/llm_context_test.go @@ -87,8 +87,7 @@ func TestWithLLMContextDefaultToolsCallsMCPProvider(t *testing.T) { newTestBot(), user, channel, - builder.WithLLMContextRequestContext(stdcontext.Background()), - builder.WithLLMContextDefaultTools(newTestBot()), + builder.WithLLMContextDefaultTools(stdcontext.Background(), newTestBot()), ) require.Equal(t, 1, mcpProvider.calls) @@ -163,8 +162,7 @@ func TestWithLLMContextDefaultToolsRetainsAuthErrorsForWildcardAllowlist(t *test bot, user, channel, - builder.WithLLMContextRequestContext(stdcontext.Background()), - builder.WithLLMContextDefaultTools(bot), + builder.WithLLMContextDefaultTools(stdcontext.Background(), bot), ) require.Empty(t, context.Tools.GetTools()) diff --git a/mcp/user_clients_test.go b/mcp/user_clients_test.go index 81534afb9..381574fd9 100644 --- a/mcp/user_clients_test.go +++ b/mcp/user_clients_test.go @@ -293,7 +293,7 @@ func TestUserClientsGetToolsResolverUsesResolverContext(t *testing.T) { require.ErrorIs(t, err, context.Canceled) } -func TestUserClientsGetToolsResolverDoesNotRequireRequestContext(t *testing.T) { +func TestUserClientsGetToolsResolverWorksWithEmptyLLMContext(t *testing.T) { server := newTestMCPServer(0, "search") session := connectInMemoryTestSession(t, server) userClients := &UserClients{