Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type MCPClientManager interface {
MarkOAuthNeeded(userID, serverName, authURL string) error
GetEmbeddedServer() mcp.EmbeddedMCPServer
EnsureMCPSessionID(userID string) (string, error)
GetToolsForUser(userID string) ([]llm.Tool, *mcp.Errors)
GetToolsForUser(ctx context.Context, userID string) ([]llm.Tool, *mcp.Errors)
GetConfig() mcp.Config

RegisterPluginServer(cfg mcp.PluginServerConfig)
Expand Down
22 changes: 16 additions & 6 deletions api/api_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func (a *API) handleChannelAnalysis(c *gin.Context) {
userID := c.GetHeader("Mattermost-User-Id")
channel := c.MustGet(ContextChannelKey).(*model.Channel)
bot := c.MustGet(ContextBotKey).(*bots.Bot)
toolBot := channelAnalysisToolBot(bot)

var data struct {
AnalysisType string `json:"analysis_type" binding:"required"`
Expand Down Expand Up @@ -87,7 +88,7 @@ func (a *API) handleChannelAnalysis(c *gin.Context) {
}

opts := []llm.ContextOption{
a.contextBuilder.WithLLMContextDefaultTools(bot),
a.contextBuilder.WithLLMContextDefaultTools(c.Request.Context(), toolBot),
}

// If the channel is a DM/GM and we have a team ID from the client, use it for context
Expand Down Expand Up @@ -120,16 +121,12 @@ func (a *API) handleChannelAnalysis(c *gin.Context) {

// Check if read_channel tool is available
availableTools := llmContext.Tools.GetTools()
hasReadChannel := false
var toolNames []string
for _, tool := range availableTools {
toolNames = append(toolNames, tool.Name)
if tool.Name == "read_channel" {
hasReadChannel = true
}
}

if !hasReadChannel {
if llmContext.Tools.GetTool("read_channel") == nil {
a.pluginAPI.Log.Error("Channel analysis failed: read_channel tool not available",
"userID", userID,
"channelID", channel.Id,
Expand Down Expand Up @@ -178,6 +175,19 @@ func (a *API) handleChannelAnalysis(c *gin.Context) {
})
}

func channelAnalysisToolBot(bot *bots.Bot) *bots.Bot {
cfg := bot.GetConfig()
if cfg.AutoEnableNewMCPTools {
return bot
}

// Channel analysis immediately scopes the catalog to bound read-only tools,
// so load the MCP catalog even when the agent uses a narrower allowlist.
cfg.AutoEnableNewMCPTools = true
cfg.EnabledMCPTools = nil
return bots.NewBot(cfg, bot.GetService(), bot.GetMMBot(), bot.LLM())
}

func (a *API) handleInterval(c *gin.Context) {
userID := c.GetHeader("Mattermost-User-Id")
channel := c.MustGet(ContextChannelKey).(*model.Channel)
Expand Down
14 changes: 8 additions & 6 deletions api/api_llm_bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package api

import (
"bytes"
stdcontext "context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -163,7 +164,7 @@ func (a *API) buildLLMBridgeContext(bot *bots.Bot, req bridgeclient.CompletionRe
return context, nil
}

func (a *API) convertAgentBridgeRequestToInternal(bot *bots.Bot, req bridgeclient.CompletionRequest, includeTools bool, operation, operationSubType string) (llm.CompletionRequest, error) {
func (a *API) convertAgentBridgeRequestToInternal(ctx stdcontext.Context, bot *bots.Bot, req bridgeclient.CompletionRequest, includeTools bool, operation, operationSubType string) (llm.CompletionRequest, error) {
posts, err := a.convertBridgePostsToInternal(req)
if err != nil {
return llm.CompletionRequest{}, err
Expand All @@ -172,7 +173,7 @@ func (a *API) convertAgentBridgeRequestToInternal(bot *bots.Bot, req bridgeclien
bridgeContext := llm.NewContext()
bridgeContext.RequestingUser = &model.User{Id: req.UserID}
if includeTools && a.contextBuilder != nil {
a.contextBuilder.WithLLMContextTools(bot)(bridgeContext)
a.contextBuilder.WithLLMContextTools(ctx, bot)(bridgeContext)
}

resolvedOperation := operation
Expand Down Expand Up @@ -270,6 +271,7 @@ func validateCompletionRequestIDs(req bridgeclient.CompletionRequest) (int, erro
}

func (a *API) prepareAgentBridgeCompletion(
ctx stdcontext.Context,
agent string,
req bridgeclient.CompletionRequest,
pluginID string,
Expand Down Expand Up @@ -314,7 +316,7 @@ func (a *API) prepareAgentBridgeCompletion(
}

toolsRequested := allowedToolNames != nil
llmRequest, err := a.convertAgentBridgeRequestToInternal(bot, req, toolsRequested, operation, operationSubType)
llmRequest, err := a.convertAgentBridgeRequestToInternal(ctx, bot, req, toolsRequested, operation, operationSubType)
if err != nil {
return nil, llm.CompletionRequest{}, nil, nil, nil, http.StatusBadRequest, fmt.Errorf("invalid request: %v", err)
}
Expand Down Expand Up @@ -689,7 +691,7 @@ func (a *API) handleGetAgentTools(c *gin.Context) {
toolContext := llm.NewContext()
toolContext.RequestingUser = &model.User{Id: userID}
if a.contextBuilder != nil {
a.contextBuilder.WithLLMContextTools(bot)(toolContext)
a.contextBuilder.WithLLMContextTools(c.Request.Context(), bot)(toolContext)
}

var tools []bridgeclient.BridgeToolInfo
Expand Down Expand Up @@ -780,7 +782,7 @@ func (a *API) handleAgentCompletionStreaming(c *gin.Context) {
return
}

bot, llmRequest, opts, shouldExecute, beforeHookKeys, statusCode, err := a.prepareAgentBridgeCompletion(agent, req, c.GetHeader("Mattermost-Plugin-ID"), llm.OperationBridgeAgent, llm.SubTypeStreaming)
bot, llmRequest, opts, shouldExecute, beforeHookKeys, statusCode, err := a.prepareAgentBridgeCompletion(c.Request.Context(), agent, req, c.GetHeader("Mattermost-Plugin-ID"), llm.OperationBridgeAgent, llm.SubTypeStreaming)
if err != nil {
c.JSON(statusCode, bridgeclient.ErrorResponse{
Error: err.Error(),
Expand Down Expand Up @@ -811,7 +813,7 @@ func (a *API) handleAgentCompletionNoStream(c *gin.Context) {
return
}

bot, llmRequest, opts, shouldExecute, beforeHookKeys, statusCode, err := a.prepareAgentBridgeCompletion(agent, req, c.GetHeader("Mattermost-Plugin-ID"), llm.OperationBridgeAgent, llm.SubTypeNoStream)
bot, llmRequest, opts, shouldExecute, beforeHookKeys, statusCode, err := a.prepareAgentBridgeCompletion(c.Request.Context(), agent, req, c.GetHeader("Mattermost-Plugin-ID"), llm.OperationBridgeAgent, llm.SubTypeNoStream)
if err != nil {
c.JSON(statusCode, bridgeclient.ErrorResponse{
Error: err.Error(),
Expand Down
21 changes: 13 additions & 8 deletions api/api_llm_bridge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ func (e *TestEnvironment) setupMCPWithEligibleTools(t *testing.T, toolNames []st
ServerOrigin: server.URL,
Description: name,
Schema: llm.NewJSONSchemaFromStruct[struct{}](),
Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
return "ok", nil
},
}
Expand Down Expand Up @@ -1415,7 +1415,7 @@ func TestBridgeGetAgentToolsReturnsEligibleOnly(t *testing.T) {
ServerOrigin: server.URL,
Description: "eligible from context",
Schema: llm.NewJSONSchemaFromStruct[struct{}](),
Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
return "ok", nil
},
},
Expand All @@ -1424,7 +1424,7 @@ func TestBridgeGetAgentToolsReturnsEligibleOnly(t *testing.T) {
ServerOrigin: server.URL,
Description: "should be filtered out",
Schema: llm.NewJSONSchemaFromStruct[struct{}](),
Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
return "ok", nil
},
},
Expand Down Expand Up @@ -1478,7 +1478,7 @@ func TestBridgeGetAgentToolsReturnsEmbeddedServerTools(t *testing.T) {
ServerOrigin: mcp.EmbeddedClientKey,
Description: "tool from embedded server",
Schema: llm.NewJSONSchemaFromStruct[struct{}](),
Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
return "ok", nil
},
},
Expand Down Expand Up @@ -1542,7 +1542,7 @@ func TestBridgeGetAgentToolsSkipsUnreachableEligibleServer(t *testing.T) {
ServerOrigin: server.URL,
Description: "eligible from context",
Schema: llm.NewJSONSchemaFromStruct[struct{}](),
Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
return "ok", nil
},
},
Expand Down Expand Up @@ -1690,6 +1690,7 @@ func TestPrepareAgentBridgeCompletionAllowedToolsRequiresUserID(t *testing.T) {
defer e.Cleanup(t)

_, _, _, _, _, statusCode, err := e.api.prepareAgentBridgeCompletion(
context.Background(),
testBotUserID,
bridgeclient.CompletionRequest{
Posts: []bridgeclient.Post{
Expand Down Expand Up @@ -1724,6 +1725,7 @@ func TestPrepareAgentBridgeCompletionToolHooksRequiresPluginID(t *testing.T) {
e.setupTestBot(botConfig)

_, _, _, _, _, statusCode, err := e.api.prepareAgentBridgeCompletion(
context.Background(),
testBotUserID,
bridgeclient.CompletionRequest{
Posts: []bridgeclient.Post{
Expand Down Expand Up @@ -1783,6 +1785,7 @@ func TestPrepareAgentBridgeCompletionStoresToolHookKeysInMCPMetadata(t *testing.
).Return(true, (*model.AppError)(nil)).Once()

_, llmRequest, _, _, beforeHookKeys, statusCode, err := e.api.prepareAgentBridgeCompletion(
context.Background(),
testBotUserID,
bridgeclient.CompletionRequest{
Posts: []bridgeclient.Post{
Expand Down Expand Up @@ -1849,6 +1852,7 @@ func TestPrepareAgentBridgeCompletionToolHooksRequiresUserID(t *testing.T) {
e.setupTestBot(botConfig)

_, _, _, _, _, statusCode, err := e.api.prepareAgentBridgeCompletion(
context.Background(),
testBotUserID,
bridgeclient.CompletionRequest{
Posts: []bridgeclient.Post{
Expand Down Expand Up @@ -1886,6 +1890,7 @@ func TestPrepareAgentBridgeCompletionToolHooksRequiresAllowedTools(t *testing.T)
e.setupTestBot(botConfig)

_, _, _, _, _, statusCode, err := e.api.prepareAgentBridgeCompletion(
context.Background(),
testBotUserID,
bridgeclient.CompletionRequest{
Posts: []bridgeclient.Post{
Expand Down Expand Up @@ -2010,15 +2015,15 @@ func TestBridgeClientAgentCompletionRejectsBuiltinToolInAllowedTools(t *testing.
ServerOrigin: server.URL,
Description: "eligible_tool",
Schema: llm.NewJSONSchemaFromStruct[struct{}](),
Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
return "ok", nil
},
},
{
Name: "builtin_only",
Description: "built-in tool with no MCP origin",
Schema: llm.NewJSONSchemaFromStruct[struct{}](),
Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
return "ok", nil
},
},
Expand Down Expand Up @@ -2186,7 +2191,7 @@ func TestBridgeGetAgentToolsReturnsEmptyWhenMCPDisabled(t *testing.T) {
Name: "context_only_tool",
Description: "should not be bridge-eligible without MCP",
Schema: llm.NewJSONSchemaFromStruct[struct{}](),
Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
return "ok", nil
},
},
Expand Down
4 changes: 2 additions & 2 deletions api/api_mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (a *API) handleGetUserMCPTools(c *gin.Context) {

mcpCfg := a.config.MCP()

tools, mcpErrors := a.mcpClientManager.GetToolsForUser(userID)
tools, mcpErrors := a.mcpClientManager.GetToolsForUser(c.Request.Context(), userID)

// Group tools by ServerOrigin
toolsByOrigin := make(map[string][]llm.Tool, len(tools))
Expand Down Expand Up @@ -139,7 +139,7 @@ func buildUserMCPServerInfo(
) UserMCPServerInfo {
toolInfos := make([]UserMCPToolInfo, 0, len(originTools))
for _, t := range originTools {
policy, enabled := serverConfig.GetToolPolicy(t.Name)
policy, enabled := serverConfig.GetToolPolicy(mcp.ToolPolicyLookupName(serverConfig, t.Name))
toolInfos = append(toolInfos, UserMCPToolInfo{
Name: t.Name,
Description: t.Description,
Expand Down
38 changes: 36 additions & 2 deletions api/api_no_tools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ func (p *noToolsTestToolProvider) GetTools(*bots.Bot) []llm.Tool {

type noToolsTestMCPProvider struct {
calls int
tools []llm.Tool
}

func (p *noToolsTestMCPProvider) GetToolsForUser(string) ([]llm.Tool, *mcp.Errors) {
func (p *noToolsTestMCPProvider) GetToolsForUser(context.Context, string) ([]llm.Tool, *mcp.Errors) {
p.calls++
return nil, nil
return p.tools, nil
}

type noToolsTestContextConfigProvider struct{}
Expand Down Expand Up @@ -283,6 +284,39 @@ func TestHandleIntervalDoesNotLoadToolsWhenToolsAreDisabled(t *testing.T) {
require.Equal(t, 0, mcpProvider.calls, "channel interval should not build MCP tools when the LLM call disables tools")
}

func TestHandleChannelAnalysisAcceptsNamespacedEmbeddedTools(t *testing.T) {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard

mcpProvider := &noToolsTestMCPProvider{
tools: []llm.Tool{
{Name: "mattermost__read_channel", ServerOrigin: mcp.EmbeddedClientKey},
{Name: "mattermost__get_channel_info", ServerOrigin: mcp.EmbeddedClientKey},
},
}
mmClient := mmapimocks.NewMockClient(t)
e, streamingService, _ := setupNoToolsAPI(t, mcpProvider, mmClient)
defer e.Cleanup(t)

requestingUser := &model.User{Id: testUserID, Username: "requester", Locale: "en"}
channel := &model.Channel{Id: testChannelID, Type: model.ChannelTypeOpen, TeamId: "teamid"}

e.mockAPI.On("GetChannel", testChannelID).Return(channel, nil)
e.mockAPI.On("GetTeam", "teamid").Return(&model.Team{Id: "teamid", Name: "team"}, nil)
e.mockAPI.On("HasPermissionToChannel", testUserID, testChannelID, model.PermissionReadChannel).Return(true)
e.mockAPI.On("GetUser", testUserID).Return(requestingUser, nil)

request := httptest.NewRequest(http.MethodPost, "/channel/"+testChannelID+"/analyze", strings.NewReader(`{"analysis_type":"custom","prompt":"summarize this channel"}`))
request.Header.Add("Mattermost-User-ID", testUserID)

recorder := httptest.NewRecorder()
e.api.ServeHTTP(&plugin.Context{}, recorder, request)

require.Equal(t, http.StatusOK, recorder.Result().StatusCode, recorder.Body.String())
require.Equal(t, 1, streamingService.newDMCalls)
require.Equal(t, 1, mcpProvider.calls)
}

// TestHandleIntervalSetsConversationRootPostID verifies that after an
// interval summary completes, the conversation's RootPostID is set to the
// newly-created response post. Without this, interval summaries appear as
Expand Down
2 changes: 1 addition & 1 deletion api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (m *mockMCPClientManager) GetHTTPClient() *http.Client {
return nil
}

func (m *mockMCPClientManager) GetToolsForUser(userID string) ([]llm.Tool, *mcp.Errors) {
func (m *mockMCPClientManager) GetToolsForUser(context.Context, string) ([]llm.Tool, *mcp.Errors) {
return m.tools, m.mcpErrors
}

Expand Down
4 changes: 2 additions & 2 deletions channels/analysis_conversation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func makeTool(name, result string) llm.Tool {
return llm.Tool{
Name: name,
Description: "test tool",
Resolver: func(_ *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) {
Resolver: func(_ context.Context, _ *llm.Context, argsGetter llm.ToolArgumentGetter) (string, error) {
return result, nil
},
}
Expand All @@ -240,7 +240,7 @@ func makeToolWithError(name, errMsg string) llm.Tool {
return llm.Tool{
Name: name,
Description: "test tool that errors",
Resolver: func(_ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
Resolver: func(_ context.Context, _ *llm.Context, _ llm.ToolArgumentGetter) (string, error) {
return "", fmt.Errorf("%s", errMsg)
},
}
Expand Down
Loading
Loading