-
Notifications
You must be signed in to change notification settings - Fork 72
Expand file tree
/
Copy pathpaged_compiler.cpp
More file actions
125 lines (111 loc) · 6.51 KB
/
Copy pathpaged_compiler.cpp
File metadata and controls
125 lines (111 loc) · 6.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include "paged_compiler.hpp"
#include "../../global_state/global_state.hpp"
#include "../../utils.hpp"
namespace infinilm::engine {
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
: GraphCompiler(model, barrier) {
for (size_t b = 1; b < 64; ++b) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 64; b < 128; b += 16) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 128; b < 256; b += 32) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 256; b <= 512; b += 64) {
decode_batch_sizes_.push_back(b);
}
}
void PagedCompiler::compile() {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
size_t nblocks = dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())->num_blocks();
size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end());
compiled_map_decode_.clear();
block_tables_holder_ = infinicore::Tensor::empty(
{nblocks * max_batch_size}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(block_tables_holder_);
for (size_t b : decode_batch_sizes_) {
InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I32, infinicore::context::getDevice());
set_zeros(input.input_ids.value());
set_zeros(input.position_ids.value());
set_zeros(input.total_sequence_lengths.value());
std::vector<int32_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int32_t), false);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
std::vector<int32_t> input_offsets_vec(b + 1, 0);
for (size_t i = 0; i <= b; i++) {
input_offsets_vec[i] = i;
}
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int32_t), false);
input.cu_seqlens = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
infinicore::context::memcpyH2D(input.cu_seqlens.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int32_t), false);
const size_t block_per_req = nblocks;
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
set_zeros(input.slot_mapping.value());
// Attention reads attn_metadata from thread-local forward context.
infinilm::global_state::get_forward_context().attn_metadata = {
input.past_sequence_lengths,
input.total_sequence_lengths,
input.input_offsets,
input.cu_seqlens,
input.block_tables,
input.slot_mapping,
};
barrier_->wait();
infinicore::context::startGraphRecording();
auto output = model_->forward(input);
auto graph = infinicore::context::stopGraphRecording();
barrier_->wait();
auto shared_output = std::shared_ptr<InfinilmModel::Output>(
new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
compiled_map_decode_[b] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
}
}
}
PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input &input) {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
size_t batch_size = input.block_tables.value()->size(0);
size_t block_per_req = input.block_tables.value()->size(1);
// only support decode only batch
if (batch_size != input.input_ids.value()->size(1)) {
return {nullptr, nullptr};
} else {
auto result = compiled_map_decode_.find(batch_size);
if (result == compiled_map_decode_.end()) {
return {nullptr, nullptr};
}
auto &graph_input = result->second.input;
graph_input.input_ids.value()->copy_from(input.input_ids.value());
graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value());
const size_t compiled_block_per_req = graph_input.block_tables.value()->size(1);
if (block_per_req > compiled_block_per_req) {
// Runtime width exceeds compiled graph slot; fall back to eager path.
return {nullptr, nullptr};
}
// Initialize full padding to -1, then overwrite the narrowed logical region.
// This matches scheduler padding semantics without risking -1 access during graph recording.
auto &graph_block_tables = graph_input.block_tables.value();
set_minus_one(graph_block_tables);
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());
auto graph = std::get<0>(result->second.compiled);
// Reuse the GraphTensor output captured at compile time.
// Do not call resume_from_blob_() on workspace-backed logits:
// that registers a second deleter on the same GPU block and
// triggers double free in PinnableBlockAllocator.
auto shared_output = std::get<1>(result->second.compiled);
return std::make_tuple(graph, shared_output);
}
} else {
return {nullptr, nullptr};
}
}
} // namespace infinilm::engine