diff --git a/README.md b/README.md index 359b237..eaa6dc9 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,9 @@ to store the credential in clear-text in a configuration file. As an alternative to providing the API key via the `OPENAI_API_KEY` environment variable, the user is encouraged to use the `api_key_cmd` configuration option. + + +#### api_key_cmd using script The `api_key_cmd` configuration option takes a string, which is executed at startup, and whose output is used as the API key. @@ -222,6 +225,71 @@ Note that the `api_key_cmd` arguments are split by whitespace. If you need whitespace inside an argument (for example to reference a path with spaces), you can wrap it in a separate script. +#### api_key_cmd using lua function + +Here is another way to provide the key by using Lua function + +```lua +local get_api_key = function (callback) + local job = require("plenary.job") + local url = "https://my-enterprise.com/key/management.url" + local value = "some-world-readable-client-id" + + job:new({ + command = "curl", + args = { + url, + "--silent", "-X", "POST", + "-H", "Accept: */*", + "-H", "Content-Type: application/x-www-form-urlencoded", + "-H", "Authorization: Basic " .. value, + "-d", "grant_type=client_credentials", + }, + on_exit = vim.schedule_wrap(function(response, exit_code) + vim.notify("Key job: exitcode " .. vim.inspect(exit_code) .. ", Key: " .. vim.inspect(response:result()), vim.log.levels.INFO) + + if exit_code ~= 0 then + -- curl failed + vim.notify("Key: failed to obtain key" .. vim.inspect(response), vim.log.levels.ERROR) + return + end + + -- Get stdout which is a json string + local result = table.concat(response:result(), "\n") + + local ok, json = pcall(vim.json.decode, result) + if not ok or not json then + vim.notify("Key: error decoding response " .. vim.inspect(result), vim.log.levels.ERROR) + return + end + + if json and json["access_token"] then + -- Notify the callback with the key and the valid duration + if callback then + callback(json["access_token"], json["expires_in"] or nil) + else + vim.env.OPENAI_API_KEY = json["access_token"] + if json["expires_in"] then + vim.env.OPENAI_API_KEY_EXPIRES = json["expires_in"] + end + end + return json["access_token"], json["expires_in"] or nil + end + end), + }) + :start() +end + +require("chatgpt").setup({ + api_key_cmd = 'get_api_key' +}) +``` + +### limited time key support +When "api_key_cmd" is run to obtain the key, it could optionally provide the key validity time. +ChatGPT plugin will automatically re-request "api_key_cmd" when time exceeds the validity time +when the next time ChatGPT is invoked. See the lua function section above for example. + ## Usage Plugin exposes following commands: diff --git a/lua/chatgpt/api.lua b/lua/chatgpt/api.lua index a557fda..28b06fb 100644 --- a/lua/chatgpt/api.lua +++ b/lua/chatgpt/api.lua @@ -5,13 +5,125 @@ local Utils = require("chatgpt.utils") local Api = {} +local key_expiry_timestamp = nil + +local function updateAuthenticationKey(key, timeout_in_secs) + if not key then + logger.warn("OPENAI_API_KEY callback is nil") + return + end + + if timeout_in_secs then + key_expiry_timestamp = os.time() + timeout_in_secs + end + + Api.OPENAI_API_KEY = key + if Api["OPENAI_API_TYPE"] == "azure" then + Api.AUTHORIZATION_HEADER = "api-key: " .. Api.OPENAI_API_KEY + else + Api.AUTHORIZATION_HEADER = "Authorization: Bearer " .. Api.OPENAI_API_KEY + end +end + +local splitCommandIntoTable = function(command) + local cmd = {} + for word in command:gmatch("%S+") do + table.insert(cmd, word) + end + return cmd +end + +local function loadConfigFromCommand(command, optionName, callback, defaultValue) + if type(command) == "function" then + return command(callback) -- or callback(defaultValue) + else + local cmd = splitCommandIntoTable(command) + job + :new({ + command = cmd[1], + args = vim.list_slice(cmd, 2, #cmd), + on_exit = function(j, exit_code) + if exit_code ~= 0 then + logger.warn("Config '" .. optionName .. "' did not return a value when executed") + return + end + local value = j:result()[1]:gsub("%s+$", "") + if value ~= nil and value ~= "" then + callback(value) + elseif defaultValue ~= nil and defaultValue ~= "" then + callback(defaultValue) + end + end, + }) + :start() + return + end +end + +local function loadConfigFromEnv(envName, configName, callback) + local variable = os.getenv(envName) + if not variable then + return + end + local value = variable:gsub("%s+$", "") + Api[configName] = value + if callback then + callback(value) + end +end + +local function loadOptionalConfig(envName, configName, optionName, callback, defaultValue) + loadConfigFromEnv(envName, configName) + if Api[configName] then + callback(Api[configName]) + elseif Config.options[optionName] ~= nil and Config.options[optionName] ~= "" then + loadConfigFromCommand(Config.options[optionName], optionName, callback, defaultValue) + else + callback(defaultValue) + end +end + +local function loadRequiredConfig(envName, configName, optionName, callback, defaultValue) + loadConfigFromEnv(envName, configName, callback) + if not Api[configName] then + if Config.options[optionName] ~= nil and Config.options[optionName] ~= "" then + loadConfigFromCommand(Config.options[optionName], optionName, callback, defaultValue) + else + logger.warn(configName .. " variable not set") + return + end + end +end + +-- Check if the key is valid and start the job +local function startJobWithKeyValidation(cb) + if not Api["OPENAI_API_KEY"] or (key_expiry_timestamp and os.time() > key_expiry_timestamp) then + if key_expiry_timestamp then + logger.info("startJobWithKeyValidation: Key expired at " .. key_expiry_timestamp) + end + Api["OPENAI_API_KEY"] = nil + loadRequiredConfig( + "OPENAI_API_KEY", + "OPENAI_API_KEY", + "api_key_cmd", + vim.schedule_wrap(function(value, timeout) + updateAuthenticationKey(value, timeout) + cb() + end) + ) + else + cb() + end + return 0 +end + function Api.completions(custom_params, cb) local openai_params = Utils.collapsed_openai_params(Config.options.openai_params) local params = vim.tbl_extend("keep", custom_params, openai_params) Api.make_call(Api.COMPLETIONS_URL, params, cb) end -function Api.chat_completions(custom_params, cb, should_stop) +local function doChatCompletions(custom_params, cb, should_stop) local openai_params = Utils.collapsed_openai_params(Config.options.openai_params) local params = vim.tbl_extend("keep", custom_params, openai_params) -- the custom params contains if model is not constant but function @@ -23,6 +135,7 @@ function Api.chat_completions(custom_params, cb, should_stop) if stream then local raw_chunks = "" local state = "START" + local prev_chunk -- store incomplete line from previous chunk cb = vim.schedule_wrap(cb) @@ -50,6 +163,10 @@ function Api.chat_completions(custom_params, cb, should_stop) "curl", args, function(chunk) + if prev_chunk ~= nil then + chunk = prev_chunk .. chunk + prev_chunk = nil + end local ok, json = pcall(vim.json.decode, chunk) if ok and json ~= nil then if json.error ~= nil then @@ -70,16 +187,18 @@ function Api.chat_completions(custom_params, cb, should_stop) }) if ok and json ~= nil then if - json - and json.choices - and json.choices[1] - and json.choices[1].delta - and json.choices[1].delta.content + json + and json.choices + and json.choices[1] + and json.choices[1].delta + and json.choices[1].delta.content then cb(json.choices[1].delta.content, state) raw_chunks = raw_chunks .. json.choices[1].delta.content state = "CONTINUE" end + else + prev_chunk = line end end end @@ -97,6 +216,12 @@ function Api.chat_completions(custom_params, cb, should_stop) end end +function Api.chat_completions(custom_params, cb, should_stop) + startJobWithKeyValidation(function() + doChatCompletions(custom_params, cb, should_stop) + end) +end + function Api.edits(custom_params, cb) local openai_params = Utils.collapsed_openai_params(Config.options.openai_params) local params = vim.tbl_extend("keep", custom_params, openai_params) @@ -136,15 +261,17 @@ function Api.make_call(url, params, cb) end end - Api.job = job - :new({ - command = "curl", - args = args, - on_exit = vim.schedule_wrap(function(response, exit_code) - Api.handle_response(response, exit_code, cb) - end), - }) - :start() + startJobWithKeyValidation(function() + Api.job = job + :new({ + command = "curl", + args = args, + on_exit = vim.schedule_wrap(function(response, exit_code) + Api.handle_response(response, exit_code, cb) + end), + }) + :start() + end) end Api.handle_response = vim.schedule_wrap(function(response, exit_code, cb) @@ -192,71 +319,6 @@ function Api.close() end end -local splitCommandIntoTable = function(command) - local cmd = {} - for word in command:gmatch("%S+") do - table.insert(cmd, word) - end - return cmd -end - -local function loadConfigFromCommand(command, optionName, callback, defaultValue) - local cmd = splitCommandIntoTable(command) - job - :new({ - command = cmd[1], - args = vim.list_slice(cmd, 2, #cmd), - on_exit = function(j, exit_code) - if exit_code ~= 0 then - logger.warn("Config '" .. optionName .. "' did not return a value when executed") - return - end - local value = j:result()[1]:gsub("%s+$", "") - if value ~= nil and value ~= "" then - callback(value) - elseif defaultValue ~= nil and defaultValue ~= "" then - callback(defaultValue) - end - end, - }) - :start() -end - -local function loadConfigFromEnv(envName, configName, callback) - local variable = os.getenv(envName) - if not variable then - return - end - local value = variable:gsub("%s+$", "") - Api[configName] = value - if callback then - callback(value) - end -end - -local function loadOptionalConfig(envName, configName, optionName, callback, defaultValue) - loadConfigFromEnv(envName, configName) - if Api[configName] then - callback(Api[configName]) - elseif Config.options[optionName] ~= nil and Config.options[optionName] ~= "" then - loadConfigFromCommand(Config.options[optionName], optionName, callback, defaultValue) - else - callback(defaultValue) - end -end - -local function loadRequiredConfig(envName, configName, optionName, callback, defaultValue) - loadConfigFromEnv(envName, configName, callback) - if not Api[configName] then - if Config.options[optionName] ~= nil and Config.options[optionName] ~= "" then - loadConfigFromCommand(Config.options[optionName], optionName, callback, defaultValue) - else - logger.warn(configName .. " variable not set") - return - end - end -end - local function loadAzureConfigs() loadRequiredConfig("OPENAI_API_BASE", "OPENAI_API_BASE", "azure_api_base_cmd", function(base) Api.OPENAI_API_BASE = base @@ -273,15 +335,15 @@ local function loadAzureConfigs() if Api["OPENAI_API_BASE"] and Api["OPENAI_API_AZURE_ENGINE"] then Api.COMPLETIONS_URL = Api.OPENAI_API_BASE - .. "/openai/deployments/" - .. Api.OPENAI_API_AZURE_ENGINE - .. "/completions?api-version=" - .. Api.OPENAI_API_AZURE_VERSION + .. "/openai/deployments/" + .. Api.OPENAI_API_AZURE_ENGINE + .. "/completions?api-version=" + .. Api.OPENAI_API_AZURE_VERSION Api.CHAT_COMPLETIONS_URL = Api.OPENAI_API_BASE - .. "/openai/deployments/" - .. Api.OPENAI_API_AZURE_ENGINE - .. "/chat/completions?api-version=" - .. Api.OPENAI_API_AZURE_VERSION + .. "/openai/deployments/" + .. Api.OPENAI_API_AZURE_ENGINE + .. "/chat/completions?api-version=" + .. Api.OPENAI_API_AZURE_VERSION end end, "2023-05-15" @@ -310,16 +372,12 @@ function Api.setup() Api.EDITS_URL = ensureUrlProtocol(Api.OPENAI_API_HOST .. "/v1/edits") end, "api.openai.com") - loadRequiredConfig("OPENAI_API_KEY", "OPENAI_API_KEY", "api_key_cmd", function(key) - Api.OPENAI_API_KEY = key - + loadRequiredConfig("OPENAI_API_KEY", "OPENAI_API_KEY", "api_key_cmd", function(key, timeout) loadOptionalConfig("OPENAI_API_TYPE", "OPENAI_API_TYPE", "api_type_cmd", function(type) if type == "azure" then loadAzureConfigs() - Api.AUTHORIZATION_HEADER = "api-key: " .. Api.OPENAI_API_KEY - else - Api.AUTHORIZATION_HEADER = "Authorization: Bearer " .. Api.OPENAI_API_KEY end + updateAuthenticationKey(key, timeout) end, "") end) end diff --git a/lua/chatgpt/prompts.lua b/lua/chatgpt/prompts.lua index 5e280e8..826d6bd 100644 --- a/lua/chatgpt/prompts.lua +++ b/lua/chatgpt/prompts.lua @@ -1,3 +1,6 @@ +if not package.loaded["telescope"] then + return {} +end local pickers = require("telescope.pickers") local conf = require("telescope.config").values local actions = require("telescope.actions")