Skip to content

Commit f73e410

Browse files
authored
refactor: refactor the creation of kvcache shape. (#1320)
1 parent 8ff6030 commit f73e410

48 files changed

Lines changed: 901 additions & 639 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 4 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -68,41 +68,9 @@ bool CommChannel::check_health() {
6868
return resp.ok();
6969
}
7070

71-
bool CommChannel::allocate_kv_cache(
72-
const std::vector<std::vector<int64_t>>& kv_cache_shape) {
71+
bool CommChannel::allocate_kv_cache(const KVCacheShape& kv_cache_shape) {
7372
proto::AllocateKVCacheRequest request;
74-
75-
auto* shape = request.mutable_kv_cache_shape();
76-
shape->mutable_key_shape()->Reserve(kv_cache_shape[0].size());
77-
shape->mutable_value_shape()->Reserve(kv_cache_shape[1].size());
78-
79-
// add key shape
80-
for (size_t i = 0; i < kv_cache_shape[0].size(); ++i) {
81-
shape->add_key_shape(kv_cache_shape[0][i]);
82-
}
83-
84-
// add value shape
85-
for (size_t i = 0; i < kv_cache_shape[1].size(); ++i) {
86-
shape->add_value_shape(kv_cache_shape[1][i]);
87-
}
88-
89-
// add index shape if exists
90-
if (kv_cache_shape.size() == kKVCacheShapeSizeWithIndex) {
91-
shape->mutable_index_shape()->Reserve(kv_cache_shape[2].size());
92-
for (size_t i = 0; i < kv_cache_shape[2].size(); ++i) {
93-
shape->add_index_shape(kv_cache_shape[2][i]);
94-
}
95-
} else if (kv_cache_shape.size() == kKVCacheShapeSizeWithConvAndSsm) {
96-
// Use for Qwen-3.5, Qwen3-next, etc
97-
shape->mutable_conv_shape()->Reserve(kv_cache_shape[2].size());
98-
shape->mutable_ssm_shape()->Reserve(kv_cache_shape[3].size());
99-
for (size_t i = 0; i < kv_cache_shape[2].size(); ++i) {
100-
shape->add_conv_shape(kv_cache_shape[2][i]);
101-
}
102-
for (size_t i = 0; i < kv_cache_shape[3].size(); ++i) {
103-
shape->add_ssm_shape(kv_cache_shape[3][i]);
104-
}
105-
}
73+
kv_cache_shape.to_proto(request.mutable_kv_cache_shape());
10674
proto::Status s;
10775
brpc::Controller cntl;
10876
stub_->AllocateKVCache(&cntl, &request, &s, nullptr);
@@ -331,30 +299,9 @@ bool CommChannel::process_group_test() {
331299
}
332300

333301
bool CommChannel::allocate_kv_cache_with_transfer(
334-
const std::vector<std::vector<int64_t>>& kv_cache_shape) {
302+
const KVCacheShape& kv_cache_shape) {
335303
proto::AllocateKVCacheRequest request;
336-
337-
auto* shape = request.mutable_kv_cache_shape();
338-
shape->mutable_key_shape()->Reserve(kv_cache_shape[0].size());
339-
shape->mutable_value_shape()->Reserve(kv_cache_shape[1].size());
340-
341-
// add key shape
342-
for (size_t i = 0; i < kv_cache_shape[0].size(); ++i) {
343-
shape->add_key_shape(kv_cache_shape[0][i]);
344-
}
345-
346-
// add value shape
347-
for (size_t i = 0; i < kv_cache_shape[1].size(); ++i) {
348-
shape->add_value_shape(kv_cache_shape[1][i]);
349-
}
350-
351-
// add index shape if exists
352-
if (kv_cache_shape.size() > 2) {
353-
shape->mutable_index_shape()->Reserve(kv_cache_shape[2].size());
354-
for (size_t i = 0; i < kv_cache_shape[2].size(); ++i) {
355-
shape->add_index_shape(kv_cache_shape[2][i]);
356-
}
357-
}
304+
kv_cache_shape.to_proto(request.mutable_kv_cache_shape());
358305

359306
proto::Status s;
360307
brpc::Controller cntl;

xllm/core/distributed_runtime/comm_channel.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,14 @@ limitations under the License.
2424
#include <vector>
2525

2626
#include "common/types.h"
27+
#include "framework/kv_cache/kv_cache_shape.h"
2728
#include "framework/xtensor/xtensor.h"
2829
#include "runtime/forward_params.h"
2930
#include "runtime/params_utils.h"
3031
#include "worker.pb.h"
3132

3233
namespace xllm {
3334

34-
static constexpr size_t kKVCacheShapeSizeWithIndex = 3;
35-
static constexpr size_t kKVCacheShapeSizeWithConvAndSsm = 4;
36-
3735
class CommChannel {
3836
public:
3937
CommChannel() = default;
@@ -43,8 +41,7 @@ class CommChannel {
4341

4442
virtual bool hello();
4543

46-
virtual bool allocate_kv_cache(
47-
const std::vector<std::vector<int64_t>>& kv_cache_shape);
44+
virtual bool allocate_kv_cache(const KVCacheShape& kv_cache_shape);
4845

4946
virtual bool get_device_info(std::string& device_ip, uint16_t& port);
5047

@@ -93,7 +90,7 @@ class CommChannel {
9390
virtual bool process_group_test();
9491

9592
virtual bool allocate_kv_cache_with_transfer(
96-
const std::vector<std::vector<int64_t>>& kv_cache_shape);
93+
const KVCacheShape& kv_cache_shape);
9794

9895
virtual void transfer_kv_blocks(
9996
const std::vector<BlockTransferInfo>& block_transfer_info,

xllm/core/distributed_runtime/engine.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,20 +174,6 @@ class Engine {
174174
return false;
175175
};
176176

177-
struct KVCacheCapacity {
178-
int64_t n_blocks = 0;
179-
int64_t n_pages = 0; // for continuous kvcache
180-
int64_t cache_size_in_bytes = 0;
181-
int64_t slot_size = 0;
182-
int64_t index_slot_size = 0;
183-
int64_t linear_slot_size = 0;
184-
int64_t linear_cache_size_in_bytes = 0;
185-
int64_t n_layers = 0;
186-
int64_t num_linear_state_blocks = 0;
187-
int64_t num_full_attention_layers = 0;
188-
int64_t num_linear_attention_layers = 0;
189-
};
190-
191177
protected:
192178
// model args
193179
ModelArgs args_;

xllm/core/distributed_runtime/llm_engine.cpp

Lines changed: 51 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ limitations under the License.
3333
#include "common/metrics.h"
3434
#include "common/options.h"
3535
#include "framework/block/hierarchy_block_manager_pool.h"
36+
#include "framework/kv_cache/kv_cache_shape.h"
3637
#include "framework/model/model_args.h"
3738
#include "framework/model_loader.h"
3839
#include "framework/xtensor/page_allocator.h"
@@ -373,7 +374,7 @@ int64_t LLMEngine::get_effective_xtensor_weight_size(
373374
return total_weight_size;
374375
}
375376

376-
Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
377+
KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
377378
const int64_t max_cache_size = options_.max_cache_size();
378379
const double max_memory_utilization = options_.max_memory_utilization();
379380

@@ -426,16 +427,17 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
426427
}
427428
}
428429

429-
Engine::KVCacheCapacity kv_cache_cap;
430-
kv_cache_cap.cache_size_in_bytes = std::max(cache_size_in_bytes, int64_t(0));
431-
CHECK_GT(kv_cache_cap.cache_size_in_bytes, 0)
430+
KVCacheCapacity kv_cache_cap;
431+
kv_cache_cap.cache_size_in_bytes() =
432+
std::max(cache_size_in_bytes, int64_t(0));
433+
CHECK_GT(kv_cache_cap.cache_size_in_bytes(), 0)
432434
<< "Available kv cache size must be greater than 0";
433435
GAUGE_SET(total_kv_cache_size_in_kilobytes,
434-
kv_cache_cap.cache_size_in_bytes / 1024);
436+
kv_cache_cap.cache_size_in_bytes() / 1024);
435437

436438
for (auto& device : options_.devices()) {
437439
DeviceMonitor::get_instance().set_total_kv_cache_memory(
438-
device.index(), kv_cache_cap.cache_size_in_bytes);
440+
device.index(), kv_cache_cap.cache_size_in_bytes());
439441
DeviceMonitor::get_instance().set_total_activation_memory(device.index());
440442
}
441443

@@ -484,7 +486,7 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
484486
// => per token: n_kv_heads floats for K + n_kv_heads for V.
485487
// MLA: key scale [num_blocks, 1, block_size] => one float per token.
486488
if (enable_kv_cache_quant) {
487-
if (options_.enable_mla()) {
489+
if (args_.enable_mla()) {
488490
// MLA scale shape is [num_blocks, 1, block_size] -> one float per token
489491
scale_slot_size = sizeof(float);
490492
} else {
@@ -511,165 +513,88 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
511513
(args_.linear_conv_kernel_dim() - 1);
512514
linear_slot_size = linear_ssm_slot_size + linear_conv_slot_size;
513515
}
514-
kv_cache_cap.slot_size = slot_size;
515-
kv_cache_cap.index_slot_size = index_slot_size;
516-
kv_cache_cap.linear_slot_size = linear_slot_size;
517-
kv_cache_cap.n_layers = args_.n_layers();
516+
kv_cache_cap.slot_size() = slot_size;
517+
kv_cache_cap.index_slot_size() = index_slot_size;
518+
kv_cache_cap.linear_slot_size() = linear_slot_size;
519+
kv_cache_cap.n_layers() = args_.n_layers();
520+
kv_cache_cap.block_size() = options_.block_size();
518521
#if !defined(USE_NPU)
519522
// this adoption is because the allocation of kv cache is based on
520523
// the number of layers, and the draft engine is using the same model as the
521524
// target engine.
522525
// so we need to override the number of layers for the draft engine.
523526
if (options_.is_draft_engine()) {
524-
kv_cache_cap.n_layers = args_.num_nextn_predict_layers();
527+
kv_cache_cap.n_layers() = args_.num_nextn_predict_layers();
525528
}
526529
#endif
527530

528-
kv_cache_cap.num_linear_state_blocks = FLAGS_max_seqs_per_batch + 2;
529-
for (int64_t layer_id = 0; layer_id < kv_cache_cap.n_layers; ++layer_id) {
531+
kv_cache_cap.num_linear_state_blocks() = FLAGS_max_seqs_per_batch + 2;
532+
for (int64_t layer_id = 0; layer_id < kv_cache_cap.n_layers(); ++layer_id) {
530533
if (is_full_attention_layer(args_, layer_id)) {
531-
++kv_cache_cap.num_full_attention_layers;
534+
++kv_cache_cap.num_full_attention_layers();
532535
} else {
533-
++kv_cache_cap.num_linear_attention_layers;
536+
++kv_cache_cap.num_linear_attention_layers();
534537
}
535538
}
536539

537540
// compute kv cache n_blocks
538-
const int32_t block_size = options_.block_size();
541+
const int64_t block_size = kv_cache_cap.block_size();
539542
const int64_t block_size_in_bytes =
540543
block_size * (slot_size + index_slot_size + scale_slot_size);
541-
kv_cache_cap.linear_cache_size_in_bytes =
542-
kv_cache_cap.num_linear_attention_layers *
543-
kv_cache_cap.num_linear_state_blocks * kv_cache_cap.linear_slot_size;
544+
kv_cache_cap.linear_cache_size_in_bytes() =
545+
kv_cache_cap.num_linear_attention_layers() *
546+
kv_cache_cap.num_linear_state_blocks() * kv_cache_cap.linear_slot_size();
544547
const int64_t available_full_cache_size_in_bytes =
545-
kv_cache_cap.cache_size_in_bytes -
546-
kv_cache_cap.linear_cache_size_in_bytes;
547-
if (kv_cache_cap.linear_slot_size > 0) {
548-
CHECK_GT(kv_cache_cap.cache_size_in_bytes,
549-
kv_cache_cap.linear_cache_size_in_bytes)
548+
kv_cache_cap.cache_size_in_bytes() -
549+
kv_cache_cap.linear_cache_size_in_bytes();
550+
if (kv_cache_cap.linear_slot_size() > 0) {
551+
CHECK_GT(kv_cache_cap.cache_size_in_bytes(),
552+
kv_cache_cap.linear_cache_size_in_bytes())
550553
<< "failed to reserve linear state cache for linear-attention layers: "
551554
<< "max_seqs_per_batch (" << FLAGS_max_seqs_per_batch
552555
<< ") is too large. Please reduce max_seqs_per_batch to less than "
553-
<< kv_cache_cap.cache_size_in_bytes /
554-
(kv_cache_cap.num_linear_attention_layers *
555-
kv_cache_cap.linear_slot_size) -
556+
<< kv_cache_cap.cache_size_in_bytes() /
557+
(kv_cache_cap.num_linear_attention_layers() *
558+
kv_cache_cap.linear_slot_size()) -
556559
2;
557560
}
558561
CHECK_GT(available_full_cache_size_in_bytes, 0)
559562
<< "no memory left for full-attention kv cache after reserving linear "
560563
"state cache";
561564
const int64_t full_attention_layers =
562-
std::max<int64_t>(kv_cache_cap.num_full_attention_layers, 1);
563-
kv_cache_cap.n_blocks = available_full_cache_size_in_bytes /
564-
(full_attention_layers * block_size_in_bytes);
565-
CHECK_GT(kv_cache_cap.n_blocks, 0) << "no n_blocks for kv cache";
565+
std::max<int64_t>(kv_cache_cap.num_full_attention_layers(), 1);
566+
kv_cache_cap.n_blocks() = available_full_cache_size_in_bytes /
567+
(full_attention_layers * block_size_in_bytes);
568+
CHECK_GT(kv_cache_cap.n_blocks(), 0) << "no n_blocks for kv cache";
566569
return kv_cache_cap;
567570
}
568571

569-
bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
572+
bool LLMEngine::allocate_kv_cache(const KVCacheCapacity& kv_cache_cap) {
570573
LOG(INFO) << "kv cache capacity: "
571-
<< readable_size(kv_cache_cap.cache_size_in_bytes)
572-
<< ", blocks: " << kv_cache_cap.n_blocks
573-
<< ", slot_size: " << kv_cache_cap.slot_size
574-
<< ", linear_slot_size: " << kv_cache_cap.linear_slot_size
575-
<< ", linear_blocks: " << kv_cache_cap.num_linear_state_blocks
574+
<< readable_size(kv_cache_cap.cache_size_in_bytes())
575+
<< ", blocks: " << kv_cache_cap.n_blocks()
576+
<< ", slot_size: " << kv_cache_cap.slot_size()
577+
<< ", linear_slot_size: " << kv_cache_cap.linear_slot_size()
578+
<< ", linear_blocks: " << kv_cache_cap.num_linear_state_blocks()
576579
<< ", reserved_linear_bytes: "
577-
<< readable_size(kv_cache_cap.linear_cache_size_in_bytes)
578-
<< ", n_layers: " << kv_cache_cap.n_layers
580+
<< readable_size(kv_cache_cap.linear_cache_size_in_bytes())
581+
<< ", n_layers: " << kv_cache_cap.n_layers()
579582
<< ", kv_cache_dtype: " << options_.kv_cache_dtype();
580583

581-
CHECK_GT(kv_cache_cap.n_blocks, 0) << "no memory for kv cache";
582-
const int32_t block_size = options_.block_size();
583-
bool enable_lighting_indexer = args_.index_n_heads() > 0;
584-
bool enable_gdn_attention = has_linear_attention_layers(args_);
584+
CHECK_GT(kv_cache_cap.n_blocks(), 0) << "no memory for kv cache";
585+
const int32_t block_size = static_cast<int32_t>(kv_cache_cap.block_size());
586+
const bool enable_gdn_attention = has_linear_attention_layers(args_);
585587

586588
// init kv cache for each worker
587-
std::vector<std::vector<int64_t>> kv_cache_shape;
588-
kv_cache_shape.reserve(2);
589-
if (options_.enable_mla()) {
590-
#if defined(USE_NPU)
591-
if (args_.model_type() == "deepseek_v3" && FLAGS_enable_prefix_cache) {
592-
kv_cache_shape.emplace_back(
593-
std::vector<int64_t>{kv_cache_cap.n_blocks,
594-
(args_.kv_lora_rank() + 15) / 16,
595-
block_size,
596-
16});
597-
kv_cache_shape.emplace_back(
598-
std::vector<int64_t>{kv_cache_cap.n_blocks,
599-
(args_.qk_rope_head_dim() + 15) / 16,
600-
block_size,
601-
16});
602-
} else {
603-
kv_cache_shape.emplace_back(std::vector<int64_t>{
604-
kv_cache_cap.n_blocks, block_size, 1, args_.kv_lora_rank()});
605-
kv_cache_shape.emplace_back(std::vector<int64_t>{
606-
kv_cache_cap.n_blocks, block_size, 1, args_.qk_rope_head_dim()});
607-
}
608-
#else
609-
kv_cache_shape.emplace_back(std::vector<int64_t>{
610-
kv_cache_cap.n_blocks, block_size, 1, args_.kv_lora_rank()});
611-
kv_cache_shape.emplace_back(std::vector<int64_t>{
612-
kv_cache_cap.n_blocks, block_size, 1, args_.qk_rope_head_dim()});
613-
#endif
614-
} else {
615-
kv_cache_shape.emplace_back(std::vector<int64_t>{
616-
kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_});
617-
kv_cache_shape.emplace_back(std::vector<int64_t>{
618-
kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_});
619-
}
620-
if (enable_lighting_indexer) {
621-
kv_cache_shape.emplace_back(std::vector<int64_t>{
622-
kv_cache_cap.n_blocks, block_size, 1, args_.index_head_dim()});
623-
}
624-
if (enable_gdn_attention) {
625-
kv_cache_shape.emplace_back(std::vector<int64_t>{
626-
kv_cache_cap.num_linear_state_blocks,
627-
args_.linear_conv_kernel_dim() - 1,
628-
args_.linear_key_head_dim() * n_local_linear_k_heads_ * 2 +
629-
args_.linear_key_head_dim() * n_local_linear_v_heads_});
630-
kv_cache_shape.emplace_back(
631-
std::vector<int64_t>{kv_cache_cap.num_linear_state_blocks,
632-
n_local_linear_v_heads_,
633-
args_.linear_key_head_dim(),
634-
args_.linear_value_head_dim()});
635-
}
636-
#if defined(USE_MLU)
637-
// transpose kv_cache layout for mlu
638-
// default layout: [n_blocks, block_size, n_head, head_dim]
639-
// => mlu layout: [n_blocks, n_head, block_size, head_dim]
640-
for (auto& shape : kv_cache_shape) {
641-
std::swap(shape[1], shape[2]);
642-
}
643-
if (options_.enable_mla()) {
644-
kv_cache_shape[0][3] = args_.kv_lora_rank() + args_.qk_rope_head_dim();
645-
kv_cache_shape[1] = std::vector<int64_t>{};
646-
}
647-
#endif
589+
const KVCacheShape kv_cache_shape(kv_cache_cap, args_, dp_local_tp_size_);
648590

649-
#if defined(USE_ILU)
650-
for (auto& shape : kv_cache_shape) {
651-
std::swap(shape[1], shape[2]);
652-
}
653-
#endif
654-
LOG(INFO) << "Initializing k cache with shape: [" << kv_cache_shape[0] << "]";
655-
LOG(INFO) << "Initializing v cache with shape: [" << kv_cache_shape[1] << "]";
656-
if (enable_lighting_indexer) {
657-
LOG(INFO) << "Initializing indexer cache with shape: [" << kv_cache_shape[2]
658-
<< "]";
659-
}
660-
if (enable_gdn_attention) {
661-
LOG(INFO) << "GND Attention is enabled";
662-
LOG(INFO) << "Initializing conv cache with shape: [" << kv_cache_shape[2]
663-
<< "]";
664-
LOG(INFO) << "Initializing ssm cache with shape: [" << kv_cache_shape[3]
665-
<< "]";
666-
}
591+
kv_cache_shape.print_shapes();
667592

668593
// initialize block manager
669594
BlockManagerPool::Options options;
670-
options.num_blocks(kv_cache_cap.n_blocks)
595+
options.num_blocks(kv_cache_cap.n_blocks())
671596
.block_size(block_size)
672-
.host_num_blocks(kv_cache_cap.n_blocks * options_.host_blocks_factor())
597+
.host_num_blocks(kv_cache_cap.n_blocks() * options_.host_blocks_factor())
673598
.enable_linear_state(enable_gdn_attention)
674599
.enable_prefix_cache(
675600
FLAGS_enable_xtensor ? false : options_.enable_prefix_cache())
@@ -678,7 +603,7 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
678603
.enable_kvcache_store(options_.enable_kvcache_store())
679604
.enable_xtensor(FLAGS_enable_xtensor)
680605
.num_layers(args_.n_layers())
681-
.slot_size(kv_cache_cap.slot_size)
606+
.slot_size(kv_cache_cap.slot_size())
682607
.model_id(options_.model_id());
683608

684609
if (options_.host_blocks_factor() > 1.0 || options_.enable_kvcache_store()) {

0 commit comments

Comments
 (0)