33#include " ../../utils.hpp"
44#include " ../rotary_embedding/rotary_embedding.hpp"
55#include < string>
6+ #include < tuple>
67
78namespace infinilm ::layers::attention {
89
@@ -48,7 +49,10 @@ Attention::Attention(std::shared_ptr<infinilm::config::ModelConfig> model_config
4849 init_kv_cache_quant_params (register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_);
4950
5051 rank_qkv_output_size_ = qkv_proj_->out_features () / static_cast <size_t >(tp_size);
51- this ->_initialize_preallocated_workspace ();
52+ enable_workspace_manager_ = infinilm::global_state::get_infinilm_config ().enable_workspace_manager ;
53+ if (enable_workspace_manager_) {
54+ this ->_register_inference_buffer ();
55+ }
5256}
5357
5458infinicore::Tensor Attention::forward (const infinicore::Tensor &positions,
@@ -68,8 +72,13 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position
6872 size_t seq_len = shape[1 ];
6973
7074 // 1. Project Q, K, V
71- auto qkv_output = max_qkv_output_->narrow ({{0 , 0 , batch_size * seq_len}})->view ({batch_size, seq_len, rank_qkv_output_size_});
72- auto [q, k, v] = qkv_proj_->forward_split_ (qkv_output, hidden_states_mutable);
75+ infinicore::Tensor q, k, v;
76+ if (enable_workspace_manager_) {
77+ auto qkv_output = max_qkv_output_->narrow ({{0 , 0 , batch_size * seq_len}})->view ({batch_size, seq_len, rank_qkv_output_size_});
78+ std::tie (q, k, v) = qkv_proj_->forward_split_ (qkv_output, hidden_states_mutable);
79+ } else {
80+ std::tie (q, k, v) = qkv_proj_->forward_split (hidden_states_mutable);
81+ }
7382
7483 // 2. Reshape for multi-head attention
7584 auto q_reshaped = q->view ({batch_size, seq_len, num_attention_heads_, head_dim_});
@@ -96,10 +105,13 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position
96105 // 5. Attn Backend calculate
97106 auto attn_output = attn_->forward (q_rope, k_reshaped, v_reshaped);
98107
99- // 7. Project output
100- auto o_output = max_o_output_->narrow ({{0 , 0 , batch_size * seq_len}})->view ({batch_size, seq_len, hidden_size_});
101- o_proj_->forward_ (o_output, attn_output);
102- return o_output;
108+ // 6. Project output
109+ if (enable_workspace_manager_) {
110+ auto o_output = max_o_output_->narrow ({{0 , 0 , batch_size * seq_len}})->view ({batch_size, seq_len, hidden_size_});
111+ o_proj_->forward_ (o_output, attn_output);
112+ return o_output;
113+ }
114+ return o_proj_->forward (attn_output);
103115}
104116
105117infinicore::Tensor Attention::forward_paged_ (const infinicore::Tensor &position_ids,
@@ -114,8 +126,13 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_
114126 ASSERT_EQ (batch_size, 1 );
115127
116128 // 1. Project Q, K, V
117- auto qkv_output = max_qkv_output_->narrow ({{0 , 0 , seq_len}})->view ({1 , seq_len, rank_qkv_output_size_});
118- auto [q, k, v] = qkv_proj_->forward_split_ (qkv_output, hidden_states_mutable);
129+ infinicore::Tensor q, k, v;
130+ if (enable_workspace_manager_) {
131+ auto qkv_output = max_qkv_output_->narrow ({{0 , 0 , seq_len}})->view ({1 , seq_len, rank_qkv_output_size_});
132+ std::tie (q, k, v) = qkv_proj_->forward_split_ (qkv_output, hidden_states_mutable);
133+ } else {
134+ std::tie (q, k, v) = qkv_proj_->forward_split (hidden_states_mutable);
135+ }
119136
120137 // 2. Reshape for multi-head attention
121138 auto q_reshaped = q->view ({seq_len, num_attention_heads_, head_dim_});
@@ -142,35 +159,44 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_
142159 auto attn_output = attn_->forward (q_reshaped, k_reshaped, v_reshaped);
143160
144161 // 6. Project output
145- auto o_output = max_o_output_->narrow ({{0 , 0 , seq_len}})->view ({1 , seq_len, hidden_size_});
146- o_proj_->forward_ (o_output, attn_output);
147- return o_output;
162+ if (enable_workspace_manager_) {
163+ auto o_output = max_o_output_->narrow ({{0 , 0 , seq_len}})->view ({1 , seq_len, hidden_size_});
164+ o_proj_->forward_ (o_output, attn_output);
165+ return o_output;
166+ }
167+ return o_proj_->forward (attn_output);
148168}
149169
150- void Attention::_initialize_preallocated_workspace () {
170+ void Attention::_register_inference_buffer () {
151171 const auto &infinilm_config = infinilm::global_state::get_infinilm_config ();
152- auto &preallocated_workspace = infinilm::global_state::get_forward_context ().preallocated_workspace ;
172+ auto &workspace_manager = infinilm::global_state::get_forward_context ().workspace_manager ;
153173 const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens ;
154174
175+ ASSERT (rank_qkv_output_size_ > 0 && hidden_size_ > 0 );
176+
155177 const std::string attention_cache_key = std::string (" Attention_max_num_batched_tokens_" )
156178 + std::to_string (max_num_batched_tokens) + " _rank_qkv_output_size_"
157179 + std::to_string (rank_qkv_output_size_) + " _hidden_size_"
158180 + std::to_string (hidden_size_) + " _dtype_"
159181 + infinicore::toString (dtype_) + " _device_"
160182 + device_.toString ();
161183
162- size_t max_output_size = std::max (rank_qkv_output_size_, hidden_size_);
163- if (preallocated_workspace.find (attention_cache_key) == preallocated_workspace.end ()) {
164- auto attention_buffer = infinicore::Tensor::empty ({max_num_batched_tokens * max_output_size}, dtype_, device_);
165- preallocated_workspace[attention_cache_key] = attention_buffer;
166- }
167-
168- auto attention_buffer = preallocated_workspace.at (attention_cache_key);
169- const auto attention_buffer_shape = attention_buffer->shape ();
170- ASSERT (attention_buffer_shape[0 ] == max_num_batched_tokens * max_output_size);
171-
172- max_qkv_output_ = attention_buffer->narrow ({{0 , 0 , max_num_batched_tokens * rank_qkv_output_size_}})->view ({max_num_batched_tokens, rank_qkv_output_size_});
173- max_o_output_ = attention_buffer->narrow ({{0 , 0 , max_num_batched_tokens * hidden_size_}})->view ({max_num_batched_tokens, hidden_size_});
184+ const size_t max_output_size = std::max (rank_qkv_output_size_, hidden_size_);
185+ const infinicore::Shape attention_buffer_shape = {max_num_batched_tokens * max_output_size};
186+ workspace_manager.register_buffer (
187+ attention_cache_key,
188+ attention_buffer_shape,
189+ dtype_,
190+ device_,
191+ [this , max_num_batched_tokens, max_output_size](const infinicore::Tensor &attention_buffer) {
192+ const auto attention_buffer_shape = attention_buffer->shape ();
193+ ASSERT (attention_buffer_shape[0 ] == max_num_batched_tokens * max_output_size);
194+
195+ max_qkv_output_ = attention_buffer->narrow ({{0 , 0 , max_num_batched_tokens * rank_qkv_output_size_}})
196+ ->view ({max_num_batched_tokens, rank_qkv_output_size_});
197+ max_o_output_ = attention_buffer->narrow ({{0 , 0 , max_num_batched_tokens * hidden_size_}})
198+ ->view ({max_num_batched_tokens, hidden_size_});
199+ });
174200}
175201
176202void init_kv_cache_quant_params (std::function<void (const std::string &, infinicore::nn::Parameter)> register_fn,
0 commit comments