Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ DEFINE_bool(enable_prefix_cache,
true,
"Whether to enable the prefix cache for the block manager.");

DEFINE_bool(
enable_in_batch_prefix_cache,
true,
"Whether to cache admitted prefill full blocks into prefix cache.");
Comment thread
Clement-Wang26 marked this conversation as resolved.

DEFINE_bool(enable_cache_upload,
false,
"Whether to upload cache info to service. This feature is only "
Expand Down
1 change: 1 addition & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ DECLARE_double(max_memory_utilization);
DECLARE_string(kv_cache_dtype);

DECLARE_bool(enable_prefix_cache);
DECLARE_bool(enable_in_batch_prefix_cache);
Comment thread
Clement-Wang26 marked this conversation as resolved.

DECLARE_bool(enable_cache_upload);

Expand Down
3 changes: 2 additions & 1 deletion xllm/core/common/help_formatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const OptionCategory kCommonOptions = {"COMMON OPTIONS",
"enable_prefill_sp",
"enable_schedule_overlap",
"enable_prefix_cache",
"enable_in_batch_prefix_cache",
Comment thread
Clement-Wang26 marked this conversation as resolved.
"enable_shm",
"enable_graph",
"enable_graph_mode_decode_no_padding",
Expand Down Expand Up @@ -101,7 +102,7 @@ const OptionCategory kBeamSearchOptions = {"BEAM SEARCH OPTIONS",

const OptionCategory kPrefixCacheOptions = {
"PREFIX CACHE OPTIONS",
{"enable_prefix_cache", "xxh3_128bits_seed"}};
{"enable_prefix_cache", "enable_in_batch_prefix_cache", "xxh3_128bits_seed"}};
Comment thread
Clement-Wang26 marked this conversation as resolved.

const OptionCategory kOtherOptions = {
"OTHER OPTIONS",
Expand Down
1 change: 1 addition & 0 deletions xllm/core/common/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ std::string Options::to_string() const {
<< ", max_cache_size: " << max_cache_size()
<< ", max_memory_utilization: " << max_memory_utilization()
<< ", enable_prefix_cache: " << enable_prefix_cache()
<< ", enable_in_batch_prefix_cache: " << enable_in_batch_prefix_cache()
<< ", max_tokens_per_batch: " << max_tokens_per_batch()
<< ", max_seqs_per_batch: " << max_seqs_per_batch()
<< ", max_tokens_per_chunk_for_prefill: "
Expand Down
1 change: 1 addition & 0 deletions xllm/core/common/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class Options {
PROPERTY(double, max_memory_utilization) = 0.9;

PROPERTY(bool, enable_prefix_cache) = true;
PROPERTY(bool, enable_in_batch_prefix_cache) = true;

// max tokens num per batch
PROPERTY(int32_t, max_tokens_per_batch) = 20480;
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/distributed_runtime/llm_master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ LLMMaster::LLMMaster(const Options& options)
.enable_pd_ooc(options_.enable_pd_ooc())
.enable_schedule_overlap(options_.enable_schedule_overlap())
.enable_chunked_prefill(options_.enable_chunked_prefill())
.enable_prefix_cache(options_.enable_prefix_cache())
.enable_in_batch_prefix_cache(options_.enable_in_batch_prefix_cache())
.instance_name(options_.instance_name())
.instance_role(options_.instance_role())
.kv_cache_transfer_mode(options_.kv_cache_transfer_mode())
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/distributed_runtime/rec_master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ RecMaster::RecMaster(const Options& options)
.enable_disagg_pd(options_.enable_disagg_pd())
.enable_schedule_overlap(options_.enable_schedule_overlap())
.enable_chunked_prefill(options_.enable_chunked_prefill())
.enable_prefix_cache(options_.enable_prefix_cache())
.enable_in_batch_prefix_cache(options_.enable_in_batch_prefix_cache())
.instance_role(options_.instance_role())
.kv_cache_transfer_mode(options_.kv_cache_transfer_mode())
.enable_service_routing(options_.enable_service_routing())
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/distributed_runtime/vlm_master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ VLMMaster::VLMMaster(const Options& options)
options.max_tokens_per_chunk_for_prefill())
.enable_disagg_pd(options_.enable_disagg_pd())
.enable_chunked_prefill(options_.enable_chunked_prefill())
.enable_prefix_cache(options_.enable_prefix_cache())
.enable_in_batch_prefix_cache(options_.enable_in_batch_prefix_cache())
.instance_name(options_.instance_name())
.instance_role(options_.instance_role())
.kv_cache_transfer_mode(options_.kv_cache_transfer_mode())
Expand Down
37 changes: 28 additions & 9 deletions xllm/core/framework/block/block_manager_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#include "block_manager_pool.h"

#include <algorithm>
#include <limits>

#include "block_manager_impl.h"
#include "common/global_flags.h"
Expand Down Expand Up @@ -125,7 +124,7 @@ bool BlockManagerPool::allocate_embedding_id(Sequence* sequence,

void BlockManagerPool::deallocate_embedding_id(Sequence* sequence,
int32_t dp_rank) {
DCHECK(sequence != nullptr);
CHECK(sequence != nullptr);
CHECK_GE(dp_rank, 0);
CHECK_LT(static_cast<size_t>(dp_rank), embedding_managers_.size());
auto embedding_block = sequence->reset_embedding_block();
Expand All @@ -152,7 +151,7 @@ void BlockManagerPool::deallocate(std::vector<Sequence*>& sequences) {
}

void BlockManagerPool::deallocate(Sequence* sequence) {
DCHECK(sequence != nullptr);
CHECK(sequence != nullptr);
// add blocks to the prefix cache
int32_t dp_rank = get_dp_rank(sequence);
cache(sequence);
Expand All @@ -173,13 +172,13 @@ void BlockManagerPool::reset_transfer_infos() {
}

bool BlockManagerPool::allocate(Sequence* sequence) {
DCHECK(sequence != nullptr);
CHECK(sequence != nullptr);
return allocate(sequence, sequence->num_tokens());
}

bool BlockManagerPool::allocate(std::vector<Sequence*>& sequences) {
for (auto* sequence : sequences) {
DCHECK(sequence != nullptr);
CHECK(sequence != nullptr);
if (!allocate(sequence, sequence->num_tokens())) {
// should we gurantee the atomicity of the allocation? all or nothing?
return false;
Expand All @@ -190,7 +189,7 @@ bool BlockManagerPool::allocate(std::vector<Sequence*>& sequences) {

bool BlockManagerPool::allocate(Sequence* sequence, size_t num_tokens) {
AUTO_COUNTER(allocate_blocks_latency_seconds);
DCHECK(sequence != nullptr);
CHECK(sequence != nullptr);
int32_t dp_rank = get_dp_rank(sequence);
const bool started_empty = sequence->kv_state().num_kv_blocks() == 0;
const bool needs_embedding_id = !sequence->has_embedding_id();
Expand Down Expand Up @@ -340,10 +339,30 @@ void BlockManagerPool::allocate_shared(Sequence* sequence) {
}

void BlockManagerPool::cache(Sequence* sequence) {
cache(sequence, sequence->kv_state().kv_cache_tokens_num());
}

void BlockManagerPool::cache(Sequence* sequence, size_t num_tokens) {
CHECK(sequence != nullptr);
if (!options_.enable_prefix_cache()) {
return;
}

const size_t block_size = static_cast<size_t>(options_.block_size());
const size_t available_tokens_num =
std::min({num_tokens,
sequence->kv_state().num_kv_blocks() * block_size,
sequence->tokens().size()});
const size_t existed_shared_blocks_num =
sequence->kv_state().shared_kv_blocks_num();
if (available_tokens_num <= existed_shared_blocks_num * block_size) {
return;
}

int32_t dp_rank = get_dp_rank(sequence);
const auto token_ids = sequence->cached_tokens();
const auto token_ids = sequence->tokens().slice(0, available_tokens_num);
auto* blocks = sequence->kv_state().mutable_kv_blocks();
auto existed_shared_blocks_num = sequence->kv_state().shared_kv_blocks_num();
CHECK_GE(blocks->size(), existed_shared_blocks_num);
block_managers_[dp_rank]->cache(
token_ids, *blocks, existed_shared_blocks_num);
}
Expand Down Expand Up @@ -399,7 +418,7 @@ double BlockManagerPool::kv_cache_utilization() const {
// currently use only for profile, which not need prefix cache.
// If more often used in the future, can be integrated into deallocate function.
void BlockManagerPool::deallocate_without_cache(Sequence* sequence) {
DCHECK(sequence != nullptr);
CHECK(sequence != nullptr);
int32_t dp_rank = get_dp_rank(sequence);
block_managers_[dp_rank]->deallocate(sequence->kv_state().kv_blocks());
deallocate_embedding_id(sequence, dp_rank);
Expand Down
1 change: 1 addition & 0 deletions xllm/core/framework/block/block_manager_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class BlockManagerPool : public KVCacheManager {

virtual void allocate_shared(Sequence* sequence) override;
virtual void cache(Sequence* sequence) override;
virtual void cache(Sequence* sequence, size_t num_tokens) override;

virtual std::vector<std::vector<BlockTransferInfo>>*
get_swap_block_transfer_infos() override;
Expand Down
46 changes: 45 additions & 1 deletion xllm/core/framework/block/block_manager_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
#include <gtest/gtest.h>

#include "block_manager_impl.h"
#include "block_manager_pool.h"
#include "framework/request/request.h"

namespace xllm {

Expand Down Expand Up @@ -119,4 +121,46 @@ TEST(BlockManagerTest, Basic) {
}
}

} // namespace xllm
TEST(BlockManagerPoolTest,
CachePrefixClampsToSequenceTokensWhenBudgetIsOverestimated) {
BlockManagerPool::Options options;
options.num_blocks_ = 16;
options.block_size_ = 4;
options.enable_prefix_cache_ = true;
BlockManagerPool pool(options, 1);

RequestSamplingParam sampling_param;
SchedulerParam scheduler_param;
StoppingChecker stopping_checker;
stopping_checker.set_max_generated_tokens(1);
stopping_checker.set_max_context_len(64);
stopping_checker.set_ignore_eos(true);

std::vector<int32_t> prompt_token_ids = {1, 2, 3, 4, 5, 6, 7};
RequestState request_state("prompt",
prompt_token_ids,
sampling_param,
scheduler_param,
stopping_checker,
/*seq_capacity=*/32,
/*n=*/1,
/*best_of=*/1,
/*logprobs=*/false,
/*stream=*/false,
/*echo=*/false,
/*skip_special_tokens=*/false,
/*enable_schedule_overlap=*/false,
nullptr,
nullptr);

auto request = std::make_shared<Request>(
"request_id", "x_request_id", "x_request_time", request_state, "service");
auto* sequence = request->sequences()[0].get();
ASSERT_TRUE(pool.allocate(sequence));

EXPECT_NO_FATAL_FAILURE(
pool.cache(sequence, /*num_tokens=*/8));
EXPECT_EQ(pool.num_blocks_in_prefix_cache()[0], 1);
}

} // namespace xllm
10 changes: 8 additions & 2 deletions xllm/core/framework/block/hierarchy_block_manager_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ HierarchyBlockManagerPool::HierarchyBlockManagerPool(
}

void HierarchyBlockManagerPool::deallocate(Sequence* sequence) {
DCHECK(sequence != nullptr);
CHECK(sequence != nullptr);
// add blocks to the prefix cache
int32_t dp_rank = BlockManagerPool::get_dp_rank(sequence);
BlockManagerPool::cache(sequence);
Expand Down Expand Up @@ -83,7 +83,13 @@ void HierarchyBlockManagerPool::deallocate(Sequence* sequence) {
host_block_managers_[dp_rank]->allocate(needed_block_num));
}

for (size_t i = cached_host_block_num; i < host_blocks->size(); i++) {
// Only offload blocks that are fully computed on device.
// In-batch cache insertion may register blocks before this step is executed,
// so bound offload range by cached_device_block_num to avoid copying
// uncomputed data to host/store.
const size_t offload_end_block_num =
std::min({cached_device_block_num, host_blocks->size(), blocks->size()});
for (size_t i = cached_host_block_num; i < offload_end_block_num; i++) {
if (blocks->at(i).ref_count() != 2) {
continue;
}
Expand Down
1 change: 1 addition & 0 deletions xllm/core/framework/block/kv_cache_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class KVCacheManager {

virtual void allocate_shared(Sequence* sequence) = 0;
virtual void cache(Sequence* sequence) = 0;
virtual void cache(Sequence* sequence, size_t num_tokens) = 0;

virtual std::vector<std::vector<BlockTransferInfo>>*
get_swap_block_transfer_infos() = 0;
Expand Down
8 changes: 4 additions & 4 deletions xllm/core/framework/xtensor/xtensor_manager_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,13 @@ int32_t XTensorManagerPool::get_dp_rank(Sequence* sequence) const {
}

bool XTensorManagerPool::allocate(Sequence* sequence) {
DCHECK(sequence != nullptr);
CHECK(sequence != nullptr);
return allocate(sequence, sequence->num_tokens());
}

bool XTensorManagerPool::allocate(std::vector<Sequence*>& sequences) {
for (auto* sequence : sequences) {
DCHECK(sequence != nullptr);
CHECK(sequence != nullptr);
if (!allocate(sequence)) {
return false;
}
Expand Down Expand Up @@ -189,7 +189,7 @@ void XTensorManagerPool::deallocate(Request* request) {

void XTensorManagerPool::deallocate(std::vector<Sequence*>& sequences) {
for (auto* sequence : sequences) {
DCHECK(sequence != nullptr);
CHECK(sequence != nullptr);
deallocate(sequence);
}
}
Expand Down Expand Up @@ -244,4 +244,4 @@ double XTensorManagerPool::kv_cache_utilization() const {
->kv_cache_utilization();
}

} // namespace xllm
} // namespace xllm
5 changes: 4 additions & 1 deletion xllm/core/framework/xtensor/xtensor_manager_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class XTensorManagerPool final : public KVCacheManager {

// unimplemented functions
void cache(Sequence* sequence) override { NOT_IMPLEMENTED(); }
void cache(Sequence* sequence, size_t num_tokens) override {
NOT_IMPLEMENTED();
}

bool allocate(Sequence* sequence,
size_t num_tokens,
Expand Down Expand Up @@ -108,4 +111,4 @@ class XTensorManagerPool final : public KVCacheManager {
std::vector<std::unique_ptr<XTensorManagerServer>> xtensor_manager_servers_;
std::string collective_server_name_;
};
} // namespace xllm
} // namespace xllm
4 changes: 4 additions & 0 deletions xllm/core/scheduler/chunked_prefill_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ void ChunkedPrefillScheduler::handle_running_queue_requests(
running_sequences_budgets_.insert(running_sequences_budgets_.end(),
candidate_token_budgets.begin(),
candidate_token_budgets.end());
cache_in_batch_prefix(candidate_sequences,
candidate_token_budgets);
remaining_token_budget -= allocated_tokens;
remaining_seq_budget -= allocated_seqs;
estimate_latency += allocated_estimate_latency;
Expand Down Expand Up @@ -399,6 +401,8 @@ void ChunkedPrefillScheduler::handle_prefill_requests(
running_sequences_budgets_.insert(running_sequences_budgets_.end(),
prefill_sequences_budget.begin(),
prefill_sequences_budget.end());
cache_in_batch_prefix(prefill_sequences,
prefill_sequences_budget);
}
// maybe can pre-compute if prompt beyond length
if (running_sequences_.empty() && !waiting_priority_queue.empty() &&
Expand Down
25 changes: 23 additions & 2 deletions xllm/core/scheduler/continuous_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ ContinuousScheduler::ContinuousScheduler(Engine* engine, const Options& options)
kv_cache_manager_ = engine_->block_manager_pool();
CHECK(kv_cache_manager_ != nullptr);

enable_prefix_cache_ = FLAGS_enable_prefix_cache;
enable_prefix_cache_ = options_.enable_prefix_cache();
enable_in_batch_prefix_cache_ = options_.enable_in_batch_prefix_cache();
Comment thread
Clement-Wang26 marked this conversation as resolved.

last_batch_.resize(options_.dp_size());

Expand Down Expand Up @@ -228,6 +229,25 @@ bool ContinuousScheduler::check_if_enough_to_evict(
return false;
}

void ContinuousScheduler::cache_in_batch_prefix(
const std::vector<Sequence*>& sequences,
const std::vector<size_t>& current_step_token_budgets) {
if (!enable_prefix_cache_ || !enable_in_batch_prefix_cache_ || sequences.empty()) {
return;
}
CHECK_EQ(sequences.size(), current_step_token_budgets.size());
for (size_t i = 0; i < sequences.size(); ++i) {
Sequence* sequence = sequences[i];
if (sequence == nullptr || !sequence->is_prefill_stage()) {
continue;
}
const size_t max_handle_num_tokens =
sequence->kv_state().kv_cache_tokens_num() +
current_step_token_budgets[i];
kv_cache_manager_->cache(sequence, max_handle_num_tokens);
}
}

void ContinuousScheduler::handle_prefill_requests(
double& latency_budget,
double& estimate_latency,
Expand Down Expand Up @@ -401,6 +421,8 @@ void ContinuousScheduler::handle_prefill_requests(
running_sequences_budgets_.insert(running_sequences_budgets_.end(),
prefill_sequences_budget.begin(),
prefill_sequences_budget.end());
cache_in_batch_prefix(prefill_sequences,
prefill_sequences_budget);
}
// maybe can pre-compute if prompt beyond length
if (running_sequences_.empty() && !waiting_priority_queue.empty() &&
Expand Down Expand Up @@ -993,7 +1015,6 @@ std::vector<Batch> ContinuousScheduler::prepare_batch() {
} else {
kv_cache_manager_->transfer_blocks();
}

GAUGE_SET(num_pending_requests,
pending_requests_.load(std::memory_order_relaxed));
GAUGE_SET(num_running_requests, running_requests_.size());
Expand Down
Loading
Loading