Skip to content

Commit 405ebc7

Browse files
committed
fix(tools): Support cancelling tools.
1 parent c9d74dd commit 405ebc7

5 files changed

Lines changed: 84 additions & 37 deletions

File tree

lua/codecompanion/interactions/chat/init.lua

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
---@field context CodeCompanion.Chat.Context
1616
---@field context_items? table<CodeCompanion.Chat.Context> Context which is sent to the LLM e.g. buffers, slash command output
1717
---@field current_request table|nil The current request being executed
18-
---@field current_tool table The current tool being executed
1918
---@field cycle number Records the number of turn-based interactions (User -> LLM) that have taken place
2019
---@field create_buf fun(): number The function that creates a new buffer for the chat
2120
---@field edit_tracker? CodeCompanion.Chat.EditTracker Edit tracking information for the chat
@@ -31,6 +30,8 @@
3130
---@field title? string The title of the chat buffer
3231
---@field tokens? nil|number The number of tokens in the chat
3332
---@field tools CodeCompanion.Tools The tools coordinator that executes available tools
33+
---@field tool_orchestrator CodeCompanion.Tools.Orchestrator The current tool orchestrator
34+
3435
---@field tool_registry CodeCompanion.Chat.ToolRegistry Methods for handling interactions between the chat buffer and tools
3536
---@field ui CodeCompanion.Chat.UI The UI of the chat buffer
3637
---@field variables? CodeCompanion.Variables The variables available to the user
@@ -1445,13 +1446,11 @@ function Chat:stop()
14451446
self:dispatch("on_cancelled")
14461447
utils.fire("ChatStopped", { bufnr = self.bufnr, id = self.id })
14471448

1448-
if self.current_tool then
1449-
local tool_job = self.current_tool
1450-
self.current_tool = nil
1451-
1449+
if self.tool_orchestrator then
14521450
pcall(function()
1453-
tool_job.cancel()
1451+
self.tool_orchestrator:cancel()
14541452
end)
1453+
self.tool_orchestrator = nil
14551454
end
14561455

14571456
if self.current_request then

lua/codecompanion/interactions/chat/keymaps/init.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ M.close = {
304304

305305
M.stop = {
306306
callback = function(chat)
307-
if chat.current_request then
307+
if chat.current_request or chat.tool_orchestrator then
308308
chat:stop()
309309
end
310310
end,

lua/codecompanion/interactions/chat/tools/init.lua

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ function Tools:execute(chat, tools)
289289
-- NOTE: Set autocmds early so that errors can be handled properly
290290
self:set_autocmds()
291291

292-
local orchestrator = Orchestrator.new(self, id)
292+
chat.tool_orchestrator = Orchestrator.new(self, id)
293293

294294
for _, tool in ipairs(tools) do
295295
local resolved_tool, error_msg, is_json_error = self:_resolve_and_prepare_tool(tool, id)
@@ -304,11 +304,11 @@ function Tools:execute(chat, tools)
304304
end
305305

306306
self.tool = resolved_tool --[[@as CodeCompanion.Tools.Tool]]
307-
orchestrator.queue:push(resolved_tool)
307+
chat.tool_orchestrator.queue:push(resolved_tool)
308308
end
309309

310310
utils.fire("ToolsStarted", { id = id, bufnr = self.bufnr })
311-
orchestrator:setup_next_tool()
311+
chat.tool_orchestrator:setup_next_tool()
312312
end
313313

314314
local ok, err = pcall(safe_execute)

lua/codecompanion/interactions/chat/tools/orchestrator.lua

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
---@class CodeCompanion.Tools.Orchestrator
2+
---@field cancelled boolean? Whether the tool execution has been cancelled
3+
---@field id number The id of the tools coordinator
4+
---@field index number The index of the current command
5+
---@field handlers table<string, function>
6+
---@field output table<string, function>
7+
---@field queue CodeCompanion.Tools.Orchestrator.Queue
8+
---@field status string The status of the tool execution "success" | "error"
9+
---@field tool CodeCompanion.Tools.Tool The current tool being executed
10+
---@field tool_output table? The output collected from the tool
11+
---@field tools CodeCompanion.Tools
12+
---@field tool_handle vim.SystemObj? SystemObj of the current tool execution
13+
114
local Approvals = require("codecompanion.interactions.chat.tools.approvals")
215
local Queue = require("codecompanion.interactions.chat.tools.runtime.queue")
316
local Runner = require("codecompanion.interactions.chat.tools.runtime.runner")
@@ -7,6 +20,11 @@ local ui_utils = require("codecompanion.utils.ui")
720
local utils = require("codecompanion.utils")
821

922
local fmt = string.format
23+
local Orchestrator = {}
24+
25+
--=============================================================================
26+
-- Private methods
27+
--=============================================================================
1028

1129
---Strip any ANSI color codes which don't render in the chat buffer
1230
---@param tbl table
@@ -19,33 +37,32 @@ local function strip_ansi(tbl)
1937
end
2038

2139
---Add a response to the chat buffer regarding a tool's execution
22-
---@param exec CodeCompanion.Tools.Orchestrator
2340
---@param llm_message string
2441
---@param user_message? string
25-
local send_response_to_chat = function(exec, llm_message, user_message)
26-
exec.tools.chat:add_tool_output(exec.tool, llm_message, user_message)
42+
function Orchestrator:_send_response_to_chat(llm_message, user_message)
43+
self.tools.chat:add_tool_output(self.tool, llm_message, user_message)
2744
end
2845

2946
---Execute a shell command with platform-specific handling
3047
---@param cmd table
3148
---@param callback function
32-
local function execute_shell_command(cmd, callback)
49+
function Orchestrator:_execute_shell_command(cmd, callback)
3350
if vim.fn.has("win32") == 1 then
3451
-- See PR #2186
3552
local shell_cmd = table.concat(cmd, " ") .. "\r\nEXIT %ERRORLEVEL%\r\n"
36-
vim.system({ "cmd.exe", "/Q", "/K" }, {
53+
self.tool_handle = vim.system({ "cmd.exe", "/Q", "/K" }, {
3754
stdin = shell_cmd,
3855
env = { PROMPT = "\r\n" },
3956
}, callback)
4057
else
41-
vim.system(os_utils.build_shell_command(cmd), {}, callback)
58+
self.tool_handle = vim.system(os_utils.build_shell_command(cmd), {}, callback)
4259
end
4360
end
4461

4562
---Converts a cmd-based tool to a function-based tool.
4663
---@param tool CodeCompanion.Tools.Tool
4764
---@return CodeCompanion.Tools.Tool
48-
local function cmd_to_func_tool(tool)
65+
function Orchestrator:_cmd_to_func_tool(tool)
4966
tool.cmds = vim
5067
.iter(tool.cmds)
5168
:map(function(cmd)
@@ -62,7 +79,7 @@ local function cmd_to_func_tool(tool)
6279
---@param tools CodeCompanion.Tools
6380
return function(tools, _, _, cb)
6481
cb = vim.schedule_wrap(cb)
65-
execute_shell_command(cmd, function(out)
82+
self:_execute_shell_command(cmd, function(out)
6683
if flag then
6784
tools.chat.tool_registry.flags = tools.chat.tool_registry.flags or {}
6885
tools.chat.tool_registry.flags[flag] = (out.code == 0)
@@ -93,17 +110,9 @@ local function cmd_to_func_tool(tool)
93110
return tool
94111
end
95112

96-
---@class CodeCompanion.Tools.Orchestrator
97-
---@field id number The id of the tools coordinator
98-
---@field index number The index of the current command
99-
---@field handlers table<string, function>
100-
---@field output table<string, function>
101-
---@field queue CodeCompanion.Tools.Orchestrator.Queue
102-
---@field status string The status of the tool execution "success" | "error"
103-
---@field tool CodeCompanion.Tools.Tool The current tool being executed
104-
---@field tool_output table? The output collected from the tool
105-
---@field tools CodeCompanion.Tools
106-
local Orchestrator = {}
113+
--=============================================================================
114+
-- Public methods
115+
--=============================================================================
107116

108117
---@param tools CodeCompanion.Tools
109118
---@param id number
@@ -184,7 +193,7 @@ function Orchestrator:_setup_handlers()
184193
rejection = rejection .. fmt(': "%s"', opts.reason)
185194
end
186195
-- If no handler is set then return a default message
187-
send_response_to_chat(self, rejection)
196+
self:_send_response_to_chat(rejection)
188197
end
189198
end,
190199
error = function(cmd)
@@ -195,7 +204,7 @@ function Orchestrator:_setup_handlers()
195204
if self.tool.output and self.tool.output.error then
196205
self.tool.output.error(self.tool, self.tools, cmd, self.tools.stderr)
197206
else
198-
send_response_to_chat(self, fmt("Error calling `%s`", self.tool.name))
207+
self:_send_response_to_chat(fmt("Error calling `%s`", self.tool.name))
199208
end
200209
end,
201210
cancelled = function(cmd)
@@ -206,8 +215,7 @@ function Orchestrator:_setup_handlers()
206215
if self.tool.output and self.tool.output.cancelled then
207216
self.tool.output.cancelled(self.tool, self.tools, cmd)
208217
else
209-
send_response_to_chat(
210-
self,
218+
self:_send_response_to_chat(
211219
fmt("The user cancelled the execution of the %s tool", self.tool.name),
212220
fmt("Cancelled `%s`", self.tool.name)
213221
)
@@ -221,17 +229,18 @@ function Orchestrator:_setup_handlers()
221229
if self.tool.output and self.tool.output.success then
222230
self.tool.output.success(self.tool, self.tools, cmd, self.tool_output or {})
223231
else
224-
send_response_to_chat(self, fmt("Executed `%s`", self.tool.name))
232+
self:_send_response_to_chat(fmt("Executed `%s`", self.tool.name))
225233
end
226234
end,
227235
}
228236
end
229237

230238
---When the tools coordinator is finished, finalize it via an autocmd
231-
---@param self CodeCompanion.Tools.Orchestrator
232239
---@return nil
233240
function Orchestrator:_finalize_tools()
234241
self.tools.tool = nil
242+
self.tools.chat.tool_orchestrator = nil
243+
235244
return utils.fire("ToolsFinished", { id = self.id, bufnr = self.tools.bufnr })
236245
end
237246

@@ -251,7 +260,7 @@ function Orchestrator:setup_next_tool(input)
251260
self.handlers.setup() -- Call this early as cmd_runner needs to setup its cmds dynamically
252261

253262
-- Transform cmd-based tools to func-based
254-
self.tool = cmd_to_func_tool(self.tool)
263+
self.tool = self:_cmd_to_func_tool(self.tool)
255264

256265
-- Get the first command to run
257266
local cmd = self.tool.cmds[1]
@@ -429,4 +438,43 @@ function Orchestrator:finalize_tool()
429438
end
430439
end
431440

441+
---Cancel the currently running tool execution
442+
---@return nil
443+
function Orchestrator:cancel()
444+
if self.cancelled then
445+
return
446+
end
447+
448+
log:debug("Orchestrator:cancel")
449+
self.cancelled = true
450+
451+
-- Kill the running system command if one exists
452+
if self.tool_handle then
453+
if vim.fn.has("win32") == 1 then
454+
-- /F flag forces process to end
455+
-- /T ends child process (required since we are wrapping the child
456+
-- process in a parent cmd.exe process)
457+
vim.system({ "taskkill", "/F", "/T", "/PID", tostring(self.tool_handle.pid) })
458+
else
459+
self.tool_handle:kill("sigkill")
460+
end
461+
462+
self.tool_handle = nil
463+
end
464+
465+
-- Output the cancellation message to the chat buffer.
466+
if self.tool and self.output then
467+
self.output.cancelled(self.tool.cmds[1])
468+
end
469+
470+
-- Clean up the cancelled tool.
471+
self:finalize_tool()
472+
473+
-- Cancel any pending tools in the queue.
474+
self:cancel_pending_tools()
475+
476+
-- Close the current tool.
477+
self:_finalize_tools()
478+
end
479+
432480
return Orchestrator

lua/codecompanion/interactions/chat/tools/runtime/runner.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ function Runner:run_tool(cmd_func, action, args)
8282

8383
---@param msg {status:"success"|"error", data:any}
8484
local function output_handler(msg)
85-
if tool_finished then
85+
if tool_finished or self.orchestrator.cancelled then
8686
return
8787
end
8888
tool_finished = true

0 commit comments

Comments
 (0)