-
Notifications
You must be signed in to change notification settings - Fork 61
Expand file tree
/
Copy pathstatic_batching_compiler.cpp
More file actions
56 lines (48 loc) · 3.1 KB
/
static_batching_compiler.cpp
File metadata and controls
56 lines (48 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include "static_batching_compiler.hpp"
#include "../../cache/cache.hpp"
namespace infinilm::engine {
StaticBatchingCompiler::StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
: GraphCompiler(model, barrier) {
}
void StaticBatchingCompiler::compile() {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())) {
size_t b = dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())->max_batch_size();
InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice());
input.past_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
barrier_->wait();
infinicore::context::startGraphRecording();
auto output = model_->forward(input);
auto graph = infinicore::context::stopGraphRecording();
barrier_->wait();
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
compiled_map_[std::make_tuple(b, 1)] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
}
}
StaticBatchingCompiler::Compiled StaticBatchingCompiler::get_compiled(
const InfinilmModel::Input &input) {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())) {
size_t batch_size = input.input_ids.value()->size(0);
size_t seqlen = input.input_ids.value()->size(1);
auto result = compiled_map_.find(std::make_tuple(batch_size, seqlen));
if (result == compiled_map_.end()) {
return std::make_tuple(nullptr, nullptr);
} else {
auto &graph_input = result->second.input;
graph_input.input_ids.value()->copy_from(input.input_ids.value());
graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.past_sequence_lengths.value()->copy_from(input.past_sequence_lengths.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
auto graph = std::get<0>(result->second.compiled);
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()});
return std::make_tuple(graph, shared_output);
}
} else {
return std::make_tuple(nullptr, nullptr);
}
}
} // namespace infinilm::engine