Skip to content

Commit 108984f

Browse files
feat: support in-batch prefix cache.
1 parent f9e58db commit 108984f

22 files changed

Lines changed: 188 additions & 21 deletions

xllm/core/common/global_flags.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,11 @@ DEFINE_bool(enable_prefix_cache,
285285
true,
286286
"Whether to enable the prefix cache for the block manager.");
287287

288+
DEFINE_bool(
289+
enable_in_batch_prefix_cache,
290+
true,
291+
"Whether to cache admitted prefill full blocks into prefix cache.");
292+
288293
DEFINE_bool(enable_cache_upload,
289294
false,
290295
"Whether to upload cache info to service. This feature is only "

xllm/core/common/global_flags.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ DECLARE_double(max_memory_utilization);
4848
DECLARE_string(kv_cache_dtype);
4949

5050
DECLARE_bool(enable_prefix_cache);
51+
DECLARE_bool(enable_in_batch_prefix_cache);
5152

5253
DECLARE_bool(enable_cache_upload);
5354

xllm/core/common/help_formatter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ const OptionCategory kCommonOptions = {"COMMON OPTIONS",
4242
"enable_prefill_sp",
4343
"enable_schedule_overlap",
4444
"enable_prefix_cache",
45+
"enable_in_batch_prefix_cache",
4546
"enable_shm",
4647
"enable_graph",
4748
"enable_graph_mode_decode_no_padding",
@@ -101,7 +102,7 @@ const OptionCategory kBeamSearchOptions = {"BEAM SEARCH OPTIONS",
101102

102103
const OptionCategory kPrefixCacheOptions = {
103104
"PREFIX CACHE OPTIONS",
104-
{"enable_prefix_cache", "xxh3_128bits_seed"}};
105+
{"enable_prefix_cache", "enable_in_batch_prefix_cache", "xxh3_128bits_seed"}};
105106

106107
const OptionCategory kOtherOptions = {
107108
"OTHER OPTIONS",

xllm/core/common/options.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ std::string Options::to_string() const {
2929
<< ", max_cache_size: " << max_cache_size()
3030
<< ", max_memory_utilization: " << max_memory_utilization()
3131
<< ", enable_prefix_cache: " << enable_prefix_cache()
32+
<< ", enable_in_batch_prefix_cache: " << enable_in_batch_prefix_cache()
3233
<< ", max_tokens_per_batch: " << max_tokens_per_batch()
3334
<< ", max_seqs_per_batch: " << max_seqs_per_batch()
3435
<< ", max_tokens_per_chunk_for_prefill: "

xllm/core/common/options.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class Options {
6262
PROPERTY(double, max_memory_utilization) = 0.9;
6363

6464
PROPERTY(bool, enable_prefix_cache) = true;
65+
PROPERTY(bool, enable_in_batch_prefix_cache) = true;
6566

6667
// max tokens num per batch
6768
PROPERTY(int32_t, max_tokens_per_batch) = 20480;

xllm/core/distributed_runtime/llm_master.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ LLMMaster::LLMMaster(const Options& options)
8686
.enable_pd_ooc(options_.enable_pd_ooc())
8787
.enable_schedule_overlap(options_.enable_schedule_overlap())
8888
.enable_chunked_prefill(options_.enable_chunked_prefill())
89+
.enable_prefix_cache(options_.enable_prefix_cache())
90+
.enable_in_batch_prefix_cache(options_.enable_in_batch_prefix_cache())
8991
.instance_name(options_.instance_name())
9092
.instance_role(options_.instance_role())
9193
.kv_cache_transfer_mode(options_.kv_cache_transfer_mode())

xllm/core/distributed_runtime/rec_master.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,8 @@ RecMaster::RecMaster(const Options& options)
504504
.enable_disagg_pd(options_.enable_disagg_pd())
505505
.enable_schedule_overlap(options_.enable_schedule_overlap())
506506
.enable_chunked_prefill(options_.enable_chunked_prefill())
507+
.enable_prefix_cache(options_.enable_prefix_cache())
508+
.enable_in_batch_prefix_cache(options_.enable_in_batch_prefix_cache())
507509
.instance_role(options_.instance_role())
508510
.kv_cache_transfer_mode(options_.kv_cache_transfer_mode())
509511
.enable_service_routing(options_.enable_service_routing())

xllm/core/distributed_runtime/vlm_master.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ VLMMaster::VLMMaster(const Options& options)
6666
options.max_tokens_per_chunk_for_prefill())
6767
.enable_disagg_pd(options_.enable_disagg_pd())
6868
.enable_chunked_prefill(options_.enable_chunked_prefill())
69+
.enable_prefix_cache(options_.enable_prefix_cache())
70+
.enable_in_batch_prefix_cache(options_.enable_in_batch_prefix_cache())
6971
.instance_name(options_.instance_name())
7072
.instance_role(options_.instance_role())
7173
.kv_cache_transfer_mode(options_.kv_cache_transfer_mode())

xllm/core/framework/block/block_manager_pool.cpp

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License.
1616
#include "block_manager_pool.h"
1717

1818
#include <algorithm>
19-
#include <limits>
2019

2120
#include "block_manager_impl.h"
2221
#include "common/global_flags.h"
@@ -125,7 +124,7 @@ bool BlockManagerPool::allocate_embedding_id(Sequence* sequence,
125124

126125
void BlockManagerPool::deallocate_embedding_id(Sequence* sequence,
127126
int32_t dp_rank) {
128-
DCHECK(sequence != nullptr);
127+
CHECK(sequence != nullptr);
129128
CHECK_GE(dp_rank, 0);
130129
CHECK_LT(static_cast<size_t>(dp_rank), embedding_managers_.size());
131130
auto embedding_block = sequence->reset_embedding_block();
@@ -152,7 +151,7 @@ void BlockManagerPool::deallocate(std::vector<Sequence*>& sequences) {
152151
}
153152

154153
void BlockManagerPool::deallocate(Sequence* sequence) {
155-
DCHECK(sequence != nullptr);
154+
CHECK(sequence != nullptr);
156155
// add blocks to the prefix cache
157156
int32_t dp_rank = get_dp_rank(sequence);
158157
cache(sequence);
@@ -173,13 +172,13 @@ void BlockManagerPool::reset_transfer_infos() {
173172
}
174173

175174
bool BlockManagerPool::allocate(Sequence* sequence) {
176-
DCHECK(sequence != nullptr);
175+
CHECK(sequence != nullptr);
177176
return allocate(sequence, sequence->num_tokens());
178177
}
179178

180179
bool BlockManagerPool::allocate(std::vector<Sequence*>& sequences) {
181180
for (auto* sequence : sequences) {
182-
DCHECK(sequence != nullptr);
181+
CHECK(sequence != nullptr);
183182
if (!allocate(sequence, sequence->num_tokens())) {
184183
// should we gurantee the atomicity of the allocation? all or nothing?
185184
return false;
@@ -190,7 +189,7 @@ bool BlockManagerPool::allocate(std::vector<Sequence*>& sequences) {
190189

191190
bool BlockManagerPool::allocate(Sequence* sequence, size_t num_tokens) {
192191
AUTO_COUNTER(allocate_blocks_latency_seconds);
193-
DCHECK(sequence != nullptr);
192+
CHECK(sequence != nullptr);
194193
int32_t dp_rank = get_dp_rank(sequence);
195194
const bool started_empty = sequence->kv_state().num_kv_blocks() == 0;
196195
const bool needs_embedding_id = !sequence->has_embedding_id();
@@ -340,10 +339,30 @@ void BlockManagerPool::allocate_shared(Sequence* sequence) {
340339
}
341340

342341
void BlockManagerPool::cache(Sequence* sequence) {
342+
cache(sequence, sequence->kv_state().kv_cache_tokens_num());
343+
}
344+
345+
void BlockManagerPool::cache(Sequence* sequence, size_t num_tokens) {
346+
CHECK(sequence != nullptr);
347+
if (!options_.enable_prefix_cache()) {
348+
return;
349+
}
350+
351+
const size_t block_size = static_cast<size_t>(options_.block_size());
352+
const size_t available_tokens_num =
353+
std::min({num_tokens,
354+
sequence->kv_state().num_kv_blocks() * block_size,
355+
sequence->tokens().size()});
356+
const size_t existed_shared_blocks_num =
357+
sequence->kv_state().shared_kv_blocks_num();
358+
if (available_tokens_num <= existed_shared_blocks_num * block_size) {
359+
return;
360+
}
361+
343362
int32_t dp_rank = get_dp_rank(sequence);
344-
const auto token_ids = sequence->cached_tokens();
363+
const auto token_ids = sequence->tokens().slice(0, available_tokens_num);
345364
auto* blocks = sequence->kv_state().mutable_kv_blocks();
346-
auto existed_shared_blocks_num = sequence->kv_state().shared_kv_blocks_num();
365+
CHECK_GE(blocks->size(), existed_shared_blocks_num);
347366
block_managers_[dp_rank]->cache(
348367
token_ids, *blocks, existed_shared_blocks_num);
349368
}
@@ -399,7 +418,7 @@ double BlockManagerPool::kv_cache_utilization() const {
399418
// currently use only for profile, which not need prefix cache.
400419
// If more often used in the future, can be integrated into deallocate function.
401420
void BlockManagerPool::deallocate_without_cache(Sequence* sequence) {
402-
DCHECK(sequence != nullptr);
421+
CHECK(sequence != nullptr);
403422
int32_t dp_rank = get_dp_rank(sequence);
404423
block_managers_[dp_rank]->deallocate(sequence->kv_state().kv_blocks());
405424
deallocate_embedding_id(sequence, dp_rank);

xllm/core/framework/block/block_manager_pool.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class BlockManagerPool : public KVCacheManager {
6666

6767
virtual void allocate_shared(Sequence* sequence) override;
6868
virtual void cache(Sequence* sequence) override;
69+
virtual void cache(Sequence* sequence, size_t num_tokens) override;
6970

7071
virtual std::vector<std::vector<BlockTransferInfo>>*
7172
get_swap_block_transfer_infos() override;

0 commit comments

Comments
 (0)