@@ -49,13 +49,15 @@ StaticKVCache::StaticKVCache(
4949 : Cache(),
5050 k_dim_(k_dim),
5151 v_dim_(v_dim),
52- num_rank_k_heads_(num_k_heads / rank_info.tp_size),
53- num_rank_v_heads_(num_v_heads / rank_info.tp_size),
5452 rank_batch_size_(config.max_batch_size()),
5553 cache_len_(config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()),
5654 rank_num_layers_(num_layers),
5755 dtype_(dtype) {
5856
57+ bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0 );
58+
59+ num_rank_k_heads_ = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size );
60+ num_rank_v_heads_ = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size );
5961 // Allocate K cache
6062 k_caches_ = infinicore::Tensor::empty (
6163 {rank_num_layers_,
@@ -90,15 +92,20 @@ infinicore::Tensor StaticKVCache::create_layer_kv_cache(
9092 const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info ();
9193
9294 size_t rank_batch_size = (config.max_batch_size ());
93- size_t num_rank_kv_heads = (num_k_heads / rank_info.tp_size );
9495 size_t kv_dim = k_dim;
96+
97+ bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0 );
98+
99+ size_t num_rank_k_heads = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size );
100+ size_t num_rank_v_heads = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size );
101+
95102 size_t cache_len = (config.max_cache_len () == std::numeric_limits<infinicore::Size>::max () || config.max_cache_len () == 0 ? max_positional_embedding : config.max_cache_len ());
96103
97104 // Allocate KV cache
98105 infinicore::Tensor kv_cache = infinicore::Tensor::empty (
99106 {2 ,
100107 rank_batch_size,
101- num_rank_kv_heads ,
108+ num_rank_k_heads ,
102109 cache_len,
103110 kv_dim},
104111 dtype,
@@ -186,12 +193,15 @@ PagedKVCache::PagedKVCache(
186193 : Cache(),
187194 k_dim_(k_dim),
188195 v_dim_(v_dim),
189- num_rank_k_heads_(num_k_heads / rank_info.tp_size),
190- num_rank_v_heads_(num_v_heads / rank_info.tp_size),
191196 rank_num_layers_(num_layers),
192197 dtype_(dtype),
193198 num_blocks_per_layer_(config.num_blocks()),
194199 block_size_(config.block_size()) {
200+
201+ bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0 );
202+
203+ num_rank_k_heads_ = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size );
204+ num_rank_v_heads_ = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size );
195205 // [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
196206 k_caches_ = infinicore::Tensor::empty (
197207 {rank_num_layers_,
@@ -224,8 +234,11 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache(
224234
225235 const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info ();
226236
227- size_t num_rank_kv_heads (num_k_heads / rank_info.tp_size );
228237 size_t kv_dim = k_dim;
238+ bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0 );
239+
240+ size_t num_rank_k_heads = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size );
241+ size_t num_rank_v_heads = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size );
229242
230243 size_t num_blocks_per_layer = config.num_blocks ();
231244 size_t block_size = config.block_size ();
@@ -234,7 +247,7 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache(
234247 infinicore::Tensor kv_cache = infinicore::Tensor::empty (
235248 {2 ,
236249 num_blocks_per_layer,
237- num_rank_kv_heads ,
250+ num_rank_k_heads ,
238251 block_size,
239252 kv_dim},
240253 dtype,
0 commit comments