Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions conversations/tool_handling.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,10 @@ func (c *Conversations) HandleToolCall(userID string, post *model.Post, channel
return fmt.Errorf("failed to update post with tool call results: %w", updateErr)
}

// Only continue if at least one tool call was successful
// Continue when the agent has any actionable tool result, including errors
// it may be able to recover from on the next turn.
if !slices.ContainsFunc(tools, func(tc llm.ToolCall) bool {
return tc.Status == llm.ToolCallStatusSuccess
return tc.Status == llm.ToolCallStatusSuccess || tc.Status == llm.ToolCallStatusError
}) {
return nil
}
Expand Down Expand Up @@ -433,21 +434,22 @@ func (c *Conversations) HandleToolResult(userID string, post *model.Post, channe
return fmt.Errorf("failed to update post after tool result approval: %w", updateErr)
}

// Do not continue streaming when no tool call succeeded (all errors/rejections).
// Re-invoking completeAndStreamToolResponse would cause a channel loop.
hasSuccessfulResult := slices.ContainsFunc(tools, func(tc llm.ToolCall) bool {
return tc.Status == llm.ToolCallStatusSuccess
// Continue when the agent has any actionable tool result, including errors
// it may be able to recover from on the next turn.
hasActionableResult := slices.ContainsFunc(tools, func(tc llm.ToolCall) bool {
return tc.Status == llm.ToolCallStatusSuccess || tc.Status == llm.ToolCallStatusError
})
if !hasSuccessfulResult {
if !hasActionableResult {
c.deleteToolCallKVEntries(post.Id, resultKVKey, toolCallKVKey)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return nil
}

defer c.deleteToolCallKVEntries(post.Id, resultKVKey, toolCallKVKey)

if err := c.completeAndStreamToolResponse(bot, user, channel, toolCallPostCopy, llmContext, toolsDisabled, allowToolsInChannel); err != nil {
return err
}

c.deleteToolCallKVEntries(post.Id, resultKVKey, toolCallKVKey)
return nil
}

Expand Down Expand Up @@ -483,6 +485,10 @@ func (c *Conversations) completeAndStreamToolResponse(
OperationSubType: llm.SubTypeToolCall,
}
var opts []llm.LanguageModelOption
if llm.CountTrailingFailedToolCalls(completionRequest.Posts) >= llm.MaxConsecutiveToolCallFailures {
completionRequest.Posts = llm.EnsureToolRetryLimitSystemMessage(completionRequest.Posts)
toolsDisabled = true
}
if toolsDisabled {
opts = append(opts, llm.WithToolsDisabled())
}
Expand Down
129 changes: 121 additions & 8 deletions conversations/tool_handling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,23 @@ func (c *fakeMMClient) GetFile(string) (io.ReadCloser, error) {
func (c *fakeMMClient) SendEphemeralPost(string, *model.Post) {}

type capturingLanguageModel struct {
autoRunTools []string
autoRunTools []string
requests []llm.CompletionRequest
chatCompletionErr error
}

func (m *capturingLanguageModel) ChatCompletion(_ llm.CompletionRequest, opts ...llm.LanguageModelOption) (*llm.TextStreamResult, error) {
func (m *capturingLanguageModel) ChatCompletion(request llm.CompletionRequest, opts ...llm.LanguageModelOption) (*llm.TextStreamResult, error) {
var cfg llm.LanguageModelConfig
for _, opt := range opts {
opt(&cfg)
}
m.autoRunTools = append([]string{}, cfg.AutoRunTools...)
requestCopy := request
requestCopy.Posts = append([]llm.Post(nil), request.Posts...)
m.requests = append(m.requests, requestCopy)
if m.chatCompletionErr != nil {
return nil, m.chatCompletionErr
}
return llm.NewStreamFromString("follow-up response"), nil
}

Expand Down Expand Up @@ -1040,7 +1048,7 @@ func TestAutoExecuteApprovedToolCalls(t *testing.T) {
})
}

func TestHandleToolResultDoesNotContinueWhenNoToolCallSucceeded(t *testing.T) {
func TestHandleToolResultContinuesWhenToolCallErrors(t *testing.T) {
const (
postID = "post-id"
channelID = "channel-id"
Expand All @@ -1060,8 +1068,10 @@ func TestHandleToolResultDoesNotContinueWhenNoToolCallSucceeded(t *testing.T) {

contextBuilder := llmcontext.NewLLMContextBuilder(client, &testToolProvider{tools: []llm.Tool{}}, nil, &testConfigProvider{})

streamingService := &fakeStreamingService{}
capturingLLM := &capturingLanguageModel{}
botService := bots.New(mockAPI, client, licenseChecker, nil, &http.Client{}, nil)
bot := bots.NewBot(llm.BotConfig{ID: botID, Name: "test-bot"}, llm.ServiceConfig{}, &model.Bot{UserId: botID, Username: "test-bot"}, nil)
bot := bots.NewBot(llm.BotConfig{ID: botID, Name: "test-bot"}, llm.ServiceConfig{}, &model.Bot{UserId: botID, Username: "test-bot"}, capturingLLM)
botService.SetBotsForTesting([]*bots.Bot{bot})

post := &model.Post{
Expand Down Expand Up @@ -1117,18 +1127,121 @@ func TestHandleToolResultDoesNotContinueWhenNoToolCallSucceeded(t *testing.T) {
}

toolCallingConfig := &testToolCallingConfig{enableChannelMentionToolCalling: true}
// Nil streaming service: completeAndStreamToolResponse would panic if invoked
conversationService := conversations.New(nil, fakeClient, nil, contextBuilder, botService, nil, licenseChecker, i18n.Init(), nil, toolCallingConfig)
promptSet, err := llm.NewPrompts(prompts.PromptsFolder)
require.NoError(t, err)
conversationService := conversations.New(promptSet, fakeClient, streamingService, contextBuilder, botService, nil, licenseChecker, i18n.Init(), nil, toolCallingConfig)

err = conversationService.HandleToolResult(requesterID, post, channel, []string{"tool-1"})
require.NoError(t, err)

// Post was updated with final tool results
// Post was updated with final tool results before the follow-up turn.
require.Len(t, fakeClient.updatedPosts, 1)
updatedPost := fakeClient.updatedPosts[0]
require.Nil(t, updatedPost.GetProp(streaming.PendingToolResultProp))

// KV entries were cleaned up (no continuation = no need to keep them)
// The continuation request includes the actionable tool error.
require.NotEmpty(t, capturingLLM.requests)
lastRequest := capturingLLM.requests[len(capturingLLM.requests)-1]
require.NotEmpty(t, lastRequest.Posts)
lastPost := lastRequest.Posts[len(lastRequest.Posts)-1]
require.Len(t, lastPost.ToolUse, 1)
require.Equal(t, llm.ToolCallStatusError, lastPost.ToolUse[0].Status)
require.Equal(t, "tool execution failed", lastPost.ToolUse[0].Result)

// A follow-up turn is streamed so the agent can inspect the error and recover.
require.Len(t, streamingService.streamedPosts, 1)

// KV entries are still cleaned up after the continuation runs.
require.Contains(t, fakeClient.kvDeletes, resultKVKey)
require.Contains(t, fakeClient.kvDeletes, toolCallKVKey)
}

func TestHandleToolResultCleansUpKVWhenContinuationFails(t *testing.T) {
const (
postID = "post-id"
channelID = "channel-id"
teamID = "team-id"
botID = "bot-id"
requesterID = "requester-id"
)

mockAPI := &plugintest.API{}
client := pluginapi.NewClient(mockAPI, nil)
licenseChecker := enterprise.NewLicenseChecker(client)

siteName := "Mattermost"
mockAPI.On("GetConfig").Return(&model.Config{TeamSettings: model.TeamSettings{SiteName: &siteName}}).Maybe()
mockAPI.On("GetLicense").Return(&model.License{SkuShortName: "advanced"}).Maybe()
mockAPI.On("GetTeam", teamID).Return(&model.Team{Id: teamID}, nil).Maybe()

contextBuilder := llmcontext.NewLLMContextBuilder(client, &testToolProvider{tools: []llm.Tool{}}, nil, &testConfigProvider{})

streamingService := &fakeStreamingService{}
capturingLLM := &capturingLanguageModel{chatCompletionErr: errors.New("follow-up failed")}
botService := bots.New(mockAPI, client, licenseChecker, nil, &http.Client{}, nil)
bot := bots.NewBot(llm.BotConfig{ID: botID, Name: "test-bot"}, llm.ServiceConfig{}, &model.Bot{UserId: botID, Username: "test-bot"}, capturingLLM)
botService.SetBotsForTesting([]*bots.Bot{bot})

post := &model.Post{
Id: postID,
UserId: botID,
ChannelId: channelID,
CreateAt: 1,
}
post.AddProp(streaming.LLMRequesterUserID, requesterID)
post.AddProp(streaming.AllowToolsInChannelProp, "true")
post.AddProp(streaming.PendingToolResultProp, "true")

toolsWithErrors := []llm.ToolCall{
{
ID: "tool-1",
Name: "failing_tool",
Arguments: json.RawMessage(`{"value":"test"}`),
Result: "tool execution failed",
Status: llm.ToolCallStatusError,
},
}
toolsJSON, err := json.Marshal(toolsWithErrors)
require.NoError(t, err)
post.AddProp(streaming.ToolCallProp, string(toolsJSON))

postList := &model.PostList{
Order: []string{postID},
Posts: map[string]*model.Post{postID: post},
}

channel := &model.Channel{
Id: channelID,
Type: model.ChannelTypeOpen,
TeamId: teamID,
}

resultKVKey := streaming.ToolResultPrivateKVKey(postID, requesterID)
toolCallKVKey := streaming.ToolCallPrivateKVKey(postID, requesterID)

fakeClient := &fakeMMClient{
users: map[string]*model.User{
requesterID: {Id: requesterID, Locale: "en"},
botID: {Id: botID, Locale: "en"},
},
posts: map[string]*model.Post{postID: post},
channels: map[string]*model.Channel{channelID: channel},
postThreads: map[string]*model.PostList{postID: postList},
kv: map[string]interface{}{
resultKVKey: toolsWithErrors,
toolCallKVKey: toolsWithErrors,
},
}

toolCallingConfig := &testToolCallingConfig{enableChannelMentionToolCalling: true}
promptSet, err := llm.NewPrompts(prompts.PromptsFolder)
require.NoError(t, err)
conversationService := conversations.New(promptSet, fakeClient, streamingService, contextBuilder, botService, nil, licenseChecker, i18n.Init(), nil, toolCallingConfig)

err = conversationService.HandleToolResult(requesterID, post, channel, []string{"tool-1"})
require.ErrorContains(t, err, "failed to get chat completion")

require.Contains(t, fakeClient.kvDeletes, resultKVKey)
require.Contains(t, fakeClient.kvDeletes, toolCallKVKey)
require.Empty(t, streamingService.streamedPosts)
}
18 changes: 14 additions & 4 deletions llm/auto_run_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ func (w *AutoRunToolsWrapper) ChatCompletion(request CompletionRequest, opts ...
opt(&cfg)
}

// If auto-run is not configured or no tools context, delegate directly
if len(cfg.AutoRunTools) == 0 || request.Context == nil || request.Context.Tools == nil {
// If auto-run is not configured, tools are disabled, or no tools context exists,
// delegate directly.
if cfg.ToolsDisabled || len(cfg.AutoRunTools) == 0 || request.Context == nil || request.Context.Tools == nil {
return w.inner.ChatCompletion(request, opts...)
}

Expand All @@ -47,8 +48,11 @@ func (w *AutoRunToolsWrapper) ChatCompletion(request CompletionRequest, opts ...
// runToolLoop runs the tool resolution loop, forwarding events and re-invoking
// the LLM when auto-runnable tool calls are received.
func (w *AutoRunToolsWrapper) runToolLoop(request CompletionRequest, opts []LanguageModelOption, autoRunTools []string, output chan<- TextStreamEvent) {
currentOpts := append([]LanguageModelOption(nil), opts...)
currentAutoRunTools := append([]string(nil), autoRunTools...)

for i := 0; i < MaxToolResolutionDepth; i++ {
result, err := w.inner.ChatCompletion(request, opts...)
result, err := w.inner.ChatCompletion(request, currentOpts...)
if err != nil {
output <- TextStreamEvent{Type: EventTypeError, Value: err}
return
Expand Down Expand Up @@ -85,7 +89,7 @@ func (w *AutoRunToolsWrapper) runToolLoop(request CompletionRequest, opts []Lang
return
}

if !ShouldAutoRunTools(toolCalls, autoRunTools) {
if !ShouldAutoRunTools(toolCalls, currentAutoRunTools) {
// Tool calls are not all auto-runnable: forward them and return
output <- TextStreamEvent{Type: EventTypeToolCalls, Value: toolCalls}
return
Expand Down Expand Up @@ -123,6 +127,12 @@ func (w *AutoRunToolsWrapper) runToolLoop(request CompletionRequest, opts []Lang
Message: accumulatedText,
ToolUse: resolvedToolCalls,
})

if CountTrailingFailedToolCalls(request.Posts) >= MaxConsecutiveToolCallFailures {
request.Posts = EnsureToolRetryLimitSystemMessage(request.Posts)
currentOpts = append(currentOpts, WithToolsDisabled())
currentAutoRunTools = nil
}
}

// If we've exhausted MaxToolResolutionDepth, send end event
Expand Down
62 changes: 62 additions & 0 deletions llm/auto_run_tools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,65 @@ func TestAutoRunToolsPreservesServerOrigin(t *testing.T) {
assert.Equal(t, "mcp_result", lastPost.ToolUse[0].Result)
assert.Equal(t, ToolCallStatusSuccess, lastPost.ToolUse[0].Status)
}

func TestAutoRunToolsWrapper_DisablesToolsAfterThreeConsecutiveFailures(t *testing.T) {
failingStore := NewNoTools()
failingStore.AddTools([]Tool{
{
Name: "test_tool",
Description: "A failing test tool",
Resolver: func(_ *Context, _ ToolArgumentGetter) (string, error) {
return "", fmt.Errorf("bad request: missing required field")
},
},
})

var capturedPosts []Post
inner := &testLLM{
responses: []testResponse{
{events: []TextStreamEvent{
{Type: EventTypeToolCalls, Value: []ToolCall{{ID: "tc1", Name: "test_tool", Arguments: json.RawMessage(`{}`)}}},
{Type: EventTypeEnd},
}},
{events: []TextStreamEvent{
{Type: EventTypeToolCalls, Value: []ToolCall{{ID: "tc2", Name: "test_tool", Arguments: json.RawMessage(`{}`)}}},
{Type: EventTypeEnd},
}},
{events: []TextStreamEvent{
{Type: EventTypeToolCalls, Value: []ToolCall{{ID: "tc3", Name: "test_tool", Arguments: json.RawMessage(`{}`)}}},
{Type: EventTypeEnd},
}},
{events: []TextStreamEvent{
{Type: EventTypeText, Value: "please provide the missing field"},
{Type: EventTypeEnd},
}},
},
}

capturingInner := &capturingLLM{inner: inner, capturedPosts: &capturedPosts}
wrapper := NewAutoRunToolsWrapper(capturingInner)

request := CompletionRequest{
Posts: []Post{{Role: PostRoleUser, Message: "run the tool"}},
Context: &Context{Tools: failingStore},
}

result, err := wrapper.ChatCompletion(request, WithAutoRunTools([]string{ToolAutoRunKey("", "test_tool")}))
require.NoError(t, err)

var texts []string
for event := range result.Stream {
if event.Type == EventTypeText {
texts = append(texts, event.Value.(string))
}
}

require.Len(t, capturedPosts, 5)
assert.Equal(t, PostRoleSystem, capturedPosts[0].Role)
assert.Contains(t, capturedPosts[0].Message, ToolRetryLimitSystemMessage)
finalRequestPost := capturedPosts[len(capturedPosts)-1]
assert.Equal(t, PostRoleBot, finalRequestPost.Role)
assert.Equal(t, ToolCallStatusError, finalRequestPost.ToolUse[0].Status)
assert.Equal(t, "please provide the missing field", texts[len(texts)-1])
assert.Equal(t, 4, inner.callCount)
}
Loading
Loading