11#include " attention.hpp"
2+ #include " ../../global_state/global_state.hpp"
23#include " ../../utils.hpp"
34#include " ../rotary_embedding/rotary_embedding.hpp"
5+ #include < string>
46
57namespace infinilm ::layers::attention {
68
79Attention::Attention (std::shared_ptr<infinilm::config::ModelConfig> model_config,
810 size_t layer_idx,
9- const infinicore::Device &device) {
11+ const infinicore::Device &device)
12+ : device_(device),
13+ dtype_ (model_config->get_dtype ()) {
1014 layer_idx_ = layer_idx;
1115 hidden_size_ = model_config->get <size_t >(" hidden_size" );
1216 head_dim_ = model_config->get <size_t >(" head_dim" );
1317
14- const auto &dtype{model_config->get_dtype ()};
1518 size_t total_num_heads = model_config->get <size_t >(" num_attention_heads" );
1619 size_t total_num_kv_heads = model_config->get <size_t >(" num_key_value_heads" );
1720 bool use_bias = model_config->get_or <bool >(" attention_bias" , true );
@@ -31,18 +34,21 @@ Attention::Attention(std::shared_ptr<infinilm::config::ModelConfig> model_config
3134 qkv_proj_ = std::make_shared<layers::linear::QKVParallelLinear>(
3235 hidden_size_, head_dim_, total_num_heads, total_num_kv_heads,
3336 " q_proj" , " k_proj" , " v_proj" , register_fn,
34- quantization_method, use_bias, dtype, device , rank_info);
37+ quantization_method, use_bias, dtype_, device_ , rank_info);
3538 o_proj_ = this ->register_module <layers::linear::RowParallelLinear>(
3639 " o_proj" , total_num_heads * head_dim_, hidden_size_, quantization_method,
37- use_output_bias, dtype, device , tp_rank, tp_size, rank_info.comm );
40+ use_output_bias, dtype_, device_ , tp_rank, tp_size, rank_info.comm );
3841
39- rotary_emb_ = infinilm::layers::rotary_embedding::get_rope (model_config, device );
42+ rotary_emb_ = infinilm::layers::rotary_embedding::get_rope (model_config, device_ );
4043
4144 float scaling = 1 .0f / std::sqrt (static_cast <float >(head_dim_));
4245 attn_ = std::make_shared<AttentionLayer>(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_,
43- kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_);
46+ kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device_ );
4447
45- init_kv_cache_quant_params (register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_);
48+ init_kv_cache_quant_params (register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_);
49+
50+ rank_qkv_output_size_ = qkv_proj_->out_features () / static_cast <size_t >(tp_size);
51+ this ->_initialize_preallocated_workspace ();
4652}
4753
4854infinicore::Tensor Attention::forward (const infinicore::Tensor &positions,
@@ -62,7 +68,8 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position
6268 size_t seq_len = shape[1 ];
6369
6470 // 1. Project Q, K, V
65- auto [q, k, v] = qkv_proj_->forward_split (hidden_states_mutable);
71+ auto qkv_output = max_qkv_output_->narrow ({{0 , 0 , batch_size * seq_len}})->view ({batch_size, seq_len, rank_qkv_output_size_});
72+ auto [q, k, v] = qkv_proj_->forward_split_ (qkv_output, hidden_states_mutable);
6673
6774 // 2. Reshape for multi-head attention
6875 auto q_reshaped = q->view ({batch_size, seq_len, num_attention_heads_, head_dim_});
@@ -90,8 +97,9 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position
9097 auto attn_output = attn_->forward (q_rope, k_reshaped, v_reshaped);
9198
9299 // 7. Project output
93- auto output = o_proj_->forward (attn_output);
94- return output;
100+ auto o_output = max_o_output_->narrow ({{0 , 0 , batch_size * seq_len}})->view ({batch_size, seq_len, hidden_size_});
101+ o_proj_->forward_ (o_output, attn_output);
102+ return o_output;
95103}
96104
97105infinicore::Tensor Attention::forward_paged_ (const infinicore::Tensor &position_ids,
@@ -106,7 +114,8 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_
106114 ASSERT_EQ (batch_size, 1 );
107115
108116 // 1. Project Q, K, V
109- auto [q, k, v] = qkv_proj_->forward_split (hidden_states_mutable);
117+ auto qkv_output = max_qkv_output_->narrow ({{0 , 0 , seq_len}})->view ({1 , seq_len, rank_qkv_output_size_});
118+ auto [q, k, v] = qkv_proj_->forward_split_ (qkv_output, hidden_states_mutable);
110119
111120 // 2. Reshape for multi-head attention
112121 auto q_reshaped = q->view ({seq_len, num_attention_heads_, head_dim_});
@@ -133,8 +142,35 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_
133142 auto attn_output = attn_->forward (q_reshaped, k_reshaped, v_reshaped);
134143
135144 // 6. Project output
136- auto output = o_proj_->forward (attn_output);
137- return output;
145+ auto o_output = max_o_output_->narrow ({{0 , 0 , seq_len}})->view ({1 , seq_len, hidden_size_});
146+ o_proj_->forward_ (o_output, attn_output);
147+ return o_output;
148+ }
149+
150+ void Attention::_initialize_preallocated_workspace () {
151+ const auto &infinilm_config = infinilm::global_state::get_infinilm_config ();
152+ auto &preallocated_workspace = infinilm::global_state::get_forward_context ().preallocated_workspace ;
153+ const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens ;
154+
155+ const std::string attention_cache_key = std::string (" Attention_max_num_batched_tokens_" )
156+ + std::to_string (max_num_batched_tokens) + " _rank_qkv_output_size_"
157+ + std::to_string (rank_qkv_output_size_) + " _hidden_size_"
158+ + std::to_string (hidden_size_) + " _dtype_"
159+ + infinicore::toString (dtype_) + " _device_"
160+ + device_.toString ();
161+
162+ size_t max_output_size = std::max (rank_qkv_output_size_, hidden_size_);
163+ if (preallocated_workspace.find (attention_cache_key) == preallocated_workspace.end ()) {
164+ auto attention_buffer = infinicore::Tensor::empty ({max_num_batched_tokens * max_output_size}, dtype_, device_);
165+ preallocated_workspace[attention_cache_key] = attention_buffer;
166+ }
167+
168+ auto attention_buffer = preallocated_workspace.at (attention_cache_key);
169+ const auto attention_buffer_shape = attention_buffer->shape ();
170+ ASSERT (attention_buffer_shape[0 ] == max_num_batched_tokens * max_output_size);
171+
172+ max_qkv_output_ = attention_buffer->narrow ({{0 , 0 , max_num_batched_tokens * rank_qkv_output_size_}})->view ({max_num_batched_tokens, rank_qkv_output_size_});
173+ max_o_output_ = attention_buffer->narrow ({{0 , 0 , max_num_batched_tokens * hidden_size_}})->view ({max_num_batched_tokens, hidden_size_});
138174}
139175
140176void init_kv_cache_quant_params (std::function<void (const std::string &, infinicore::nn::Parameter)> register_fn,
0 commit comments