diff --git a/lua/CopilotChat/config/providers.lua b/lua/CopilotChat/config/providers.lua index 5b4e95f9..aac7b36c 100644 --- a/lua/CopilotChat/config/providers.lua +++ b/lua/CopilotChat/config/providers.lua @@ -196,34 +196,295 @@ local function get_github_models_token(tag) return github_device_flow(tag, '178c6fc778ccc68e1d6a', 'read:user copilot') end ---- Helper function to extract text content from Responses API output parts ----@param parts table Array of content parts from Responses API +--- Prepare input for Responses API +---@param inputs table +---@param opts CopilotChat.config.providers.Options +---@return table +local function prepare_responses_input(inputs, opts) + local instructions = nil + local input_messages = {} + + for _, msg in ipairs(inputs) do + if msg.role == constants.ROLE.SYSTEM then + instructions = instructions and (instructions .. '\n\n' .. msg.content) or msg.content + elseif msg.role == constants.ROLE.TOOL then + table.insert(input_messages, { + type = 'function_call_output', + call_id = msg.tool_call_id, + output = msg.content, + }) + else + table.insert(input_messages, { + role = msg.role, + content = msg.content, + }) + + if msg.tool_calls then + for _, tool_call in ipairs(msg.tool_calls) do + table.insert(input_messages, { + type = 'function_call', + call_id = tool_call.id, + name = tool_call.name, + arguments = tool_call.arguments or '', + }) + end + end + end + end + + local out = { + model = opts.model.id, + stream = opts.model.streaming ~= false, + input = input_messages, + } + + if instructions then + out.instructions = instructions + end + + if opts.tools and opts.model.tools then + out.tools = vim.tbl_map(function(tool) + return { + type = 'function', + name = tool.name, + description = tool.description, + parameters = tool.schema, + } + end, opts.tools) + end + + return out +end + +--- Prepare input for Chat Completions API +---@param inputs table +---@param opts CopilotChat.config.providers.Options +---@return table +local function prepare_chat_input(inputs, opts) + local is_o1 = vim.startswith(opts.model.id, 'o1') + + inputs = vim.tbl_map(function(input) + local output = { + role = (is_o1 and input.role == constants.ROLE.SYSTEM) and constants.ROLE.USER or input.role, + content = input.content, + } + + if input.tool_call_id then + output.tool_call_id = input.tool_call_id + end + + if input.tool_calls then + output.tool_calls = vim.tbl_map(function(tool_call) + return { + id = tool_call.id, + type = 'function', + ['function'] = { + name = tool_call.name, + arguments = tool_call.arguments or nil, + }, + } + end, input.tool_calls) + end + + return output + end, inputs) + + local out = { + messages = inputs, + model = opts.model.id, + stream = opts.model.streaming or false, + } + + if opts.tools and opts.model.tools then + out.tools = vim.tbl_map(function(tool) + return { + type = 'function', + ['function'] = { + name = tool.name, + description = tool.description, + parameters = tool.schema, + }, + } + end, opts.tools) + end + + if not is_o1 then + out.n = 1 + out.top_p = 1 + out.temperature = opts.temperature + end + + if opts.model.max_output_tokens then + out.max_tokens = opts.model.max_output_tokens + end + + return out +end +---@param parts table Array of content parts ---@return string The concatenated text content local function extract_text_from_parts(parts) - local content = '' if not parts or type(parts) ~= 'table' then - return content + return '' end + local content = '' for _, part in ipairs(parts) do - if type(part) == 'table' then - -- Handle different content types from Responses API - if part.type == 'output_text' or part.type == 'text' then + if type(part) == 'string' then + content = content .. part + elseif type(part) == 'table' then + -- Responses API: parts have type field + if part.type == 'text' or part.type == 'output_text' or part.type == 'input_text' then content = content .. (part.text or '') - elseif part.output_text then - -- Handle nested output_text - if type(part.output_text) == 'string' then - content = content .. part.output_text - elseif type(part.output_text) == 'table' and part.output_text.text then - content = content .. part.output_text.text + -- Fallback for simpler structures + elseif part.text then + content = content .. part.text + end + end + end + return content +end + +--- Parse Responses API output (both streaming and non-streaming) +---@param output table Raw API response +---@return CopilotChat.config.providers.Output +local function prepare_responses_output(output) + local content = '' + local reasoning = '' + local finish_reason = nil + local total_tokens = nil + local tool_calls = {} + + -- Handle errors + local error_msg = output.error or (output.response and output.response.error) + if error_msg then + if type(error_msg) == 'table' then + error_msg = error_msg.message or vim.inspect(error_msg) + end + return { + content = '', + reasoning = '', + finish_reason = 'error: ' .. tostring(error_msg), + total_tokens = nil, + tool_calls = {}, + } + end + + -- Handle streaming events + if output.type then + if output.type == 'response.output_text.delta' then + -- Streaming text delta + if output.delta and type(output.delta) == 'string' then + content = output.delta + elseif output.delta and output.delta.text then + content = output.delta.text + end + elseif output.type == 'response.output_item.done' then + -- Complete output item (including tool calls) + local item = output.item + if item and item.type == 'function_call' then + table.insert(tool_calls, { + id = item.call_id or ('tooluse_' .. (#tool_calls + 1)), + index = #tool_calls + 1, + name = item.name or '', + arguments = item.arguments or '', + }) + end + elseif output.type == 'response.completed' or output.type == 'response.done' then + local response = output.response + if response then + if response.reasoning and response.reasoning.summary then + reasoning = response.reasoning.summary + end + if response.usage then + total_tokens = response.usage.total_tokens + end + finish_reason = 'stop' + end + elseif output.type == 'response.failed' then + finish_reason = 'error: ' .. (output.error and output.error.message or 'unknown error') + end + -- Handle non-streaming response + elseif output.response then + local response = output.response + if response.output and #response.output > 0 then + for _, msg in ipairs(response.output) do + if msg.content then + content = content .. extract_text_from_parts(msg.content) + end + if msg.tool_calls then + for i, tool_call in ipairs(msg.tool_calls) do + table.insert(tool_calls, { + id = tool_call.call_id or ('tooluse_' .. i), + index = i, + name = tool_call.name or '', + arguments = tool_call.arguments or '', + }) + end end end - elseif type(part) == 'string' then - content = content .. part end + if response.reasoning and response.reasoning.summary then + reasoning = response.reasoning.summary + end + if response.usage then + total_tokens = response.usage.total_tokens + end + finish_reason = response.status == 'completed' and 'stop' or nil end - return content + return { + content = content, + reasoning = reasoning, + finish_reason = finish_reason, + total_tokens = total_tokens, + tool_calls = tool_calls, + } +end + +--- Parse Chat Completions API output (both streaming and non-streaming) +---@param output table Raw API response +---@return CopilotChat.config.providers.Output +local function prepare_chat_output(output) + local tool_calls = {} + + local choice + if output.choices and #output.choices > 0 then + for _, c in ipairs(output.choices) do + local message = c.message or c.delta + if message and message.tool_calls then + for i, tool_call in ipairs(message.tool_calls) do + local fn = tool_call['function'] + if fn then + local index = tool_call.index or i + local id = utils.empty(tool_call.id) and ('tooluse_' .. index) or tool_call.id + table.insert(tool_calls, { + id = id, + index = index, + name = fn.name, + arguments = fn.arguments or '', + }) + end + end + end + end + choice = output.choices[1] + else + choice = output + end + + local message = choice.message or choice.delta + local content = message and message.content + local reasoning = message and (message.reasoning or message.reasoning_content) + local usage = choice.usage and choice.usage.total_tokens or output.usage and output.usage.total_tokens + local finish_reason = choice.finish_reason or choice.done_reason or output.finish_reason or output.done_reason + + return { + content = content, + reasoning = reasoning, + finish_reason = finish_reason, + total_tokens = usage, + tool_calls = tool_calls, + } end ---@class CopilotChat.config.providers.Options @@ -380,355 +641,23 @@ M.copilot = { end, prepare_input = function(inputs, opts) - local is_o1 = vim.startswith(opts.model.id, 'o1') - - -- Check if this model uses the Responses API if opts.model.use_responses then - -- Prepare input for Responses API - local instructions = nil - local input_messages = {} - - for _, msg in ipairs(inputs) do - if msg.role == constants.ROLE.SYSTEM then - -- Combine system messages as instructions - if instructions then - instructions = instructions .. '\n\n' .. msg.content - else - instructions = msg.content - end - else - -- Include the message in the input array - table.insert(input_messages, { - role = msg.role, - content = msg.content, - }) - end - end - - -- The Responses API expects the input field to be an array of message objects - local out = { - model = opts.model.id, - -- Always request streaming for Responses API (honor model.streaming or default to true) - stream = opts.model.streaming ~= false, - input = input_messages, - } - - -- Add instructions if we have any system messages - if instructions then - out.instructions = instructions - end - - -- Add tools for Responses API if available - if opts.tools and opts.model.tools then - out.tools = vim.tbl_map(function(tool) - return { - type = 'function', - ['function'] = { - name = tool.name, - description = tool.description, - parameters = tool.schema, - strict = true, - }, - } - end, opts.tools) - end - - -- Note: temperature is not supported by Responses API, so we don't include it - - return out - end - - -- Original Chat Completion API logic - inputs = vim.tbl_map(function(input) - local output = { - role = input.role, - content = input.content, - } - - if is_o1 then - if input.role == constants.ROLE.SYSTEM then - output.role = constants.ROLE.USER - end - end - - if input.tool_call_id then - output.tool_call_id = input.tool_call_id - end - - if input.tool_calls then - output.tool_calls = vim.tbl_map(function(tool_call) - return { - id = tool_call.id, - type = 'function', - ['function'] = { - name = tool_call.name, - arguments = tool_call.arguments or nil, - }, - } - end, input.tool_calls) - end - - return output - end, inputs) - - local out = { - messages = inputs, - model = opts.model.id, - stream = opts.model.streaming or false, - } - - if opts.tools and opts.model.tools then - out.tools = vim.tbl_map(function(tool) - return { - type = 'function', - ['function'] = { - name = tool.name, - description = tool.description, - parameters = tool.schema, - }, - } - end, opts.tools) - end - - if not is_o1 then - out.n = 1 - out.top_p = 1 - out.temperature = opts.temperature + return prepare_responses_input(inputs, opts) end - - if opts.model.max_output_tokens then - out.max_tokens = opts.model.max_output_tokens - end - - return out + return prepare_chat_input(inputs, opts) end, prepare_output = function(output, opts) - -- Check if this model uses the Responses API if opts and opts.model and opts.model.use_responses then - -- Handle Responses API output format - local content = '' - local reasoning = '' - local finish_reason = nil - local total_tokens = 0 - local tool_calls = {} - - -- Check for error in response - if output.error then - -- Surface the error as a finish reason to stop processing - local error_msg = output.error - if type(error_msg) == 'table' then - error_msg = error_msg.message or vim.inspect(error_msg) - end - return { - content = '', - reasoning = '', - finish_reason = 'error: ' .. tostring(error_msg), - total_tokens = nil, - tool_calls = {}, - } - end - - if output.type then - -- This is a streaming response from Responses API - if output.type == 'response.created' or output.type == 'response.in_progress' then - -- In-progress events, we don't have content yet - return { - content = '', - reasoning = '', - finish_reason = nil, - total_tokens = nil, - tool_calls = {}, - } - elseif output.type == 'response.completed' then - -- Completed response: do NOT resend content here to avoid duplication. - -- Only signal finish and capture usage/reasoning. - local response = output.response - if response then - if response.reasoning and response.reasoning.summary then - reasoning = response.reasoning.summary - end - if response.usage then - total_tokens = response.usage.total_tokens - end - finish_reason = 'stop' - end - return { - content = '', - reasoning = reasoning, - finish_reason = finish_reason, - total_tokens = total_tokens, - tool_calls = {}, - } - elseif output.type == 'response.content.delta' or output.type == 'response.output_text.delta' then - -- Streaming content delta - if output.delta then - if type(output.delta) == 'string' then - content = output.delta - elseif type(output.delta) == 'table' then - if output.delta.content then - content = output.delta.content - elseif output.delta.output_text then - content = extract_text_from_parts({ output.delta.output_text }) - elseif output.delta.text then - content = output.delta.text - end - end - end - elseif output.type == 'response.delta' then - -- Handle response.delta with nested output_text - if output.delta and output.delta.output_text then - content = extract_text_from_parts({ output.delta.output_text }) - end - elseif output.type == 'response.content.done' or output.type == 'response.output_text.done' then - -- Terminal content event; keep streaming open until response.completed provides usage info - finish_reason = nil - elseif output.type == 'response.error' then - -- Handle error event - local error_msg = output.error - if type(error_msg) == 'table' then - error_msg = error_msg.message or vim.inspect(error_msg) - end - finish_reason = 'error: ' .. tostring(error_msg) - elseif output.type == 'response.tool_call.delta' then - -- Handle tool call delta events - if output.delta and output.delta.tool_calls then - for _, tool_call in ipairs(output.delta.tool_calls) do - local id = tool_call.id or ('tooluse_' .. (tool_call.index or 1)) - local existing_call = nil - for _, tc in ipairs(tool_calls) do - if tc.id == id then - existing_call = tc - break - end - end - if not existing_call then - table.insert(tool_calls, { - id = id, - index = tool_call.index or #tool_calls + 1, - name = tool_call.name or '', - arguments = tool_call.arguments or '', - }) - else - -- Append arguments - existing_call.arguments = existing_call.arguments .. (tool_call.arguments or '') - end - end - end - end - elseif output.response then - -- Non-streaming response or final response - local response = output.response - - -- Check for error in the response object - if response.error then - local error_msg = response.error - if type(error_msg) == 'table' then - error_msg = error_msg.message or vim.inspect(error_msg) - end - return { - content = '', - reasoning = '', - finish_reason = 'error: ' .. tostring(error_msg), - total_tokens = nil, - tool_calls = {}, - } - end - - if response.output and #response.output > 0 then - for _, msg in ipairs(response.output) do - if msg.content and #msg.content > 0 then - content = content .. extract_text_from_parts(msg.content) - end - -- Extract tool calls from output messages - if msg.tool_calls then - for i, tool_call in ipairs(msg.tool_calls) do - local id = tool_call.id or ('tooluse_' .. i) - table.insert(tool_calls, { - id = id, - index = tool_call.index or i, - name = tool_call.name or '', - arguments = tool_call.arguments or '', - }) - end - end - end - end - - if response.reasoning and response.reasoning.summary then - reasoning = response.reasoning.summary - end - - if response.usage then - total_tokens = response.usage.total_tokens - end - - finish_reason = response.status == 'completed' and 'stop' or nil - end - - return { - content = content, - reasoning = reasoning, - finish_reason = finish_reason, - total_tokens = total_tokens, - tool_calls = tool_calls, - } + return prepare_responses_output(output) end - - -- Original Chat Completion API logic - local tool_calls = {} - - local choice - if output.choices and #output.choices > 0 then - for _, choice in ipairs(output.choices) do - local message = choice.message or choice.delta - if message and message.tool_calls then - for i, tool_call in ipairs(message.tool_calls) do - local fn = tool_call['function'] - if fn then - local index = tool_call.index or i - local id = utils.empty(tool_call.id) and ('tooluse_' .. index) or tool_call.id - table.insert(tool_calls, { - id = id, - index = index, - name = fn.name, - arguments = fn.arguments or '', - }) - end - end - end - end - - choice = output.choices[1] - else - choice = output - end - - local message = choice.message or choice.delta - local content = message and message.content - local reasoning = message and (message.reasoning or message.reasoning_content) - local usage = choice.usage and choice.usage.total_tokens - if not usage then - usage = output.usage and output.usage.total_tokens - end - local finish_reason = choice.finish_reason or choice.done_reason or output.finish_reason or output.done_reason - - return { - content = content, - reasoning = reasoning, - finish_reason = finish_reason, - total_tokens = usage, - tool_calls = tool_calls, - } + return prepare_chat_output(output) end, get_url = function(opts) - -- Check if this model uses the Responses API if opts and opts.model and opts.model.use_responses then return 'https://api.githubcopilot.com/responses' end - - -- Default to Chat Completion API return 'https://api.githubcopilot.com/chat/completions' end, } @@ -755,17 +684,15 @@ M.github_models = { return vim .iter(response.body) :map(function(model) - local max_output_tokens = model.limits.max_output_tokens - local max_input_tokens = model.limits.max_input_tokens return { id = model.id, name = model.name, - tokenizer = 'o200k_base', - max_input_tokens = max_input_tokens, - max_output_tokens = max_output_tokens, - streaming = vim.tbl_contains(model.capabilities, 'streaming'), - tools = vim.tbl_contains(model.capabilities, 'tool-calling'), - reasoning = vim.tbl_contains(model.capabilities, 'reasoning'), + tokenizer = 'o200k_base', -- GitHub Models doesn't expose tokenizer info + max_input_tokens = model.limits and model.limits.max_input_tokens, + max_output_tokens = model.limits and model.limits.max_output_tokens, + streaming = model.capabilities and vim.tbl_contains(model.capabilities, 'streaming') or false, + tools = model.capabilities and vim.tbl_contains(model.capabilities, 'tool-calling') or false, + reasoning = model.capabilities and vim.tbl_contains(model.capabilities, 'reasoning') or false, version = model.version, } end)