Skip to content

Commit 26865a9

Browse files
authored
feat: Retry model response after failed tool calls (#2329)
1 parent d5dec4e commit 26865a9

1 file changed

Lines changed: 29 additions & 7 deletions

File tree

model/mcp.go

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ type ToolCall struct {
5656
IsError bool `json:"isError"`
5757
}
5858

59+
const toolErrorRecoveryPrompt = "The previous tool call failed. Do not finish as if the task is complete. If possible, recover by calling the appropriate tool again with corrected arguments or a different tool. If recovery is not possible, clearly explain what is blocked and what input or condition is needed."
60+
5961
type ToolCallDelta struct {
6062
Index int `json:"index"`
6163
ID string `json:"id,omitempty"`
@@ -188,6 +190,7 @@ func QueryTextWithTools(p ModelProvider, question string, writer io.Writer, hist
188190
fmt.Printf(" Tool %d: [%s] args: %s\n", i+1, tc.Function.Name, tc.Function.Arguments)
189191
}
190192

193+
roundHasToolError := false
191194
for _, toolCall := range toolCalls {
192195
serverName, toolName := mcp.GetServerNameAndToolNameFromId(toolCall.Function.Name)
193196

@@ -198,10 +201,14 @@ func QueryTextWithTools(p ModelProvider, question string, writer io.Writer, hist
198201
ToolCall: toolCall,
199202
})
200203

201-
messages, err = callMcpTool(toolCall, serverName, toolName, toolSession.McpToolSet, messages, writer, lang)
204+
var toolFailed bool
205+
messages, toolFailed, err = callMcpTool(toolCall, serverName, toolName, toolSession.McpToolSet, messages, writer, lang)
202206
if err != nil {
203207
return nil, err
204208
}
209+
if toolFailed {
210+
roundHasToolError = true
211+
}
205212
}
206213

207214
toolSession.ToolMessages.Messages = messages
@@ -212,6 +219,21 @@ func QueryTextWithTools(p ModelProvider, question string, writer io.Writer, hist
212219
}
213220

214221
toolCalls = normalizeToolCalls(toolSession)
222+
if len(toolCalls) == 0 && roundHasToolError {
223+
messages = append(messages, &RawMessage{
224+
Text: toolErrorRecoveryPrompt,
225+
Author: "System",
226+
})
227+
toolSession.ToolMessages.Messages = messages
228+
229+
fmt.Printf("\n--- LLM Call (Round %d recovery) | Tool error recovery prompt added ---\n", round)
230+
modelResult, err = p.QueryText(question, writer, history, prompt, knowledgeMessages, toolSession, lang)
231+
if err != nil {
232+
return nil, err
233+
}
234+
235+
toolCalls = normalizeToolCalls(toolSession)
236+
}
215237
}
216238

217239
fmt.Printf("LLM Decision: [Final Answer — no more tool calls after round %d]\n", round)
@@ -252,12 +274,12 @@ func startHeartbeat(writer io.Writer, mu *sync.Mutex) chan<- struct{} {
252274
return stop
253275
}
254276

255-
func callMcpTool(toolCall openai.ToolCall, serverName, toolName string, mcpToolSet *mcp.ToolSet, messages []*RawMessage, writer io.Writer, lang string) ([]*RawMessage, error) {
277+
func callMcpTool(toolCall openai.ToolCall, serverName, toolName string, mcpToolSet *mcp.ToolSet, messages []*RawMessage, writer io.Writer, lang string) ([]*RawMessage, bool, error) {
256278
var arguments map[string]interface{}
257279
ctx := context.Background()
258280

259281
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil {
260-
return nil, fmt.Errorf(i18n.Translate(lang, "model:failed to parse tool arguments: %v"), err)
282+
return nil, false, fmt.Errorf(i18n.Translate(lang, "model:failed to parse tool arguments: %v"), err)
261283
}
262284

263285
// Send tool-start event immediately so the frontend can show the tool call before execution
@@ -281,14 +303,14 @@ func callMcpTool(toolCall openai.ToolCall, serverName, toolName string, mcpToolS
281303
if serverName == "" {
282304
// builtin tools
283305
if mcpToolSet.BuiltinTools == nil {
284-
return messages, nil
306+
return messages, false, nil
285307
}
286308
result, err = mcpToolSet.BuiltinTools.ExecuteTool(ctx, toolName, arguments)
287309
} else {
288310
// MCP server tools
289311
conn, ok := mcpToolSet.Connections[serverName]
290312
if !ok {
291-
return messages, nil
313+
return messages, false, nil
292314
}
293315
req := &protocol.CallToolRequest{
294316
Name: toolName,
@@ -324,7 +346,7 @@ func callMcpTool(toolCall openai.ToolCall, serverName, toolName string, mcpToolS
324346

325347
responseJson, err := json.Marshal(response)
326348
if err != nil {
327-
return nil, fmt.Errorf(i18n.Translate(lang, "model:failed to marshal tool response: %v"), err)
349+
return nil, false, fmt.Errorf(i18n.Translate(lang, "model:failed to marshal tool response: %v"), err)
328350
}
329351

330352
var contentStr string
@@ -352,7 +374,7 @@ func callMcpTool(toolCall openai.ToolCall, serverName, toolName string, mcpToolS
352374
}
353375

354376
messages = append(messages, createToolMessage(toolCall, string(responseJson)))
355-
return messages, nil
377+
return messages, !response.Success, nil
356378
}
357379

358380
func GetToolCallsFromWriter(toolMessage string) []ToolCall {

0 commit comments

Comments
 (0)