Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions lua/codecompanion/interactions/chat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
---@field context CodeCompanion.Chat.Context
---@field context_items? table<CodeCompanion.Chat.Context> Context which is sent to the LLM e.g. buffers, slash command output
---@field current_request table|nil The current request being executed
---@field current_tool table The current tool being executed
---@field cycle number Records the number of turn-based interactions (User -> LLM) that have taken place
---@field create_buf fun(): number The function that creates a new buffer for the chat
---@field edit_tracker? CodeCompanion.Chat.EditTracker Edit tracking information for the chat
Expand All @@ -31,6 +30,8 @@
---@field title? string The title of the chat buffer
---@field tokens? nil|number The number of tokens in the chat
---@field tools CodeCompanion.Tools The tools coordinator that executes available tools
---@field tool_orchestrator CodeCompanion.Tools.Orchestrator The current tool orchestrator

---@field tool_registry CodeCompanion.Chat.ToolRegistry Methods for handling interactions between the chat buffer and tools
---@field ui CodeCompanion.Chat.UI The UI of the chat buffer
---@field variables? CodeCompanion.Variables The variables available to the user
Expand Down Expand Up @@ -1445,13 +1446,11 @@ function Chat:stop()
self:dispatch("on_cancelled")
utils.fire("ChatStopped", { bufnr = self.bufnr, id = self.id })

if self.current_tool then
local tool_job = self.current_tool
self.current_tool = nil

if self.tool_orchestrator then
pcall(function()
tool_job.cancel()
self.tool_orchestrator:cancel()
end)
self.tool_orchestrator = nil
end

if self.current_request then
Expand Down
2 changes: 1 addition & 1 deletion lua/codecompanion/interactions/chat/keymaps/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ M.close = {

M.stop = {
callback = function(chat)
if chat.current_request then
if chat.current_request or chat.tool_orchestrator then
chat:stop()
end
end,
Expand Down
6 changes: 3 additions & 3 deletions lua/codecompanion/interactions/chat/tools/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ function Tools:execute(chat, tools)
-- NOTE: Set autocmds early so that errors can be handled properly
self:set_autocmds()

local orchestrator = Orchestrator.new(self, id)
chat.tool_orchestrator = Orchestrator.new(self, id)

for _, tool in ipairs(tools) do
local resolved_tool, error_msg, is_json_error = self:_resolve_and_prepare_tool(tool, id)
Expand All @@ -304,11 +304,11 @@ function Tools:execute(chat, tools)
end

self.tool = resolved_tool --[[@as CodeCompanion.Tools.Tool]]
orchestrator.queue:push(resolved_tool)
chat.tool_orchestrator.queue:push(resolved_tool)
end

utils.fire("ToolsStarted", { id = id, bufnr = self.bufnr })
orchestrator:setup_next_tool()
chat.tool_orchestrator:setup_next_tool()
end

local ok, err = pcall(safe_execute)
Expand Down
94 changes: 67 additions & 27 deletions lua/codecompanion/interactions/chat/tools/orchestrator.lua
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
---@class CodeCompanion.Tools.Orchestrator
---@field cancelled boolean? Whether the tool execution has been cancelled
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit but I prefer these to be in alphabetical order as my eyes can find things faster

---@field handlers table<string, function>
---@field id number The id of the tools coordinator
---@field index number The index of the current command
---@field output table<string, function>
---@field queue CodeCompanion.Tools.Orchestrator.Queue
---@field status string The status of the tool execution "success" | "error"
---@field tool CodeCompanion.Tools.Tool The current tool being executed
---@field tool_handle vim.SystemObj? SystemObj of the current tool execution
---@field tool_output table? The output collected from the tool
---@field tools CodeCompanion.Tools

local Approvals = require("codecompanion.interactions.chat.tools.approvals")
local Queue = require("codecompanion.interactions.chat.tools.runtime.queue")
local Runner = require("codecompanion.interactions.chat.tools.runtime.runner")
Expand All @@ -7,6 +20,7 @@ local ui_utils = require("codecompanion.utils.ui")
local utils = require("codecompanion.utils")

local fmt = string.format
local Orchestrator = {}

---Strip any ANSI color codes which don't render in the chat buffer
---@param tbl table
Expand All @@ -19,33 +33,32 @@ local function strip_ansi(tbl)
end

---Add a response to the chat buffer regarding a tool's execution
---@param exec CodeCompanion.Tools.Orchestrator
---@param llm_message string
---@param user_message? string
local send_response_to_chat = function(exec, llm_message, user_message)
exec.tools.chat:add_tool_output(exec.tool, llm_message, user_message)
function Orchestrator:_send_response_to_chat(llm_message, user_message)
self.tools.chat:add_tool_output(self.tool, llm_message, user_message)
end

---Execute a shell command with platform-specific handling
---@param cmd table
---@param callback function
local function execute_shell_command(cmd, callback)
function Orchestrator:_execute_shell_command(cmd, callback)
if vim.fn.has("win32") == 1 then
-- See PR #2186
local shell_cmd = table.concat(cmd, " ") .. "\r\nEXIT %ERRORLEVEL%\r\n"
vim.system({ "cmd.exe", "/Q", "/K" }, {
self.tool_handle = vim.system({ "cmd.exe", "/Q", "/K" }, {
stdin = shell_cmd,
env = { PROMPT = "\r\n" },
}, callback)
else
vim.system(os_utils.build_shell_command(cmd), {}, callback)
self.tool_handle = vim.system(os_utils.build_shell_command(cmd), {}, callback)
end
end

---Converts a cmd-based tool to a function-based tool.
---@param tool CodeCompanion.Tools.Tool
---@return CodeCompanion.Tools.Tool
local function cmd_to_func_tool(tool)
function Orchestrator:_cmd_to_func_tool(tool)
tool.cmds = vim
.iter(tool.cmds)
:map(function(cmd)
Expand All @@ -62,7 +75,7 @@ local function cmd_to_func_tool(tool)
---@param tools CodeCompanion.Tools
return function(tools, _, _, cb)
cb = vim.schedule_wrap(cb)
execute_shell_command(cmd, function(out)
self:_execute_shell_command(cmd, function(out)
if flag then
tools.chat.tool_registry.flags = tools.chat.tool_registry.flags or {}
tools.chat.tool_registry.flags[flag] = (out.code == 0)
Expand Down Expand Up @@ -93,18 +106,6 @@ local function cmd_to_func_tool(tool)
return tool
end

---@class CodeCompanion.Tools.Orchestrator
---@field id number The id of the tools coordinator
---@field index number The index of the current command
---@field handlers table<string, function>
---@field output table<string, function>
---@field queue CodeCompanion.Tools.Orchestrator.Queue
---@field status string The status of the tool execution "success" | "error"
---@field tool CodeCompanion.Tools.Tool The current tool being executed
---@field tool_output table? The output collected from the tool
---@field tools CodeCompanion.Tools
local Orchestrator = {}

---@param tools CodeCompanion.Tools
---@param id number
function Orchestrator.new(tools, id)
Expand Down Expand Up @@ -184,7 +185,7 @@ function Orchestrator:_setup_handlers()
rejection = rejection .. fmt(': "%s"', opts.reason)
end
-- If no handler is set then return a default message
send_response_to_chat(self, rejection)
self:_send_response_to_chat(rejection)
end
end,
error = function(cmd)
Expand All @@ -195,7 +196,7 @@ function Orchestrator:_setup_handlers()
if self.tool.output and self.tool.output.error then
self.tool.output.error(self.tool, self.tools, cmd, self.tools.stderr)
else
send_response_to_chat(self, fmt("Error calling `%s`", self.tool.name))
self:_send_response_to_chat(fmt("Error calling `%s`", self.tool.name))
end
end,
cancelled = function(cmd)
Expand All @@ -206,8 +207,7 @@ function Orchestrator:_setup_handlers()
if self.tool.output and self.tool.output.cancelled then
self.tool.output.cancelled(self.tool, self.tools, cmd)
else
send_response_to_chat(
self,
self:_send_response_to_chat(
fmt("The user cancelled the execution of the %s tool", self.tool.name),
fmt("Cancelled `%s`", self.tool.name)
)
Expand All @@ -221,17 +221,18 @@ function Orchestrator:_setup_handlers()
if self.tool.output and self.tool.output.success then
self.tool.output.success(self.tool, self.tools, cmd, self.tool_output or {})
else
send_response_to_chat(self, fmt("Executed `%s`", self.tool.name))
self:_send_response_to_chat(fmt("Executed `%s`", self.tool.name))
end
end,
}
end

---When the tools coordinator is finished, finalize it via an autocmd
---@param self CodeCompanion.Tools.Orchestrator
---@return nil
function Orchestrator:_finalize_tools()
self.tools.tool = nil
self.tools.chat.tool_orchestrator = nil

return utils.fire("ToolsFinished", { id = self.id, bufnr = self.tools.bufnr })
end

Expand All @@ -251,7 +252,7 @@ function Orchestrator:setup_next_tool(input)
self.handlers.setup() -- Call this early as cmd_runner needs to setup its cmds dynamically

-- Transform cmd-based tools to func-based
self.tool = cmd_to_func_tool(self.tool)
self.tool = self:_cmd_to_func_tool(self.tool)

-- Get the first command to run
local cmd = self.tool.cmds[1]
Expand Down Expand Up @@ -429,4 +430,43 @@ function Orchestrator:finalize_tool()
end
end

---Cancel the currently running tool execution
---@return nil
function Orchestrator:cancel()
if self.cancelled then
return
end

log:debug("Orchestrator:cancel")
self.cancelled = true

-- Kill the running system command if one exists
if self.tool_handle then
if vim.fn.has("win32") == 1 then
-- /F flag forces process to end
-- /T ends child process (required since we are wrapping the child
-- process in a parent cmd.exe process)
vim.system({ "taskkill", "/F", "/T", "/PID", tostring(self.tool_handle.pid) })
else
self.tool_handle:kill("sigkill")
end

self.tool_handle = nil
end

-- Output the cancellation message to the chat buffer.
if self.tool and self.output then
self.output.cancelled(self.tool.cmds[1])
end

-- Clean up the cancelled tool.
self:finalize_tool()

-- Cancel any pending tools in the queue.
self:cancel_pending_tools()

-- Close the current tool.
self:_finalize_tools()
end

return Orchestrator
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function Runner:run_tool(cmd_func, action, args)

---@param msg {status:"success"|"error", data:any}
local function output_handler(msg)
if tool_finished then
if tool_finished or self.orchestrator.cancelled then
return
end
tool_finished = true
Expand Down
9 changes: 7 additions & 2 deletions tests/interactions/chat/tools/runtime/test_queue_async.lua
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ T["Tools"]["queue"]["can queue multiple async functions"] = function()
},
}
tools:execute(chat, tool_call)
vim.wait(2100)

while (chat.tool_orchestrator) do
vim.wait(100)
end
]])

-- Test order
Expand Down Expand Up @@ -88,7 +91,9 @@ T["Tools"]["queue"]["can queue async function with sync function"] = function()
},
}
tools:execute(chat, tool_call)
vim.wait(1100)
while (chat.tool_orchestrator) do
vim.wait(100)
end
]])

-- Test order
Expand Down
Loading