@@ -140,6 +140,7 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) {
140140 const auto & layer_state = per_layer_states_[layer_idx];
141141
142142 const auto window_size = layer_state.window_size ;
143+ const auto seq_len = layer_state.window_index * layer_state.window_size ;
143144 const auto & key_cache_shape_in = layer_state.key_cache_shape_in ;
144145 const auto & key_cache_shape_out = layer_state.key_cache_shape_out ;
145146 const auto & value_cache_shape_in = layer_state.value_cache_shape_in ;
@@ -151,10 +152,10 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) {
151152 int64_t num_key_cache_chunks = key_cache_shape_in[0 ] * key_cache_shape_in[2 ];
152153 for (int64_t j = 0 ; j < num_key_cache_chunks; ++j) {
153154 {
154- cpu_span<uint8_t > key_cache_dst (key_cache_in_data + j * key_cache_shape_in[3 ],
155- key_cache_shape_in[ 3 ] - window_size );
156- cpu_span<uint8_t > key_cache_src (key_cache_in_data + j * key_cache_shape_in[3 ] + window_size ,
157- key_cache_shape_in[ 3 ] - window_size );
155+ cpu_span<uint8_t > key_cache_dst (key_cache_in_data + j * key_cache_shape_in[3 ] + key_cache_shape_in[ 3 ] - seq_len - window_size ,
156+ seq_len );
157+ cpu_span<uint8_t > key_cache_src (key_cache_in_data + j * key_cache_shape_in[3 ] + key_cache_shape_in[ 3 ] - seq_len ,
158+ seq_len );
158159 std::copy (key_cache_src.begin (), key_cache_src.end (), key_cache_dst.begin ());
159160 }
160161 {
@@ -171,11 +172,12 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) {
171172
172173 for (int64_t j = 0 ; j < value_cache_shape_in[0 ]; ++j) {
173174 {
174- cpu_span<uint8_t > value_cache_dst (value_cache_in_data + (j * value_cache_shape_in[2 ] * value_cache_shape_in[3 ]),
175- (value_cache_shape_in[2 ] - window_size) * value_cache_shape_in[3 ]);
175+ cpu_span<uint8_t > value_cache_dst (value_cache_in_data + (j * value_cache_shape_in[2 ] * value_cache_shape_in[3 ]) +
176+ ((value_cache_shape_in[2 ] - seq_len - window_size) * value_cache_shape_in[3 ]),
177+ seq_len * value_cache_shape_in[3 ]);
176178 cpu_span<uint8_t > value_cache_src (value_cache_in_data + (j * value_cache_shape_in[2 ] * value_cache_shape_in[3 ]) +
177- (window_size * value_cache_shape_in[3 ]),
178- (value_cache_shape_in[ 2 ] - window_size) * value_cache_shape_in[3 ]);
179+ ((value_cache_shape_in[ 2 ] - seq_len) * value_cache_shape_in[3 ]),
180+ seq_len * value_cache_shape_in[3 ]);
179181 std::copy (value_cache_src.begin (), value_cache_src.end (), value_cache_dst.begin ());
180182 }
181183 {
@@ -287,6 +289,7 @@ void WindowedKeyValueCache::TransitionLayerToTokenGeneration(size_t layer_idx) {
287289 value_caches_out_[layer_idx] = OrtValue::CreateTensor (Allocator (), updated_value_cache_shape_out, type_);
288290
289291 // update values in per-layer state
292+ layer_state.window_index = layer_state.window_index * layer_state.window_size / updated_window_size;
290293 layer_state.window_size = updated_window_size;
291294 layer_state.key_cache_shape_in = updated_key_cache_shape_in;
292295 layer_state.value_cache_shape_in = updated_value_cache_shape_in;
0 commit comments