Skip to content

Commit bb68ca5

Browse files
committed
add chunk_prefill_compiler.cpp/.hpp
1 parent f2c8bab commit bb68ca5

2 files changed

Lines changed: 228 additions & 0 deletions

File tree

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#include "chunk_prefill_compiler.hpp"
2+
#include "infinicore/context/context.hpp"
3+
4+
5+
namespace {
6+
inline void set_zeros(infinicore::Tensor &tensor) {
7+
std::vector<uint8_t> zeros(tensor->nbytes(), 0);
8+
infinicore::context::memcpyH2D(tensor->data(), zeros.data(), tensor->nbytes(), false);
9+
}
10+
} // namespace
11+
12+
namespace infinilm::engine {
13+
14+
ChunkPrefillCompiler::ChunkPrefillCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
15+
: GraphCompiler(model, barrier) {
16+
// Enumerate chunk sizes for chunk-prefill
17+
for (size_t cs : {64, 128, 256, 512, 1024, 2048}) {
18+
chunk_sizes_.push_back(cs);
19+
}
20+
// Enumerate batch sizes for prefill (typically smaller than decode)
21+
for (size_t b = 1; b < 32; b++) {
22+
prefill_batch_sizes_.push_back(b);
23+
}
24+
for (size_t b = 32; b < 64; b += 8) {
25+
prefill_batch_sizes_.push_back(b);
26+
}
27+
for (size_t b = 64; b < 128; b += 16) {
28+
prefill_batch_sizes_.push_back(b);
29+
}
30+
for (size_t b = 128; b < 256; b += 32) {
31+
prefill_batch_sizes_.push_back(b);
32+
}
33+
for (size_t b = 256; b <= 512; b += 64) {
34+
prefill_batch_sizes_.push_back(b);
35+
}
36+
}
37+
38+
void ChunkPrefillCompiler::compile() {
39+
if (model_->get_cache_config() != nullptr &&
40+
dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
41+
42+
const auto *paged_config =
43+
dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config());
44+
size_t nblocks = paged_config->num_blocks();
45+
46+
compiled_map_prefill_.clear();
47+
48+
// Max total tokens to avoid OOM during graph recording
49+
constexpr size_t MAX_TOTAL_TOKENS = 4096;
50+
51+
// Pre-allocate a shared block_tables_holder for the largest (batch_size) we'll use
52+
size_t max_batch = *std::max_element(prefill_batch_sizes_.begin(), prefill_batch_sizes_.end());
53+
size_t block_per_req = nblocks / max_batch;
54+
block_tables_holder_ = infinicore::Tensor::empty(
55+
{nblocks}, infinicore::DataType::I32, infinicore::context::getDevice());
56+
set_zeros(block_tables_holder_);
57+
58+
for (size_t b : prefill_batch_sizes_) {
59+
for (size_t cs : chunk_sizes_) {
60+
size_t total_tokens = b * cs;
61+
if (total_tokens > MAX_TOTAL_TOKENS) {
62+
continue;
63+
}
64+
65+
size_t bpr = nblocks / b; // block_per_req for this batch size
66+
67+
InfinilmModel::Input input;
68+
69+
// input_ids: [1, total_tokens] — all tokens for this batch packed together
70+
input.input_ids = infinicore::Tensor::empty(
71+
{1, total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice());
72+
set_zeros(input.input_ids.value());
73+
74+
// position_ids: [total_tokens]
75+
input.position_ids = infinicore::Tensor::empty(
76+
{total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice());
77+
set_zeros(input.position_ids.value());
78+
79+
// total_sequence_lengths: [b], set to cs (first-chunk scenario)
80+
input.total_sequence_lengths = infinicore::Tensor::empty(
81+
{b}, infinicore::DataType::I32, infinicore::context::getDevice());
82+
{
83+
std::vector<int32_t> tsl(b, static_cast<int32_t>(cs));
84+
infinicore::context::memcpyH2D(
85+
input.total_sequence_lengths.value()->data(),
86+
tsl.data(), b * sizeof(int32_t), false);
87+
}
88+
89+
// input_offsets: [b+1], stride = cs
90+
input.input_offsets = infinicore::Tensor::empty(
91+
{b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
92+
{
93+
std::vector<int32_t> offsets(b + 1);
94+
for (size_t i = 0; i <= b; i++) {
95+
offsets[i] = static_cast<int32_t>(i * cs);
96+
}
97+
infinicore::context::memcpyH2D(
98+
input.input_offsets.value()->data(),
99+
offsets.data(), (b + 1) * sizeof(int32_t), false);
100+
}
101+
102+
// cu_seqlens: [b+1], same layout as input_offsets for prefill
103+
input.cu_seqlens = infinicore::Tensor::empty(
104+
{b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
105+
{
106+
std::vector<int32_t> cu(b + 1);
107+
for (size_t i = 0; i <= b; i++) {
108+
cu[i] = static_cast<int32_t>(i * cs);
109+
}
110+
infinicore::context::memcpyH2D(
111+
input.cu_seqlens.value()->data(),
112+
cu.data(), (b + 1) * sizeof(int32_t), false);
113+
}
114+
115+
// block_tables: view into the shared holder [b, bpr]
116+
input.block_tables = block_tables_holder_->as_strided(
117+
{b, bpr}, {(ptrdiff_t)bpr, 1});
118+
119+
// slot_mapping: [total_tokens]
120+
input.slot_mapping = infinicore::Tensor::empty(
121+
{total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice());
122+
set_zeros(input.slot_mapping.value());
123+
124+
barrier_->wait();
125+
infinicore::context::startGraphRecording();
126+
auto output = model_->forward(input);
127+
auto graph = infinicore::context::stopGraphRecording();
128+
barrier_->wait();
129+
130+
auto shared_output = std::shared_ptr<InfinilmModel::Output>(
131+
new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
132+
133+
compiled_map_prefill_[std::make_tuple(b, cs)] =
134+
CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
135+
}
136+
}
137+
}
138+
}
139+
140+
ChunkPrefillCompiler::Compiled ChunkPrefillCompiler::get_compiled(const InfinilmModel::Input &input) {
141+
if (model_->get_cache_config() == nullptr ||
142+
!dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
143+
return {nullptr, nullptr};
144+
}
145+
146+
if (!input.block_tables.has_value() || !input.input_ids.has_value()) {
147+
return {nullptr, nullptr};
148+
}
149+
150+
size_t batch_size = input.block_tables.value()->size(0);
151+
size_t block_per_req = input.block_tables.value()->size(1);
152+
size_t total_tokens = input.input_ids.value()->size(1);
153+
154+
// Prefill: total_tokens is a multiple of batch_size, and chunk_size > 1
155+
if (total_tokens == 0 || total_tokens % batch_size != 0) {
156+
return {nullptr, nullptr};
157+
}
158+
size_t chunk_size = total_tokens / batch_size;
159+
if (chunk_size <= 1) {
160+
// Single-token case belongs to decode
161+
return {nullptr, nullptr};
162+
}
163+
164+
auto result = compiled_map_prefill_.find(std::make_tuple(batch_size, chunk_size));
165+
if (result == compiled_map_prefill_.end()) {
166+
return {nullptr, nullptr};
167+
}
168+
169+
auto &graph_input = result->second.input;
170+
171+
graph_input.input_ids.value()->copy_from(input.input_ids.value());
172+
graph_input.position_ids.value()->copy_from(input.position_ids.value());
173+
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
174+
graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
175+
graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value());
176+
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
177+
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());
178+
179+
auto graph = std::get<0>(result->second.compiled);
180+
auto shared_output = std::shared_ptr<InfinilmModel::Output>(
181+
new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()});
182+
183+
return std::make_tuple(graph, shared_output);
184+
}
185+
186+
} // namespace infinilm::engine
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#pragma once
2+
3+
#include "graph_compiler.hpp"
4+
5+
#include <unordered_map>
6+
7+
namespace infinilm::engine {
8+
class ChunkPrefillCompiler : public GraphCompiler {
9+
public:
10+
ChunkPrefillCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
11+
12+
void compile() override;
13+
14+
Compiled get_compiled(const InfinilmModel::Input &input) override;
15+
16+
private:
17+
struct TupleHash {
18+
size_t operator()(const std::tuple<size_t, size_t> &t) const noexcept {
19+
auto h1 = std::hash<size_t>{}(std::get<0>(t));
20+
auto h2 = std::hash<size_t>{}(std::get<1>(t));
21+
return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2));
22+
}
23+
};
24+
25+
std::vector<size_t> chunk_sizes_;
26+
std::vector<size_t> prefill_batch_sizes_;
27+
28+
infinicore::Tensor block_tables_holder_;
29+
30+
struct CompiledResult {
31+
InfinilmModel::Input input;
32+
Compiled compiled;
33+
};
34+
35+
// Key: (batch_size, chunk_size)
36+
std::unordered_map<
37+
std::tuple<size_t, size_t>,
38+
CompiledResult,
39+
TupleHash>
40+
compiled_map_prefill_;
41+
};
42+
} // namespace infinilm::engine

0 commit comments

Comments
 (0)