|
1 | 1 | ---@module "codecompanion" |
2 | 2 |
|
3 | 3 | local cc_common = require("vectorcode.integrations.codecompanion.common") |
| 4 | +local cc_config = require("codecompanion.config").config |
| 5 | +local cc_schema = require("codecompanion.schema") |
| 6 | +local http_client = require("codecompanion.http") |
4 | 7 | local vc_config = require("vectorcode.config") |
5 | 8 | local check_cli_wrap = vc_config.check_cli_wrap |
6 | 9 | local logger = vc_config.logger |
@@ -80,6 +83,91 @@ local filter_results = function(results, chat) |
80 | 83 | return filtered_results |
81 | 84 | end |
82 | 85 |
|
| 86 | +---@alias ChatMessage {role: string, content:string} |
| 87 | + |
| 88 | +---@param adapter CodeCompanion.Adapter |
| 89 | +---@param system_prompt string |
| 90 | +---@param user_messages string|string[] |
| 91 | +---@return {messages: ChatMessage[], tools:table?} |
| 92 | +local function make_oneshot_payload(adapter, system_prompt, user_messages) |
| 93 | + if type(user_messages) == "string" then |
| 94 | + user_messages = { user_messages } |
| 95 | + end |
| 96 | + local messages = |
| 97 | + { { role = cc_config.constants.SYSTEM_ROLE, content = system_prompt } } |
| 98 | + for _, m in pairs(user_messages) do |
| 99 | + table.insert(messages, { role = cc_config.constants.USER_ROLE, content = m }) |
| 100 | + end |
| 101 | + return { messages = adapter:map_roles(messages) } |
| 102 | +end |
| 103 | + |
| 104 | +---@param result VectorCode.QueryResult[] |
| 105 | +---@param cmd QueryToolArgs |
| 106 | +---@param summarise_opts VectorCode.CodeCompanion.SummariseOpts |
| 107 | +---@param callback fun(summary:string) |
| 108 | +local function generate_summary(result, summarise_opts, cmd, callback) |
| 109 | + assert(vim.islist(result), "result should be a list of VectorCode.QueryResult") |
| 110 | + local result_xml = table.concat(vim |
| 111 | + .iter(result) |
| 112 | + :map(function(res) |
| 113 | + return cc_common.process_result(res) |
| 114 | + end) |
| 115 | + :totable()) |
| 116 | + |
| 117 | + if summarise_opts.enabled and type(callback) == "function" then |
| 118 | + ---@type CodeCompanion.Adapter |
| 119 | + local adapter = |
| 120 | + vim.deepcopy(require("codecompanion.adapters").resolve(summarise_opts.adapter)) |
| 121 | + |
| 122 | + local system_prompt = summarise_opts.system_prompt |
| 123 | + if type(system_prompt) == "function" then |
| 124 | + system_prompt = system_prompt( |
| 125 | + cc_common.get_query_tool_opts().summarise.system_prompt --[[@as string]] |
| 126 | + ) |
| 127 | + end |
| 128 | + |
| 129 | + assert( |
| 130 | + type(system_prompt) == "string", |
| 131 | + "`system_prompt` should have been converted to a string." |
| 132 | + ) |
| 133 | + if summarise_opts.query_augmented then |
| 134 | + system_prompt = string.format( |
| 135 | + [[%s |
| 136 | + |
| 137 | +The code provided to you is the result of a search in a codebase from the following query: %s. |
| 138 | +When summarising the code, pay extra attention on information related to the queries. |
| 139 | + ]], |
| 140 | + system_prompt, |
| 141 | + table.concat(cmd.query, ", ") |
| 142 | + ) |
| 143 | + end |
| 144 | + local payload = make_oneshot_payload(adapter, system_prompt, result_xml) |
| 145 | + local settings = |
| 146 | + vim.deepcopy(adapter:map_schema_to_params(cc_schema.get_default(adapter))) |
| 147 | + settings.opts.stream = false |
| 148 | + |
| 149 | + ---@type CodeCompanion.Client |
| 150 | + local client = http_client.new({ adapter = settings }) |
| 151 | + client:request(payload, { |
| 152 | + ---@param _adapter CodeCompanion.Adapter |
| 153 | + callback = function(_, data, _adapter) |
| 154 | + if data then |
| 155 | + local res = _adapter.handlers.chat_output(_adapter, data) |
| 156 | + if res and res.status == "success" then |
| 157 | + local gen_summary = vim.trim(res.output.content or "") |
| 158 | + if gen_summary ~= "" then |
| 159 | + return callback(gen_summary) |
| 160 | + end |
| 161 | + end |
| 162 | + end |
| 163 | + return callback(result_xml) |
| 164 | + end, |
| 165 | + }, { silent = true }) |
| 166 | + else |
| 167 | + callback(result_xml) |
| 168 | + end |
| 169 | +end |
| 170 | + |
83 | 171 | ---@param opts VectorCode.CodeCompanion.QueryToolOpts? |
84 | 172 | ---@return CodeCompanion.Agent.Tool |
85 | 173 | return check_cli_wrap(function(opts) |
@@ -181,7 +269,27 @@ return check_cli_wrap(function(opts) |
181 | 269 |
|
182 | 270 | job_runner.run_async(args, function(result, error) |
183 | 271 | if vim.islist(result) and #result > 0 and result[1].path ~= nil then ---@cast result VectorCode.QueryResult[] |
184 | | - cb({ status = "success", data = result }) |
| 272 | + if opts.no_duplicate then |
| 273 | + result = filter_results(result, agent.chat) |
| 274 | + end |
| 275 | + local max_result = #result |
| 276 | + if opts.max_num > 0 then |
| 277 | + max_result = math.min(tonumber(opts.max_num) or 1, max_result) |
| 278 | + end |
| 279 | + while #result > max_result do |
| 280 | + table.remove(result) |
| 281 | + end |
| 282 | + local summary_opts = vim.deepcopy(opts.summarise) or {} |
| 283 | + if type(summary_opts.enabled) == "function" then |
| 284 | + summary_opts.enabled = summary_opts.enabled(agent.chat, result) |
| 285 | + end |
| 286 | + generate_summary(result, summary_opts, action, function(s) |
| 287 | + cb({ |
| 288 | + status = "success", |
| 289 | + ---@type VectorCode.CodeCompanion.QueryToolResult |
| 290 | + data = { raw_results = result, count = #result, summary = s }, |
| 291 | + }) |
| 292 | + end) |
185 | 293 | else |
186 | 294 | if type(error) == "table" then |
187 | 295 | error = cc_common.flatten_table_to_string(error) |
@@ -280,50 +388,33 @@ If a query returned empty or repeated results, you should avoid using these quer |
280 | 388 | end, |
281 | 389 | ---@param agent CodeCompanion.Agent |
282 | 390 | ---@param cmd QueryToolArgs |
283 | | - ---@param stdout VectorCode.QueryResult[][] |
| 391 | + ---@param stdout VectorCode.CodeCompanion.QueryToolResult[] |
284 | 392 | success = function(self, agent, cmd, stdout) |
285 | 393 | stdout = stdout[1] |
286 | 394 | logger.info( |
287 | 395 | ("CodeCompanion tool with command %s finished."):format(vim.inspect(cmd)) |
288 | 396 | ) |
289 | | - local user_message |
290 | | - local max_result = #stdout |
291 | | - if opts.max_num > 0 then |
292 | | - max_result = math.min(opts.max_num or 1, max_result) |
293 | | - end |
294 | | - if opts.no_duplicate then |
295 | | - stdout = filter_results(stdout, agent.chat) |
296 | | - end |
297 | | - for i, file in pairs(stdout) do |
298 | | - if i <= max_result then |
299 | | - if i == 1 then |
300 | | - user_message = string.format( |
301 | | - "**VectorCode Tool**: Retrieved %d %s(s)", |
302 | | - max_result, |
303 | | - mode |
304 | | - ) |
305 | | - if cmd.project_root then |
306 | | - user_message = user_message .. " from " .. cmd.project_root |
307 | | - end |
308 | | - user_message = user_message .. "\n" |
309 | | - else |
310 | | - user_message = "" |
311 | | - end |
312 | | - agent.chat:add_tool_output( |
313 | | - self, |
314 | | - cc_common.process_result(file), |
315 | | - user_message |
316 | | - ) |
317 | | - if not opts.chunk_mode then |
318 | | - -- only add to reference if running in full document mode |
319 | | - local ref = { |
320 | | - source = cc_common.tool_result_source, |
321 | | - id = file.path, |
322 | | - path = file.path, |
323 | | - opts = { visible = false }, |
324 | | - } |
325 | | - agent.chat.references:add(ref) |
326 | | - end |
| 397 | + agent.chat:add_tool_output( |
| 398 | + self, |
| 399 | + stdout.summary |
| 400 | + or table.concat(vim |
| 401 | + .iter(stdout.raw_results or {}) |
| 402 | + :map(function(res) |
| 403 | + return cc_common.process_result(res) |
| 404 | + end) |
| 405 | + :totable()), |
| 406 | + string.format("**VectorCode Tool**: Retrieved %d %s(s)", stdout.count, mode) |
| 407 | + ) |
| 408 | + for _, file in pairs(stdout) do |
| 409 | + if not opts.chunk_mode then |
| 410 | + -- skip referencing because there will be multiple chunks with the same path (id). |
| 411 | + -- TODO: figure out a way to deduplicate. |
| 412 | + agent.chat.references:add({ |
| 413 | + source = cc_common.tool_result_source, |
| 414 | + id = file.path, |
| 415 | + path = file.path, |
| 416 | + opts = { visible = false }, |
| 417 | + }) |
327 | 418 | end |
328 | 419 | end |
329 | 420 | end, |
|
0 commit comments