-
Notifications
You must be signed in to change notification settings - Fork 61
Expand file tree
/
Copy pathpaged_compiler.cpp
More file actions
107 lines (95 loc) · 5.32 KB
/
paged_compiler.cpp
File metadata and controls
107 lines (95 loc) · 5.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#include "paged_compiler.hpp"
namespace {
// Todo: replace with Tensor::zeros when it is available
inline void set_zeros(infinicore::Tensor &tensor) {
std::vector<uint8_t> zeros(tensor->nbytes(), 0);
infinicore::context::memcpyH2D(tensor->data(), zeros.data(), tensor->nbytes(), false);
}
} // namespace
namespace infinilm::engine {
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
: GraphCompiler(model, barrier) {
for (size_t b = 1; b < 32; b++) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 32; b < 64; b += 8) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 64; b < 128; b += 16) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 128; b < 256; b += 32) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 256; b <= 512; b += 64) {
decode_batch_sizes_.push_back(b);
}
}
void PagedCompiler::compile() {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
size_t nblocks = dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())->num_blocks();
size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end());
compiled_map_decode_.clear();
block_tables_holder_ = infinicore::Tensor::empty(
{nblocks}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(block_tables_holder_);
for (size_t b : decode_batch_sizes_) {
size_t block_per_req = nblocks / b;
InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.input_ids.value());
set_zeros(input.position_ids.value());
set_zeros(input.total_sequence_lengths.value());
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.input_offsets.value());
std::vector<int64_t> input_offsets_vec(b + 1, 0);
for (size_t i = 0; i <= b; i++) {
input_offsets_vec[i] = i;
}
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int64_t), false);
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.slot_mapping.value());
barrier_->wait();
infinicore::context::startGraphRecording();
auto output = model_->forward(input);
auto graph = infinicore::context::stopGraphRecording();
barrier_->wait();
auto shared_output = std::shared_ptr<InfinilmModel::Output>(
new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
compiled_map_decode_[b] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
}
}
}
PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input &input) {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
size_t batch_size = input.block_tables.value()->size(0);
size_t block_per_req = input.block_tables.value()->size(1);
// only support decode only batch
if (batch_size != input.input_ids.value()->size(1)) {
return {nullptr, nullptr};
} else {
auto result = compiled_map_decode_.find(batch_size);
if (result == compiled_map_decode_.end()) {
return {nullptr, nullptr};
}
auto &graph_input = result->second.input;
graph_input.input_ids.value()->copy_from(input.input_ids.value());
graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());
auto graph = std::get<0>(result->second.compiled);
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()});
return std::make_tuple(graph, shared_output);
}
} else {
return {nullptr, nullptr};
}
}
} // namespace infinilm::engine