diff --git a/lua/codecompanion/interactions/chat/init.lua b/lua/codecompanion/interactions/chat/init.lua index 3bf2e482e..eb7ce4d56 100644 --- a/lua/codecompanion/interactions/chat/init.lua +++ b/lua/codecompanion/interactions/chat/init.lua @@ -15,7 +15,6 @@ ---@field context CodeCompanion.Chat.Context ---@field context_items? table Context which is sent to the LLM e.g. buffers, slash command output ---@field current_request table|nil The current request being executed ----@field current_tool table The current tool being executed ---@field cycle number Records the number of turn-based interactions (User -> LLM) that have taken place ---@field create_buf fun(): number The function that creates a new buffer for the chat ---@field edit_tracker? CodeCompanion.Chat.EditTracker Edit tracking information for the chat @@ -31,6 +30,8 @@ ---@field title? string The title of the chat buffer ---@field tokens? nil|number The number of tokens in the chat ---@field tools CodeCompanion.Tools The tools coordinator that executes available tools +---@field tool_orchestrator CodeCompanion.Tools.Orchestrator The current tool orchestrator + ---@field tool_registry CodeCompanion.Chat.ToolRegistry Methods for handling interactions between the chat buffer and tools ---@field ui CodeCompanion.Chat.UI The UI of the chat buffer ---@field variables? CodeCompanion.Variables The variables available to the user @@ -1445,13 +1446,11 @@ function Chat:stop() self:dispatch("on_cancelled") utils.fire("ChatStopped", { bufnr = self.bufnr, id = self.id }) - if self.current_tool then - local tool_job = self.current_tool - self.current_tool = nil - + if self.tool_orchestrator then pcall(function() - tool_job.cancel() + self.tool_orchestrator:cancel() end) + self.tool_orchestrator = nil end if self.current_request then diff --git a/lua/codecompanion/interactions/chat/keymaps/init.lua b/lua/codecompanion/interactions/chat/keymaps/init.lua index 33a1ddc6b..709627b7f 100644 --- a/lua/codecompanion/interactions/chat/keymaps/init.lua +++ b/lua/codecompanion/interactions/chat/keymaps/init.lua @@ -304,7 +304,7 @@ M.close = { M.stop = { callback = function(chat) - if chat.current_request then + if chat.current_request or chat.tool_orchestrator then chat:stop() end end, diff --git a/lua/codecompanion/interactions/chat/tools/init.lua b/lua/codecompanion/interactions/chat/tools/init.lua index ad5fe4468..bd13b9468 100644 --- a/lua/codecompanion/interactions/chat/tools/init.lua +++ b/lua/codecompanion/interactions/chat/tools/init.lua @@ -289,7 +289,7 @@ function Tools:execute(chat, tools) -- NOTE: Set autocmds early so that errors can be handled properly self:set_autocmds() - local orchestrator = Orchestrator.new(self, id) + chat.tool_orchestrator = Orchestrator.new(self, id) for _, tool in ipairs(tools) do local resolved_tool, error_msg, is_json_error = self:_resolve_and_prepare_tool(tool, id) @@ -304,11 +304,11 @@ function Tools:execute(chat, tools) end self.tool = resolved_tool --[[@as CodeCompanion.Tools.Tool]] - orchestrator.queue:push(resolved_tool) + chat.tool_orchestrator.queue:push(resolved_tool) end utils.fire("ToolsStarted", { id = id, bufnr = self.bufnr }) - orchestrator:setup_next_tool() + chat.tool_orchestrator:setup_next_tool() end local ok, err = pcall(safe_execute) diff --git a/lua/codecompanion/interactions/chat/tools/orchestrator.lua b/lua/codecompanion/interactions/chat/tools/orchestrator.lua index f0e0a0baf..c96f01a40 100644 --- a/lua/codecompanion/interactions/chat/tools/orchestrator.lua +++ b/lua/codecompanion/interactions/chat/tools/orchestrator.lua @@ -1,3 +1,16 @@ +---@class CodeCompanion.Tools.Orchestrator +---@field cancelled boolean? Whether the tool execution has been cancelled +---@field handlers table +---@field id number The id of the tools coordinator +---@field index number The index of the current command +---@field output table +---@field queue CodeCompanion.Tools.Orchestrator.Queue +---@field status string The status of the tool execution "success" | "error" +---@field tool CodeCompanion.Tools.Tool The current tool being executed +---@field tool_handle vim.SystemObj? SystemObj of the current tool execution +---@field tool_output table? The output collected from the tool +---@field tools CodeCompanion.Tools + local Approvals = require("codecompanion.interactions.chat.tools.approvals") local Queue = require("codecompanion.interactions.chat.tools.runtime.queue") local Runner = require("codecompanion.interactions.chat.tools.runtime.runner") @@ -7,6 +20,7 @@ local ui_utils = require("codecompanion.utils.ui") local utils = require("codecompanion.utils") local fmt = string.format +local Orchestrator = {} ---Strip any ANSI color codes which don't render in the chat buffer ---@param tbl table @@ -19,33 +33,32 @@ local function strip_ansi(tbl) end ---Add a response to the chat buffer regarding a tool's execution ----@param exec CodeCompanion.Tools.Orchestrator ---@param llm_message string ---@param user_message? string -local send_response_to_chat = function(exec, llm_message, user_message) - exec.tools.chat:add_tool_output(exec.tool, llm_message, user_message) +function Orchestrator:_send_response_to_chat(llm_message, user_message) + self.tools.chat:add_tool_output(self.tool, llm_message, user_message) end ---Execute a shell command with platform-specific handling ---@param cmd table ---@param callback function -local function execute_shell_command(cmd, callback) +function Orchestrator:_execute_shell_command(cmd, callback) if vim.fn.has("win32") == 1 then -- See PR #2186 local shell_cmd = table.concat(cmd, " ") .. "\r\nEXIT %ERRORLEVEL%\r\n" - vim.system({ "cmd.exe", "/Q", "/K" }, { + self.tool_handle = vim.system({ "cmd.exe", "/Q", "/K" }, { stdin = shell_cmd, env = { PROMPT = "\r\n" }, }, callback) else - vim.system(os_utils.build_shell_command(cmd), {}, callback) + self.tool_handle = vim.system(os_utils.build_shell_command(cmd), {}, callback) end end ---Converts a cmd-based tool to a function-based tool. ---@param tool CodeCompanion.Tools.Tool ---@return CodeCompanion.Tools.Tool -local function cmd_to_func_tool(tool) +function Orchestrator:_cmd_to_func_tool(tool) tool.cmds = vim .iter(tool.cmds) :map(function(cmd) @@ -62,7 +75,7 @@ local function cmd_to_func_tool(tool) ---@param tools CodeCompanion.Tools return function(tools, _, _, cb) cb = vim.schedule_wrap(cb) - execute_shell_command(cmd, function(out) + self:_execute_shell_command(cmd, function(out) if flag then tools.chat.tool_registry.flags = tools.chat.tool_registry.flags or {} tools.chat.tool_registry.flags[flag] = (out.code == 0) @@ -93,18 +106,6 @@ local function cmd_to_func_tool(tool) return tool end ----@class CodeCompanion.Tools.Orchestrator ----@field id number The id of the tools coordinator ----@field index number The index of the current command ----@field handlers table ----@field output table ----@field queue CodeCompanion.Tools.Orchestrator.Queue ----@field status string The status of the tool execution "success" | "error" ----@field tool CodeCompanion.Tools.Tool The current tool being executed ----@field tool_output table? The output collected from the tool ----@field tools CodeCompanion.Tools -local Orchestrator = {} - ---@param tools CodeCompanion.Tools ---@param id number function Orchestrator.new(tools, id) @@ -184,7 +185,7 @@ function Orchestrator:_setup_handlers() rejection = rejection .. fmt(': "%s"', opts.reason) end -- If no handler is set then return a default message - send_response_to_chat(self, rejection) + self:_send_response_to_chat(rejection) end end, error = function(cmd) @@ -195,7 +196,7 @@ function Orchestrator:_setup_handlers() if self.tool.output and self.tool.output.error then self.tool.output.error(self.tool, self.tools, cmd, self.tools.stderr) else - send_response_to_chat(self, fmt("Error calling `%s`", self.tool.name)) + self:_send_response_to_chat(fmt("Error calling `%s`", self.tool.name)) end end, cancelled = function(cmd) @@ -206,8 +207,7 @@ function Orchestrator:_setup_handlers() if self.tool.output and self.tool.output.cancelled then self.tool.output.cancelled(self.tool, self.tools, cmd) else - send_response_to_chat( - self, + self:_send_response_to_chat( fmt("The user cancelled the execution of the %s tool", self.tool.name), fmt("Cancelled `%s`", self.tool.name) ) @@ -221,17 +221,18 @@ function Orchestrator:_setup_handlers() if self.tool.output and self.tool.output.success then self.tool.output.success(self.tool, self.tools, cmd, self.tool_output or {}) else - send_response_to_chat(self, fmt("Executed `%s`", self.tool.name)) + self:_send_response_to_chat(fmt("Executed `%s`", self.tool.name)) end end, } end ---When the tools coordinator is finished, finalize it via an autocmd ----@param self CodeCompanion.Tools.Orchestrator ---@return nil function Orchestrator:_finalize_tools() self.tools.tool = nil + self.tools.chat.tool_orchestrator = nil + return utils.fire("ToolsFinished", { id = self.id, bufnr = self.tools.bufnr }) end @@ -251,7 +252,7 @@ function Orchestrator:setup_next_tool(input) self.handlers.setup() -- Call this early as cmd_runner needs to setup its cmds dynamically -- Transform cmd-based tools to func-based - self.tool = cmd_to_func_tool(self.tool) + self.tool = self:_cmd_to_func_tool(self.tool) -- Get the first command to run local cmd = self.tool.cmds[1] @@ -429,4 +430,43 @@ function Orchestrator:finalize_tool() end end +---Cancel the currently running tool execution +---@return nil +function Orchestrator:cancel() + if self.cancelled then + return + end + + log:debug("Orchestrator:cancel") + self.cancelled = true + + -- Kill the running system command if one exists + if self.tool_handle then + if vim.fn.has("win32") == 1 then + -- /F flag forces process to end + -- /T ends child process (required since we are wrapping the child + -- process in a parent cmd.exe process) + vim.system({ "taskkill", "/F", "/T", "/PID", tostring(self.tool_handle.pid) }) + else + self.tool_handle:kill("sigkill") + end + + self.tool_handle = nil + end + + -- Output the cancellation message to the chat buffer. + if self.tool and self.output then + self.output.cancelled(self.tool.cmds[1]) + end + + -- Clean up the cancelled tool. + self:finalize_tool() + + -- Cancel any pending tools in the queue. + self:cancel_pending_tools() + + -- Close the current tool. + self:_finalize_tools() +end + return Orchestrator diff --git a/lua/codecompanion/interactions/chat/tools/runtime/runner.lua b/lua/codecompanion/interactions/chat/tools/runtime/runner.lua index 0c4bfbce5..4301e74c8 100644 --- a/lua/codecompanion/interactions/chat/tools/runtime/runner.lua +++ b/lua/codecompanion/interactions/chat/tools/runtime/runner.lua @@ -82,7 +82,7 @@ function Runner:run_tool(cmd_func, action, args) ---@param msg {status:"success"|"error", data:any} local function output_handler(msg) - if tool_finished then + if tool_finished or self.orchestrator.cancelled then return end tool_finished = true diff --git a/tests/interactions/chat/tools/runtime/test_queue_async.lua b/tests/interactions/chat/tools/runtime/test_queue_async.lua index cc80d1584..1721f7d7b 100644 --- a/tests/interactions/chat/tools/runtime/test_queue_async.lua +++ b/tests/interactions/chat/tools/runtime/test_queue_async.lua @@ -56,7 +56,10 @@ T["Tools"]["queue"]["can queue multiple async functions"] = function() }, } tools:execute(chat, tool_call) - vim.wait(2100) + + while (chat.tool_orchestrator) do + vim.wait(100) + end ]]) -- Test order @@ -88,7 +91,9 @@ T["Tools"]["queue"]["can queue async function with sync function"] = function() }, } tools:execute(chat, tool_call) - vim.wait(1100) + while (chat.tool_orchestrator) do + vim.wait(100) + end ]]) -- Test order