@@ -22,6 +22,7 @@ import {
2222import { resolveBatchItemsPrioritizationStrategy } from "./utils/resolveBatchItemsPrioritizationStrategy.js" ;
2323import { LlamaSampler } from "./LlamaSampler.js" ;
2424import { TokenPredictor } from "./TokenPredictor.js" ;
25+ import { padSafeContextSize } from "./utils/padSafeContextSize.js" ;
2526import type { Llama } from "../../bindings/Llama.js" ;
2627
2728const defaultLoraScale = 1 ;
@@ -98,12 +99,15 @@ export class LlamaContext {
9899 if ( _model . disposed )
99100 throw new DisposedError ( ) ;
100101
102+ const kvUnified = false ;
101103 this . _llama = _model . _llama ;
102104 this . _model = _model ;
103105 this . _backendContextDisposeGuard = new DisposeGuard ( [ this . _model . _backendModelDisposeGuard ] ) ;
104106 this . _modelPreventDisposalHandle = this . _model . _backendModelDisposeGuard . createPreventDisposalHandle ( ) ;
105107 this . _totalSequences = Math . max ( 1 , Math . floor ( sequences ) ) ;
106- this . _contextSize = Math . max ( 2 , contextSize ) ;
108+ this . _contextSize = kvUnified
109+ ? Math . floor ( padSafeContextSize ( Math . max ( 2 , contextSize ) * this . _totalSequences , "up" ) / this . _totalSequences )
110+ : padSafeContextSize ( Math . max ( 2 , contextSize ) , "up" ) ;
107111 this . _batchSize = Math . max ( batchSize , this . _totalSequences ) ;
108112 this . _flashAttention = flashAttention ;
109113 this . _idealThreads = typeof threads === "number"
@@ -124,7 +128,7 @@ export class LlamaContext {
124128 this . _performanceTracking = ! ! performanceTracking ;
125129 this . _swaFullCache = ! ! swaFullCache ;
126130 this . _ctx = new this . _llama . _bindings . AddonContext ( this . _model . _model , removeNullFields ( {
127- contextSize : this . _contextSize * this . _totalSequences , // each sequence needs its own <contextSize> of cells
131+ contextSize : padSafeContextSize ( this . _contextSize * this . _totalSequences , "up" ) , // each sequence needs its own <contextSize> of cells
128132 batchSize : this . _batchSize + (
129133 ( ! this . _swaFullCache && this . model . fileInsights . swaSize != null && this . model . fileInsights . swaSize > 0 )
130134 ? 1 // +1 to handle edge cases with SWA KV cache
0 commit comments