From 108984f8a2d515fa1031784b926745239bb2c398 Mon Sep 17 00:00:00 2001 From: Clement-Wang26 Date: Wed, 25 Mar 2026 20:40:36 +0800 Subject: [PATCH] feat: support in-batch prefix cache. --- xllm/core/common/global_flags.cpp | 5 ++ xllm/core/common/global_flags.h | 1 + xllm/core/common/help_formatter.h | 3 +- xllm/core/common/options.cpp | 1 + xllm/core/common/options.h | 1 + xllm/core/distributed_runtime/llm_master.cpp | 2 + xllm/core/distributed_runtime/rec_master.cpp | 2 + xllm/core/distributed_runtime/vlm_master.cpp | 2 + .../framework/block/block_manager_pool.cpp | 37 +++++++++++---- .../core/framework/block/block_manager_pool.h | 1 + .../framework/block/block_manager_test.cpp | 46 ++++++++++++++++++- .../block/hierarchy_block_manager_pool.cpp | 10 +++- xllm/core/framework/block/kv_cache_manager.h | 1 + .../xtensor/xtensor_manager_pool.cpp | 8 ++-- .../framework/xtensor/xtensor_manager_pool.h | 5 +- .../scheduler/chunked_prefill_scheduler.cpp | 4 ++ xllm/core/scheduler/continuous_scheduler.cpp | 25 +++++++++- xllm/core/scheduler/continuous_scheduler.h | 7 +++ .../scheduler/continuous_scheduler_test.cpp | 39 ++++++++++++++++ xllm/core/scheduler/mix_scheduler.cpp | 2 + .../core/scheduler/prefill_only_scheduler.cpp | 6 ++- xllm/xllm.cpp | 1 + 22 files changed, 188 insertions(+), 21 deletions(-) diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 61b3a792e..b97664301 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -285,6 +285,11 @@ DEFINE_bool(enable_prefix_cache, true, "Whether to enable the prefix cache for the block manager."); +DEFINE_bool( + enable_in_batch_prefix_cache, + true, + "Whether to cache admitted prefill full blocks into prefix cache."); + DEFINE_bool(enable_cache_upload, false, "Whether to upload cache info to service. This feature is only " diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index a2f4c80c5..252d27a8a 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -48,6 +48,7 @@ DECLARE_double(max_memory_utilization); DECLARE_string(kv_cache_dtype); DECLARE_bool(enable_prefix_cache); +DECLARE_bool(enable_in_batch_prefix_cache); DECLARE_bool(enable_cache_upload); diff --git a/xllm/core/common/help_formatter.h b/xllm/core/common/help_formatter.h index 0c1f714d2..e5083442f 100644 --- a/xllm/core/common/help_formatter.h +++ b/xllm/core/common/help_formatter.h @@ -42,6 +42,7 @@ const OptionCategory kCommonOptions = {"COMMON OPTIONS", "enable_prefill_sp", "enable_schedule_overlap", "enable_prefix_cache", + "enable_in_batch_prefix_cache", "enable_shm", "enable_graph", "enable_graph_mode_decode_no_padding", @@ -101,7 +102,7 @@ const OptionCategory kBeamSearchOptions = {"BEAM SEARCH OPTIONS", const OptionCategory kPrefixCacheOptions = { "PREFIX CACHE OPTIONS", - {"enable_prefix_cache", "xxh3_128bits_seed"}}; + {"enable_prefix_cache", "enable_in_batch_prefix_cache", "xxh3_128bits_seed"}}; const OptionCategory kOtherOptions = { "OTHER OPTIONS", diff --git a/xllm/core/common/options.cpp b/xllm/core/common/options.cpp index 3b15cbd41..94e577ea9 100644 --- a/xllm/core/common/options.cpp +++ b/xllm/core/common/options.cpp @@ -29,6 +29,7 @@ std::string Options::to_string() const { << ", max_cache_size: " << max_cache_size() << ", max_memory_utilization: " << max_memory_utilization() << ", enable_prefix_cache: " << enable_prefix_cache() + << ", enable_in_batch_prefix_cache: " << enable_in_batch_prefix_cache() << ", max_tokens_per_batch: " << max_tokens_per_batch() << ", max_seqs_per_batch: " << max_seqs_per_batch() << ", max_tokens_per_chunk_for_prefill: " diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index 702d3dd4b..62773d712 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -62,6 +62,7 @@ class Options { PROPERTY(double, max_memory_utilization) = 0.9; PROPERTY(bool, enable_prefix_cache) = true; + PROPERTY(bool, enable_in_batch_prefix_cache) = true; // max tokens num per batch PROPERTY(int32_t, max_tokens_per_batch) = 20480; diff --git a/xllm/core/distributed_runtime/llm_master.cpp b/xllm/core/distributed_runtime/llm_master.cpp index 337b8b314..0a7ec8f86 100644 --- a/xllm/core/distributed_runtime/llm_master.cpp +++ b/xllm/core/distributed_runtime/llm_master.cpp @@ -86,6 +86,8 @@ LLMMaster::LLMMaster(const Options& options) .enable_pd_ooc(options_.enable_pd_ooc()) .enable_schedule_overlap(options_.enable_schedule_overlap()) .enable_chunked_prefill(options_.enable_chunked_prefill()) + .enable_prefix_cache(options_.enable_prefix_cache()) + .enable_in_batch_prefix_cache(options_.enable_in_batch_prefix_cache()) .instance_name(options_.instance_name()) .instance_role(options_.instance_role()) .kv_cache_transfer_mode(options_.kv_cache_transfer_mode()) diff --git a/xllm/core/distributed_runtime/rec_master.cpp b/xllm/core/distributed_runtime/rec_master.cpp index 391b84046..09249b8f1 100644 --- a/xllm/core/distributed_runtime/rec_master.cpp +++ b/xllm/core/distributed_runtime/rec_master.cpp @@ -504,6 +504,8 @@ RecMaster::RecMaster(const Options& options) .enable_disagg_pd(options_.enable_disagg_pd()) .enable_schedule_overlap(options_.enable_schedule_overlap()) .enable_chunked_prefill(options_.enable_chunked_prefill()) + .enable_prefix_cache(options_.enable_prefix_cache()) + .enable_in_batch_prefix_cache(options_.enable_in_batch_prefix_cache()) .instance_role(options_.instance_role()) .kv_cache_transfer_mode(options_.kv_cache_transfer_mode()) .enable_service_routing(options_.enable_service_routing()) diff --git a/xllm/core/distributed_runtime/vlm_master.cpp b/xllm/core/distributed_runtime/vlm_master.cpp index eb54b7f93..a3c749720 100755 --- a/xllm/core/distributed_runtime/vlm_master.cpp +++ b/xllm/core/distributed_runtime/vlm_master.cpp @@ -66,6 +66,8 @@ VLMMaster::VLMMaster(const Options& options) options.max_tokens_per_chunk_for_prefill()) .enable_disagg_pd(options_.enable_disagg_pd()) .enable_chunked_prefill(options_.enable_chunked_prefill()) + .enable_prefix_cache(options_.enable_prefix_cache()) + .enable_in_batch_prefix_cache(options_.enable_in_batch_prefix_cache()) .instance_name(options_.instance_name()) .instance_role(options_.instance_role()) .kv_cache_transfer_mode(options_.kv_cache_transfer_mode()) diff --git a/xllm/core/framework/block/block_manager_pool.cpp b/xllm/core/framework/block/block_manager_pool.cpp index bd7a28f92..57536baa3 100644 --- a/xllm/core/framework/block/block_manager_pool.cpp +++ b/xllm/core/framework/block/block_manager_pool.cpp @@ -16,7 +16,6 @@ limitations under the License. #include "block_manager_pool.h" #include -#include #include "block_manager_impl.h" #include "common/global_flags.h" @@ -125,7 +124,7 @@ bool BlockManagerPool::allocate_embedding_id(Sequence* sequence, void BlockManagerPool::deallocate_embedding_id(Sequence* sequence, int32_t dp_rank) { - DCHECK(sequence != nullptr); + CHECK(sequence != nullptr); CHECK_GE(dp_rank, 0); CHECK_LT(static_cast(dp_rank), embedding_managers_.size()); auto embedding_block = sequence->reset_embedding_block(); @@ -152,7 +151,7 @@ void BlockManagerPool::deallocate(std::vector& sequences) { } void BlockManagerPool::deallocate(Sequence* sequence) { - DCHECK(sequence != nullptr); + CHECK(sequence != nullptr); // add blocks to the prefix cache int32_t dp_rank = get_dp_rank(sequence); cache(sequence); @@ -173,13 +172,13 @@ void BlockManagerPool::reset_transfer_infos() { } bool BlockManagerPool::allocate(Sequence* sequence) { - DCHECK(sequence != nullptr); + CHECK(sequence != nullptr); return allocate(sequence, sequence->num_tokens()); } bool BlockManagerPool::allocate(std::vector& sequences) { for (auto* sequence : sequences) { - DCHECK(sequence != nullptr); + CHECK(sequence != nullptr); if (!allocate(sequence, sequence->num_tokens())) { // should we gurantee the atomicity of the allocation? all or nothing? return false; @@ -190,7 +189,7 @@ bool BlockManagerPool::allocate(std::vector& sequences) { bool BlockManagerPool::allocate(Sequence* sequence, size_t num_tokens) { AUTO_COUNTER(allocate_blocks_latency_seconds); - DCHECK(sequence != nullptr); + CHECK(sequence != nullptr); int32_t dp_rank = get_dp_rank(sequence); const bool started_empty = sequence->kv_state().num_kv_blocks() == 0; const bool needs_embedding_id = !sequence->has_embedding_id(); @@ -340,10 +339,30 @@ void BlockManagerPool::allocate_shared(Sequence* sequence) { } void BlockManagerPool::cache(Sequence* sequence) { + cache(sequence, sequence->kv_state().kv_cache_tokens_num()); +} + +void BlockManagerPool::cache(Sequence* sequence, size_t num_tokens) { + CHECK(sequence != nullptr); + if (!options_.enable_prefix_cache()) { + return; + } + + const size_t block_size = static_cast(options_.block_size()); + const size_t available_tokens_num = + std::min({num_tokens, + sequence->kv_state().num_kv_blocks() * block_size, + sequence->tokens().size()}); + const size_t existed_shared_blocks_num = + sequence->kv_state().shared_kv_blocks_num(); + if (available_tokens_num <= existed_shared_blocks_num * block_size) { + return; + } + int32_t dp_rank = get_dp_rank(sequence); - const auto token_ids = sequence->cached_tokens(); + const auto token_ids = sequence->tokens().slice(0, available_tokens_num); auto* blocks = sequence->kv_state().mutable_kv_blocks(); - auto existed_shared_blocks_num = sequence->kv_state().shared_kv_blocks_num(); + CHECK_GE(blocks->size(), existed_shared_blocks_num); block_managers_[dp_rank]->cache( token_ids, *blocks, existed_shared_blocks_num); } @@ -399,7 +418,7 @@ double BlockManagerPool::kv_cache_utilization() const { // currently use only for profile, which not need prefix cache. // If more often used in the future, can be integrated into deallocate function. void BlockManagerPool::deallocate_without_cache(Sequence* sequence) { - DCHECK(sequence != nullptr); + CHECK(sequence != nullptr); int32_t dp_rank = get_dp_rank(sequence); block_managers_[dp_rank]->deallocate(sequence->kv_state().kv_blocks()); deallocate_embedding_id(sequence, dp_rank); diff --git a/xllm/core/framework/block/block_manager_pool.h b/xllm/core/framework/block/block_manager_pool.h index b8668f2d4..16c5dfb01 100644 --- a/xllm/core/framework/block/block_manager_pool.h +++ b/xllm/core/framework/block/block_manager_pool.h @@ -66,6 +66,7 @@ class BlockManagerPool : public KVCacheManager { virtual void allocate_shared(Sequence* sequence) override; virtual void cache(Sequence* sequence) override; + virtual void cache(Sequence* sequence, size_t num_tokens) override; virtual std::vector>* get_swap_block_transfer_infos() override; diff --git a/xllm/core/framework/block/block_manager_test.cpp b/xllm/core/framework/block/block_manager_test.cpp index f27c0b7ad..6bf0a6e90 100644 --- a/xllm/core/framework/block/block_manager_test.cpp +++ b/xllm/core/framework/block/block_manager_test.cpp @@ -17,6 +17,8 @@ limitations under the License. #include #include "block_manager_impl.h" +#include "block_manager_pool.h" +#include "framework/request/request.h" namespace xllm { @@ -119,4 +121,46 @@ TEST(BlockManagerTest, Basic) { } } -} // namespace xllm \ No newline at end of file +TEST(BlockManagerPoolTest, + CachePrefixClampsToSequenceTokensWhenBudgetIsOverestimated) { + BlockManagerPool::Options options; + options.num_blocks_ = 16; + options.block_size_ = 4; + options.enable_prefix_cache_ = true; + BlockManagerPool pool(options, 1); + + RequestSamplingParam sampling_param; + SchedulerParam scheduler_param; + StoppingChecker stopping_checker; + stopping_checker.set_max_generated_tokens(1); + stopping_checker.set_max_context_len(64); + stopping_checker.set_ignore_eos(true); + + std::vector prompt_token_ids = {1, 2, 3, 4, 5, 6, 7}; + RequestState request_state("prompt", + prompt_token_ids, + sampling_param, + scheduler_param, + stopping_checker, + /*seq_capacity=*/32, + /*n=*/1, + /*best_of=*/1, + /*logprobs=*/false, + /*stream=*/false, + /*echo=*/false, + /*skip_special_tokens=*/false, + /*enable_schedule_overlap=*/false, + nullptr, + nullptr); + + auto request = std::make_shared( + "request_id", "x_request_id", "x_request_time", request_state, "service"); + auto* sequence = request->sequences()[0].get(); + ASSERT_TRUE(pool.allocate(sequence)); + + EXPECT_NO_FATAL_FAILURE( + pool.cache(sequence, /*num_tokens=*/8)); + EXPECT_EQ(pool.num_blocks_in_prefix_cache()[0], 1); +} + +} // namespace xllm diff --git a/xllm/core/framework/block/hierarchy_block_manager_pool.cpp b/xllm/core/framework/block/hierarchy_block_manager_pool.cpp index 9d8a04140..4ef18a7ac 100644 --- a/xllm/core/framework/block/hierarchy_block_manager_pool.cpp +++ b/xllm/core/framework/block/hierarchy_block_manager_pool.cpp @@ -51,7 +51,7 @@ HierarchyBlockManagerPool::HierarchyBlockManagerPool( } void HierarchyBlockManagerPool::deallocate(Sequence* sequence) { - DCHECK(sequence != nullptr); + CHECK(sequence != nullptr); // add blocks to the prefix cache int32_t dp_rank = BlockManagerPool::get_dp_rank(sequence); BlockManagerPool::cache(sequence); @@ -83,7 +83,13 @@ void HierarchyBlockManagerPool::deallocate(Sequence* sequence) { host_block_managers_[dp_rank]->allocate(needed_block_num)); } - for (size_t i = cached_host_block_num; i < host_blocks->size(); i++) { + // Only offload blocks that are fully computed on device. + // In-batch cache insertion may register blocks before this step is executed, + // so bound offload range by cached_device_block_num to avoid copying + // uncomputed data to host/store. + const size_t offload_end_block_num = + std::min({cached_device_block_num, host_blocks->size(), blocks->size()}); + for (size_t i = cached_host_block_num; i < offload_end_block_num; i++) { if (blocks->at(i).ref_count() != 2) { continue; } diff --git a/xllm/core/framework/block/kv_cache_manager.h b/xllm/core/framework/block/kv_cache_manager.h index bea512303..65e97475f 100644 --- a/xllm/core/framework/block/kv_cache_manager.h +++ b/xllm/core/framework/block/kv_cache_manager.h @@ -55,6 +55,7 @@ class KVCacheManager { virtual void allocate_shared(Sequence* sequence) = 0; virtual void cache(Sequence* sequence) = 0; + virtual void cache(Sequence* sequence, size_t num_tokens) = 0; virtual std::vector>* get_swap_block_transfer_infos() = 0; diff --git a/xllm/core/framework/xtensor/xtensor_manager_pool.cpp b/xllm/core/framework/xtensor/xtensor_manager_pool.cpp index da4a4465d..8f1df435e 100644 --- a/xllm/core/framework/xtensor/xtensor_manager_pool.cpp +++ b/xllm/core/framework/xtensor/xtensor_manager_pool.cpp @@ -155,13 +155,13 @@ int32_t XTensorManagerPool::get_dp_rank(Sequence* sequence) const { } bool XTensorManagerPool::allocate(Sequence* sequence) { - DCHECK(sequence != nullptr); + CHECK(sequence != nullptr); return allocate(sequence, sequence->num_tokens()); } bool XTensorManagerPool::allocate(std::vector& sequences) { for (auto* sequence : sequences) { - DCHECK(sequence != nullptr); + CHECK(sequence != nullptr); if (!allocate(sequence)) { return false; } @@ -189,7 +189,7 @@ void XTensorManagerPool::deallocate(Request* request) { void XTensorManagerPool::deallocate(std::vector& sequences) { for (auto* sequence : sequences) { - DCHECK(sequence != nullptr); + CHECK(sequence != nullptr); deallocate(sequence); } } @@ -244,4 +244,4 @@ double XTensorManagerPool::kv_cache_utilization() const { ->kv_cache_utilization(); } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/framework/xtensor/xtensor_manager_pool.h b/xllm/core/framework/xtensor/xtensor_manager_pool.h index c70de5a9e..1a4c728fc 100644 --- a/xllm/core/framework/xtensor/xtensor_manager_pool.h +++ b/xllm/core/framework/xtensor/xtensor_manager_pool.h @@ -40,6 +40,9 @@ class XTensorManagerPool final : public KVCacheManager { // unimplemented functions void cache(Sequence* sequence) override { NOT_IMPLEMENTED(); } + void cache(Sequence* sequence, size_t num_tokens) override { + NOT_IMPLEMENTED(); + } bool allocate(Sequence* sequence, size_t num_tokens, @@ -108,4 +111,4 @@ class XTensorManagerPool final : public KVCacheManager { std::vector> xtensor_manager_servers_; std::string collective_server_name_; }; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/scheduler/chunked_prefill_scheduler.cpp b/xllm/core/scheduler/chunked_prefill_scheduler.cpp index e38d8b200..5aae77b3d 100644 --- a/xllm/core/scheduler/chunked_prefill_scheduler.cpp +++ b/xllm/core/scheduler/chunked_prefill_scheduler.cpp @@ -166,6 +166,8 @@ void ChunkedPrefillScheduler::handle_running_queue_requests( running_sequences_budgets_.insert(running_sequences_budgets_.end(), candidate_token_budgets.begin(), candidate_token_budgets.end()); + cache_in_batch_prefix(candidate_sequences, + candidate_token_budgets); remaining_token_budget -= allocated_tokens; remaining_seq_budget -= allocated_seqs; estimate_latency += allocated_estimate_latency; @@ -399,6 +401,8 @@ void ChunkedPrefillScheduler::handle_prefill_requests( running_sequences_budgets_.insert(running_sequences_budgets_.end(), prefill_sequences_budget.begin(), prefill_sequences_budget.end()); + cache_in_batch_prefix(prefill_sequences, + prefill_sequences_budget); } // maybe can pre-compute if prompt beyond length if (running_sequences_.empty() && !waiting_priority_queue.empty() && diff --git a/xllm/core/scheduler/continuous_scheduler.cpp b/xllm/core/scheduler/continuous_scheduler.cpp index a0b2fb93a..61bfd1b0f 100644 --- a/xllm/core/scheduler/continuous_scheduler.cpp +++ b/xllm/core/scheduler/continuous_scheduler.cpp @@ -105,7 +105,8 @@ ContinuousScheduler::ContinuousScheduler(Engine* engine, const Options& options) kv_cache_manager_ = engine_->block_manager_pool(); CHECK(kv_cache_manager_ != nullptr); - enable_prefix_cache_ = FLAGS_enable_prefix_cache; + enable_prefix_cache_ = options_.enable_prefix_cache(); + enable_in_batch_prefix_cache_ = options_.enable_in_batch_prefix_cache(); last_batch_.resize(options_.dp_size()); @@ -228,6 +229,25 @@ bool ContinuousScheduler::check_if_enough_to_evict( return false; } +void ContinuousScheduler::cache_in_batch_prefix( + const std::vector& sequences, + const std::vector& current_step_token_budgets) { + if (!enable_prefix_cache_ || !enable_in_batch_prefix_cache_ || sequences.empty()) { + return; + } + CHECK_EQ(sequences.size(), current_step_token_budgets.size()); + for (size_t i = 0; i < sequences.size(); ++i) { + Sequence* sequence = sequences[i]; + if (sequence == nullptr || !sequence->is_prefill_stage()) { + continue; + } + const size_t max_handle_num_tokens = + sequence->kv_state().kv_cache_tokens_num() + + current_step_token_budgets[i]; + kv_cache_manager_->cache(sequence, max_handle_num_tokens); + } +} + void ContinuousScheduler::handle_prefill_requests( double& latency_budget, double& estimate_latency, @@ -401,6 +421,8 @@ void ContinuousScheduler::handle_prefill_requests( running_sequences_budgets_.insert(running_sequences_budgets_.end(), prefill_sequences_budget.begin(), prefill_sequences_budget.end()); + cache_in_batch_prefix(prefill_sequences, + prefill_sequences_budget); } // maybe can pre-compute if prompt beyond length if (running_sequences_.empty() && !waiting_priority_queue.empty() && @@ -993,7 +1015,6 @@ std::vector ContinuousScheduler::prepare_batch() { } else { kv_cache_manager_->transfer_blocks(); } - GAUGE_SET(num_pending_requests, pending_requests_.load(std::memory_order_relaxed)); GAUGE_SET(num_running_requests, running_requests_.size()); diff --git a/xllm/core/scheduler/continuous_scheduler.h b/xllm/core/scheduler/continuous_scheduler.h index c6fb04f34..08e65c620 100644 --- a/xllm/core/scheduler/continuous_scheduler.h +++ b/xllm/core/scheduler/continuous_scheduler.h @@ -91,6 +91,8 @@ class ContinuousScheduler : public Scheduler { PROPERTY(bool, enable_schedule_overlap) = true; PROPERTY(bool, enable_chunked_prefill) = true; + PROPERTY(bool, enable_prefix_cache) = true; + PROPERTY(bool, enable_in_batch_prefix_cache) = true; PROPERTY(bool, enable_service_routing) = false; @@ -218,6 +220,7 @@ class ContinuousScheduler : public Scheduler { std::unique_ptr profile_manager_; bool enable_prefix_cache_ = false; + bool enable_in_batch_prefix_cache_ = false; // the number of requests that are waiting to be scheduled std::atomic pending_requests_{0}; @@ -291,6 +294,10 @@ class ContinuousScheduler : public Scheduler { size_t max_handle_num_tokens, size_t& num_request_to_evict); + void cache_in_batch_prefix( + const std::vector& sequences, + const std::vector& current_step_token_budgets); + // build a batch of requests from the priority queue virtual std::vector prepare_batch(); diff --git a/xllm/core/scheduler/continuous_scheduler_test.cpp b/xllm/core/scheduler/continuous_scheduler_test.cpp index 0df840c34..ed6b34501 100644 --- a/xllm/core/scheduler/continuous_scheduler_test.cpp +++ b/xllm/core/scheduler/continuous_scheduler_test.cpp @@ -333,6 +333,45 @@ TEST(ContinuousSchedulerFactoryTest, opt.max_tokens_per_chunk_for_prefill()); } +TEST(ContinuousSchedulerTest, + InBatchCachePrefillBlocksIncreaseSharedBlocksForLaterRequests) { + const auto run_with_in_batch_prefix_cache = + [](bool enable_in_batch_prefix_cache) -> size_t { + ScopedBoolFlagValue enable_prefix_cache_flag(FLAGS_enable_prefix_cache, + true); + + ContinuousScheduler::Options opt = + create_scheduler_options(1024, 16, 0, 1024, 1); + opt.enable_in_batch_prefix_cache_ = enable_in_batch_prefix_cache; + auto engine = std::make_unique(32, 4, true); + auto scheduler = std::make_unique(engine.get(), opt); + + auto first_request = + generate_request_with_prompt_tokens({1, 2, 3, 4, 5, 6, 7, 8}, 1, 30000); + auto second_request = + generate_request_with_prompt_tokens({1, 2, 3, 4, 5, 6, 7, 8}, 1, 30000); + scheduler->add_request(first_request); + scheduler->add_request(second_request); + + auto batch = scheduler->prepare_batch_test(); + EXPECT_EQ(batch.size(), 1); + EXPECT_EQ(batch[0].size(), 2); + EXPECT_EQ(first_request->sequences()[0]->kv_state().shared_kv_blocks_num(), + 0); + return second_request->sequences()[0]->kv_state().shared_kv_blocks_num(); + }; + + const size_t second_request_shared_blocks_when_enabled = + run_with_in_batch_prefix_cache(true); + const size_t second_request_shared_blocks_when_disabled = + run_with_in_batch_prefix_cache(false); + + EXPECT_GT(second_request_shared_blocks_when_enabled, + second_request_shared_blocks_when_disabled); + EXPECT_GT(second_request_shared_blocks_when_enabled, 0); + EXPECT_EQ(second_request_shared_blocks_when_disabled, 0); +} + // TEST-1: // test preempt TEST(ContinuousSchedulerTest, OnDecodePreemptOffDecode) { diff --git a/xllm/core/scheduler/mix_scheduler.cpp b/xllm/core/scheduler/mix_scheduler.cpp index 5aad7e2b4..ce76440da 100644 --- a/xllm/core/scheduler/mix_scheduler.cpp +++ b/xllm/core/scheduler/mix_scheduler.cpp @@ -402,6 +402,8 @@ void MixScheduler::handle_running_queue_requests( running_sequences_budgets_.insert(running_sequences_budgets_.end(), candidate_token_budgets.begin(), candidate_token_budgets.end()); + cache_in_batch_prefix(candidate_sequences, + candidate_token_budgets); remaining_token_budget -= allocated_tokens; remaining_seq_budget -= allocated_seqs; remaining_copy_blocks_budget -= allocated_copy_blocks; diff --git a/xllm/core/scheduler/prefill_only_scheduler.cpp b/xllm/core/scheduler/prefill_only_scheduler.cpp index 0821edc65..5df641171 100644 --- a/xllm/core/scheduler/prefill_only_scheduler.cpp +++ b/xllm/core/scheduler/prefill_only_scheduler.cpp @@ -211,6 +211,8 @@ void PrefillOnlyScheduler::handle_prefill_requests( running_sequences_budgets_.insert(running_sequences_budgets_.end(), prefill_sequences_budget.begin(), prefill_sequences_budget.end()); + cache_in_batch_prefix(prefill_sequences, + prefill_sequences_budget); } // maybe can pre-compute if prompt beyond length if (running_sequences_.empty() && !waiting_priority_queue.empty() && @@ -397,6 +399,8 @@ void PrefillOnlyScheduler::handle_last_step_prefill_requests( running_sequences_budgets_.insert(running_sequences_budgets_.end(), prefill_sequences_budget.begin(), prefill_sequences_budget.end()); + cache_in_batch_prefix(prefill_sequences, + prefill_sequences_budget); } // maybe can pre-compute if prompt beyond length if (running_sequences_.empty() && !last_step_prefill_requests.empty() && @@ -673,4 +677,4 @@ std::vector PrefillOnlyScheduler::prepare_batch() { return batches; } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index fcb1f98ba..1350aa35d 100644 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -193,6 +193,7 @@ int run() { .max_cache_size(FLAGS_max_cache_size) .max_memory_utilization(FLAGS_max_memory_utilization) .enable_prefix_cache(FLAGS_enable_prefix_cache) + .enable_in_batch_prefix_cache(FLAGS_enable_in_batch_prefix_cache) .max_tokens_per_batch(FLAGS_max_tokens_per_batch) .max_seqs_per_batch(FLAGS_max_seqs_per_batch) .max_tokens_per_chunk_for_prefill(FLAGS_max_tokens_per_chunk_for_prefill)