Skip to content

Commit f2c8bab

Browse files
committed
add chunkprefill and prefill cuda graph
1 parent e340689 commit f2c8bab

16 files changed

Lines changed: 266 additions & 30 deletions

csrc/engine/compiler/general_compiler.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
#include "general_compiler.hpp"
22

33
namespace infinilm::engine {
4-
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier) : GraphCompiler(model, barrier) {
4+
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier, bool enable_chunk_prefill_graph)
5+
: GraphCompiler(model, barrier), enable_chunk_prefill_graph_(enable_chunk_prefill_graph) {
56
static_batching_compiler_ = std::make_unique<StaticBatchingCompiler>(model_, barrier);
7+
chunk_prefill_compiler_ = std::make_unique<ChunkPrefillCompiler>(model_, barrier);
68
paged_compiler_ = std::make_unique<PagedCompiler>(model_, barrier);
79
}
810

911
void GeneralCompiler::compile() {
1012
static_batching_compiler_->compile();
13+
if (enable_chunk_prefill_graph_) {
14+
chunk_prefill_compiler_->compile();
15+
}
1116
paged_compiler_->compile();
1217
}
1318

@@ -19,6 +24,11 @@ GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Inp
1924
if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) {
2025
return result;
2126
}
27+
// chunk-prefill must be checked before decode (decode would also match if chunk_size==1)
28+
result = chunk_prefill_compiler_.get()->get_compiled(input);
29+
if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) {
30+
return result;
31+
}
2232
result = paged_compiler_.get()->get_compiled(input);
2333
return result;
2434
}
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
#pragma once
22

3+
#include "chunk_prefill_compiler.hpp"
34
#include "paged_compiler.hpp"
45
#include "static_batching_compiler.hpp"
56

67
namespace infinilm::engine {
78
class GeneralCompiler : public GraphCompiler {
89
public:
9-
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
10+
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier, bool enable_chunk_prefill_graph = false);
1011

1112
void compile() override;
1213

@@ -15,5 +16,7 @@ class GeneralCompiler : public GraphCompiler {
1516
private:
1617
std::unique_ptr<StaticBatchingCompiler> static_batching_compiler_;
1718
std::unique_ptr<PagedCompiler> paged_compiler_;
19+
std::unique_ptr<ChunkPrefillCompiler> chunk_prefill_compiler_;
20+
bool enable_chunk_prefill_graph_;
1821
};
1922
} // namespace infinilm::engine

csrc/engine/infer_engine.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ InferEngine::InferEngine(
2525
infinicore::Device::Type device_type,
2626
const cache::CacheConfig *cache_config,
2727
bool enable_graph_compiling,
28+
bool enable_chunk_prefill_graph,
2829
backends::AttentionBackend attention_backend) // Changed parameter
2930
: communication_group_(distributed_config, device_type),
3031
legacy_model_config_(config),
@@ -43,6 +44,7 @@ InferEngine::InferEngine(
4344
cache_config_ != nullptr ? cache_config_.get() : nullptr,
4445
barrier_.get(),
4546
enable_graph_compiling,
47+
enable_chunk_prefill_graph,
4648
attention_backend_));
4749
}
4850

@@ -56,6 +58,7 @@ InferEngine::InferEngine(
5658
infinicore::Device::Type device_type,
5759
const cache::CacheConfig *cache_config,
5860
bool enable_graph_compiling,
61+
bool enable_chunk_prefill_graph,
5962
backends::AttentionBackend attention_backend,
6063
std::optional<infinicore::DataType> kv_cache_dtype) // Changed parameter
6164
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
@@ -82,6 +85,7 @@ InferEngine::InferEngine(
8285
cache_config_ != nullptr ? cache_config_.get() : nullptr,
8386
barrier_.get(),
8487
enable_graph_compiling,
88+
enable_chunk_prefill_graph,
8589
attention_backend_));
8690
}
8791
// Compile the model on all workers

csrc/engine/infer_engine.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class InferEngine {
3939
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
4040
const cache::CacheConfig *cache_config = nullptr,
4141
bool enable_graph_compiling = false,
42+
bool enable_chunk_prefill_graph = false,
4243
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
4344

4445
InferEngine(
@@ -47,6 +48,7 @@ class InferEngine {
4748
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
4849
const cache::CacheConfig *cache_config = nullptr,
4950
bool enable_graph_compiling = false,
51+
bool enable_chunk_prefill_graph = false,
5052
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default,
5153
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt);
5254

csrc/engine/rank_worker.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config,
2727
const cache::CacheConfig *cache_config,
2828
RankBarrier *barrier,
2929
bool enable_graph_compiling,
30+
bool enable_chunk_prefill_graph,
3031
backends::AttentionBackend attention_backend)
3132
: legacy_model_config_(model_config),
3233
rank_info_(rank_info),
3334
attention_backend_(attention_backend),
3435
enable_graph_compiling_(enable_graph_compiling),
36+
enable_chunk_prefill_graph_(enable_chunk_prefill_graph),
3537
job_cmd_(Command::INIT),
3638
has_job_(false),
3739
job_done_(false),
@@ -56,12 +58,14 @@ RankWorker::RankWorker(
5658
const cache::CacheConfig *cache_config,
5759
RankBarrier *barrier,
5860
bool enable_graph_compiling,
61+
bool enable_chunk_prefill_graph,
5962
backends::AttentionBackend attention_backend)
6063
: infinilm_config_(infinilm_config),
6164
model_config_(infinilm_config->model_config),
6265
rank_info_(rank_info),
6366
attention_backend_(attention_backend),
6467
enable_graph_compiling_(enable_graph_compiling),
68+
enable_chunk_prefill_graph_(enable_chunk_prefill_graph),
6569
job_cmd_(Command::INIT),
6670
has_job_(false),
6771
job_done_(false),
@@ -303,7 +307,7 @@ void RankWorker::thread_loop() {
303307
throw std::runtime_error("Failed to create model");
304308
}
305309
if (enable_graph_compiling_) {
306-
compiler_ = std::make_unique<GeneralCompiler>(model_, barrier_);
310+
compiler_ = std::make_unique<GeneralCompiler>(model_, barrier_, enable_chunk_prefill_graph_);
307311
}
308312

309313
init_done_ = true;

csrc/engine/rank_worker.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,15 @@ class RankWorker {
7575
const cache::CacheConfig *cache_config,
7676
RankBarrier *barrier,
7777
bool enable_graph_compiling,
78+
bool enable_chunk_prefill_graph,
7879
backends::AttentionBackend attention_backend);
7980

8081
RankWorker(std::shared_ptr<infinilm::global_state::InfinilmConfig> infinilm_config,
8182
const distributed::RankInfo &rank_info,
8283
const cache::CacheConfig *cache_config,
8384
RankBarrier *barrier,
8485
bool enable_graph_compiling,
86+
bool enable_chunk_prefill_graph,
8587
backends::AttentionBackend attention_backend);
8688

8789
// Submit a parameter load job and wait until the load completes on the worker thread.
@@ -131,6 +133,7 @@ class RankWorker {
131133

132134
// Graph Compiling
133135
bool enable_graph_compiling_;
136+
bool enable_chunk_prefill_graph_;
134137
std::unique_ptr<GraphCompiler> compiler_;
135138

136139
// Command for the pending job (protected by mutex_)

csrc/pybind11/engine/engine.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,23 @@ inline void bind_infer_engine(py::module &m) {
3737
infinicore::Device::Type dev,
3838
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
3939
bool enable_graph_compiling,
40+
bool enable_chunk_prefill_graph,
4041
const std::string &attention_backend) {
4142
return std::make_shared<InferEngine>(
4243
cfg,
4344
dist,
4445
dev,
4546
cache_cfg ? cache_cfg.get() : nullptr,
4647
enable_graph_compiling,
48+
enable_chunk_prefill_graph,
4749
infinilm::backends::parse_attention_backend(attention_backend));
4850
}),
4951
py::arg("config"),
5052
py::arg("distributed_config") = distributed::DistConfig(),
5153
py::arg("device_type") = infinicore::context::getDevice().getType(),
5254
py::arg("cache_config") = py::none(),
5355
py::arg("enable_graph_compiling") = false,
56+
py::arg("enable_chunk_prefill_graph") = false,
5457
py::arg("attention_backend") = "default")
5558
.def("load_param", &InferEngine::load_param,
5659
py::arg("name"), py::arg("param"),
@@ -81,6 +84,7 @@ inline void bind_infer_engine(py::module &m) {
8184
infinicore::Device::Type dev,
8285
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
8386
bool enable_graph_compiling,
87+
bool enable_chunk_prefill_graph,
8488
const std::string &attention_backend,
8589
std::optional<infinicore::DataType> kv_cache_dtype) {
8690
return std::make_shared<InferEngine>(
@@ -89,6 +93,7 @@ inline void bind_infer_engine(py::module &m) {
8993
dev,
9094
cache_cfg ? cache_cfg.get() : nullptr,
9195
enable_graph_compiling,
96+
enable_chunk_prefill_graph,
9297
infinilm::backends::parse_attention_backend(attention_backend),
9398
kv_cache_dtype);
9499
}),
@@ -97,6 +102,7 @@ inline void bind_infer_engine(py::module &m) {
97102
py::arg("device_type") = infinicore::context::getDevice().getType(),
98103
py::arg("cache_config") = py::none(),
99104
py::arg("enable_graph_compiling") = false,
105+
py::arg("enable_chunk_prefill_graph") = false,
100106
py::arg("attention_backend") = "default",
101107
py::arg("kv_cache_dtype") = py::none())
102108
.def("load_param", &InferEngine::load_param,

python/infinilm/base_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(self):
6161

6262
self.attn = self.args.attn
6363
self.enable_graph = self.args.enable_graph
64+
self.enable_chunk_prefill_graph = self.args.enable_chunk_prefill_graph
65+
self.chunk_size = self.args.chunk_size
6466
self.enable_paged_attn = self.args.enable_paged_attn
6567
self.num_blocks = self.args.num_blocks
6668
self.block_size = self.args.block_size
@@ -122,6 +124,8 @@ def _add_common_args(self):
122124
choices=["default", "paged-attn", "flash-attn"],
123125
)
124126
self.parser.add_argument("--enable-graph", action="store_true")
127+
self.parser.add_argument("--enable-chunk-prefill-graph", action="store_true", help="enable chunk-prefill graph compiling")
128+
self.parser.add_argument("--chunk-size", type=int, default=512, help="tokens per chunked-prefill slice (0 to disable)")
125129
self.parser.add_argument(
126130
"--enable-paged-attn",
127131
action="store_true",

python/infinilm/infer_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
distributed_config=DistConfig(1),
4646
cache_config=None,
4747
enable_graph_compiling=False,
48+
enable_chunk_prefill_graph=False,
4849
attention_backend="default",
4950
kv_cache_dtype=None,
5051
):
@@ -60,6 +61,7 @@ def __init__(
6061
device._underlying.type,
6162
cache_config,
6263
enable_graph_compiling,
64+
enable_chunk_prefill_graph,
6365
attention_backend,
6466
(
6567
parse_dtype(kv_cache_dtype)._underlying

python/infinilm/llm/llm.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class EngineConfig:
7272
top_p: float = 0.8
7373
top_k: int = 1
7474
enable_graph: bool = False
75+
enable_chunk_prefill_graph: bool = False
76+
chunk_size: int = 0
7577
attn_backend: str = "default"
7678
skip_load: bool = False
7779

@@ -91,6 +93,7 @@ def __init__(self, config: EngineConfig):
9193
device=self.device,
9294
distributed_config=DistConfig(config.tensor_parallel_size),
9395
enable_graph_compiling=config.enable_graph,
96+
enable_chunk_prefill_graph=config.enable_chunk_prefill_graph,
9497
attention_backend=config.attn_backend,
9598
)
9699

@@ -167,6 +170,8 @@ def _init_device(self):
167170

168171
def add_request(self, request: InferenceRequest):
169172
"""Add a request to the scheduler."""
173+
if self.cache_type == "paged" and self.config.chunk_size > 0:
174+
request.chunk_size = self.config.chunk_size
170175
self.scheduler.add_request(request)
171176

172177
def step(self) -> tuple[list[InferenceRequest], list[tuple]]:
@@ -210,14 +215,39 @@ def _update_requests(
210215
sampled_tokens: List[int],
211216
) -> List[tuple]:
212217
"""Update request status after inference step."""
213-
if is_prefill:
218+
# Detect a chunked-prefill mid-step: single request, prefill phase,
219+
# and this chunk does not yet cover the whole prompt. In that case
220+
# we must NOT consume a sampled token, NOT commit prefill blocks,
221+
# and re-enqueue the request to keep chunking.
222+
chunk_mid_step = (
223+
is_prefill
224+
and len(requests) == 1
225+
and requests[0].is_chunking()
226+
and not requests[0].chunk_is_last()
227+
)
228+
229+
if is_prefill and not chunk_mid_step:
214230
match self.cache_type:
215231
case "paged":
216232
self.scheduler.cache_manager.reset_req_blocks()
217233
case "static":
218234
self.scheduler.update_cache()
219235
case _:
220236
raise ValueError(f"Unsupported cache_type: {self.cache_type}")
237+
238+
if chunk_mid_step:
239+
req = requests[0]
240+
req.chunk_prefill_offset += req.chunk_size
241+
# If this request was aborted while chunking, drop it.
242+
if req.is_aborted():
243+
logger.info(
244+
f"Request {req.request_id} aborted by client during chunked-prefill"
245+
)
246+
return []
247+
# Re-enqueue to keep producing chunks; no token sampled yet.
248+
self.scheduler.requeue_chunking(req)
249+
return []
250+
221251
pending = []
222252
for req, token_id in zip(requests, sampled_tokens):
223253
if req.is_aborted():
@@ -227,6 +257,10 @@ def _update_requests(
227257
continue
228258

229259
if req.is_prefill:
260+
# Clean up chunked-prefill state on the final chunk so the
261+
# next forward pass on this request takes the decode path.
262+
req.chunk_prefill_offset = 0
263+
req.chunk_size = 0
230264
req.is_prefill = False
231265

232266
req.generated_token_ids.append(token_id)
@@ -361,6 +395,8 @@ def __init__(
361395
top_p: float = 0.8,
362396
top_k: int = 1,
363397
enable_graph: bool = False,
398+
enable_chunk_prefill_graph: bool = False,
399+
chunk_size: int = 0,
364400
attn_backend: str = "default",
365401
skip_load: bool = False,
366402
):
@@ -398,6 +434,8 @@ def __init__(
398434
top_p=top_p,
399435
top_k=top_k,
400436
enable_graph=enable_graph,
437+
enable_chunk_prefill_graph=enable_chunk_prefill_graph,
438+
chunk_size=chunk_size,
401439
attn_backend=attn_backend,
402440
skip_load=skip_load,
403441
)
@@ -539,6 +577,8 @@ def __init__(
539577
top_p: float = 0.8,
540578
top_k: int = 1,
541579
enable_graph: bool = False,
580+
enable_chunk_prefill_graph: bool = False,
581+
chunk_size: int = 0,
542582
attn_backend: str = "default",
543583
):
544584
"""Initialize AsyncLLMEngine.
@@ -575,6 +615,8 @@ def __init__(
575615
top_p=top_p,
576616
top_k=top_k,
577617
enable_graph=enable_graph,
618+
enable_chunk_prefill_graph=enable_chunk_prefill_graph,
619+
chunk_size=chunk_size,
578620
attn_backend=attn_backend,
579621
)
580622
self.engine = LLMEngine(config)

0 commit comments

Comments
 (0)