33#include " ../../utils.hpp"
44#include " ../rotary_embedding/rotary_embedding.hpp"
55#include < string>
6+ #include < tuple>
67
78namespace infinilm ::layers::attention {
89
@@ -48,7 +49,10 @@ Attention::Attention(std::shared_ptr<infinilm::config::ModelConfig> model_config
4849 init_kv_cache_quant_params (register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_);
4950
5051 rank_qkv_output_size_ = qkv_proj_->out_features () / static_cast <size_t >(tp_size);
51- this ->_initialize_preallocated_workspace ();
52+ enable_workspace_manager_ = infinilm::global_state::get_infinilm_config ().enable_workspace_manager ;
53+ if (enable_workspace_manager_) {
54+ this ->_register_inference_buffer ();
55+ }
5256}
5357
5458infinicore::Tensor Attention::forward (const infinicore::Tensor &positions,
@@ -68,8 +72,15 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position
6872 size_t seq_len = shape[1 ];
6973
7074 // 1. Project Q, K, V
71- auto qkv_output = max_qkv_output_->narrow ({{0 , 0 , batch_size * seq_len}})->view ({batch_size, seq_len, rank_qkv_output_size_});
72- auto [q, k, v] = qkv_proj_->forward_split_ (qkv_output, hidden_states_mutable);
75+ infinicore::Tensor q;
76+ infinicore::Tensor k;
77+ infinicore::Tensor v;
78+ if (enable_workspace_manager_) {
79+ auto qkv_output = max_qkv_output_->narrow ({{0 , 0 , batch_size * seq_len}})->view ({batch_size, seq_len, rank_qkv_output_size_});
80+ std::tie (q, k, v) = qkv_proj_->forward_split_ (qkv_output, hidden_states_mutable);
81+ } else {
82+ std::tie (q, k, v) = qkv_proj_->forward_split (hidden_states_mutable);
83+ }
7384
7485 // 2. Reshape for multi-head attention
7586 auto q_reshaped = q->view ({batch_size, seq_len, num_attention_heads_, head_dim_});
@@ -96,10 +107,13 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position
96107 // 5. Attn Backend calculate
97108 auto attn_output = attn_->forward (q_rope, k_reshaped, v_reshaped);
98109
99- // 7. Project output
100- auto o_output = max_o_output_->narrow ({{0 , 0 , batch_size * seq_len}})->view ({batch_size, seq_len, hidden_size_});
101- o_proj_->forward_ (o_output, attn_output);
102- return o_output;
110+ // 6. Project output
111+ if (enable_workspace_manager_) {
112+ auto o_output = max_o_output_->narrow ({{0 , 0 , batch_size * seq_len}})->view ({batch_size, seq_len, hidden_size_});
113+ o_proj_->forward_ (o_output, attn_output);
114+ return o_output;
115+ }
116+ return o_proj_->forward (attn_output);
103117}
104118
105119infinicore::Tensor Attention::forward_paged_ (const infinicore::Tensor &position_ids,
@@ -114,8 +128,15 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_
114128 ASSERT_EQ (batch_size, 1 );
115129
116130 // 1. Project Q, K, V
117- auto qkv_output = max_qkv_output_->narrow ({{0 , 0 , seq_len}})->view ({1 , seq_len, rank_qkv_output_size_});
118- auto [q, k, v] = qkv_proj_->forward_split_ (qkv_output, hidden_states_mutable);
131+ infinicore::Tensor q;
132+ infinicore::Tensor k;
133+ infinicore::Tensor v;
134+ if (enable_workspace_manager_) {
135+ auto qkv_output = max_qkv_output_->narrow ({{0 , 0 , seq_len}})->view ({1 , seq_len, rank_qkv_output_size_});
136+ std::tie (q, k, v) = qkv_proj_->forward_split_ (qkv_output, hidden_states_mutable);
137+ } else {
138+ std::tie (q, k, v) = qkv_proj_->forward_split (hidden_states_mutable);
139+ }
119140
120141 // 2. Reshape for multi-head attention
121142 auto q_reshaped = q->view ({seq_len, num_attention_heads_, head_dim_});
@@ -142,35 +163,44 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_
142163 auto attn_output = attn_->forward (q_reshaped, k_reshaped, v_reshaped);
143164
144165 // 6. Project output
145- auto o_output = max_o_output_->narrow ({{0 , 0 , seq_len}})->view ({1 , seq_len, hidden_size_});
146- o_proj_->forward_ (o_output, attn_output);
147- return o_output;
166+ if (enable_workspace_manager_) {
167+ auto o_output = max_o_output_->narrow ({{0 , 0 , seq_len}})->view ({1 , seq_len, hidden_size_});
168+ o_proj_->forward_ (o_output, attn_output);
169+ return o_output;
170+ }
171+ return o_proj_->forward (attn_output);
148172}
149173
150- void Attention::_initialize_preallocated_workspace () {
174+ void Attention::_register_inference_buffer () {
151175 const auto &infinilm_config = infinilm::global_state::get_infinilm_config ();
152- auto &preallocated_workspace = infinilm::global_state::get_forward_context ().preallocated_workspace ;
176+ auto &workspace_manager = infinilm::global_state::get_forward_context ().workspace_manager ;
153177 const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens ;
154178
179+ ASSERT (rank_qkv_output_size_ > 0 && hidden_size_ > 0 );
180+
155181 const std::string attention_cache_key = std::string (" Attention_max_num_batched_tokens_" )
156182 + std::to_string (max_num_batched_tokens) + " _rank_qkv_output_size_"
157183 + std::to_string (rank_qkv_output_size_) + " _hidden_size_"
158184 + std::to_string (hidden_size_) + " _dtype_"
159185 + infinicore::toString (dtype_) + " _device_"
160186 + device_.toString ();
161187
162- size_t max_output_size = std::max (rank_qkv_output_size_, hidden_size_);
163- if (preallocated_workspace.find (attention_cache_key) == preallocated_workspace.end ()) {
164- auto attention_buffer = infinicore::Tensor::empty ({max_num_batched_tokens * max_output_size}, dtype_, device_);
165- preallocated_workspace[attention_cache_key] = attention_buffer;
166- }
167-
168- auto attention_buffer = preallocated_workspace.at (attention_cache_key);
169- const auto attention_buffer_shape = attention_buffer->shape ();
170- ASSERT (attention_buffer_shape[0 ] == max_num_batched_tokens * max_output_size);
171-
172- max_qkv_output_ = attention_buffer->narrow ({{0 , 0 , max_num_batched_tokens * rank_qkv_output_size_}})->view ({max_num_batched_tokens, rank_qkv_output_size_});
173- max_o_output_ = attention_buffer->narrow ({{0 , 0 , max_num_batched_tokens * hidden_size_}})->view ({max_num_batched_tokens, hidden_size_});
188+ const size_t max_output_size = std::max (rank_qkv_output_size_, hidden_size_);
189+ const infinicore::Shape attention_buffer_shape = {max_num_batched_tokens * max_output_size};
190+ workspace_manager.register_buffer (
191+ attention_cache_key,
192+ attention_buffer_shape,
193+ dtype_,
194+ device_,
195+ [this , max_num_batched_tokens, max_output_size](const infinicore::Tensor &attention_buffer) {
196+ const auto attention_buffer_shape = attention_buffer->shape ();
197+ ASSERT (attention_buffer_shape[0 ] == max_num_batched_tokens * max_output_size);
198+
199+ max_qkv_output_ = attention_buffer->narrow ({{0 , 0 , max_num_batched_tokens * rank_qkv_output_size_}})
200+ ->view ({max_num_batched_tokens, rank_qkv_output_size_});
201+ max_o_output_ = attention_buffer->narrow ({{0 , 0 , max_num_batched_tokens * hidden_size_}})
202+ ->view ({max_num_batched_tokens, hidden_size_});
203+ });
174204}
175205
176206void init_kv_cache_quant_params (std::function<void (const std::string &, infinicore::nn::Parameter)> register_fn,
0 commit comments