Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions lua/eca/approve.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
local M = {}

---@param tool_call eca.ToolCallRun
function M.get_preview_lines(tool_call)
if not tool_call.details then
local arguments = vim.split(vim.inspect(tool_call.arguments), "\n")
local messages = {}
if tool_call.summary then
table.insert(messages, "Summary: " .. tool_call.summary)
end
table.insert(messages, "Tool Name: " .. tool_call.name)
table.insert(messages, "Tool Type: " .. tool_call.origin)
table.insert(messages, "Tool Arguments: ")
for _, v in pairs(arguments) do
table.insert(messages, v)
end
return messages
end
local lines = vim.split(tool_call.details.diff, "\n")
return { tool_call.details.path, unpack(lines) }
end

---@param lines string[]
---@return {row: number, col: number, width: number, height: number}
local function get_position(lines)
local gheight = math.floor(
vim.api.nvim_list_uis() and vim.api.nvim_list_uis()[1] and vim.api.nvim_list_uis()[1].height or vim.o.lines
)
local gwidth = math.floor(
vim.api.nvim_list_uis() and vim.api.nvim_list_uis()[1] and vim.api.nvim_list_uis()[1].width or vim.o.columns
)
local height = #lines > 10 and 35 or #lines
local width = 0
for _, line in ipairs(lines) do
if #line > width then
width = #line
end
end
return {
row = (gheight - height) * 0.5,
col = (gwidth - width) * 0.5,
width = math.floor(width * 1.5),
height = height,
}
end

---@param tool_call eca.ToolCallRun
---@param on_accept function
---@param on_deny function
function M.display_preview_lines(tool_call, on_accept, on_deny)
local lines = M.get_preview_lines(tool_call)
local buf = vim.api.nvim_create_buf(false, false)
vim.api.nvim_buf_set_lines(buf, 0, -1, false, lines)
vim.api.nvim_set_option_value("modifiable", false, { buf = buf })
local position = get_position(lines)
local title = tool_call.summary or tool_call.name
local win = vim.api.nvim_open_win(buf, true, {
border = "single",
title = "Approve Tool Call(y/n): " .. title,
relative = "editor",
row = position.row,
col = position.col,
width = position.width,
height = position.height,
})
if tool_call.details then
vim.api.nvim_set_option_value("filetype", "diff", { buf = buf })
else
vim.api.nvim_set_option_value("number", false, { win = win })
vim.api.nvim_set_option_value("relativenumber", false, { win = win })
end

vim.keymap.set({ "n", "i" }, "y", "", {
buffer = buf,
callback = function()
vim.api.nvim_win_close(win, true)
vim.api.nvim_buf_delete(buf, { force = true })
if on_accept then
on_accept()
end
end,
})
vim.keymap.set({ "n", "i" }, "n", "", {
buffer = buf,
callback = function()
vim.api.nvim_win_close(win, true)
vim.api.nvim_buf_delete(buf, { force = true })
if on_deny then
on_deny()
end
end,
})
end

---@param tool_call eca.ToolCallRun
---@param on_accept function
---@param on_deny function
function M.approve_tool_call(tool_call, on_accept, on_deny)
M.display_preview_lines(tool_call, on_accept, on_deny)
end
return M
8 changes: 5 additions & 3 deletions lua/eca/mediator.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ end

---@param method string
---@param params eca.MessageParams
---@param callback fun(err: string, result: table)
---@param callback? fun(err?: string, result?: table)
function mediator:send(method, params, callback)
if not self.server:is_running() then
callback("Server is not running, please start the server", nil)
return
if callback then
callback("Server is not running, please start the server", nil)
end
require("eca.logger").notify("Server is not rnning, please start the server", vim.log.levels.WARN)
end
self.server:send_request(method, params, callback)
end
Expand Down
15 changes: 14 additions & 1 deletion lua/eca/sidebar.lua
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ function M:_set_welcome_content()
"- **RepoMap**: Use `:EcaAddRepoMap` to add repository structure context",
"",
"---",
""
"",
}

Logger.debug("Setting welcome content for new chat")
Expand Down Expand Up @@ -1204,6 +1204,7 @@ function M:handle_chat_content_received(params)
end

local content = params.content
local chat_id = params.chatId

if content.type == "text" then
-- Handle streaming text content
Expand Down Expand Up @@ -1234,6 +1235,8 @@ function M:handle_chat_content_received(params)
self:_handle_tool_call_prepare(content)
-- IMPORTANT: Return immediately - do NOT display anything for toolCallPrepare
return
elseif content.type == "toolCallRun" then
self:render_tool_call(content, chat_id)
elseif content.type == "toolCallRunning" then
-- Show the accumulated tool call
self:_display_tool_call(content)
Expand Down Expand Up @@ -1274,6 +1277,16 @@ function M:handle_chat_content_received(params)
end
end

function M:render_tool_call(tool_content, chat_id)
if tool_content.type == "toolCallRun" and tool_content.manualApproval then
return require("eca.approve").approve_tool_call(tool_content, function()
self.mediator:send("chat/toolCallApprove", { chatId = chat_id, toolCallId = tool_content.id }, nil)
end, function()
self.mediator:send("chat/toolCallReject", { chatId = chat_id, toolCallId = tool_content.id }, nil)
end)
end
end

---@param text string
function M:_handle_streaming_text(text)
-- Only check for empty text
Expand Down
41 changes: 32 additions & 9 deletions lua/eca/types.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
---@meta

---@alias eca.ChatModel string
---@alias eca.ChatBehavior 'agent'|'plan'
---@class eca.ServerCapabilities
---@field welcomeMessage string
---@field models eca.ChatModel[]
---@field defaultModel eca.ChatModel
---@field behaviors eca.ChatBehavior[]
---@field defaultBehavior eca.ChatBehavior

---@class eca.ChatContext
---@field type string
---@field path? string
Expand All @@ -9,7 +18,7 @@
---@field name? string
---@field description? string
---@field mimeType? string
---@field server string
---@field server? string

---@class eca.ChatCommand
---@field name string
Expand All @@ -33,24 +42,30 @@
---@field linesAdded integer the count of lines added in this change
---@field linesRemoved integer the count of lines removed in this change

---@class eca.ToolCallRun
---@field type 'toolCallRun'
---@class eca.ToolCallPrepare
---@field type 'toolCallPrepare'
---@field origin eca.ToolCallOrigin
---@field id string the id of the tool call
---@field name string name of the tool
---@field arguments {[string]: string} arguments of the tool call
---@field argumentsText {[string]: string} arguments of the tool call
---@field manualApproval boolean whether the call requires manual approval from the user
---@field summary string summary text to present about this tool call
---@field details eca.ToolCallDetails extra details for the call. clients may use this to present a different UX for this tool call.
--- extra details for the call. clients may use this to present a different UX
--- for this tool call.
---@field details eca.ToolCallDetails
---

---@class eca.ToolCallRunning
---@field type 'toolCallRunning'
---@class eca.ToolCallRun
---@field type 'toolCallRun'
---@field origin eca.ToolCallOrigin
---@field id string the id of the tool call
---@field name string name of the tool
---@field arguments {[string]: string} arguments of the tool call
---@field summary? string summary text to present about this tool call
---@field details? eca.ToolCallDetails extra details for the call. clients may use this to present a different UX for this tool call.
---@field manualApproval boolean whether the call requires manual approval from the user
---@field summary string summary text to present about this tool call
--- extra details for the call. clients may use this to present a different UX
--- for this tool call.
---@field details eca.ToolCallDetails

---@class eca.ToolCalled
---@field type 'toolCalled'
Expand All @@ -61,3 +76,11 @@
---@field outputs {type: 'text', text: string}[] the result of the tool call
---@field summary? string summary text to present about the tool call
---@field details? eca.ToolCallDetails extra details about the call

---@class eca.UsageContent
---@field type 'usage'
---@field messageInputTokens number
---@field messageOutputTokens number
---@field sessionTokens number
---@field messageCost? string
---@field sessionCost? string
48 changes: 48 additions & 0 deletions tests/stubs/tool_calls.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
local stubs = {}

stubs.read_file = {
arguments = {
path = "/Users/tgeorge/git/eca-nvim/hack/messages.lua",
},
id = "toolu_013zj73SHzZNoeE7kzD7qzb4",
manualApproval = true,
name = "eca_read_file",
origin = "native",
summary = "Reading file messages.lua",
type = "toolCallRun",
}

stubs.edit_file = {
arguments = {
new_content = 'local M = {}\n\n--- Show ECA messages using snacks.picker\nfunction M.show()\n local has_snacks, picker = pcall(require, "snacks.picker")\n if not has_snacks then\n vim.notify("snacks.picker is not available", vim.log.levels.ERROR)\n return\n end\n\n Snacks.picker(\n ---@type snacks.picker.Config\n {\n source = "eca messages",\n finder = function(opts, ctx)\n ---@type snacks.picker.finder.Item[]\n local items = {}\n for msg in vim.iter(require("eca").server.messages) do\n local decoded = vim.json.decode(msg.content)\n table.insert(items, {\n text = decoded.method,\n idx = decoded.id,\n preview = {\n text = vim.inspect(decoded),\n ft = "lua",\n},\n})\n end\n return items\n end,\n preview = "preview",\n format = "text",\n confirm = function(self, item, _)\n vim.fn.setreg("", item.preview.text)\n self:close()\n end,\n }\n )\nend\n\nreturn M',
original_content = 'local has_snacks, picker = pcall(require, "snacks.picker")\nif has_snacks then\n Snacks.picker(\n ---@type snacks.picker.Config\n {\n source = "eca messages",\n finder = function(opts, ctx)\n ---@type snacks.picker.finder.Item[]\n local items = {}\n for msg in vim.iter(require("eca").server.messages) do\n local decoded = vim.json.decode(msg.content)\n table.insert(items, {\n text = decoded.method,\n idx = decoded.id,\n preview = {\n text = vim.inspect(decoded),\n ft = "lua",\n },\n })\n end\n return items\n end,\n preview = "preview",\n format = "text",\n confirm = function(self, item, _)\n vim.fn.setreg("", item.preview.text)\n self:close()\n end,\n }\n )\nend',
path = "/Users/tgeorge/git/eca-nvim/hack/messages.lua",
},
details = {
diff = '@@ -1, 5 +1, 13 @@\n-local has_snacks, picker = pcall(require, "snacks.picker")\n-if has_snacks then\n+local M = {}\n+\n+--- Show ECA messages using snacks.picker\n+function M.show()\n+ local has_snacks, picker = pcall(require, "snacks.picker")\n+ if not has_snacks then\n+ vim.notify("snacks.picker is not available", vim.log.levels.ERROR)\n+ return\n+ end\n+\n Snacks.picker(\n ---@type snacks.picker.Config\n {\n@@ -29, 3 +37, 5 @@\n }\n )\n end\n+\n+return M',
linesAdded = 12,
linesRemoved = 10,
path = "/Users/tgeorge/git/eca-nvim/hack/messages.lua",
type = "fileChange",
},
id = "toolu_01KAVb3qpJDcSnbnJmpUndQF",
manualApproval = true,
name = "eca_edit_file",
origin = "native",
summary = "Editting file",
type = "toolCallRun",
}

stubs.mcp = {
arguments = {
content = 'return "hello world"',
path = "/Users/tgeorge/git/eca-nvim/hack/test_mcp_write_file.lua",
},
id = "toolu_01B8xcb7csLRHvqrnAZTgzPi",
manualApproval = true,
name = "write_file",
origin = "mcp",
type = "toolCallRun",
}

return stubs
91 changes: 91 additions & 0 deletions tests/test_approve.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
local MiniTest = require("mini.test")
local eq = MiniTest.expect.equality
local child = MiniTest.new_child_neovim()
local stubs = require("tests.stubs.tool_calls")

local T = MiniTest.new_set({
hooks = {
pre_case = function()
child.restart({ "-u", "scripts/minimal_init.lua" })
child.lua([[
_G.notifications = {}
_G.on_accept = function() table.insert(_G.notifications, "accept") end
_G.on_reject = function() table.insert(_G.notifications, "reject") end
]])
end,
post_once = child.stop,
},
})

T["preview lines"] = function()
local test_cases = {
{
input = stubs.read_file,
want = {
"Summary: Reading file messages.lua",
"Tool Name: eca_read_file",
"Tool Type: native",
"Tool Arguments: ",
"{",
' path = "/Users/tgeorge/git/eca-nvim/hack/messages.lua"',
"}",
},
},
{
input = stubs.edit_file,
want = {
"/Users/tgeorge/git/eca-nvim/hack/messages.lua",
"@@ -1, 5 +1, 13 @@",
'-local has_snacks, picker = pcall(require, "snacks.picker")',
"-if has_snacks then",
"+local M = {}",
"+",
"+--- Show ECA messages using snacks.picker",
"+function M.show()",
'+ local has_snacks, picker = pcall(require, "snacks.picker")',
"+ if not has_snacks then",
'+ vim.notify("snacks.picker is not available", vim.log.levels.ERROR)',
"+ return",
"+ end",
"+",
" Snacks.picker(",
" ---@type snacks.picker.Config",
" {",
"@@ -29, 3 +37, 5 @@",
" }",
" )",
" end",
"+",
"+return M",
},
},
{
input = stubs.mcp,
want = {
"Tool Name: write_file",
"Tool Type: mcp",
"Tool Arguments: ",
"{",
" content = 'return \"hello world\"',",
' path = "/Users/tgeorge/git/eca-nvim/hack/test_mcp_write_file.lua"',
"}",
},
},
}
for _, test_case in pairs(test_cases) do
local got = child.lua_get('require("eca.approve").get_preview_lines(...)', { test_case.input })
eq(got, test_case.want)
end
end

T["tool approval calls callback"] = function()
child.lua("_G.tool_call = " .. vim.inspect(stubs.read_file))
child.lua('require("eca.approve").approve_tool_call(_G.tool_call, _G.on_accept, _G.on_reject)')
child.type_keys("y")
eq(child.lua_get("_G.notifications"), { "accept" })
child.lua('require("eca.approve").approve_tool_call(_G.tool_call, _G.on_accept, _G.on_reject)')
child.type_keys("n")
eq(child.lua_get("_G.notifications"), { "accept", "reject" })
end

return T