Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions config/mcp_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,31 @@
package config

const (
MCPToolPolicyAsk = "ask"
MCPToolPolicyAutoRun = "auto_run"
MCPToolPolicyAsk = "ask"
MCPToolPolicyAutoRun = "auto_run"
MCPToolPolicyAutoRunEverywhere = "auto_run_everywhere"
)

// MCPToolConfig represents per-tool configuration for an MCP server.
type MCPToolConfig struct {
Name string `json:"name"`
Policy string `json:"policy"` // "auto_run" | "ask"
Policy string `json:"policy"` // "auto_run" | "auto_run_everywhere" | "ask"
Enabled bool `json:"enabled"`
}

// IsToolPolicyAutoRun returns true when the policy allows automatic execution in at least
// one context. The legacy "auto_run" policy remains DM-only for full completion, while
// "auto_run_everywhere" also bypasses channel result sharing.
func IsToolPolicyAutoRun(policy string) bool {
return policy == MCPToolPolicyAutoRun || policy == MCPToolPolicyAutoRunEverywhere
}

// IsToolPolicyAutoRunEverywhere returns true only for policies that should run to
// completion without any additional approval regardless of conversation context.
func IsToolPolicyAutoRunEverywhere(policy string) bool {
return policy == MCPToolPolicyAutoRunEverywhere
}

// MCPEmbeddedServerConfig contains configuration for the embedded MCP server
type MCPEmbeddedServerConfig struct {
Enabled bool `json:"enabled"`
Expand Down Expand Up @@ -71,15 +85,15 @@ func (s *MCPServerConfig) GetToolPolicy(toolName string) (string, bool) {
return MCPToolPolicyAsk, true
}

if policy != MCPToolPolicyAutoRun && policy != MCPToolPolicyAsk {
if !IsToolPolicyAutoRun(policy) && policy != MCPToolPolicyAsk {
policy = MCPToolPolicyAsk
}

return policy, enabled
}

// IsToolAutoRun returns true only when the tool has policy "auto_run" and is enabled.
// IsToolAutoRun returns true when the tool is enabled and configured for any auto-run mode.
func (s *MCPServerConfig) IsToolAutoRun(toolName string) bool {
policy, enabled := s.GetToolPolicy(toolName)
return policy == MCPToolPolicyAutoRun && enabled
return IsToolPolicyAutoRun(policy) && enabled
}
19 changes: 13 additions & 6 deletions conversations/conversations.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ func (c *Conversations) SetToolPolicyChecker(checker streaming.ToolPolicyChecker
c.toolPolicyChecker = checker
}

func (c *Conversations) isToolAutoRunnable(serverOrigin, toolName string) bool {
if c.toolPolicyChecker == nil {
return false
}

policy, enabled := c.toolPolicyChecker.GetToolPolicy(serverOrigin, toolName)
return mcp.IsToolPolicyAutoRun(policy) && enabled
}

func (c *Conversations) appendDMAutoRunOptions(isDM bool, llmContext *llm.Context, opts []llm.LanguageModelOption) []llm.LanguageModelOption {
if !isDM || c.toolPolicyChecker == nil || llmContext == nil || llmContext.Tools == nil {
return opts
Expand All @@ -112,8 +121,7 @@ func (c *Conversations) appendDMAutoRunOptions(isDM bool, llmContext *llm.Contex
allTools := llmContext.Tools.GetTools()
var autoRunNames []string
for _, t := range allTools {
policy, enabled := c.toolPolicyChecker.GetToolPolicy(t.ServerOrigin, t.Name)
if policy == mcp.ToolPolicyAutoRun && enabled {
if c.isToolAutoRunnable(t.ServerOrigin, t.Name) {
autoRunNames = append(autoRunNames, llm.ToolAutoRunKey(t.ServerOrigin, t.Name))
}
}
Expand Down Expand Up @@ -202,10 +210,9 @@ func (c *Conversations) ProcessUserRequestWithContext(bot *bots.Bot, postingUser
result = mmtools.DecorateStreamWithAnnotations(result, webSearchData, nil)
}

// Wrap stream with MCP auto-approval when tools are active (DM or channel with
// tool calling enabled). DMs pass allowToolsInChannel=false but toolsDisabled is
// false, so we key off toolsDisabled rather than allowToolsInChannel alone.
if !toolsDisabled && context != nil && context.Tools != nil && c.toolPolicyChecker != nil {
// Wrap stream with MCP auto-approval only for channels. DMs use the model-level
// auto-run wrapper via WithAutoRunTools and should not be pre-executed twice.
if !isDM && !toolsDisabled && context != nil && context.Tools != nil && c.toolPolicyChecker != nil {
result = wrapStreamWithMCPAutoApproval(result, context, c.toolPolicyChecker)
}

Expand Down
4 changes: 2 additions & 2 deletions conversations/mcp_auto_approval.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

// wrapStreamWithMCPAutoApproval wraps a text stream to automatically execute
// MCP tool calls whose per-tool policy is ToolPolicyAutoRun + enabled.
// MCP tool calls whose per-tool policy satisfies mcp.IsToolPolicyAutoRun + enabled.
//
// When ALL tool calls in a batch are auto-runnable, the wrapper:
// 1. Executes each tool via the ToolStore
Expand Down Expand Up @@ -57,7 +57,7 @@ func wrapStreamWithMCPAutoApproval(
toolCalls[i].ServerOrigin = tool.ServerOrigin
}
policy, enabled := policyChecker.GetToolPolicy(toolCalls[i].ServerOrigin, toolCalls[i].Name)
if policy != mcp.ToolPolicyAutoRun || !enabled {
if !mcp.IsToolPolicyAutoRun(policy) || !enabled {
allAutoRun = false
}
}
Expand Down
36 changes: 33 additions & 3 deletions conversations/tool_handling.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ func (c *Conversations) HandleToolCall(userID string, post *model.Post, channel
}

if !isDM {
post.DelProp(streaming.AutoShareToolResultProp)
hasReviewableResult := slices.ContainsFunc(tools, func(tc llm.ToolCall) bool {
return tc.Status == llm.ToolCallStatusSuccess || tc.Status == llm.ToolCallStatusError
})
Expand All @@ -230,6 +231,7 @@ func (c *Conversations) HandleToolCall(userID string, post *model.Post, channel
post.AddProp(streaming.ToolCallProp, string(resolvedToolsJSON))
post.AddProp(streaming.ToolCallRedactedProp, "true")
post.DelProp(streaming.PendingToolResultProp)
post.DelProp(streaming.AutoShareToolResultProp)
if updateErr := c.mmClient.UpdatePost(post); updateErr != nil {
return fmt.Errorf("failed to update post with tool call results: %w", updateErr)
}
Expand All @@ -252,6 +254,7 @@ func (c *Conversations) HandleToolCall(userID string, post *model.Post, channel
post.AddProp(streaming.ToolCallProp, string(resolvedToolsJSON))
post.AddProp(streaming.ToolCallRedactedProp, "true")
post.AddProp(streaming.PendingToolResultProp, "true")
post.DelProp(streaming.AutoShareToolResultProp)
// Persist web search context so HandleToolResult and subsequent messages can find it
if params := llmContext.Parameters; len(params) > 0 {
if _, hasWebSearch := params[mmtools.WebSearchContextKey]; hasWebSearch {
Expand Down Expand Up @@ -279,6 +282,7 @@ func (c *Conversations) HandleToolCall(userID string, post *model.Post, channel
return fmt.Errorf("failed to marshal tool call results: %w", err)
}
post.AddProp(streaming.ToolCallProp, string(resolvedToolsJSON))
post.DelProp(streaming.AutoShareToolResultProp)

// Persist web search context if it exists (so it's available for subsequent tool calls)
if webSearchParams := llmContext.Parameters; len(webSearchParams) > 0 {
Expand Down Expand Up @@ -368,6 +372,7 @@ func (c *Conversations) HandleToolResult(userID string, post *model.Post, channe
post.AddProp(streaming.ToolCallProp, string(redactedToolsJSON))
post.AddProp(streaming.ToolCallRedactedProp, "true")
post.DelProp(streaming.PendingToolResultProp)
post.DelProp(streaming.AutoShareToolResultProp)
if updateErr := c.mmClient.UpdatePost(post); updateErr != nil {
return fmt.Errorf("failed to update post after tool result rejection: %w", updateErr)
}
Expand Down Expand Up @@ -418,6 +423,7 @@ func (c *Conversations) HandleToolResult(userID string, post *model.Post, channe
post.AddProp(streaming.ToolCallProp, string(resolvedToolsJSON))
post.DelProp(streaming.ToolCallRedactedProp)
post.DelProp(streaming.PendingToolResultProp)
post.DelProp(streaming.AutoShareToolResultProp)
// Persist web search context so subsequent messages in the thread preserve citations
if params := llmContext.Parameters; len(params) > 0 {
if _, hasWebSearch := params[mmtools.WebSearchContextKey]; hasWebSearch {
Expand Down Expand Up @@ -500,9 +506,9 @@ func (c *Conversations) completeAndStreamToolResponse(
result = mmtools.DecorateStreamWithAnnotations(result, webSearchData, nil)
}

// Same MCP auto_run execution as ProcessUserRequestWithContext (DMs use
// allowToolsInChannel=false; toolsDisabled reflects whether tools are off).
if !toolsDisabled && llmContext != nil && llmContext.Tools != nil && c.toolPolicyChecker != nil {
// Same channel-only MCP auto-approval as ProcessUserRequestWithContext. DMs
// use the model-level auto-run wrapper via WithAutoRunTools.
if !mmapi.IsDMWith(bot.GetMMBot().UserId, channel) && !toolsDisabled && llmContext != nil && llmContext.Tools != nil && c.toolPolicyChecker != nil {
result = wrapStreamWithMCPAutoApproval(result, llmContext, c.toolPolicyChecker)
}

Expand All @@ -528,6 +534,7 @@ func (c *Conversations) AutoExecuteApprovedToolCalls(postID string, requesterID
c.mmClient.LogError("Auto-execute: failed to get post", "error", err, "post_id", postID)
return
}
autoShareResults := post.GetProp(streaming.AutoShareToolResultProp) != nil

channel, err := c.mmClient.GetChannel(post.ChannelId)
if err != nil {
Expand All @@ -537,6 +544,29 @@ func (c *Conversations) AutoExecuteApprovedToolCalls(postID string, requesterID

if err := c.HandleToolCall(requesterID, post, channel, approvedToolIDs); err != nil {
c.mmClient.LogError("Auto-execute: HandleToolCall failed", "error", err, "post_id", postID)
return
}

if !autoShareResults {
return
}

var tools []llm.ToolCall
resultKVKey := streaming.ToolResultPrivateKVKey(postID, requesterID)
if kvErr := c.mmClient.KVGet(resultKVKey, &tools); kvErr != nil {
c.mmClient.LogError("Auto-execute: failed to load tool results for auto-share", "error", kvErr, "post_id", postID, "kv_key", resultKVKey)
return
}

approvedResultIDs := make([]string, 0, len(tools))
for _, tool := range tools {
if tool.Status == llm.ToolCallStatusSuccess || tool.Status == llm.ToolCallStatusError {
approvedResultIDs = append(approvedResultIDs, tool.ID)
}
}

if err := c.HandleToolResult(requesterID, post, channel, approvedResultIDs); err != nil {
c.mmClient.LogError("Auto-execute: HandleToolResult failed", "error", err, "post_id", postID)
}
}

Expand Down
116 changes: 114 additions & 2 deletions conversations/tool_handling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ func TestAutoExecuteApprovedToolCalls(t *testing.T) {
requesterID = "requester-id"
)

t.Run("happy path - all tools execute successfully", func(t *testing.T) {
t.Run("auto approved channel tools still require result review by default", func(t *testing.T) {
mockAPI := &plugintest.API{}
client := pluginapi.NewClient(mockAPI, nil)
licenseChecker := enterprise.NewLicenseChecker(client)
Expand Down Expand Up @@ -857,7 +857,8 @@ func TestAutoExecuteApprovedToolCalls(t *testing.T) {
toolCallingConfig := &testToolCallingConfig{enableChannelMentionToolCalling: true}
conversationService := conversations.New(nil, fakeClient, nil, contextBuilder, botService, nil, licenseChecker, i18n.Init(), nil, toolCallingConfig)

// Call AutoExecuteApprovedToolCalls with pre-approved tool IDs
// Call AutoExecuteApprovedToolCalls with pre-approved tool IDs.
// Without the everywhere policy, channels should still land in result review.
conversationService.AutoExecuteApprovedToolCalls(postID, requesterID, []string{"tool-1"})

// Verify results stored in KV
Expand All @@ -874,13 +875,124 @@ func TestAutoExecuteApprovedToolCalls(t *testing.T) {
require.NotEmpty(t, fakeClient.updatedPosts)
lastUpdated := fakeClient.updatedPosts[len(fakeClient.updatedPosts)-1]
require.Equal(t, "true", lastUpdated.GetProp(streaming.PendingToolResultProp))
require.Nil(t, lastUpdated.GetProp(streaming.AutoShareToolResultProp))

// Verify tool calls on post are still redacted
toolCallProp, ok := lastUpdated.GetProp(streaming.ToolCallProp).(string)
require.True(t, ok)
require.NotContains(t, toolCallProp, "auto-data")
})

t.Run("auto run everywhere auto shares channel tool results", func(t *testing.T) {
mockAPI := &plugintest.API{}
client := pluginapi.NewClient(mockAPI, nil)
licenseChecker := enterprise.NewLicenseChecker(client)

siteName := "Mattermost"
mockAPI.On("GetConfig").Return(&model.Config{TeamSettings: model.TeamSettings{SiteName: &siteName}}).Maybe()
mockAPI.On("GetLicense").Return(&model.License{SkuShortName: "advanced"}).Maybe()
mockAPI.On("GetTeam", teamID).Return(&model.Team{Id: teamID}, nil).Maybe()

tool := llm.Tool{
Name: "test_tool",
Description: "test tool",
ServerOrigin: mcp.EmbeddedClientKey,
Schema: llm.NewJSONSchemaFromStruct[toolArgs](),
Resolver: func(_ *llm.Context, args llm.ToolArgumentGetter) (string, error) {
var parsed toolArgs
if err := args(&parsed); err != nil {
return "", err
}
return "result:" + parsed.Value, nil
},
}
contextBuilder := llmcontext.NewLLMContextBuilder(client, &testToolProvider{tools: []llm.Tool{tool}}, nil, &testConfigProvider{})
promptSet, err := llm.NewPrompts(prompts.PromptsFolder)
require.NoError(t, err)

streamingService := &fakeStreamingService{}
capturingLLM := &capturingLanguageModel{}

botService := bots.New(mockAPI, client, licenseChecker, nil, &http.Client{}, nil)
bot := bots.NewBot(llm.BotConfig{ID: botID, Name: "test-bot"}, llm.ServiceConfig{}, &model.Bot{UserId: botID, Username: "test-bot"}, capturingLLM)
botService.SetBotsForTesting([]*bots.Bot{bot})

post := &model.Post{
Id: postID,
UserId: botID,
ChannelId: channelID,
CreateAt: 1,
}
post.AddProp(streaming.LLMRequesterUserID, requesterID)
post.AddProp(streaming.AllowToolsInChannelProp, "true")
post.AddProp(streaming.AutoApprovedToolCallProp, "true")
post.AddProp(streaming.AutoShareToolResultProp, "true")

toolCalls := []llm.ToolCall{
{
ID: "tool-1",
Name: "test_tool",
ServerOrigin: mcp.EmbeddedClientKey,
Arguments: json.RawMessage(`{"value":"auto-data"}`),
},
}
redactedToolCalls := streaming.RedactToolCalls(toolCalls)
redactedJSON, err := json.Marshal(redactedToolCalls)
require.NoError(t, err)
post.AddProp(streaming.ToolCallProp, string(redactedJSON))
post.AddProp(streaming.ToolCallRedactedProp, "true")
post.AddProp(streaming.PendingToolResultProp, "true")

postList := &model.PostList{
Order: []string{postID},
Posts: map[string]*model.Post{postID: post},
}

channel := &model.Channel{
Id: channelID,
Type: model.ChannelTypeOpen,
TeamId: teamID,
}

fakeClient := &fakeMMClient{
users: map[string]*model.User{
requesterID: {Id: requesterID, Username: "user", Locale: "en"},
botID: {Id: botID, Username: "bot", Locale: "en"},
},
posts: map[string]*model.Post{postID: post},
channels: map[string]*model.Channel{channelID: channel},
postThreads: map[string]*model.PostList{postID: postList},
kv: map[string]interface{}{},
}

toolCallKVKey := streaming.ToolCallPrivateKVKey(postID, requesterID)
fakeClient.kv[toolCallKVKey] = toolCalls

toolCallingConfig := &testToolCallingConfig{enableChannelMentionToolCalling: true}
conversationService := conversations.New(promptSet, fakeClient, streamingService, contextBuilder, botService, nil, licenseChecker, i18n.Init(), nil, toolCallingConfig)
conversationService.SetToolPolicyChecker(streaming.ToolPolicyFunc(func(serverBaseURL, toolName string) (string, bool) {
if serverBaseURL == mcp.EmbeddedClientKey && toolName == "test_tool" {
return mcp.ToolPolicyAutoRunEverywhere, true
}
return mcp.ToolPolicyAsk, true
}))

conversationService.AutoExecuteApprovedToolCalls(postID, requesterID, []string{"tool-1"})

require.NotEmpty(t, fakeClient.updatedPosts)
lastUpdated := fakeClient.updatedPosts[len(fakeClient.updatedPosts)-1]
require.Nil(t, lastUpdated.GetProp(streaming.PendingToolResultProp))
require.Nil(t, lastUpdated.GetProp(streaming.AutoShareToolResultProp))
require.Nil(t, lastUpdated.GetProp(streaming.ToolCallRedactedProp))

require.Len(t, streamingService.streamedPosts, 1)
require.NotNil(t, capturingLLM.autoRunTools)

resultKVKey := streaming.ToolResultPrivateKVKey(postID, requesterID)
require.Contains(t, fakeClient.kvDeletes, resultKVKey)
require.Contains(t, fakeClient.kvDeletes, toolCallKVKey)
})

t.Run("tool execution error - result still stored", func(t *testing.T) {
mockAPI := &plugintest.API{}
client := pluginapi.NewClient(mockAPI, nil)
Expand Down
2 changes: 2 additions & 0 deletions e2e/helpers/system-console-container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export interface SystemConsolePluginConfig {
enableUserRestrictions?: boolean;
enableVectorIndex?: boolean;
enableTokenUsageLogging?: boolean;
enableChannelMentionToolCalling?: boolean;
defaultBotName?: string;
allowedUpstreamHostnames?: string;
allowUnsafeLinks?: boolean;
Expand Down Expand Up @@ -102,6 +103,7 @@ export async function RunSystemConsoleContainer(config: SystemConsolePluginConfi
enableUserRestrictions: config.enableUserRestrictions ?? false,
enableVectorIndex: config.enableVectorIndex ?? false,
enableTokenUsageLogging: config.enableTokenUsageLogging,
enableChannelMentionToolCalling: config.enableChannelMentionToolCalling ?? false,
defaultBotName: config.defaultBotName,
allowedUpstreamHostnames: config.allowedUpstreamHostnames,
allowUnsafeLinks: config.allowUnsafeLinks,
Expand Down
1 change: 1 addition & 0 deletions e2e/helpers/tool-config-container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export async function RunToolConfigContainer(): Promise<MattermostContainer> {
*/
export async function RunToolConfigContainerWithPolicies(): Promise<MattermostContainer> {
return RunSystemConsoleContainer({
enableChannelMentionToolCalling: true,
services: [
{
id: 'mock-service',
Expand Down
2 changes: 1 addition & 1 deletion e2e/helpers/tool-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ export class ToolConfigUIHelper {
}

/** Set tool policy via dropdown */
async setToolPolicy(toolName: string, policy: 'Auto Run' | 'Ask Every Time'): Promise<void> {
async setToolPolicy(toolName: string, policy: 'Auto Run (DM)' | 'Auto Run (Everywhere)' | 'Ask Every Time'): Promise<void> {
const dropdown = this.getToolPolicyDropdown(toolName);
await dropdown.selectOption({ label: policy });
}
Expand Down
Loading
Loading