Skip to content

Commit fde8d7c

Browse files
Simon12345777wooway777
authored andcommitted
Add feature ChunkPrefill
1 parent e4cd0d7 commit fde8d7c

21 files changed

Lines changed: 1209 additions & 76 deletions
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#include "chunk_prefill_compiler.hpp"
2+
#include "../../global_state/global_state.hpp"
3+
#include "infinicore/context/context.hpp"
4+
5+
6+
namespace infinilm::engine {
7+
8+
ChunkPrefillCompiler::ChunkPrefillCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
9+
: GraphCompiler(model, barrier) {
10+
// Enumerate chunk sizes for chunk-prefill
11+
for (size_t cs : {256}) {
12+
chunk_sizes_.push_back(cs);
13+
}
14+
// Enumerate batch sizes for prefill (typically smaller than decode)
15+
for (size_t b = 1; b < 32; b++) {
16+
prefill_batch_sizes_.push_back(b);
17+
}
18+
for (size_t b = 32; b < 64; b += 8) {
19+
prefill_batch_sizes_.push_back(b);
20+
}
21+
for (size_t b = 64; b < 128; b += 16) {
22+
prefill_batch_sizes_.push_back(b);
23+
}
24+
for (size_t b = 128; b < 256; b += 32) {
25+
prefill_batch_sizes_.push_back(b);
26+
}
27+
for (size_t b = 256; b <= 512; b += 64) {
28+
prefill_batch_sizes_.push_back(b);
29+
}
30+
}
31+
32+
void ChunkPrefillCompiler::compile() {
33+
if (model_->get_cache_config() != nullptr &&
34+
dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
35+
36+
const auto *paged_config =
37+
dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config());
38+
size_t nblocks = paged_config->num_blocks();
39+
40+
compiled_map_prefill_.clear();
41+
42+
// Max total tokens to avoid OOM during graph recording
43+
constexpr size_t MAX_TOTAL_TOKENS = 4096;
44+
45+
// Pre-allocate a shared block_tables_holder for the largest (batch_size) we'll use
46+
size_t max_batch = *std::max_element(prefill_batch_sizes_.begin(), prefill_batch_sizes_.end());
47+
size_t block_per_req = nblocks / max_batch;
48+
block_tables_holder_ = infinicore::Tensor::zeros(
49+
{nblocks}, infinicore::DataType::I32, infinicore::context::getDevice());
50+
51+
for (size_t b : prefill_batch_sizes_) {
52+
for (size_t cs : chunk_sizes_) {
53+
size_t total_tokens = b * cs;
54+
if (total_tokens > MAX_TOTAL_TOKENS) {
55+
continue;
56+
}
57+
58+
size_t bpr = nblocks / b; // block_per_req for this batch size
59+
60+
InfinilmModel::Input input;
61+
62+
// input_ids: [1, total_tokens] — all tokens for this batch packed together
63+
input.input_ids = infinicore::Tensor::zeros(
64+
{1, total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice());
65+
66+
// position_ids: [total_tokens]
67+
input.position_ids = infinicore::Tensor::zeros(
68+
{total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice());
69+
70+
// total_sequence_lengths: [b], set to cs (first-chunk scenario)
71+
input.total_sequence_lengths = infinicore::Tensor::empty(
72+
{b}, infinicore::DataType::I32, infinicore::context::getDevice());
73+
{
74+
std::vector<int32_t> tsl(b, static_cast<int32_t>(cs));
75+
infinicore::context::memcpyH2D(
76+
input.total_sequence_lengths.value()->data(),
77+
tsl.data(), b * sizeof(int32_t), false);
78+
}
79+
80+
// input_offsets: [b+1], stride = cs
81+
input.input_offsets = infinicore::Tensor::empty(
82+
{b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
83+
{
84+
std::vector<int32_t> offsets(b + 1);
85+
for (size_t i = 0; i <= b; i++) {
86+
offsets[i] = static_cast<int32_t>(i * cs);
87+
}
88+
infinicore::context::memcpyH2D(
89+
input.input_offsets.value()->data(),
90+
offsets.data(), (b + 1) * sizeof(int32_t), false);
91+
}
92+
93+
// cu_seqlens: [b+1], same layout as input_offsets for prefill
94+
input.cu_seqlens = infinicore::Tensor::empty(
95+
{b + 1}, infinicore::DataType::I32, infinicore::context::getDevice());
96+
{
97+
std::vector<int32_t> cu(b + 1);
98+
for (size_t i = 0; i <= b; i++) {
99+
cu[i] = static_cast<int32_t>(i * cs);
100+
}
101+
infinicore::context::memcpyH2D(
102+
input.cu_seqlens.value()->data(),
103+
cu.data(), (b + 1) * sizeof(int32_t), false);
104+
}
105+
106+
// block_tables: view into the shared holder [b, bpr]
107+
input.block_tables = block_tables_holder_->as_strided(
108+
{b, bpr}, {(ptrdiff_t)bpr, 1});
109+
110+
// slot_mapping: [total_tokens]
111+
input.slot_mapping = infinicore::Tensor::zeros(
112+
{total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice());
113+
114+
// Attention reads attn_metadata from thread-local forward context.
115+
infinilm::global_state::get_forward_context().attn_metadata = {
116+
input.past_sequence_lengths,
117+
input.total_sequence_lengths,
118+
input.input_offsets,
119+
input.cu_seqlens,
120+
input.block_tables,
121+
input.slot_mapping,
122+
};
123+
124+
barrier_->wait();
125+
infinicore::context::startGraphRecording();
126+
auto output = model_->forward(input);
127+
auto graph = infinicore::context::stopGraphRecording();
128+
barrier_->wait();
129+
130+
auto shared_output = std::shared_ptr<InfinilmModel::Output>(
131+
new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
132+
133+
compiled_map_prefill_[std::make_tuple(b, cs)] =
134+
CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
135+
}
136+
}
137+
}
138+
}
139+
140+
ChunkPrefillCompiler::Compiled ChunkPrefillCompiler::get_compiled(const InfinilmModel::Input &input) {
141+
if (model_->get_cache_config() == nullptr ||
142+
!dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
143+
return {nullptr, nullptr};
144+
}
145+
146+
if (!input.block_tables.has_value() || !input.input_ids.has_value()) {
147+
return {nullptr, nullptr};
148+
}
149+
150+
size_t batch_size = input.block_tables.value()->size(0);
151+
size_t block_per_req = input.block_tables.value()->size(1);
152+
size_t total_tokens = input.input_ids.value()->size(1);
153+
154+
// Prefill: total_tokens is a multiple of batch_size, and chunk_size > 1
155+
if (total_tokens == 0 || total_tokens % batch_size != 0) {
156+
return {nullptr, nullptr};
157+
}
158+
size_t chunk_size = total_tokens / batch_size;
159+
if (chunk_size <= 1) {
160+
// Single-token case belongs to decode
161+
return {nullptr, nullptr};
162+
}
163+
164+
auto result = compiled_map_prefill_.find(std::make_tuple(batch_size, chunk_size));
165+
if (result == compiled_map_prefill_.end()) {
166+
return {nullptr, nullptr};
167+
}
168+
169+
auto &graph_input = result->second.input;
170+
171+
graph_input.input_ids.value()->copy_from(input.input_ids.value());
172+
graph_input.position_ids.value()->copy_from(input.position_ids.value());
173+
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
174+
graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
175+
graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value());
176+
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
177+
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());
178+
179+
auto graph = std::get<0>(result->second.compiled);
180+
auto shared_output = std::shared_ptr<InfinilmModel::Output>(
181+
new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()});
182+
183+
return std::make_tuple(graph, shared_output);
184+
}
185+
186+
} // namespace infinilm::engine
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#pragma once
2+
3+
#include "graph_compiler.hpp"
4+
5+
#include <unordered_map>
6+
7+
namespace infinilm::engine {
8+
class ChunkPrefillCompiler : public GraphCompiler {
9+
public:
10+
ChunkPrefillCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
11+
12+
void compile() override;
13+
14+
Compiled get_compiled(const InfinilmModel::Input &input) override;
15+
16+
private:
17+
struct TupleHash {
18+
size_t operator()(const std::tuple<size_t, size_t> &t) const noexcept {
19+
auto h1 = std::hash<size_t>{}(std::get<0>(t));
20+
auto h2 = std::hash<size_t>{}(std::get<1>(t));
21+
return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2));
22+
}
23+
};
24+
25+
std::vector<size_t> chunk_sizes_;
26+
std::vector<size_t> prefill_batch_sizes_;
27+
28+
infinicore::Tensor block_tables_holder_;
29+
30+
struct CompiledResult {
31+
InfinilmModel::Input input;
32+
Compiled compiled;
33+
};
34+
35+
// Key: (batch_size, chunk_size)
36+
std::unordered_map<
37+
std::tuple<size_t, size_t>,
38+
CompiledResult,
39+
TupleHash>
40+
compiled_map_prefill_;
41+
};
42+
} // namespace infinilm::engine

csrc/engine/compiler/general_compiler.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
#include "general_compiler.hpp"
22

33
namespace infinilm::engine {
4-
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier) : GraphCompiler(model, barrier) {
4+
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier, bool enable_chunk_prefill_graph)
5+
: GraphCompiler(model, barrier), enable_chunk_prefill_graph_(enable_chunk_prefill_graph) {
56
static_batching_compiler_ = std::make_unique<StaticBatchingCompiler>(model_, barrier);
7+
chunk_prefill_compiler_ = std::make_unique<ChunkPrefillCompiler>(model_, barrier);
68
paged_compiler_ = std::make_unique<PagedCompiler>(model_, barrier);
79
}
810

911
void GeneralCompiler::compile() {
1012
static_batching_compiler_->compile();
13+
if (enable_chunk_prefill_graph_) {
14+
chunk_prefill_compiler_->compile();
15+
}
1116
paged_compiler_->compile();
1217
}
1318

@@ -19,6 +24,11 @@ GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Inp
1924
if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) {
2025
return result;
2126
}
27+
// chunk-prefill must be checked before decode (decode would also match if chunk_size==1)
28+
result = chunk_prefill_compiler_.get()->get_compiled(input);
29+
if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) {
30+
return result;
31+
}
2232
result = paged_compiler_.get()->get_compiled(input);
2333
return result;
2434
}
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
#pragma once
22

3+
#include "chunk_prefill_compiler.hpp"
34
#include "paged_compiler.hpp"
45
#include "static_batching_compiler.hpp"
56

67
namespace infinilm::engine {
78
class GeneralCompiler : public GraphCompiler {
89
public:
9-
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
10+
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier, bool enable_chunk_prefill_graph = false);
1011

1112
void compile() override;
1213

@@ -15,5 +16,7 @@ class GeneralCompiler : public GraphCompiler {
1516
private:
1617
std::unique_ptr<StaticBatchingCompiler> static_batching_compiler_;
1718
std::unique_ptr<PagedCompiler> paged_compiler_;
19+
std::unique_ptr<ChunkPrefillCompiler> chunk_prefill_compiler_;
20+
bool enable_chunk_prefill_graph_;
1821
};
1922
} // namespace infinilm::engine

csrc/engine/infer_engine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ InferEngine::InferEngine(
1313
infinicore::Device::Type device_type,
1414
const cache::CacheConfig *cache_config,
1515
bool enable_graph_compiling,
16+
bool enable_chunk_prefill_graph,
1617
backends::AttentionBackend attention_backend,
1718
std::optional<infinicore::DataType> kv_cache_dtype) // Changed parameter
1819
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
@@ -39,6 +40,7 @@ InferEngine::InferEngine(
3940
cache_config_ != nullptr ? cache_config_.get() : nullptr,
4041
barrier_.get(),
4142
enable_graph_compiling,
43+
enable_chunk_prefill_graph,
4244
attention_backend_));
4345
}
4446
// Compile the model on all workers

csrc/engine/infer_engine.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class InferEngine {
2626
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
2727
const cache::CacheConfig *cache_config = nullptr,
2828
bool enable_graph_compiling = false,
29+
bool enable_chunk_prefill_graph = false,
2930
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default,
3031
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt);
3132

csrc/engine/rank_worker.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ RankWorker::RankWorker(
1616
const cache::CacheConfig *cache_config,
1717
RankBarrier *barrier,
1818
bool enable_graph_compiling,
19+
bool enable_chunk_prefill_graph,
1920
backends::AttentionBackend attention_backend)
2021
: infinilm_config_(infinilm_config),
2122
model_config_(infinilm_config->model_config),
2223
rank_info_(rank_info),
2324
attention_backend_(attention_backend),
2425
enable_graph_compiling_(enable_graph_compiling),
26+
enable_chunk_prefill_graph_(enable_chunk_prefill_graph),
2527
job_cmd_(Command::INIT),
2628
has_job_(false),
2729
job_done_(false),
@@ -270,7 +272,7 @@ void RankWorker::thread_loop() {
270272
throw std::runtime_error("Failed to create model");
271273
}
272274
if (enable_graph_compiling_) {
273-
compiler_ = std::make_unique<GeneralCompiler>(model_, barrier_);
275+
compiler_ = std::make_unique<GeneralCompiler>(model_, barrier_, enable_chunk_prefill_graph_);
274276
}
275277

276278
init_done_ = true;

csrc/engine/rank_worker.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class RankWorker {
7575
const cache::CacheConfig *cache_config,
7676
RankBarrier *barrier,
7777
bool enable_graph_compiling,
78+
bool enable_chunk_prefill_graph,
7879
backends::AttentionBackend attention_backend);
7980

8081
// Submit a parameter load job and wait until the load completes on the worker thread.
@@ -125,6 +126,7 @@ class RankWorker {
125126

126127
// Graph Compiling
127128
bool enable_graph_compiling_;
129+
bool enable_chunk_prefill_graph_;
128130
std::unique_ptr<GraphCompiler> compiler_;
129131

130132
// Command for the pending job (protected by mutex_)

csrc/pybind11/engine/engine.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ inline void bind_infer_engine(py::module &m) {
3737
infinicore::Device::Type dev,
3838
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
3939
bool enable_graph_compiling,
40+
bool enable_chunk_prefill_graph,
4041
const std::string &attention_backend,
4142
std::optional<infinicore::DataType> kv_cache_dtype) {
4243
return std::make_shared<InferEngine>(
@@ -45,6 +46,7 @@ inline void bind_infer_engine(py::module &m) {
4546
dev,
4647
cache_cfg ? cache_cfg.get() : nullptr,
4748
enable_graph_compiling,
49+
enable_chunk_prefill_graph,
4850
infinilm::backends::parse_attention_backend(attention_backend),
4951
kv_cache_dtype);
5052
}),
@@ -53,6 +55,7 @@ inline void bind_infer_engine(py::module &m) {
5355
py::arg("device_type") = infinicore::context::getDevice().getType(),
5456
py::arg("cache_config") = py::none(),
5557
py::arg("enable_graph_compiling") = false,
58+
py::arg("enable_chunk_prefill_graph") = false,
5659
py::arg("attention_backend") = "default",
5760
py::arg("kv_cache_dtype") = py::none())
5861
.def("load_param", &InferEngine::load_param,

python/infinilm/base_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def __init__(self):
6060

6161
self.attn = self.args.attn
6262
self.enable_graph = self.args.enable_graph
63+
self.enable_chunk_prefill_graph = self.args.enable_chunk_prefill_graph
64+
self.chunk_size = self.args.chunk_size
6365
self.enable_paged_attn = self.args.enable_paged_attn
6466
self.num_blocks = self.args.num_blocks
6567
self.block_size = self.args.block_size
@@ -123,6 +125,8 @@ def _add_common_args(self):
123125
choices=["default", "paged-attn", "flash-attn"],
124126
)
125127
self.parser.add_argument("--enable-graph", action="store_true")
128+
self.parser.add_argument("--enable-chunk-prefill-graph", action="store_true", help="enable chunk-prefill graph compiling")
129+
self.parser.add_argument("--chunk-size", type=int, default=0, help="tokens per chunked-prefill slice (0 to disable)")
126130
self.parser.add_argument(
127131
"--enable-paged-attn",
128132
action="store_true",

0 commit comments

Comments
 (0)