22
33#include " ../global_state/global_state.hpp"
44#include " ../utils.hpp"
5- #include " infinicore/ops.hpp"
6- #include < stdexcept>
75
86namespace infinilm ::cache {
97// ==========================
@@ -32,58 +30,12 @@ StaticKVCacheConfig::max_cache_len() const {
3230 return max_cache_len_;
3331}
3432
33+ namespace StaticKVCache {
34+
3535// ==========================
3636// StaticKVCache
3737// ==========================
38-
39- StaticKVCache::StaticKVCache (
40- infinicore::Size k_dim,
41- infinicore::Size v_dim,
42- infinicore::Size num_k_heads,
43- infinicore::Size num_v_heads,
44- infinicore::Size num_layers,
45- infinicore::Size max_positional_embedding,
46- infinicore::DataType dtype,
47- const StaticKVCacheConfig &config,
48- const engine::distributed::RankInfo &rank_info)
49- : Cache(),
50- k_dim_(k_dim),
51- v_dim_(v_dim),
52- rank_batch_size_(config.max_batch_size()),
53- cache_len_(config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()),
54- rank_num_layers_(num_layers),
55- dtype_(dtype) {
56-
57- bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0 );
58-
59- num_rank_k_heads_ = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size );
60- num_rank_v_heads_ = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size );
61- // Allocate K cache
62- k_caches_ = infinicore::Tensor::empty (
63- {rank_num_layers_,
64- rank_batch_size_,
65- num_rank_k_heads_,
66- cache_len_,
67- k_dim_},
68- dtype_,
69- rank_info.device );
70- set_zeros (k_caches_);
71-
72- // Allocate V cache
73- v_caches_ = infinicore::Tensor::empty (
74- {rank_num_layers_,
75- rank_batch_size_,
76- num_rank_v_heads_,
77- cache_len_,
78- v_dim_},
79- dtype_,
80- rank_info.device );
81- set_zeros (v_caches_);
82-
83- infinicore::context::syncStream ();
84- }
85-
86- infinicore::Tensor StaticKVCache::create_layer_kv_cache (
38+ infinicore::Tensor create_layer_kv_cache (
8739 const infinicore::Size k_dim,
8840 const infinicore::Size v_dim,
8941 const infinicore::Size num_k_heads,
@@ -120,45 +72,7 @@ infinicore::Tensor StaticKVCache::create_layer_kv_cache(
12072
12173 return kv_cache;
12274}
123-
124- std::tuple<infinicore::Tensor, infinicore::Tensor>
125- StaticKVCache::update (size_t layer_idx,
126- const infinicore::Tensor &k,
127- const infinicore::Tensor &v,
128- const infinicore::Tensor &past_sequence_lengths) {
129- ASSERT (layer_idx < rank_num_layers_);
130-
131- auto batch_size = k->size (0 );
132- auto update_len = k->size (2 );
133-
134- ASSERT_EQ (batch_size, rank_batch_size_);
135-
136- auto k_cache_layer = k_caches_->narrow ({{0 , layer_idx, 1 }})->squeeze (0 );
137- auto v_cache_layer = v_caches_->narrow ({{0 , layer_idx, 1 }})->squeeze (0 );
138-
139- auto device = k_cache_layer->device ();
140-
141- #ifdef ENABLE_KV_CACHING
142- infinicore::op::kv_caching_ (
143- k_cache_layer,
144- v_cache_layer,
145- k,
146- v,
147- past_sequence_lengths);
148- #else
149- size_t cache_pos = reinterpret_cast <int32_t *>(past_sequence_lengths->to (infinicore::Device::cpu ())->data ())[0 ];
150- auto result_len = cache_pos + update_len;
151- ASSERT (result_len <= cache_len_);
152-
153- auto k_cache_update = k_cache_layer->narrow ({{2 , cache_pos, update_len}});
154- auto v_cache_update = v_cache_layer->narrow ({{2 , cache_pos, update_len}});
155-
156- k_cache_update->copy_from (k);
157- v_cache_update->copy_from (v);
158- #endif
159-
160- return {k_cache_layer, v_cache_layer};
161- }
75+ }; // namespace StaticKVCache
16276
16377// ==========================
16478// PagedKVCacheConfig
@@ -185,56 +99,11 @@ PagedKVCacheConfig::block_size() const {
18599 return block_size_;
186100}
187101
102+ namespace PagedKVCache {
188103// ==========================
189104// PagedKVCache
190105// ==========================
191- PagedKVCache::PagedKVCache (
192- infinicore::Size k_dim,
193- infinicore::Size v_dim,
194- infinicore::Size num_k_heads,
195- infinicore::Size num_v_heads,
196- infinicore::Size num_layers,
197- infinicore::DataType dtype,
198- const PagedKVCacheConfig &config,
199- const engine::distributed::RankInfo &rank_info)
200- : Cache(),
201- k_dim_(k_dim),
202- v_dim_(v_dim),
203- rank_num_layers_(num_layers),
204- dtype_(dtype),
205- num_blocks_per_layer_(config.num_blocks()),
206- block_size_(config.block_size()) {
207-
208- bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0 );
209-
210- num_rank_k_heads_ = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size );
211- num_rank_v_heads_ = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size );
212- // [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
213- k_caches_ = infinicore::Tensor::empty (
214- {rank_num_layers_,
215- num_blocks_per_layer_,
216- num_rank_k_heads_,
217- block_size_,
218- k_dim_},
219- dtype_,
220- rank_info.device );
221- set_zeros (k_caches_);
222-
223- // [num_layers, num_blocks, num_rank_v_heads, block_size, v_dim]
224- v_caches_ = infinicore::Tensor::empty (
225- {rank_num_layers_,
226- num_blocks_per_layer_,
227- num_rank_v_heads_,
228- block_size_,
229- v_dim_},
230- dtype_,
231- rank_info.device );
232- set_zeros (v_caches_);
233-
234- infinicore::context::syncStream ();
235- }
236-
237- infinicore::Tensor PagedKVCache::create_layer_kv_cache (
106+ infinicore::Tensor create_layer_kv_cache (
238107 infinicore::Size k_dim,
239108 infinicore::Size v_dim,
240109 infinicore::Size num_k_heads,
@@ -273,86 +142,6 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache(
273142
274143 return kv_cache;
275144}
276-
277- std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update (
278- size_t layer_idx,
279- const infinicore::Tensor &k,
280- const infinicore::Tensor &v,
281- const infinicore::Tensor &slot_mapping) {
282-
283- auto &&[k_cache_layer, v_cache_layer] = get_paged_kv (layer_idx);
284-
285- infinicore::op::paged_caching_ (
286- k_cache_layer,
287- v_cache_layer,
288- k,
289- v,
290- slot_mapping);
291- return {k_cache_layer, v_cache_layer};
292- }
293-
294- std::tuple<infinicore::Tensor, infinicore::Tensor>
295- PagedKVCache::get_paged_kv (size_t layer_idx) {
296- auto k_cache_layer = k_caches_->narrow ({{0 , layer_idx, 1 }})->squeeze (0 );
297- auto v_cache_layer = v_caches_->narrow ({{0 , layer_idx, 1 }})->squeeze (0 );
298- return {k_cache_layer, v_cache_layer};
299- }
300-
301- std::tuple<infinicore::Tensor, infinicore::Tensor>
302- PagedKVCache::get_contiguous_kv (
303- size_t layer_idx,
304- const infinicore::Tensor block_tables,
305- const infinicore::Tensor cache_lens,
306- const infinicore::Tensor input_offsets,
307- size_t request_id) {
308- ASSERT_EQ (block_tables->dtype (), infinicore::DataType::I32 );
309- ASSERT_EQ (cache_lens->dtype (), infinicore::DataType::I32 );
310- ASSERT_EQ (input_offsets->dtype (), infinicore::DataType::I32 );
311-
312- auto nreq = block_tables->size (0 );
313- auto block_tables_cpu = block_tables->to (infinicore::Device::cpu ());
314- auto cache_lens_cpu = cache_lens->to (infinicore::Device::cpu ());
315- auto input_offsets_cpu = input_offsets->to (infinicore::Device::cpu ());
316- infinicore::context::syncDevice ();
317-
318- // [num_blocks, num_rank_v_heads, block_size, v_dim]
319- auto &&[k_cache_layer, v_cache_layer] = get_paged_kv (layer_idx);
320-
321- auto req = request_id;
322- auto cache_lens_ptr = reinterpret_cast <const int32_t *>(cache_lens_cpu->data ());
323- auto input_offsets_ptr = reinterpret_cast <const int32_t *>(input_offsets_cpu->data ());
324- int32_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1 ] - input_offsets_ptr[req]);
325-
326- auto full_k = infinicore::Tensor::empty (
327- {num_rank_k_heads_, (size_t )total_len, k_dim_},
328- k_cache_layer->dtype (), k_cache_layer->device ());
329-
330- auto full_v = infinicore::Tensor::empty (
331- {num_rank_v_heads_, (size_t )total_len, v_dim_},
332- v_cache_layer->dtype (), v_cache_layer->device ());
333-
334- size_t nblocks = total_len / block_size_;
335- size_t r = total_len % block_size_;
336-
337- for (size_t b = 0 ; b < nblocks; b++) {
338- size_t bid = *((int32_t *)(block_tables_cpu->narrow ({{0 , req, 1 }, {1 , b, 1 }})->data ()));
339-
340- full_k->narrow ({{1 , b * block_size_, block_size_}})
341- ->copy_from (k_cache_layer->narrow ({{0 , bid, 1 }})->squeeze (0 ));
342- full_v->narrow ({{1 , b * block_size_, block_size_}})
343- ->copy_from (v_cache_layer->narrow ({{0 , bid, 1 }})->squeeze (0 ));
344- }
345-
346- if (r > 0 ) {
347- size_t bid = *((int32_t *)(block_tables_cpu->narrow ({{0 , req, 1 }, {1 , nblocks, 1 }})->data ()));
348-
349- full_k->narrow ({{1 , nblocks * block_size_, r}})
350- ->copy_from (k_cache_layer->narrow ({{0 , bid, 1 }})->squeeze (0 )->narrow ({{1 , 0 , r}}));
351- full_v->narrow ({{1 , nblocks * block_size_, r}})
352- ->copy_from (v_cache_layer->narrow ({{0 , bid, 1 }})->squeeze (0 )->narrow ({{1 , 0 , r}}));
353- }
354-
355- return {full_k, full_v};
356- }
145+ }; // namespace PagedKVCache
357146
358147} // namespace infinilm::cache
0 commit comments