Skip to content

Commit dec52b1

Browse files
authored
Merge pull request #292 from InfiniTensor/Issue#291
Issue/291: kv replica when tp_size > kv_head_num
2 parents 5e61074 + fcc1a82 commit dec52b1

4 files changed

Lines changed: 47 additions & 21 deletions

File tree

csrc/cache/kv_cache.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ StaticKVCache::StaticKVCache(
4949
: Cache(),
5050
k_dim_(k_dim),
5151
v_dim_(v_dim),
52-
num_rank_k_heads_(num_k_heads / rank_info.tp_size),
53-
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
5452
rank_batch_size_(config.max_batch_size()),
5553
cache_len_(config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()),
5654
rank_num_layers_(num_layers),
5755
dtype_(dtype) {
5856

57+
bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0);
58+
59+
num_rank_k_heads_ = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size);
60+
num_rank_v_heads_ = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size);
5961
// Allocate K cache
6062
k_caches_ = infinicore::Tensor::empty(
6163
{rank_num_layers_,
@@ -90,15 +92,20 @@ infinicore::Tensor StaticKVCache::create_layer_kv_cache(
9092
const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info();
9193

9294
size_t rank_batch_size = (config.max_batch_size());
93-
size_t num_rank_kv_heads = (num_k_heads / rank_info.tp_size);
9495
size_t kv_dim = k_dim;
96+
97+
bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0);
98+
99+
size_t num_rank_k_heads = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size);
100+
size_t num_rank_v_heads = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size);
101+
95102
size_t cache_len = (config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len());
96103

97104
// Allocate KV cache
98105
infinicore::Tensor kv_cache = infinicore::Tensor::empty(
99106
{2,
100107
rank_batch_size,
101-
num_rank_kv_heads,
108+
num_rank_k_heads,
102109
cache_len,
103110
kv_dim},
104111
dtype,
@@ -186,12 +193,15 @@ PagedKVCache::PagedKVCache(
186193
: Cache(),
187194
k_dim_(k_dim),
188195
v_dim_(v_dim),
189-
num_rank_k_heads_(num_k_heads / rank_info.tp_size),
190-
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
191196
rank_num_layers_(num_layers),
192197
dtype_(dtype),
193198
num_blocks_per_layer_(config.num_blocks()),
194199
block_size_(config.block_size()) {
200+
201+
bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0);
202+
203+
num_rank_k_heads_ = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size);
204+
num_rank_v_heads_ = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size);
195205
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
196206
k_caches_ = infinicore::Tensor::empty(
197207
{rank_num_layers_,
@@ -224,8 +234,11 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache(
224234

225235
const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info();
226236

227-
size_t num_rank_kv_heads(num_k_heads / rank_info.tp_size);
228237
size_t kv_dim = k_dim;
238+
bool is_kv_replica = (num_k_heads < rank_info.tp_size && num_v_heads < rank_info.tp_size && num_k_heads == num_v_heads && rank_info.tp_size % num_k_heads == 0);
239+
240+
size_t num_rank_k_heads = is_kv_replica ? 1 : (num_k_heads / rank_info.tp_size);
241+
size_t num_rank_v_heads = is_kv_replica ? 1 : (num_v_heads / rank_info.tp_size);
229242

230243
size_t num_blocks_per_layer = config.num_blocks();
231244
size_t block_size = config.block_size();
@@ -234,7 +247,7 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache(
234247
infinicore::Tensor kv_cache = infinicore::Tensor::empty(
235248
{2,
236249
num_blocks_per_layer,
237-
num_rank_kv_heads,
250+
num_rank_k_heads,
238251
block_size,
239252
kv_dim},
240253
dtype,

csrc/layers/attention/attention.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,9 @@ Attention::Attention(std::shared_ptr<infinilm::config::ModelConfig> model_config
2222
const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info();
2323
int tp_rank = infinilm::global_state::get_tensor_model_parallel_rank();
2424
int tp_size = infinilm::global_state::get_tensor_model_parallel_world_size();
25-
if ((total_num_kv_heads < tp_size) || (0 != (total_num_kv_heads % tp_size))) {
26-
throw std::runtime_error("infinilm::layers::attention::Attention: num_key_value_heads must be divisible by tp_size");
27-
}
2825

2926
num_attention_heads_ = total_num_heads / tp_size;
30-
num_key_value_heads_ = total_num_kv_heads / tp_size;
27+
num_key_value_heads_ = total_num_kv_heads < tp_size ? 1 : total_num_kv_heads / tp_size;
3128

3229
auto quant_scheme = model_config->get_quant_scheme();
3330
auto quantization_method = model_config->get_quantization_method();

csrc/layers/linear/fused_linear.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
9595
engine::distributed::RankInfo rank_info)
9696
: infinicore::nn::ColumnParallelLinear(
9797
hidden_size,
98-
num_q_head * q_dim + num_k_head * k_dim + num_v_head * v_dim,
98+
calculate_out_feature_size(num_q_head, q_dim, num_k_head, k_dim, num_v_head, v_dim, rank_info),
9999
quantization,
100100
(q_bias || k_bias || v_bias),
101101
dtype,
@@ -110,18 +110,16 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
110110
num_v_head_(num_v_head),
111111
q_bias_(q_bias),
112112
k_bias_(k_bias),
113-
v_bias_(v_bias) {
114-
if (num_q_head % tp_size_ != 0 || num_k_head % tp_size_ != 0 || num_v_head % tp_size_ != 0) {
115-
throw std::runtime_error("QKVParallelLinear: num_[q|k|v]_head must be divisible by tp_size");
116-
}
113+
v_bias_(v_bias),
114+
num_kv_head_replicas_(calculate_kv_replicas(num_k_head, rank_info.tp_size)) {
117115

118116
if ((q_bias_ != k_bias_) || (k_bias_ != v_bias_)) {
119117
throw std::runtime_error("q_bias, k_bias, v_bias must all match");
120118
}
121119

122120
q_out_size_ = num_q_head_ * q_dim_ / tp_size_;
123-
k_out_size_ = num_k_head_ * k_dim_ / tp_size_;
124-
v_out_size_ = num_v_head_ * v_dim_ / tp_size_;
121+
k_out_size_ = num_kv_head_replicas_ * num_k_head_ * k_dim_ / tp_size_;
122+
v_out_size_ = num_kv_head_replicas_ * num_v_head_ * v_dim_ / tp_size_;
125123
}
126124

127125
std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor>
@@ -144,13 +142,13 @@ infinicore::nn::Parameter QKVParallelLinear::get_q_weight() const {
144142
infinicore::nn::Parameter QKVParallelLinear::get_k_weight() const {
145143
return infinicore::nn::Parameter(
146144
weight_->narrow({{0, q_out_size_, k_out_size_}}),
147-
0, tp_rank_, tp_size_);
145+
0, tp_rank_, tp_size_, num_k_head_);
148146
}
149147

150148
infinicore::nn::Parameter QKVParallelLinear::get_v_weight() const {
151149
return infinicore::nn::Parameter(
152150
weight_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}),
153-
0, tp_rank_, tp_size_);
151+
0, tp_rank_, tp_size_, num_v_head_);
154152
}
155153

156154
infinicore::nn::Parameter QKVParallelLinear::get_q_weight_scale() const {

csrc/layers/linear/fused_linear.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,22 @@ class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear {
8181
bool has_k_bias() const;
8282
bool has_v_bias() const;
8383

84+
private:
85+
static size_t calculate_kv_replicas(size_t num_k_head, size_t tp_size) {
86+
if (num_k_head % tp_size == 0) {
87+
return 1;
88+
}
89+
if (tp_size % num_k_head == 0) {
90+
return (tp_size + num_k_head - 1) / num_k_head;
91+
}
92+
throw std::runtime_error("Invalid KV head configuration");
93+
}
94+
95+
static size_t
96+
calculate_out_feature_size(size_t num_q_head, size_t q_dim, size_t num_k_head, size_t k_dim, size_t num_v_head, size_t v_dim, engine::distributed::RankInfo rank_info) {
97+
return num_q_head * q_dim + num_k_head * k_dim * calculate_kv_replicas(num_k_head, rank_info.tp_size) + num_v_head * v_dim * calculate_kv_replicas(num_v_head, rank_info.tp_size);
98+
}
99+
84100
private:
85101
size_t q_dim_;
86102
size_t k_dim_;
@@ -94,6 +110,8 @@ class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear {
94110
size_t q_out_size_; // num_q_head * q_dim / tp_size
95111
size_t k_out_size_; // num_k_head * k_dim / tp_size
96112
size_t v_out_size_; // num_v_head * v_dim / tp_size
113+
114+
size_t num_kv_head_replicas_ = 1;
97115
};
98116

99117
class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear {

0 commit comments

Comments
 (0)