Skip to content

Commit 4d5dc1c

Browse files
pxwgdeathbeam
andauthored
feat(copilot): Support "Auto" model mode (Smart Model Selection) similar to VS Code (#1518)
* update: add auto into model completions * update: add auto model selection module * fix: remove unnecessary notifications * ui: show current model at header * refactor: remove function to provider configs * fix: rename duplicate function name * fix: remove route_model provider to a better position * fix: add the 'auto' model option as an model in copilot provider * refactor: move model selector to specific provider * cleanup some stuff Signed-off-by: Tomas Slusny <slusnucky@gmail.com> --------- Signed-off-by: Tomas Slusny <slusnucky@gmail.com> Co-authored-by: Tomas Slusny <slusnucky@gmail.com>
1 parent 99a1190 commit 4d5dc1c

3 files changed

Lines changed: 63 additions & 3 deletions

File tree

lua/CopilotChat/client.lua

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,16 @@ function Client:ask(opts)
319319
error('Provider not found: ' .. provider_name)
320320
end
321321

322+
if provider.resolve_model then
323+
local headers = self:authenticate(provider_name)
324+
local resolved_model = provider.resolve_model(headers, opts.model)
325+
opts.model = resolved_model
326+
model_config = models[opts.model]
327+
if not model_config then
328+
error('Resolved model not found: ' .. opts.model)
329+
end
330+
end
331+
322332
local options = {
323333
model = vim.tbl_extend('force', model_config, {
324334
id = opts.model:gsub(':' .. provider_name .. '$', ''),
@@ -389,6 +399,7 @@ function Client:ask(opts)
389399
local errored = nil
390400
local finished = false
391401
local token_count = 0
402+
local out_model = nil
392403
local response_content_buffer = stringbuffer()
393404
local response_reasoning_buffer = stringbuffer()
394405

@@ -451,6 +462,10 @@ function Client:ask(opts)
451462
response_reasoning_buffer:put(out.reasoning)
452463
end
453464

465+
if out.model then
466+
out_model = out.model
467+
end
468+
454469
if opts.on_progress then
455470
opts.on_progress({
456471
role = constants.ROLE.ASSISTANT,
@@ -589,6 +604,7 @@ function Client:ask(opts)
589604
content = response_text,
590605
reasoning = response_reasoning,
591606
tool_calls = #tool_calls:values() > 0 and tool_calls:values() or nil,
607+
model = out_model,
592608
},
593609
token_count = token_count,
594610
token_max_count = max_tokens,

lua/CopilotChat/config/providers.lua

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ local function get_github_models_token(tag)
197197
end
198198

199199
--- Prepare input for Responses API
200-
---@param inputs table<CopilotChat.client.Message>
200+
---@param inputs CopilotChat.client.Message[]
201201
---@param opts CopilotChat.config.providers.Options
202202
---@return table
203203
local function prepare_responses_input(inputs, opts)
@@ -257,7 +257,7 @@ local function prepare_responses_input(inputs, opts)
257257
end
258258

259259
--- Prepare input for Chat Completions API
260-
---@param inputs table<CopilotChat.client.Message>
260+
---@param inputs CopilotChat.client.Message[]
261261
---@param opts CopilotChat.config.providers.Options
262262
---@return table
263263
local function prepare_chat_input(inputs, opts)
@@ -353,6 +353,7 @@ local function prepare_responses_output(output)
353353
local finish_reason = nil
354354
local total_tokens = nil
355355
local tool_calls = {}
356+
local model = nil
356357

357358
-- Handle errors
358359
local error_msg = output.error or (output.response and output.response.error)
@@ -366,6 +367,7 @@ local function prepare_responses_output(output)
366367
finish_reason = 'error: ' .. tostring(error_msg),
367368
total_tokens = nil,
368369
tool_calls = {},
370+
model = nil,
369371
}
370372
end
371373

@@ -398,6 +400,9 @@ local function prepare_responses_output(output)
398400
if response.usage then
399401
total_tokens = response.usage.total_tokens
400402
end
403+
if response.model then
404+
model = response.model
405+
end
401406
finish_reason = 'stop'
402407
end
403408
elseif output.type == 'response.failed' then
@@ -429,6 +434,9 @@ local function prepare_responses_output(output)
429434
if response.usage then
430435
total_tokens = response.usage.total_tokens
431436
end
437+
if response.model then
438+
model = response.model
439+
end
432440
finish_reason = response.status == 'completed' and 'stop' or nil
433441
end
434442

@@ -438,6 +446,7 @@ local function prepare_responses_output(output)
438446
finish_reason = finish_reason,
439447
total_tokens = total_tokens,
440448
tool_calls = tool_calls,
449+
model = model,
441450
}
442451
end
443452

@@ -477,13 +486,15 @@ local function prepare_chat_output(output)
477486
local reasoning = message and (message.reasoning or message.reasoning_content)
478487
local usage = choice.usage and choice.usage.total_tokens or output.usage and output.usage.total_tokens
479488
local finish_reason = choice.finish_reason or choice.done_reason or output.finish_reason or output.done_reason
489+
local model = choice.model or output.model
480490

481491
return {
482492
content = content,
483493
reasoning = reasoning,
484494
finish_reason = finish_reason,
485495
total_tokens = usage,
486496
tool_calls = tool_calls,
497+
model = model,
487498
}
488499
end
489500

@@ -498,13 +509,15 @@ end
498509
---@field finish_reason string?
499510
---@field total_tokens number?
500511
---@field tool_calls table<CopilotChat.client.ToolCall>
512+
---@field model string?
501513

502514
---@class CopilotChat.config.providers.Provider
503515
---@field disabled nil|boolean
504516
---@field get_headers nil|fun():table<string, string>,number?
505517
---@field get_info nil|fun(headers:table):string[]
506518
---@field get_models nil|fun(headers:table):table<CopilotChat.client.Model>
507-
---@field prepare_input nil|fun(inputs:table<CopilotChat.client.Message>, opts:CopilotChat.config.providers.Options):table
519+
---@field resolve_model nil|fun(headers:table, model: string):string
520+
---@field prepare_input nil|fun(inputs:CopilotChat.client.Message[], opts:CopilotChat.config.providers.Options):table,table?
508521
---@field prepare_output nil|fun(output:table, opts:CopilotChat.config.providers.Options):CopilotChat.config.providers.Output
509522
---@field get_url nil|fun(opts:CopilotChat.config.providers.Options):string
510523

@@ -529,6 +542,7 @@ M.copilot = {
529542
['Editor-Version'] = EDITOR_VERSION,
530543
['Editor-Plugin-Version'] = 'CopilotChat.nvim/*',
531544
['Copilot-Integration-Id'] = 'vscode-chat',
545+
['x-github-api-version'] = '2025-10-01',
532546
},
533547
response.body.expires_at
534548
end,
@@ -637,9 +651,36 @@ M.copilot = {
637651
end
638652
end
639653

654+
-- Auto model selector
655+
table.insert(models, {
656+
id = 'auto',
657+
name = 'Auto (Copilot)',
658+
description = 'Auto selects the best model for your request.',
659+
})
660+
640661
return models
641662
end,
642663

664+
resolve_model = function(headers, model)
665+
if model ~= 'auto' then
666+
return model
667+
end
668+
669+
local url = 'https://api.githubcopilot.com/models/session'
670+
local response, err = curl.post(url, {
671+
headers = headers,
672+
body = { auto_mode = { model_hints = { 'auto' } } },
673+
json_response = true,
674+
json_request = true,
675+
})
676+
677+
if err then
678+
error(err)
679+
end
680+
681+
return response.body.selected_model
682+
end,
683+
643684
prepare_input = function(inputs, opts)
644685
local request
645686
if opts.model.use_responses then

lua/CopilotChat/ui/chat.lua

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,9 @@ function Chat:render()
771771
-- Overlay section header with nice display
772772
local header_value = self.headers[message.role]
773773
local header_line = message.section.start_line - 2
774+
if message.model then
775+
header_value = header_value .. ' (' .. message.model .. ')'
776+
end
774777

775778
vim.api.nvim_buf_set_extmark(self.bufnr, highlight_ns, header_line, 0, {
776779
conceal = '',

0 commit comments

Comments
 (0)