Skip to content

Commit 0fa9dbc

Browse files
author
wangpengcheng
committed
issue/424 -Clean up unused code.
1 parent 3d7585b commit 0fa9dbc

37 files changed

Lines changed: 63 additions & 2781 deletions

csrc/cache/kv_cache.cpp

Lines changed: 7 additions & 218 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
#include "../global_state/global_state.hpp"
44
#include "../utils.hpp"
5-
#include "infinicore/ops.hpp"
6-
#include <stdexcept>
75

86
namespace infinilm::cache {
97
// ==========================
@@ -32,58 +30,12 @@ StaticKVCacheConfig::max_cache_len() const {
3230
return max_cache_len_;
3331
}
3432

33+
namespace StaticKVCache {
34+
3535
// ==========================
3636
// StaticKVCache
3737
// ==========================
38-
39-
StaticKVCache::StaticKVCache(
40-
infinicore::Size k_dim,
41-
infinicore::Size v_dim,
42-
infinicore::Size num_k_heads,
43-
infinicore::Size num_v_heads,
44-
infinicore::Size num_layers,
45-
infinicore::Size max_positional_embedding,
46-
infinicore::DataType dtype,
47-
const StaticKVCacheConfig &config,
48-
const engine::distributed::RankInfo &rank_info)
49-
: Cache(),
50-
k_dim_(k_dim),
51-
v_dim_(v_dim),
52-
rank_batch_size_(config.max_batch_size()),
53-
cache_len_(config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()),
54-
rank_num_layers_(num_layers),
55-
dtype_(dtype) {
56-
57-
bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0);
58-
59-
num_rank_k_heads_ = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size);
60-
num_rank_v_heads_ = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size);
61-
// Allocate K cache
62-
k_caches_ = infinicore::Tensor::empty(
63-
{rank_num_layers_,
64-
rank_batch_size_,
65-
num_rank_k_heads_,
66-
cache_len_,
67-
k_dim_},
68-
dtype_,
69-
rank_info.device);
70-
set_zeros(k_caches_);
71-
72-
// Allocate V cache
73-
v_caches_ = infinicore::Tensor::empty(
74-
{rank_num_layers_,
75-
rank_batch_size_,
76-
num_rank_v_heads_,
77-
cache_len_,
78-
v_dim_},
79-
dtype_,
80-
rank_info.device);
81-
set_zeros(v_caches_);
82-
83-
infinicore::context::syncStream();
84-
}
85-
86-
infinicore::Tensor StaticKVCache::create_layer_kv_cache(
38+
infinicore::Tensor create_layer_kv_cache(
8739
const infinicore::Size k_dim,
8840
const infinicore::Size v_dim,
8941
const infinicore::Size num_k_heads,
@@ -120,45 +72,7 @@ infinicore::Tensor StaticKVCache::create_layer_kv_cache(
12072

12173
return kv_cache;
12274
}
123-
124-
std::tuple<infinicore::Tensor, infinicore::Tensor>
125-
StaticKVCache::update(size_t layer_idx,
126-
const infinicore::Tensor &k,
127-
const infinicore::Tensor &v,
128-
const infinicore::Tensor &past_sequence_lengths) {
129-
ASSERT(layer_idx < rank_num_layers_);
130-
131-
auto batch_size = k->size(0);
132-
auto update_len = k->size(2);
133-
134-
ASSERT_EQ(batch_size, rank_batch_size_);
135-
136-
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
137-
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
138-
139-
auto device = k_cache_layer->device();
140-
141-
#ifdef ENABLE_KV_CACHING
142-
infinicore::op::kv_caching_(
143-
k_cache_layer,
144-
v_cache_layer,
145-
k,
146-
v,
147-
past_sequence_lengths);
148-
#else
149-
size_t cache_pos = reinterpret_cast<int32_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
150-
auto result_len = cache_pos + update_len;
151-
ASSERT(result_len <= cache_len_);
152-
153-
auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}});
154-
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});
155-
156-
k_cache_update->copy_from(k);
157-
v_cache_update->copy_from(v);
158-
#endif
159-
160-
return {k_cache_layer, v_cache_layer};
161-
}
75+
}; // namespace StaticKVCache
16276

16377
// ==========================
16478
// PagedKVCacheConfig
@@ -185,56 +99,11 @@ PagedKVCacheConfig::block_size() const {
18599
return block_size_;
186100
}
187101

102+
namespace PagedKVCache {
188103
// ==========================
189104
// PagedKVCache
190105
// ==========================
191-
PagedKVCache::PagedKVCache(
192-
infinicore::Size k_dim,
193-
infinicore::Size v_dim,
194-
infinicore::Size num_k_heads,
195-
infinicore::Size num_v_heads,
196-
infinicore::Size num_layers,
197-
infinicore::DataType dtype,
198-
const PagedKVCacheConfig &config,
199-
const engine::distributed::RankInfo &rank_info)
200-
: Cache(),
201-
k_dim_(k_dim),
202-
v_dim_(v_dim),
203-
rank_num_layers_(num_layers),
204-
dtype_(dtype),
205-
num_blocks_per_layer_(config.num_blocks()),
206-
block_size_(config.block_size()) {
207-
208-
bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0);
209-
210-
num_rank_k_heads_ = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size);
211-
num_rank_v_heads_ = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size);
212-
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
213-
k_caches_ = infinicore::Tensor::empty(
214-
{rank_num_layers_,
215-
num_blocks_per_layer_,
216-
num_rank_k_heads_,
217-
block_size_,
218-
k_dim_},
219-
dtype_,
220-
rank_info.device);
221-
set_zeros(k_caches_);
222-
223-
// [num_layers, num_blocks, num_rank_v_heads, block_size, v_dim]
224-
v_caches_ = infinicore::Tensor::empty(
225-
{rank_num_layers_,
226-
num_blocks_per_layer_,
227-
num_rank_v_heads_,
228-
block_size_,
229-
v_dim_},
230-
dtype_,
231-
rank_info.device);
232-
set_zeros(v_caches_);
233-
234-
infinicore::context::syncStream();
235-
}
236-
237-
infinicore::Tensor PagedKVCache::create_layer_kv_cache(
106+
infinicore::Tensor create_layer_kv_cache(
238107
infinicore::Size k_dim,
239108
infinicore::Size v_dim,
240109
infinicore::Size num_k_heads,
@@ -273,86 +142,6 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache(
273142

274143
return kv_cache;
275144
}
276-
277-
std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
278-
size_t layer_idx,
279-
const infinicore::Tensor &k,
280-
const infinicore::Tensor &v,
281-
const infinicore::Tensor &slot_mapping) {
282-
283-
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);
284-
285-
infinicore::op::paged_caching_(
286-
k_cache_layer,
287-
v_cache_layer,
288-
k,
289-
v,
290-
slot_mapping);
291-
return {k_cache_layer, v_cache_layer};
292-
}
293-
294-
std::tuple<infinicore::Tensor, infinicore::Tensor>
295-
PagedKVCache::get_paged_kv(size_t layer_idx) {
296-
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
297-
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
298-
return {k_cache_layer, v_cache_layer};
299-
}
300-
301-
std::tuple<infinicore::Tensor, infinicore::Tensor>
302-
PagedKVCache::get_contiguous_kv(
303-
size_t layer_idx,
304-
const infinicore::Tensor block_tables,
305-
const infinicore::Tensor cache_lens,
306-
const infinicore::Tensor input_offsets,
307-
size_t request_id) {
308-
ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I32);
309-
ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I32);
310-
ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I32);
311-
312-
auto nreq = block_tables->size(0);
313-
auto block_tables_cpu = block_tables->to(infinicore::Device::cpu());
314-
auto cache_lens_cpu = cache_lens->to(infinicore::Device::cpu());
315-
auto input_offsets_cpu = input_offsets->to(infinicore::Device::cpu());
316-
infinicore::context::syncDevice();
317-
318-
// [num_blocks, num_rank_v_heads, block_size, v_dim]
319-
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);
320-
321-
auto req = request_id;
322-
auto cache_lens_ptr = reinterpret_cast<const int32_t *>(cache_lens_cpu->data());
323-
auto input_offsets_ptr = reinterpret_cast<const int32_t *>(input_offsets_cpu->data());
324-
int32_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]);
325-
326-
auto full_k = infinicore::Tensor::empty(
327-
{num_rank_k_heads_, (size_t)total_len, k_dim_},
328-
k_cache_layer->dtype(), k_cache_layer->device());
329-
330-
auto full_v = infinicore::Tensor::empty(
331-
{num_rank_v_heads_, (size_t)total_len, v_dim_},
332-
v_cache_layer->dtype(), v_cache_layer->device());
333-
334-
size_t nblocks = total_len / block_size_;
335-
size_t r = total_len % block_size_;
336-
337-
for (size_t b = 0; b < nblocks; b++) {
338-
size_t bid = *((int32_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data()));
339-
340-
full_k->narrow({{1, b * block_size_, block_size_}})
341-
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
342-
full_v->narrow({{1, b * block_size_, block_size_}})
343-
->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
344-
}
345-
346-
if (r > 0) {
347-
size_t bid = *((int32_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data()));
348-
349-
full_k->narrow({{1, nblocks * block_size_, r}})
350-
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
351-
full_v->narrow({{1, nblocks * block_size_, r}})
352-
->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
353-
}
354-
355-
return {full_k, full_v};
356-
}
145+
}; // namespace PagedKVCache
357146

358147
} // namespace infinilm::cache

0 commit comments

Comments
 (0)