Skip to content

Commit 1b73c32

Browse files
author
chilukam-qti
committed
KV Cache optimization Based on SeqLen
Optimized Sliding Window based KVCache copy from present to past by copying only cache for seqlen instead of entire context length
1 parent 8b61ec4 commit 1b73c32

1 file changed

Lines changed: 11 additions & 8 deletions

File tree

src/models/windowed_kv_cache.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) {
140140
const auto& layer_state = per_layer_states_[layer_idx];
141141

142142
const auto window_size = layer_state.window_size;
143+
const auto seq_len = layer_state.window_index * layer_state.window_size;
143144
const auto& key_cache_shape_in = layer_state.key_cache_shape_in;
144145
const auto& key_cache_shape_out = layer_state.key_cache_shape_out;
145146
const auto& value_cache_shape_in = layer_state.value_cache_shape_in;
@@ -151,10 +152,10 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) {
151152
int64_t num_key_cache_chunks = key_cache_shape_in[0] * key_cache_shape_in[2];
152153
for (int64_t j = 0; j < num_key_cache_chunks; ++j) {
153154
{
154-
cpu_span<uint8_t> key_cache_dst(key_cache_in_data + j * key_cache_shape_in[3],
155-
key_cache_shape_in[3] - window_size);
156-
cpu_span<uint8_t> key_cache_src(key_cache_in_data + j * key_cache_shape_in[3] + window_size,
157-
key_cache_shape_in[3] - window_size);
155+
cpu_span<uint8_t> key_cache_dst(key_cache_in_data + j * key_cache_shape_in[3] + key_cache_shape_in[3] - seq_len - window_size,
156+
seq_len);
157+
cpu_span<uint8_t> key_cache_src(key_cache_in_data + j * key_cache_shape_in[3] + key_cache_shape_in[3] - seq_len,
158+
seq_len);
158159
std::copy(key_cache_src.begin(), key_cache_src.end(), key_cache_dst.begin());
159160
}
160161
{
@@ -171,11 +172,12 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) {
171172

172173
for (int64_t j = 0; j < value_cache_shape_in[0]; ++j) {
173174
{
174-
cpu_span<uint8_t> value_cache_dst(value_cache_in_data + (j * value_cache_shape_in[2] * value_cache_shape_in[3]),
175-
(value_cache_shape_in[2] - window_size) * value_cache_shape_in[3]);
175+
cpu_span<uint8_t> value_cache_dst(value_cache_in_data + (j * value_cache_shape_in[2] * value_cache_shape_in[3]) +
176+
((value_cache_shape_in[2] - seq_len - window_size) * value_cache_shape_in[3]),
177+
seq_len * value_cache_shape_in[3]);
176178
cpu_span<uint8_t> value_cache_src(value_cache_in_data + (j * value_cache_shape_in[2] * value_cache_shape_in[3]) +
177-
(window_size * value_cache_shape_in[3]),
178-
(value_cache_shape_in[2] - window_size) * value_cache_shape_in[3]);
179+
((value_cache_shape_in[2] - seq_len) * value_cache_shape_in[3]),
180+
seq_len * value_cache_shape_in[3]);
179181
std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin());
180182
}
181183
{
@@ -287,6 +289,7 @@ void WindowedKeyValueCache::TransitionLayerToTokenGeneration(size_t layer_idx) {
287289
value_caches_out_[layer_idx] = OrtValue::CreateTensor(Allocator(), updated_value_cache_shape_out, type_);
288290

289291
// update values in per-layer state
292+
layer_state.window_index = layer_state.window_index * layer_state.window_size / updated_window_size;
290293
layer_state.window_size = updated_window_size;
291294
layer_state.key_cache_shape_in = updated_key_cache_shape_in;
292295
layer_state.value_cache_shape_in = updated_value_cache_shape_in;

0 commit comments

Comments
 (0)