Skip to content

Commit 47b678b

Browse files
committed
feat: context kv cache key and value type configurations
1 parent 1257846 commit 47b678b

19 files changed

Lines changed: 508 additions & 69 deletions

llama/addon/AddonContext.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,20 @@ AddonContext::AddonContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Ad
443443
context_params.no_perf = !(options.Get("performanceTracking").As<Napi::Boolean>().Value());
444444
}
445445

446+
if (options.Has("kvCacheKeyType") && options.Get("kvCacheKeyType").IsNumber()) {
447+
auto keyType = options.Get("kvCacheKeyType").As<Napi::Number>().Int32Value();
448+
if (keyType >= 0 && keyType < GGML_TYPE_COUNT) {
449+
context_params.type_k = keyType;
450+
}
451+
}
452+
453+
if (options.Has("kvCacheValueType") && options.Get("kvCacheValueType").IsNumber()) {
454+
auto valueType = options.Get("kvCacheValueType").As<Napi::Number>().Int32Value();
455+
if (valueType >= 0 && valueType < GGML_TYPE_COUNT) {
456+
context_params.type_v = valueType;
457+
}
458+
}
459+
446460
if (options.Has("swaFullCache")) {
447461
context_params.swa_full = options.Get("swaFullCache").As<Napi::Boolean>().Value();
448462
}
@@ -1063,7 +1077,7 @@ void AddonContext::init(Napi::Object exports) {
10631077
}
10641078

10651079
AddonContextSequenceCheckpoint::AddonContextSequenceCheckpoint(const Napi::CallbackInfo& info) : Napi::ObjectWrap<AddonContextSequenceCheckpoint>(info) {
1066-
1080+
10671081
}
10681082
AddonContextSequenceCheckpoint::~AddonContextSequenceCheckpoint() {
10691083
dispose();
@@ -1099,7 +1113,7 @@ class AddonContextSequenceCheckpointInitWorker : public Napi::AsyncWorker {
10991113
checkpoint->minPos = llama_memory_seq_pos_min(llama_get_memory(context->ctx), checkpoint->sequenceId);
11001114
checkpoint->maxPos = llama_memory_seq_pos_max(llama_get_memory(context->ctx), checkpoint->sequenceId);
11011115
const size_t checkpointSize = llama_state_seq_get_size_ext(context->ctx, checkpoint->sequenceId, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
1102-
1116+
11031117
checkpoint->data.resize(checkpointSize, 0);
11041118
llama_state_seq_get_data_ext(context->ctx, checkpoint->data.data(), checkpointSize, checkpoint->sequenceId, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
11051119
} catch (const std::exception& e) {
@@ -1164,4 +1178,4 @@ void AddonContextSequenceCheckpoint::init(Napi::Object exports) {
11641178
}
11651179
)
11661180
);
1167-
}
1181+
}

src/bindings/AddonTypes.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ export type BindingModule = {
3131
ranking?: boolean,
3232
threads?: number,
3333
performanceTracking?: boolean,
34+
kvCacheKeyType?: number,
35+
kvCacheValueType?: number,
3436
swaFullCache?: boolean
3537
}): AddonContext
3638
},

src/cli/commands/ChatCommand.ts

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import {withCliCommandDescriptionDocsUrl} from "../utils/withCliCommandDescripti
3131
import {ConsoleInteraction, ConsoleInteractionKey} from "../utils/ConsoleInteraction.js";
3232
import {DraftSequenceTokenPredictor} from "../../evaluator/LlamaContext/tokenPredictors/DraftSequenceTokenPredictor.js";
3333
import {ParsedXtcArg, parseXtcArg} from "../utils/parseXtcArg.js";
34+
import {GgmlType} from "../../gguf/types/GgufTensorInfoTypes.js";
3435

3536
type ChatCommand = {
3637
modelPath?: string,
@@ -46,6 +47,8 @@ type ChatCommand = {
4647
contextSize?: number,
4748
batchSize?: number,
4849
flashAttention?: boolean,
50+
kvCacheKeyType?: "currentQuant" | keyof typeof GgmlType,
51+
kvCacheValueType?: "currentQuant" | keyof typeof GgmlType,
4952
swaFullCache?: boolean,
5053
noTrimWhitespace: boolean,
5154
grammar: "text" | Parameters<typeof LlamaGrammar.getFor>[1],
@@ -172,6 +175,24 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
172175
default: false,
173176
description: "Enable flash attention"
174177
})
178+
.option("kvCacheKeyType", {
179+
alias: "kvckt",
180+
type: "string",
181+
choices: [
182+
"currentQuant",
183+
...Object.keys(GgmlType).filter((key) => typeof key === "string") as (keyof typeof GgmlType)[]
184+
] as const,
185+
description: "The type of the key for the context KV cache tensors"
186+
})
187+
.option("kvCacheValueType", {
188+
alias: "kvcvt",
189+
type: "string",
190+
choices: [
191+
"currentQuant",
192+
...Object.keys(GgmlType).filter((key) => typeof key === "string") as (keyof typeof GgmlType)[]
193+
] as const,
194+
description: "The type of the value for the context KV cache tensors"
195+
})
175196
.option("swaFullCache", {
176197
alias: "noSwa",
177198
type: "boolean",
@@ -379,7 +400,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
379400
},
380401
async handler({
381402
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt,
382-
promptFile, wrapper, noJinja, contextSize, batchSize, flashAttention, swaFullCache,
403+
promptFile, wrapper, noJinja, contextSize, batchSize, flashAttention, kvCacheKeyType, kvCacheValueType, swaFullCache,
383404
noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK,
384405
topP, seed, xtc, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
385406
repeatFrequencyPenalty, repeatPresencePenalty, dryRepeatPenaltyStrength, dryRepeatPenaltyBase, dryRepeatPenaltyAllowedLength,
@@ -390,8 +411,8 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
390411
try {
391412
await RunChat({
392413
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja, contextSize,
393-
batchSize, flashAttention, swaFullCache, noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads,
394-
temperature, minP, topK, topP, seed, xtc,
414+
batchSize, flashAttention, kvCacheKeyType, kvCacheValueType, swaFullCache, noTrimWhitespace, grammar, jsonSchemaGrammarFile,
415+
threads, temperature, minP, topK, topP, seed, xtc,
395416
gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
396417
dryRepeatPenaltyStrength, dryRepeatPenaltyBase, dryRepeatPenaltyAllowedLength, dryRepeatPenaltyLastTokens,
397418
maxTokens, reasoningBudget, noHistory, environmentFunctions, tokenPredictionDraftModel, tokenPredictionModelContextSize,
@@ -408,7 +429,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
408429

409430
async function RunChat({
410431
modelPath: modelArg, header: headerArg, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja,
411-
contextSize, batchSize, flashAttention, swaFullCache, noTrimWhitespace, grammar: grammarArg,
432+
contextSize, batchSize, kvCacheKeyType, kvCacheValueType, flashAttention, swaFullCache, noTrimWhitespace, grammar: grammarArg,
412433
jsonSchemaGrammarFile: jsonSchemaGrammarFilePath,
413434
threads, temperature, minP, topK, topP, seed, xtc, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine,
414435
repeatFrequencyPenalty, repeatPresencePenalty, dryRepeatPenaltyStrength, dryRepeatPenaltyBase, dryRepeatPenaltyAllowedLength,
@@ -444,12 +465,16 @@ async function RunChat({
444465
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
445466
flashAttention,
446467
swaFullCache,
468+
kvCacheKeyType,
469+
kvCacheValueType,
447470
useMmap
448471
});
449472
const resolvedDraftModelPath = (tokenPredictionDraftModel != null && tokenPredictionDraftModel !== "")
450473
? await resolveCommandGgufPath(tokenPredictionDraftModel, llama, headers, {
451474
flashAttention,
452475
swaFullCache,
476+
kvCacheKeyType,
477+
kvCacheValueType,
453478
useMmap,
454479
consoleTitle: "Draft model file"
455480
})
@@ -495,6 +520,8 @@ async function RunChat({
495520
? {fitContext: {contextSize}}
496521
: undefined,
497522
defaultContextFlashAttention: flashAttention,
523+
defaultContextKvCacheKeyType: kvCacheKeyType,
524+
defaultContextKvCacheValueType: kvCacheValueType,
498525
defaultContextSwaFullCache: swaFullCache,
499526
useMmap,
500527
useDirectIo,
@@ -530,6 +557,8 @@ async function RunChat({
530557
return await llama.loadModel({
531558
modelPath: resolvedDraftModelPath,
532559
defaultContextFlashAttention: flashAttention,
560+
defaultContextKvCacheKeyType: kvCacheKeyType,
561+
defaultContextKvCacheValueType: kvCacheValueType,
533562
defaultContextSwaFullCache: swaFullCache,
534563
useMmap,
535564
useDirectIo,

src/cli/commands/CompleteCommand.ts

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import {documentationPageUrls} from "../../config.js";
2323
import {ConsoleInteraction, ConsoleInteractionKey} from "../utils/ConsoleInteraction.js";
2424
import {DraftSequenceTokenPredictor} from "../../evaluator/LlamaContext/tokenPredictors/DraftSequenceTokenPredictor.js";
2525
import {ParsedXtcArg, parseXtcArg} from "../utils/parseXtcArg.js";
26+
import {GgmlType} from "../../gguf/types/GgufTensorInfoTypes.js";
2627

2728
type CompleteCommand = {
2829
modelPath?: string,
@@ -34,6 +35,8 @@ type CompleteCommand = {
3435
contextSize?: number,
3536
batchSize?: number,
3637
flashAttention?: boolean,
38+
kvCacheKeyType?: "currentQuant" | keyof typeof GgmlType,
39+
kvCacheValueType?: "currentQuant" | keyof typeof GgmlType,
3740
swaFullCache?: boolean,
3841
threads?: number,
3942
temperature: number,
@@ -129,6 +132,24 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
129132
default: false,
130133
description: "Enable flash attention"
131134
})
135+
.option("kvCacheKeyType", {
136+
alias: "kvckt",
137+
type: "string",
138+
choices: [
139+
"currentQuant",
140+
...Object.keys(GgmlType).filter((key) => typeof key === "string") as (keyof typeof GgmlType)[]
141+
] as const,
142+
description: "The type of the key for the context KV cache tensors"
143+
})
144+
.option("kvCacheValueType", {
145+
alias: "kvcvt",
146+
type: "string",
147+
choices: [
148+
"currentQuant",
149+
...Object.keys(GgmlType).filter((key) => typeof key === "string") as (keyof typeof GgmlType)[]
150+
] as const,
151+
description: "The type of the value for the context KV cache tensors"
152+
})
132153
.option("swaFullCache", {
133154
alias: "noSwa",
134155
type: "boolean",
@@ -299,15 +320,16 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
299320
},
300321
async handler({
301322
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize,
302-
flashAttention, swaFullCache, threads, temperature, minP, topK,
323+
flashAttention, kvCacheKeyType, kvCacheValueType, swaFullCache, threads, temperature, minP, topK,
303324
topP, seed, xtc, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
304325
repeatFrequencyPenalty, repeatPresencePenalty, dryRepeatPenaltyStrength, dryRepeatPenaltyBase, dryRepeatPenaltyAllowedLength,
305326
dryRepeatPenaltyLastTokens, maxTokens, tokenPredictionDraftModel, tokenPredictionModelContextSize,
306327
debug, numa, meter, timing, noMmap, useDirectIo, printTimings
307328
}) {
308329
try {
309330
await RunCompletion({
310-
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention, swaFullCache,
331+
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention,
332+
kvCacheKeyType, kvCacheValueType, swaFullCache,
311333
threads, temperature, minP, topK, topP, seed, xtc, gpuLayers, lastTokensRepeatPenalty,
312334
repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, dryRepeatPenaltyStrength,
313335
dryRepeatPenaltyBase, dryRepeatPenaltyAllowedLength, dryRepeatPenaltyLastTokens, maxTokens,
@@ -323,7 +345,8 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
323345

324346

325347
async function RunCompletion({
326-
modelPath: modelArg, header: headerArg, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention, swaFullCache,
348+
modelPath: modelArg, header: headerArg, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention,
349+
kvCacheKeyType, kvCacheValueType, swaFullCache,
327350
threads, temperature, minP, topK, topP, seed, xtc, gpuLayers,
328351
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
329352
dryRepeatPenaltyStrength, dryRepeatPenaltyBase, dryRepeatPenaltyAllowedLength, dryRepeatPenaltyLastTokens,
@@ -356,13 +379,17 @@ async function RunCompletion({
356379
const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
357380
flashAttention,
358381
swaFullCache,
359-
useMmap
382+
useMmap,
383+
kvCacheKeyType,
384+
kvCacheValueType
360385
});
361386
const resolvedDraftModelPath = (tokenPredictionDraftModel != null && tokenPredictionDraftModel !== "")
362387
? await resolveCommandGgufPath(tokenPredictionDraftModel, llama, headers, {
363388
flashAttention,
364389
swaFullCache,
365390
useMmap,
391+
kvCacheKeyType,
392+
kvCacheValueType,
366393
consoleTitle: "Draft model file"
367394
})
368395
: undefined;
@@ -400,6 +427,8 @@ async function RunCompletion({
400427
? {fitContext: {contextSize}}
401428
: undefined,
402429
defaultContextFlashAttention: flashAttention,
430+
defaultContextKvCacheKeyType: kvCacheKeyType,
431+
defaultContextKvCacheValueType: kvCacheValueType,
403432
defaultContextSwaFullCache: swaFullCache,
404433
useMmap,
405434
useDirectIo,
@@ -435,6 +464,8 @@ async function RunCompletion({
435464
return await llama.loadModel({
436465
modelPath: resolvedDraftModelPath,
437466
defaultContextFlashAttention: flashAttention,
467+
defaultContextKvCacheKeyType: kvCacheKeyType,
468+
defaultContextKvCacheValueType: kvCacheValueType,
438469
defaultContextSwaFullCache: swaFullCache,
439470
useMmap,
440471
useDirectIo,

0 commit comments

Comments
 (0)