Skip to content

Commit c811d4c

Browse files
author
wangpengcheng
committed
issue/407 - Release the GIL; preallocated workspace
1 parent 89c0a16 commit c811d4c

47 files changed

Lines changed: 731 additions & 110 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

csrc/engine/infer_engine.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@ InferEngine::InferEngine(
1414
const cache::CacheConfig *cache_config,
1515
bool enable_graph_compiling,
1616
backends::AttentionBackend attention_backend,
17-
std::optional<infinicore::DataType> kv_cache_dtype) // Changed parameter
17+
std::optional<infinicore::DataType> kv_cache_dtype, // Changed parameter
18+
size_t max_num_batched_tokens)
1819
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
1920
if (cache_config != nullptr) {
2021
cache_config_ = cache_config->unique_copy();
2122
}
2223

2324
// Load model config if model_path is provided, model_path must be valid, and config.json exists
2425
this->model_config_ = infinilm::config::ConfigFactory::createConfig(config_str);
25-
auto infinilm_config = std::make_shared<infinilm::global_state::InfinilmConfig>(attention_backend, this->model_config_);
26+
auto infinilm_config = std::make_shared<infinilm::global_state::InfinilmConfig>(attention_backend, this->model_config_, max_num_batched_tokens);
2627

2728
// Only support offline int8 kv cache quantization in this version
2829
if (kv_cache_dtype.has_value()) {

csrc/engine/infer_engine.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class InferEngine {
2727
const cache::CacheConfig *cache_config = nullptr,
2828
bool enable_graph_compiling = false,
2929
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default,
30-
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt);
30+
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt,
31+
size_t max_num_batched_tokens = 2048);
3132

3233
// Load a parameter to all workers (each can extract its shard inside RankWorker)
3334
void load_param(const std::string &name, const infinicore::Tensor &param);

csrc/global_state/forward_context.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "../models/infinilm_model.hpp"
4+
#include <unordered_map>
45

56
namespace infinilm::global_state {
67

@@ -43,6 +44,9 @@ struct AttentionMetadata {
4344
struct ForwardContext {
4445
AttentionMetadata attn_metadata;
4546
std::vector<infinicore::Tensor> kv_cache_vec;
47+
48+
// preallocated workspace for some modules
49+
std::unordered_map<std::string, infinicore::Tensor> preallocated_workspace;
4650
};
4751

4852
void initialize_forward_context(ForwardContext &forward_context);

csrc/global_state/infinilm_config.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@ struct InfinilmConfig {
1414
public:
1515
InfinilmConfig() = default;
1616
InfinilmConfig(const infinilm::backends::AttentionBackend &backend,
17-
const std::shared_ptr<infinilm::config::ModelConfig> &model_config)
17+
const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
18+
size_t max_num_batched_tokens)
1819
: attention_backend(backend),
19-
model_config(model_config) {}
20+
model_config(model_config),
21+
max_num_batched_tokens(max_num_batched_tokens) {
22+
const size_t max_position_embeddings = model_config->get<size_t>("max_position_embeddings");
23+
ASSERT(max_num_batched_tokens >= 512 && max_num_batched_tokens <= max_position_embeddings);
24+
}
2025

2126
public:
2227
infinilm::backends::AttentionBackend attention_backend;
2328
std::shared_ptr<infinilm::config::ModelConfig> model_config;
29+
size_t max_num_batched_tokens = 0;
2430
};
2531

2632
/**

csrc/layers/attention/attention.cpp

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
#include "attention.hpp"
2+
#include "../../global_state/global_state.hpp"
23
#include "../../utils.hpp"
34
#include "../rotary_embedding/rotary_embedding.hpp"
5+
#include <string>
46

57
namespace infinilm::layers::attention {
68

79
Attention::Attention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
810
size_t layer_idx,
9-
const infinicore::Device &device) {
11+
const infinicore::Device &device)
12+
: device_(device),
13+
dtype_(model_config->get_dtype()) {
1014
layer_idx_ = layer_idx;
1115
hidden_size_ = model_config->get<size_t>("hidden_size");
1216
head_dim_ = model_config->get<size_t>("head_dim");
1317

14-
const auto &dtype{model_config->get_dtype()};
1518
size_t total_num_heads = model_config->get<size_t>("num_attention_heads");
1619
size_t total_num_kv_heads = model_config->get<size_t>("num_key_value_heads");
1720
bool use_bias = model_config->get_or<bool>("attention_bias", true);
@@ -31,18 +34,21 @@ Attention::Attention(std::shared_ptr<infinilm::config::ModelConfig> model_config
3134
qkv_proj_ = std::make_shared<layers::linear::QKVParallelLinear>(
3235
hidden_size_, head_dim_, total_num_heads, total_num_kv_heads,
3336
"q_proj", "k_proj", "v_proj", register_fn,
34-
quantization_method, use_bias, dtype, device, rank_info);
37+
quantization_method, use_bias, dtype_, device_, rank_info);
3538
o_proj_ = this->register_module<layers::linear::RowParallelLinear>(
3639
"o_proj", total_num_heads * head_dim_, hidden_size_, quantization_method,
37-
use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm);
40+
use_output_bias, dtype_, device_, tp_rank, tp_size, rank_info.comm);
3841

39-
rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device);
42+
rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device_);
4043

4144
float scaling = 1.0f / std::sqrt(static_cast<float>(head_dim_));
4245
attn_ = std::make_shared<AttentionLayer>(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_,
43-
kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_);
46+
kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_, device_);
4447

45-
init_kv_cache_quant_params(register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_);
48+
init_kv_cache_quant_params(register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_);
49+
50+
rank_qkv_output_size_ = qkv_proj_->out_features() / static_cast<size_t>(tp_size);
51+
this->_initialize_preallocated_workspace();
4652
}
4753

4854
infinicore::Tensor Attention::forward(const infinicore::Tensor &positions,
@@ -62,7 +68,8 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position
6268
size_t seq_len = shape[1];
6369

6470
// 1. Project Q, K, V
65-
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
71+
auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_});
72+
auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable);
6673

6774
// 2. Reshape for multi-head attention
6875
auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_});
@@ -90,8 +97,9 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position
9097
auto attn_output = attn_->forward(q_rope, k_reshaped, v_reshaped);
9198

9299
// 7. Project output
93-
auto output = o_proj_->forward(attn_output);
94-
return output;
100+
auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_});
101+
o_proj_->forward_(o_output, attn_output);
102+
return o_output;
95103
}
96104

97105
infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ids,
@@ -106,7 +114,8 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_
106114
ASSERT_EQ(batch_size, 1);
107115

108116
// 1. Project Q, K, V
109-
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
117+
auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_});
118+
auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable);
110119

111120
// 2. Reshape for multi-head attention
112121
auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_});
@@ -133,8 +142,35 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_
133142
auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped);
134143

135144
// 6. Project output
136-
auto output = o_proj_->forward(attn_output);
137-
return output;
145+
auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_});
146+
o_proj_->forward_(o_output, attn_output);
147+
return o_output;
148+
}
149+
150+
void Attention::_initialize_preallocated_workspace() {
151+
const auto &infinilm_config = infinilm::global_state::get_infinilm_config();
152+
auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace;
153+
const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens;
154+
155+
const std::string attention_cache_key = std::string("Attention_max_num_batched_tokens_")
156+
+ std::to_string(max_num_batched_tokens) + "_rank_qkv_output_size_"
157+
+ std::to_string(rank_qkv_output_size_) + "_hidden_size_"
158+
+ std::to_string(hidden_size_) + "_dtype_"
159+
+ infinicore::toString(dtype_) + "_device_"
160+
+ device_.toString();
161+
162+
size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_);
163+
if (preallocated_workspace.find(attention_cache_key) == preallocated_workspace.end()) {
164+
auto attention_buffer = infinicore::Tensor::empty({max_num_batched_tokens * max_output_size}, dtype_, device_);
165+
preallocated_workspace[attention_cache_key] = attention_buffer;
166+
}
167+
168+
auto attention_buffer = preallocated_workspace.at(attention_cache_key);
169+
const auto attention_buffer_shape = attention_buffer->shape();
170+
ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size);
171+
172+
max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}})->view({max_num_batched_tokens, rank_qkv_output_size_});
173+
max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}})->view({max_num_batched_tokens, hidden_size_});
138174
}
139175

140176
void init_kv_cache_quant_params(std::function<void(const std::string &, infinicore::nn::Parameter)> register_fn,

csrc/layers/attention/attention.hpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "../../global_state/global_state.hpp"
66
#include "../linear/linear.hpp"
77
#include "backends/attention_layer.hpp"
8+
#include "infinicore/device.hpp"
9+
#include "infinicore/dtype.hpp"
810
#include "infinicore/nn/module.hpp"
911
#include "infinicore/nn/rope.hpp"
1012
#include "infinicore/tensor.hpp"
@@ -37,6 +39,8 @@ class Attention : public infinicore::nn::Module {
3739
infinicore::Tensor forward_paged_(const infinicore::Tensor &positions,
3840
const infinicore::Tensor &hidden_states) const;
3941

42+
void _initialize_preallocated_workspace();
43+
4044
protected:
4145
std::shared_ptr<infinilm::layers::linear::QKVParallelLinear> qkv_proj_;
4246
std::shared_ptr<infinilm::layers::linear::RowParallelLinear> o_proj_;
@@ -49,13 +53,22 @@ class Attention : public infinicore::nn::Module {
4953
size_t num_key_value_heads_;
5054
size_t hidden_size_;
5155
size_t head_dim_;
56+
infinicore::Device device_;
57+
infinicore::DataType dtype_;
5258

5359
// For off-line kv cache quantization
5460
INFINICORE_NN_PARAMETER(kv_cache_k_scale);
5561
INFINICORE_NN_PARAMETER(kv_cache_v_scale);
62+
63+
private:
64+
size_t rank_qkv_output_size_;
65+
66+
// preallocated workspace for Attention
67+
infinicore::Tensor max_qkv_output_;
68+
infinicore::Tensor max_o_output_;
5669
};
5770
void init_kv_cache_quant_params(std::function<void(const std::string &, infinicore::nn::Parameter)> register_fn,
58-
const infinicore::Device &device,
59-
infinicore::nn::Parameter &kv_cache_k_scale,
60-
infinicore::nn::Parameter &kv_cache_v_scale);
71+
const infinicore::Device &device,
72+
infinicore::nn::Parameter &kv_cache_k_scale,
73+
infinicore::nn::Parameter &kv_cache_v_scale);
6174
} // namespace infinilm::layers::attention

csrc/layers/attention/backends/attention_layer.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@ AttentionLayer::AttentionLayer(size_t num_heads,
99
size_t layer_idx,
1010
infinicore::Tensor k_scale,
1111
infinicore::Tensor v_scale,
12-
::infinilm::backends::AttentionBackend attn_backend) : k_scale_(k_scale), v_scale_(v_scale), layer_idx_(layer_idx), attn_backend_(attn_backend) {
12+
::infinilm::backends::AttentionBackend attn_backend,
13+
const infinicore::Device &device) : k_scale_(k_scale), v_scale_(v_scale), layer_idx_(layer_idx), attn_backend_(attn_backend) {
1314
switch (attn_backend) {
1415
case ::infinilm::backends::AttentionBackend::STATIC_ATTN:
15-
attn_backend_impl_ = std::make_shared<backends::StaticAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
16+
attn_backend_impl_ = std::make_shared<backends::StaticAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx, device);
1617
break;
1718
case ::infinilm::backends::AttentionBackend::PAGED_ATTN:
18-
attn_backend_impl_ = std::make_shared<backends::PagedAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
19+
attn_backend_impl_ = std::make_shared<backends::PagedAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx, device);
1920
break;
2021
case ::infinilm::backends::AttentionBackend::FLASH_ATTN:
21-
attn_backend_impl_ = std::make_shared<backends::FlashAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
22+
attn_backend_impl_ = std::make_shared<backends::FlashAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx, device);
2223
break;
2324
default:
2425
throw std::runtime_error("infinilm::layers::attention::AttentionLayer: unsupported attention backend");

csrc/layers/attention/backends/attention_layer.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ class AttentionLayer {
3131
size_t layer_idx,
3232
infinicore::Tensor k_scale,
3333
infinicore::Tensor v_scale,
34-
::infinilm::backends::AttentionBackend attention_backend);
34+
::infinilm::backends::AttentionBackend attention_backend,
35+
const infinicore::Device &device);
3536

3637
infinicore::Tensor forward(infinicore::Tensor &query,
3738
infinicore::Tensor &key,

csrc/layers/attention/backends/flash_attn.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,38 @@
11
#include "flash_attn.hpp"
22

3+
#include "../../../global_state/global_state.hpp"
34
#include "../../../utils.hpp"
45
#include "infinicore/ops.hpp"
56
#include "infinicore/ops/mha_kvcache.hpp"
67
#include "infinicore/ops/mha_varlen.hpp"
8+
#include <string>
79

810
namespace infinilm::layers::attention::backends {
911

1012
FlashAttentionImpl::FlashAttentionImpl(size_t num_heads,
1113
size_t head_size,
1214
float scale,
1315
size_t num_kv_heads,
14-
size_t layer_idx)
16+
size_t layer_idx,
17+
const infinicore::Device &device)
1518
: num_heads_(num_heads),
1619
head_size_(head_size),
1720
scale_(scale),
1821
num_kv_heads_(num_kv_heads),
1922
layer_idx_(layer_idx),
20-
head_dim_(head_size) {
23+
head_dim_(head_size),
24+
device_(device) {
2125

2226
const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config();
2327
if (!infinilm_config.model_config) {
2428
throw std::runtime_error("infinilm::layers::attention::backends::FlashAttentionImpl: model_config is null");
2529
}
26-
max_position_embeddings_ = infinilm_config.model_config->get<size_t>("max_position_embeddings");
30+
31+
const auto &model_config = infinilm_config.model_config;
32+
dtype_ = model_config->get_dtype();
33+
max_position_embeddings_ = model_config->get<size_t>("max_position_embeddings");
34+
35+
this->_initialize_preallocated_workspace();
2736
}
2837

2938
infinicore::Tensor FlashAttentionImpl::forward(const AttentionLayer &layer,
@@ -48,8 +57,9 @@ infinicore::Tensor FlashAttentionImpl::forward(const AttentionLayer &layer,
4857
bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]);
4958

5059
// 2. Compute attention
51-
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device());
60+
infinicore::Tensor attn_output;
5261
if (is_prefill) {
62+
attn_output = max_attn_output_->narrow({{0, 0, seq_len}});
5363
infinicore::op::mha_varlen_(
5464
attn_output,
5565
query,
@@ -99,4 +109,27 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> FlashAttentionImpl::do_kv_cac
99109
return {k_cache_layer, v_cache_layer};
100110
}
101111

112+
void FlashAttentionImpl::_initialize_preallocated_workspace() {
113+
const auto &infinilm_config = infinilm::global_state::get_infinilm_config();
114+
auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace;
115+
const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens;
116+
117+
const std::string cache_key = std::string("FlashAttentionImpl_max_num_batched_tokens_")
118+
+ std::to_string(max_num_batched_tokens) + "_num_heads_"
119+
+ std::to_string(num_heads_) + "_head_dim_"
120+
+ std::to_string(head_dim_) + "_dtype_"
121+
+ infinicore::toString(dtype_) + "_device_"
122+
+ device_.toString();
123+
124+
if (preallocated_workspace.find(cache_key) == preallocated_workspace.end()) {
125+
auto flash_attention_impl_buffer = infinicore::Tensor::empty({max_num_batched_tokens, num_heads_, head_dim_}, dtype_, device_);
126+
preallocated_workspace[cache_key] = flash_attention_impl_buffer;
127+
}
128+
129+
auto flash_attention_impl_buffer = preallocated_workspace.at(cache_key);
130+
const auto buffer_shape = flash_attention_impl_buffer->shape();
131+
ASSERT(buffer_shape[0] == max_num_batched_tokens && buffer_shape[1] == num_heads_ && buffer_shape[2] == head_dim_);
132+
133+
max_attn_output_ = flash_attention_impl_buffer;
134+
}
102135
} // namespace infinilm::layers::attention::backends

csrc/layers/attention/backends/flash_attn.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

33
#include "../../../global_state/global_state.hpp"
4+
#include "infinicore/device.hpp"
5+
#include "infinicore/dtype.hpp"
46
#include "infinicore/tensor.hpp"
57
#include <tuple>
68

@@ -16,7 +18,8 @@ class FlashAttentionImpl {
1618
size_t head_size,
1719
float scale,
1820
size_t num_kv_heads,
19-
size_t layer_idx);
21+
size_t layer_idx,
22+
const infinicore::Device &device);
2023

2124
/**
2225
* @brief Forward pass with FlashAttention.
@@ -43,12 +46,20 @@ class FlashAttentionImpl {
4346
const infinicore::Tensor slot_mapping) const;
4447

4548
private:
49+
void _initialize_preallocated_workspace();
50+
4651
size_t num_heads_;
4752
size_t head_size_;
4853
float scale_;
4954
size_t num_kv_heads_;
5055
size_t layer_idx_;
5156
size_t head_dim_; // Note: head_dim equals to head_size
5257
size_t max_position_embeddings_;
58+
infinicore::Device device_;
59+
infinicore::DataType dtype_;
60+
61+
// preallocated workspace for FlashAttentionImpl
62+
infinicore::Tensor max_attn_output_;
5363
};
64+
5465
} // namespace infinilm::layers::attention::backends

0 commit comments

Comments
 (0)