Skip to content

Commit e70bcd0

Browse files
committed
fix: pad the context size to align with the implementation in llama.cpp
1 parent 24daf6d commit e70bcd0

6 files changed

Lines changed: 102 additions & 71 deletions

File tree

src/cli/commands/inspect/commands/InspectMeasureCommand.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import {withCliCommandDescriptionDocsUrl} from "../../../utils/withCliCommandDes
2121
import {documentationPageUrls} from "../../../../config.js";
2222
import {Llama} from "../../../../bindings/Llama.js";
2323
import {toBytes} from "../../../utils/toBytes.js";
24+
import {padSafeContextSize} from "../../../../evaluator/LlamaContext/utils/padSafeContextSize.js";
2425

2526
type InspectMeasureCommand = {
2627
modelPath?: string,
@@ -952,6 +953,8 @@ function getContextSizesCheckPlan(trainContextSize: number, tests: number = 10,
952953
if (size < 2)
953954
size = 2;
954955

956+
size = padSafeContextSize(size, "up");
957+
955958
if (res[res.length - 1] === size) {
956959
shouldStop = true;
957960
return;

src/evaluator/LlamaContext/LlamaContext.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import {
2222
import {resolveBatchItemsPrioritizationStrategy} from "./utils/resolveBatchItemsPrioritizationStrategy.js";
2323
import {LlamaSampler} from "./LlamaSampler.js";
2424
import {TokenPredictor} from "./TokenPredictor.js";
25+
import {padSafeContextSize} from "./utils/padSafeContextSize.js";
2526
import type {Llama} from "../../bindings/Llama.js";
2627

2728
const defaultLoraScale = 1;
@@ -98,12 +99,15 @@ export class LlamaContext {
9899
if (_model.disposed)
99100
throw new DisposedError();
100101

102+
const kvUnified = false;
101103
this._llama = _model._llama;
102104
this._model = _model;
103105
this._backendContextDisposeGuard = new DisposeGuard([this._model._backendModelDisposeGuard]);
104106
this._modelPreventDisposalHandle = this._model._backendModelDisposeGuard.createPreventDisposalHandle();
105107
this._totalSequences = Math.max(1, Math.floor(sequences));
106-
this._contextSize = Math.max(2, contextSize);
108+
this._contextSize = kvUnified
109+
? Math.floor(padSafeContextSize(Math.max(2, contextSize) * this._totalSequences, "up") / this._totalSequences)
110+
: padSafeContextSize(Math.max(2, contextSize), "up");
107111
this._batchSize = Math.max(batchSize, this._totalSequences);
108112
this._flashAttention = flashAttention;
109113
this._idealThreads = typeof threads === "number"
@@ -124,7 +128,7 @@ export class LlamaContext {
124128
this._performanceTracking = !!performanceTracking;
125129
this._swaFullCache = !!swaFullCache;
126130
this._ctx = new this._llama._bindings.AddonContext(this._model._model, removeNullFields({
127-
contextSize: this._contextSize * this._totalSequences, // each sequence needs its own <contextSize> of cells
131+
contextSize: padSafeContextSize(this._contextSize * this._totalSequences, "up"), // each sequence needs its own <contextSize> of cells
128132
batchSize: this._batchSize + (
129133
(!this._swaFullCache && this.model.fileInsights.swaSize != null && this.model.fileInsights.swaSize > 0)
130134
? 1 // +1 to handle edge cases with SWA KV cache
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
const contextSizePad = 256;
2+
3+
export function padSafeContextSize(value: number, padDirection: "up" | "down", padding: number = contextSizePad) {
4+
const paddedSize = ggmlPad(value, padding);
5+
6+
if (paddedSize === value)
7+
return value;
8+
else if (padDirection === "up")
9+
return paddedSize;
10+
else if (padDirection === "down") {
11+
const smallerPaddedSize = ggmlPad(value - padding, padding);
12+
if (smallerPaddedSize >= padding)
13+
return smallerPaddedSize;
14+
}
15+
16+
return paddedSize;
17+
}
18+
function ggmlPad(value: number, padding: number): number {
19+
return ((value + padding - 1) & ~(padding - 1));
20+
}

src/gguf/insights/GgufInsights.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import {GgufFileInfo} from "../types/GgufFileInfoTypes.js";
55
import {GgufTensorInfo} from "../types/GgufTensorInfoTypes.js";
66
import {GgufArchitectureType} from "../types/GgufMetadataTypes.js";
77
import {getReadablePath} from "../../cli/utils/getReadablePath.js";
8+
import {padSafeContextSize} from "../../evaluator/LlamaContext/utils/padSafeContextSize.js";
89
import {GgufInsightsConfigurationResolver} from "./GgufInsightsConfigurationResolver.js";
910
import {GgufInsightsTokens} from "./GgufInsightsTokens.js";
1011

@@ -211,6 +212,7 @@ export class GgufInsights {
211212
const llmData = this._ggufFileInfo.architectureMetadata;
212213
const tensorInfo = this._ggufFileInfo.fullTensorInfo ?? [];
213214
const slidingWindow = this.swaSize ?? 0;
215+
const kvUnified = false;
214216
const usingSWA = !swaFullCache && slidingWindow > 0 && slidingWindow < contextSize &&
215217
(this.trainContextSize == null || slidingWindow < this.trainContextSize);
216218
const swaPattern = getSwaPatternForArchitecture(this._ggufFileInfo.metadata?.general?.architecture);
@@ -220,7 +222,9 @@ export class GgufInsights {
220222

221223
// source: `llama_kv_cache_unified::get_padding` in `llama-kv-cache.cpp`
222224
const kvCachePadding = 1;
223-
const actualContextSize = sequences * contextSize;
225+
const actualContextSize = kvUnified
226+
? padSafeContextSize(sequences * contextSize, "up")
227+
: sequences * padSafeContextSize(contextSize, "up");
224228
const kvSize = usingSWA
225229
? (
226230
(1 - nonSwaPercent) * Math.min(actualContextSize, ggmlPad(sequences * slidingWindow + batchSize, kvCachePadding)) +

0 commit comments

Comments
 (0)