@@ -33,6 +33,7 @@ limitations under the License.
3333#include " common/metrics.h"
3434#include " common/options.h"
3535#include " framework/block/hierarchy_block_manager_pool.h"
36+ #include " framework/kv_cache/kv_cache_shape.h"
3637#include " framework/model/model_args.h"
3738#include " framework/model_loader.h"
3839#include " framework/xtensor/page_allocator.h"
@@ -373,7 +374,7 @@ int64_t LLMEngine::get_effective_xtensor_weight_size(
373374 return total_weight_size;
374375}
375376
376- Engine:: KVCacheCapacity LLMEngine::estimate_kv_cache_capacity () {
377+ KVCacheCapacity LLMEngine::estimate_kv_cache_capacity () {
377378 const int64_t max_cache_size = options_.max_cache_size ();
378379 const double max_memory_utilization = options_.max_memory_utilization ();
379380
@@ -426,16 +427,17 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
426427 }
427428 }
428429
429- Engine::KVCacheCapacity kv_cache_cap;
430- kv_cache_cap.cache_size_in_bytes = std::max (cache_size_in_bytes, int64_t (0 ));
431- CHECK_GT (kv_cache_cap.cache_size_in_bytes , 0 )
430+ KVCacheCapacity kv_cache_cap;
431+ kv_cache_cap.cache_size_in_bytes () =
432+ std::max (cache_size_in_bytes, int64_t (0 ));
433+ CHECK_GT (kv_cache_cap.cache_size_in_bytes (), 0 )
432434 << " Available kv cache size must be greater than 0" ;
433435 GAUGE_SET (total_kv_cache_size_in_kilobytes,
434- kv_cache_cap.cache_size_in_bytes / 1024 );
436+ kv_cache_cap.cache_size_in_bytes () / 1024 );
435437
436438 for (auto & device : options_.devices ()) {
437439 DeviceMonitor::get_instance ().set_total_kv_cache_memory (
438- device.index (), kv_cache_cap.cache_size_in_bytes );
440+ device.index (), kv_cache_cap.cache_size_in_bytes () );
439441 DeviceMonitor::get_instance ().set_total_activation_memory (device.index ());
440442 }
441443
@@ -484,7 +486,7 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
484486 // => per token: n_kv_heads floats for K + n_kv_heads for V.
485487 // MLA: key scale [num_blocks, 1, block_size] => one float per token.
486488 if (enable_kv_cache_quant) {
487- if (options_ .enable_mla ()) {
489+ if (args_ .enable_mla ()) {
488490 // MLA scale shape is [num_blocks, 1, block_size] -> one float per token
489491 scale_slot_size = sizeof (float );
490492 } else {
@@ -511,165 +513,88 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
511513 (args_.linear_conv_kernel_dim () - 1 );
512514 linear_slot_size = linear_ssm_slot_size + linear_conv_slot_size;
513515 }
514- kv_cache_cap.slot_size = slot_size;
515- kv_cache_cap.index_slot_size = index_slot_size;
516- kv_cache_cap.linear_slot_size = linear_slot_size;
517- kv_cache_cap.n_layers = args_.n_layers ();
516+ kv_cache_cap.slot_size () = slot_size;
517+ kv_cache_cap.index_slot_size () = index_slot_size;
518+ kv_cache_cap.linear_slot_size () = linear_slot_size;
519+ kv_cache_cap.n_layers () = args_.n_layers ();
520+ kv_cache_cap.block_size () = options_.block_size ();
518521#if !defined(USE_NPU)
519522 // this adoption is because the allocation of kv cache is based on
520523 // the number of layers, and the draft engine is using the same model as the
521524 // target engine.
522525 // so we need to override the number of layers for the draft engine.
523526 if (options_.is_draft_engine ()) {
524- kv_cache_cap.n_layers = args_.num_nextn_predict_layers ();
527+ kv_cache_cap.n_layers () = args_.num_nextn_predict_layers ();
525528 }
526529#endif
527530
528- kv_cache_cap.num_linear_state_blocks = FLAGS_max_seqs_per_batch + 2 ;
529- for (int64_t layer_id = 0 ; layer_id < kv_cache_cap.n_layers ; ++layer_id) {
531+ kv_cache_cap.num_linear_state_blocks () = FLAGS_max_seqs_per_batch + 2 ;
532+ for (int64_t layer_id = 0 ; layer_id < kv_cache_cap.n_layers () ; ++layer_id) {
530533 if (is_full_attention_layer (args_, layer_id)) {
531- ++kv_cache_cap.num_full_attention_layers ;
534+ ++kv_cache_cap.num_full_attention_layers () ;
532535 } else {
533- ++kv_cache_cap.num_linear_attention_layers ;
536+ ++kv_cache_cap.num_linear_attention_layers () ;
534537 }
535538 }
536539
537540 // compute kv cache n_blocks
538- const int32_t block_size = options_ .block_size ();
541+ const int64_t block_size = kv_cache_cap .block_size ();
539542 const int64_t block_size_in_bytes =
540543 block_size * (slot_size + index_slot_size + scale_slot_size);
541- kv_cache_cap.linear_cache_size_in_bytes =
542- kv_cache_cap.num_linear_attention_layers *
543- kv_cache_cap.num_linear_state_blocks * kv_cache_cap.linear_slot_size ;
544+ kv_cache_cap.linear_cache_size_in_bytes () =
545+ kv_cache_cap.num_linear_attention_layers () *
546+ kv_cache_cap.num_linear_state_blocks () * kv_cache_cap.linear_slot_size () ;
544547 const int64_t available_full_cache_size_in_bytes =
545- kv_cache_cap.cache_size_in_bytes -
546- kv_cache_cap.linear_cache_size_in_bytes ;
547- if (kv_cache_cap.linear_slot_size > 0 ) {
548- CHECK_GT (kv_cache_cap.cache_size_in_bytes ,
549- kv_cache_cap.linear_cache_size_in_bytes )
548+ kv_cache_cap.cache_size_in_bytes () -
549+ kv_cache_cap.linear_cache_size_in_bytes () ;
550+ if (kv_cache_cap.linear_slot_size () > 0 ) {
551+ CHECK_GT (kv_cache_cap.cache_size_in_bytes () ,
552+ kv_cache_cap.linear_cache_size_in_bytes () )
550553 << " failed to reserve linear state cache for linear-attention layers: "
551554 << " max_seqs_per_batch (" << FLAGS_max_seqs_per_batch
552555 << " ) is too large. Please reduce max_seqs_per_batch to less than "
553- << kv_cache_cap.cache_size_in_bytes /
554- (kv_cache_cap.num_linear_attention_layers *
555- kv_cache_cap.linear_slot_size ) -
556+ << kv_cache_cap.cache_size_in_bytes () /
557+ (kv_cache_cap.num_linear_attention_layers () *
558+ kv_cache_cap.linear_slot_size () ) -
556559 2 ;
557560 }
558561 CHECK_GT (available_full_cache_size_in_bytes, 0 )
559562 << " no memory left for full-attention kv cache after reserving linear "
560563 " state cache" ;
561564 const int64_t full_attention_layers =
562- std::max<int64_t >(kv_cache_cap.num_full_attention_layers , 1 );
563- kv_cache_cap.n_blocks = available_full_cache_size_in_bytes /
564- (full_attention_layers * block_size_in_bytes);
565- CHECK_GT (kv_cache_cap.n_blocks , 0 ) << " no n_blocks for kv cache" ;
565+ std::max<int64_t >(kv_cache_cap.num_full_attention_layers () , 1 );
566+ kv_cache_cap.n_blocks () = available_full_cache_size_in_bytes /
567+ (full_attention_layers * block_size_in_bytes);
568+ CHECK_GT (kv_cache_cap.n_blocks () , 0 ) << " no n_blocks for kv cache" ;
566569 return kv_cache_cap;
567570}
568571
569- bool LLMEngine::allocate_kv_cache (const Engine:: KVCacheCapacity& kv_cache_cap) {
572+ bool LLMEngine::allocate_kv_cache (const KVCacheCapacity& kv_cache_cap) {
570573 LOG (INFO) << " kv cache capacity: "
571- << readable_size (kv_cache_cap.cache_size_in_bytes )
572- << " , blocks: " << kv_cache_cap.n_blocks
573- << " , slot_size: " << kv_cache_cap.slot_size
574- << " , linear_slot_size: " << kv_cache_cap.linear_slot_size
575- << " , linear_blocks: " << kv_cache_cap.num_linear_state_blocks
574+ << readable_size (kv_cache_cap.cache_size_in_bytes () )
575+ << " , blocks: " << kv_cache_cap.n_blocks ()
576+ << " , slot_size: " << kv_cache_cap.slot_size ()
577+ << " , linear_slot_size: " << kv_cache_cap.linear_slot_size ()
578+ << " , linear_blocks: " << kv_cache_cap.num_linear_state_blocks ()
576579 << " , reserved_linear_bytes: "
577- << readable_size (kv_cache_cap.linear_cache_size_in_bytes )
578- << " , n_layers: " << kv_cache_cap.n_layers
580+ << readable_size (kv_cache_cap.linear_cache_size_in_bytes () )
581+ << " , n_layers: " << kv_cache_cap.n_layers ()
579582 << " , kv_cache_dtype: " << options_.kv_cache_dtype ();
580583
581- CHECK_GT (kv_cache_cap.n_blocks , 0 ) << " no memory for kv cache" ;
582- const int32_t block_size = options_.block_size ();
583- bool enable_lighting_indexer = args_.index_n_heads () > 0 ;
584- bool enable_gdn_attention = has_linear_attention_layers (args_);
584+ CHECK_GT (kv_cache_cap.n_blocks (), 0 ) << " no memory for kv cache" ;
585+ const int32_t block_size = static_cast <int32_t >(kv_cache_cap.block_size ());
586+ const bool enable_gdn_attention = has_linear_attention_layers (args_);
585587
586588 // init kv cache for each worker
587- std::vector<std::vector<int64_t >> kv_cache_shape;
588- kv_cache_shape.reserve (2 );
589- if (options_.enable_mla ()) {
590- #if defined(USE_NPU)
591- if (args_.model_type () == " deepseek_v3" && FLAGS_enable_prefix_cache) {
592- kv_cache_shape.emplace_back (
593- std::vector<int64_t >{kv_cache_cap.n_blocks ,
594- (args_.kv_lora_rank () + 15 ) / 16 ,
595- block_size,
596- 16 });
597- kv_cache_shape.emplace_back (
598- std::vector<int64_t >{kv_cache_cap.n_blocks ,
599- (args_.qk_rope_head_dim () + 15 ) / 16 ,
600- block_size,
601- 16 });
602- } else {
603- kv_cache_shape.emplace_back (std::vector<int64_t >{
604- kv_cache_cap.n_blocks , block_size, 1 , args_.kv_lora_rank ()});
605- kv_cache_shape.emplace_back (std::vector<int64_t >{
606- kv_cache_cap.n_blocks , block_size, 1 , args_.qk_rope_head_dim ()});
607- }
608- #else
609- kv_cache_shape.emplace_back (std::vector<int64_t >{
610- kv_cache_cap.n_blocks , block_size, 1 , args_.kv_lora_rank ()});
611- kv_cache_shape.emplace_back (std::vector<int64_t >{
612- kv_cache_cap.n_blocks , block_size, 1 , args_.qk_rope_head_dim ()});
613- #endif
614- } else {
615- kv_cache_shape.emplace_back (std::vector<int64_t >{
616- kv_cache_cap.n_blocks , block_size, n_local_kv_heads_, head_dim_});
617- kv_cache_shape.emplace_back (std::vector<int64_t >{
618- kv_cache_cap.n_blocks , block_size, n_local_kv_heads_, head_dim_});
619- }
620- if (enable_lighting_indexer) {
621- kv_cache_shape.emplace_back (std::vector<int64_t >{
622- kv_cache_cap.n_blocks , block_size, 1 , args_.index_head_dim ()});
623- }
624- if (enable_gdn_attention) {
625- kv_cache_shape.emplace_back (std::vector<int64_t >{
626- kv_cache_cap.num_linear_state_blocks ,
627- args_.linear_conv_kernel_dim () - 1 ,
628- args_.linear_key_head_dim () * n_local_linear_k_heads_ * 2 +
629- args_.linear_key_head_dim () * n_local_linear_v_heads_});
630- kv_cache_shape.emplace_back (
631- std::vector<int64_t >{kv_cache_cap.num_linear_state_blocks ,
632- n_local_linear_v_heads_,
633- args_.linear_key_head_dim (),
634- args_.linear_value_head_dim ()});
635- }
636- #if defined(USE_MLU)
637- // transpose kv_cache layout for mlu
638- // default layout: [n_blocks, block_size, n_head, head_dim]
639- // => mlu layout: [n_blocks, n_head, block_size, head_dim]
640- for (auto & shape : kv_cache_shape) {
641- std::swap (shape[1 ], shape[2 ]);
642- }
643- if (options_.enable_mla ()) {
644- kv_cache_shape[0 ][3 ] = args_.kv_lora_rank () + args_.qk_rope_head_dim ();
645- kv_cache_shape[1 ] = std::vector<int64_t >{};
646- }
647- #endif
589+ const KVCacheShape kv_cache_shape (kv_cache_cap, args_, dp_local_tp_size_);
648590
649- #if defined(USE_ILU)
650- for (auto & shape : kv_cache_shape) {
651- std::swap (shape[1 ], shape[2 ]);
652- }
653- #endif
654- LOG (INFO) << " Initializing k cache with shape: [" << kv_cache_shape[0 ] << " ]" ;
655- LOG (INFO) << " Initializing v cache with shape: [" << kv_cache_shape[1 ] << " ]" ;
656- if (enable_lighting_indexer) {
657- LOG (INFO) << " Initializing indexer cache with shape: [" << kv_cache_shape[2 ]
658- << " ]" ;
659- }
660- if (enable_gdn_attention) {
661- LOG (INFO) << " GND Attention is enabled" ;
662- LOG (INFO) << " Initializing conv cache with shape: [" << kv_cache_shape[2 ]
663- << " ]" ;
664- LOG (INFO) << " Initializing ssm cache with shape: [" << kv_cache_shape[3 ]
665- << " ]" ;
666- }
591+ kv_cache_shape.print_shapes ();
667592
668593 // initialize block manager
669594 BlockManagerPool::Options options;
670- options.num_blocks (kv_cache_cap.n_blocks )
595+ options.num_blocks (kv_cache_cap.n_blocks () )
671596 .block_size (block_size)
672- .host_num_blocks (kv_cache_cap.n_blocks * options_.host_blocks_factor ())
597+ .host_num_blocks (kv_cache_cap.n_blocks () * options_.host_blocks_factor ())
673598 .enable_linear_state (enable_gdn_attention)
674599 .enable_prefix_cache (
675600 FLAGS_enable_xtensor ? false : options_.enable_prefix_cache ())
@@ -678,7 +603,7 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
678603 .enable_kvcache_store (options_.enable_kvcache_store ())
679604 .enable_xtensor (FLAGS_enable_xtensor)
680605 .num_layers (args_.n_layers ())
681- .slot_size (kv_cache_cap.slot_size )
606+ .slot_size (kv_cache_cap.slot_size () )
682607 .model_id (options_.model_id ());
683608
684609 if (options_.host_blocks_factor () > 1.0 || options_.enable_kvcache_store ()) {
0 commit comments