Skip to content

Commit 590b24b

Browse files
fix(client): optimize token counting algorithm
1 parent 011b0c4 commit 590b24b

1 file changed

Lines changed: 35 additions & 54 deletions

File tree

src/ChatGPTClient.js

Lines changed: 35 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import { fetchEventSource } from '@waylaidwanderer/fetch-event-source';
66
import { Agent } from 'undici';
77

88
const CHATGPT_MODEL = 'text-chat-davinci-002-sh-alpha-aoruigiofdj83';
9-
const CHATGPT_TOKENIZER = get_encoding('cl100k_base');
109

1110
export default class ChatGPTClient {
1211
constructor(
@@ -51,9 +50,10 @@ export default class ChatGPTClient {
5150
if (isChatGptModel) {
5251
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
5352
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
54-
// without tripping the stop sequences, so I'm using "##im_start##" instead.
53+
// without tripping the stop sequences, so I'm using "||>" instead.
5554
this.startToken = '||>';
5655
this.endToken = '';
56+
this.gptEncoder = get_encoding('cl100k_base');
5757
} else if (isUnofficialChatGptModel) {
5858
this.startToken = '<|im_start|>';
5959
this.endToken = '<|im_end|>';
@@ -62,9 +62,16 @@ export default class ChatGPTClient {
6262
'<|im_end|>': 100265,
6363
});
6464
} else {
65-
this.startToken = '<|endoftext|>';
66-
this.endToken = this.startToken;
67-
this.gptEncoder = encoding_for_model('text-davinci-003');
65+
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
66+
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
67+
// as a single token. So we're using this instead.
68+
this.startToken = '||>';
69+
this.endToken = '';
70+
try {
71+
this.gptEncoder = encoding_for_model(this.modelOptions.model);
72+
} catch {
73+
this.gptEncoder = encoding_for_model('text-davinci-003');
74+
}
6875
}
6976

7077
if (!this.modelOptions.stop) {
@@ -342,10 +349,7 @@ export default class ChatGPTClient {
342349

343350
let currentTokenCount;
344351
if (isChatGptModel) {
345-
currentTokenCount = this.constructor.getTokenCountForMessages([
346-
instructionsPayload,
347-
messagePayload,
348-
]);
352+
currentTokenCount = this.getTokenCountForMessage(instructionsPayload) + this.getTokenCountForMessage(messagePayload);
349353
} else {
350354
currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`);
351355
}
@@ -370,21 +374,8 @@ export default class ChatGPTClient {
370374
newPromptBody = `${promptPrefix}${messageString}${promptBody}`;
371375
}
372376

373-
// The reason I don't simply get the token count of the messageString and add it to currentTokenCount is because
374-
// joined words may combine into a single token. Actually, that isn't really applicable here, but I can't
375-
// resist doing it the "proper" way.
376-
let newTokenCount;
377-
if (isChatGptModel) {
378-
newTokenCount = this.constructor.getTokenCountForMessages([
379-
instructionsPayload,
380-
{
381-
...messagePayload,
382-
content: newPromptBody,
383-
},
384-
]);
385-
} else {
386-
newTokenCount = this.getTokenCount(`${newPromptBody}${promptSuffix}`);
387-
}
377+
const tokenCountForMessage = this.getTokenCount(messageString);
378+
const newTokenCount = currentTokenCount + tokenCountForMessage;
388379
if (newTokenCount > maxTokenCount) {
389380
if (promptBody) {
390381
// This message would put us over the token limit, so don't add it.
@@ -395,6 +386,7 @@ export default class ChatGPTClient {
395386
}
396387
promptBody = newPromptBody;
397388
currentTokenCount = newTokenCount;
389+
// wait for next tick to avoid blocking the event loop
398390
await new Promise((resolve) => setTimeout(resolve, 0));
399391
return buildPromptBody();
400392
}
@@ -404,19 +396,14 @@ export default class ChatGPTClient {
404396
await buildPromptBody();
405397

406398
const prompt = `${promptBody}${promptSuffix}`;
407-
408-
let numTokens;
409399
if (isChatGptModel) {
410400
messagePayload.content = prompt;
411-
numTokens = this.constructor.getTokenCountForMessages([
412-
instructionsPayload,
413-
messagePayload,
414-
]);
415-
} else {
416-
numTokens = this.getTokenCount(prompt);
401+
// Add 2 tokens for metadata after all messages have been counted.
402+
currentTokenCount += 2;
417403
}
404+
418405
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
419-
this.modelOptions.max_tokens = Math.min(this.maxContextTokens - numTokens, this.maxResponseTokens);
406+
this.modelOptions.max_tokens = Math.min(this.maxContextTokens - currentTokenCount, this.maxResponseTokens);
420407

421408
if (isChatGptModel) {
422409
return [
@@ -434,30 +421,24 @@ export default class ChatGPTClient {
434421
/**
435422
* Algorithm adapted from "6. Counting tokens for chat API calls" of
436423
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
437-
* @param {*[]} messages
424+
*
425+
* An additional 2 tokens need to be added for metadata after all messages have been counted.
426+
*
427+
* @param {*} message
438428
*/
439-
static getTokenCountForMessages(messages) {
440-
// Get the encoding tokenizer
441-
const tokenizer = CHATGPT_TOKENIZER;
442-
443-
// Map each message to the number of tokens it contains
444-
const messageTokenCounts = messages.map((message) => {
445-
// Map each property of the message to the number of tokens it contains
446-
const propertyTokenCounts = Object.entries(message).map(([key, value]) => {
447-
// Count the number of tokens in the property value
448-
const numTokens = tokenizer.encode(value).length;
449-
450-
// Subtract 1 token if the property key is 'name'
451-
const adjustment = (key === 'name') ? 1 : 0;
452-
return numTokens - adjustment;
453-
});
454-
455-
// Sum the number of tokens in all properties and add 4 for metadata
456-
return propertyTokenCounts.reduce((a, b) => a + b, 4);
429+
getTokenCountForMessage(message) {
430+
// Map each property of the message to the number of tokens it contains
431+
const propertyTokenCounts = Object.entries(message).map(([key, value]) => {
432+
// Count the number of tokens in the property value
433+
const numTokens = this.getTokenCount(value);
434+
435+
// Subtract 1 token if the property key is 'name'
436+
const adjustment = (key === 'name') ? 1 : 0;
437+
return numTokens - adjustment;
457438
});
458439

459-
// Sum the number of tokens in all messages and add 2 for metadata
460-
return messageTokenCounts.reduce((a, b) => a + b, 2);
440+
// Sum the number of tokens in all properties and add 4 for metadata
441+
return propertyTokenCounts.reduce((a, b) => a + b, 4);
461442
}
462443

463444
/**

0 commit comments

Comments
 (0)