diff --git a/src/config.ts b/src/config.ts index 9683fe3..669ee8d 100644 --- a/src/config.ts +++ b/src/config.ts @@ -1,3 +1,7 @@ +// Specify system prompt +// Only required for 'chat' type model +export const SYSTEM_PROMPT = "You are a professional writer. Completion the given story/writting based on user prompt. Answer only with the story/writting and nothing else. DO NOT REPEAT THE USER PROMPT IN THE STORY/WRITTING. DO NOT INLCUDE CONTENT SUCH as \"sure, here is the story\" or any similar response. At every beginning of paragraph, include a newline character"; + // This value is the default for a local llama.cpp server. // Change as needed if you use another provider. export const BASE_URL = "http://localhost:8080/v1/"; @@ -9,13 +13,72 @@ export const API_KEY = ""; // so it is usually not necessary to change this value. export const MAX_TOKENS = 500; +// Specify the model name used by the API. +// Update this value if a different model is required. +export const MODEL = "llama3-70b-8192"; + +// Specify the model type: 'chat' or 'completion' +// - 'chat': Most models, including GPT-4 Turbo, use this type +// - 'completion': Non-chat models, like GPT-3.5-Turbo-Instruct +export const MODEL_TYPE = "completion"; + // Add generation/sampling parameters here. Their effect depends on the API provider. // No attempt is made to normalize these parameters across different providers. // They are passed to the API endpoint unmodified. export const PARAMS = { - temperature: 1.0, - top_k: 0, - top_p: 1.0, - min_p: 0.02, - repeat_penalty: 1.0, + temperature: 1.0, + top_k: 0, + top_p: 1.0, + min_p: 0.02, + repeat_penalty: 1.0, }; + + +//Example - local model (model type: 'completion') + +// export const BASE_URL = "http://localhost:8080/v1/"; +// export const MODEL_TYPE = "chat"; +// export const MODEL = "llama3-70b-8192"; +// export const API_KEY = ""; +// export const MAX_TOKENS = 500; +// export const PARAMS = { +// temperature: 1.0, +// top_k: 0, +// top_p: 1.0, +// min_p: 0.02, +// repeat_penalty: 1.0, +// }; + + +//Example - OpenAI, gpt-3.5-turbo-instruct (model type: 'completion') + +// export const BASE_URL = "https://api.openai.com/v1"; +// export const MODEL_TYPE = "completion"; +// export const MODEL = "gpt-3.5-turbo-instruct"; +// export const API_KEY = ""; +// export const MAX_TOKENS = 500; +// export const PARAMS = { + // temperature: 1.0, + // top_k: 0, + // top_p: 1.0, + // min_p: 0.02, + // repeat_penalty: 1.0, +// }; + +//Example - Groq API, llama-70b (model type: 'chat') + +// export const BASE_URL = "https://api.groq.com/openai/v1/"; +// export const API_KEY = ""; +// export const MAX_TOKENS = 500; +// export const MODEL = "llama3-70b-8192"; +// export const MODEL_TYPE = "chat"; +// export const PARAMS = { + // temperature: 1.0, + // top_k: 0, + // top_p: 1.0, + // min_p: 0.02, + // repeat_penalty: 1.0, +// }; + + + diff --git a/src/main.ts b/src/main.ts index c857f51..ce58be5 100644 --- a/src/main.ts +++ b/src/main.ts @@ -11,7 +11,7 @@ import scrollIntoView from "scroll-into-view-if-needed"; import ParagraphClipboard from "./clipboard.ts"; import SplitEmbed from "./embed.ts"; -import { BASE_URL, API_KEY, MAX_TOKENS, PARAMS } from "./config.ts"; +import { BASE_URL, API_KEY, MAX_TOKENS, PARAMS, MODEL, MODEL_TYPE, SYSTEM_PROMPT } from "./config.ts"; enum State { Editing, @@ -41,23 +41,34 @@ function scrollEmbedIntoView() { } async function streamText(prompt: string, pane: Element): Promise { - const params: OpenAI.CompletionCreateParamsStreaming = { - stream: true, - // This parameter is ignored by most OpenAI-compatible local API providers. - model: "gpt-3.5-turbo-instruct", - prompt: prompt, - max_tokens: MAX_TOKENS, - // @ts-ignore: llama.cpp - n_predict: MAX_TOKENS, - // @ts-ignore: llama.cpp - cache_prompt: true, - ...PARAMS, - }; - - const controller = new AbortController(); - controllers.push(controller); - - const stream = await client.completions.create(params, { signal: controller.signal }); + let params: any; + let stream: any; + + if (MODEL_TYPE === "chat") { + params = { + stream: true, + model: MODEL, + messages: [{role: "system", "content": SYSTEM_PROMPT}, {role: "user", content: prompt }], + max_tokens: MAX_TOKENS + }; + const controller = new AbortController(); + controllers.push(controller); + stream = await client.chat.completions.create(params, { signal: controller.signal }); + } else { + params = { + stream: true, + model: MODEL, + prompt: prompt, + max_tokens: MAX_TOKENS + // @ts-ignore: llama.cpp + // n_predict: MAX_TOKENS, + // @ts-ignore: llama.cpp + // cache_prompt: true + }; + const controller = new AbortController(); + controllers.push(controller); + stream = await client.completions.create(params, { signal: controller.signal }); + } let text = ""; let startFound = false; @@ -65,7 +76,9 @@ async function streamText(prompt: string, pane: Element): Promise { for await (const chunk of stream) { let newText: string; - if (chunk.hasOwnProperty("choices")) { + if (MODEL_TYPE === "chat") { + newText = chunk.choices[0]?.delta?.content || ""; + } else if (chunk.hasOwnProperty("choices")) { newText = chunk.choices[0].text; } else { // @ts-ignore: llama.cpp