Skip to content

Commit 6bb2040

Browse files
author
wangpengcheng
committed
refactor with register_inference_buffer.
1 parent ed1c8c0 commit 6bb2040

18 files changed

Lines changed: 474 additions & 190 deletions

File tree

csrc/engine/rank_worker.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,14 @@ void RankWorker::thread_loop() {
278278
if (!model_) {
279279
throw std::runtime_error("Failed to create model");
280280
}
281+
282+
infinicore::context::syncStream();
283+
284+
if (infinilm_config_->enable_workspace_manager) {
285+
forward_context_.workspace_manager.finalize_and_bind(rank_info_.device);
286+
}
287+
infinicore::context::syncStream();
288+
281289
if (enable_graph_compiling_) {
282290
compiler_ = std::make_unique<GeneralCompiler>(model_, barrier_);
283291
}

csrc/global_state/forward_context.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#pragma once
22

33
#include "../models/infinilm_model.hpp"
4-
#include <unordered_map>
4+
#include "../utils.hpp"
5+
#include "workspace_manager.hpp"
6+
#include <vector>
57

68
namespace infinilm::global_state {
79

@@ -49,9 +51,7 @@ struct ForwardContext {
4951
AttentionMetadata attn_metadata;
5052
MultiModalMetadata mm_metadata;
5153
std::vector<infinicore::Tensor> kv_cache_vec;
52-
53-
// preallocated workspace for some modules
54-
std::unordered_map<std::string, infinicore::Tensor> preallocated_workspace;
54+
WorkspaceManager workspace_manager;
5555
};
5656

5757
void initialize_forward_context(ForwardContext &forward_context);

csrc/global_state/infinilm_config.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,19 @@ struct InfinilmConfig {
1919
: attention_backend(backend),
2020
model_config(model_config),
2121
max_num_batched_tokens(max_num_batched_tokens) {
22-
const size_t max_position_embeddings = model_config->get<size_t>("max_position_embeddings");
23-
ASSERT(max_num_batched_tokens >= 512 && max_num_batched_tokens <= max_position_embeddings);
22+
23+
if (max_num_batched_tokens > 0) {
24+
const size_t max_position_embeddings = model_config->get<size_t>("max_position_embeddings");
25+
ASSERT(max_num_batched_tokens >= 512 && max_num_batched_tokens <= max_position_embeddings);
26+
enable_workspace_manager = true;
27+
}
2428
}
2529

2630
public:
2731
infinilm::backends::AttentionBackend attention_backend;
2832
std::shared_ptr<infinilm::config::ModelConfig> model_config;
2933
size_t max_num_batched_tokens = 0;
34+
bool enable_workspace_manager{false};
3035
};
3136

3237
/**
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#pragma once
2+
3+
#include "../models/infinilm_model.hpp"
4+
#include "../utils.hpp"
5+
#include <algorithm>
6+
#include <cstdio>
7+
#include <functional>
8+
#include <string>
9+
#include <unordered_map>
10+
#include <vector>
11+
12+
namespace infinilm::global_state {
13+
14+
// /**
15+
// * @brief Unified GPU inference workspace manager.
16+
// *
17+
// * Phase 1: modules register buffer layouts via ``register_buffer``.
18+
// * Phase 2/3: ``finalize_and_bind`` allocates ``scratch_buffer_`` and binds views.
19+
// */
20+
// class WorkspaceManager {
21+
// public:
22+
// using BindFn = std::function<void(const infinicore::Tensor &)>;
23+
24+
// WorkspaceManager() = default;
25+
// ~WorkspaceManager() = default;
26+
27+
// /**
28+
// * @brief Register a buffer appended at the current scratch_buffer tail.
29+
// *
30+
// * @param name Unique cache key; duplicate keys share one slot.
31+
// * @param shape Tensor shape for the bound view.
32+
// * @param dtype Element type of the bound view.
33+
// * @param device Device on which scratch_buffer is allocated.
34+
// * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view.
35+
// */
36+
// void register_buffer(const std::string &name,
37+
// const infinicore::Shape &shape,
38+
// const infinicore::DataType &dtype,
39+
// const infinicore::Device &device,
40+
// BindFn bind_fn) {
41+
// register_buffer_impl(name, total_bytes_, shape, dtype, device, std::move(bind_fn), true);
42+
// }
43+
44+
// /**
45+
// * @brief Register a buffer pinned at a fixed byte offset.
46+
// *
47+
// * @param name Unique cache key; duplicate keys share one slot.
48+
// * @param offset Byte offset in scratch_buffer (currently only 0 is supported).
49+
// * @param shape Tensor shape for the bound view.
50+
// * @param dtype Element type of the bound view.
51+
// * @param device Device on which scratch_buffer is allocated.
52+
// * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view.
53+
// */
54+
// void register_buffer(const std::string &name,
55+
// size_t offset,
56+
// const infinicore::Shape &shape,
57+
// const infinicore::DataType &dtype,
58+
// const infinicore::Device &device,
59+
// BindFn bind_fn) {
60+
// ASSERT(0 == offset);
61+
// register_buffer_impl(name, offset, shape, dtype, device, std::move(bind_fn), false);
62+
// }
63+
64+
// /**
65+
// * @brief Allocate scratch_buffer and run all registered bind callbacks.
66+
// *
67+
// * @param device Device on which scratch_buffer is allocated.
68+
// */
69+
// void finalize_and_bind(const infinicore::Device &device) {
70+
// ASSERT(!finalized_);
71+
// if (total_bytes_ == 0) {
72+
// finalized_ = true;
73+
// return;
74+
// }
75+
76+
// ASSERT(device.getType() != infinicore::Device::Type::CPU);
77+
78+
// scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, device);
79+
80+
// spdlog::info("WorkspaceManager: finalize_and_bind {:.3f} MB", total_bytes_ / 1024.0 / 1024.0);
81+
82+
// for (auto &[name, reg] : registrations_) {
83+
// auto *base_ptr = scratch_buffer_->data() + reg.offset;
84+
// auto view = infinicore::Tensor::from_blob(static_cast<void *>(base_ptr), reg.shape, reg.dtype, device);
85+
// inference_buffers_[name] = view;
86+
// for (auto &bind_fn : reg.bind_callbacks) {
87+
// bind_fn(view);
88+
// }
89+
// }
90+
91+
// finalized_ = true;
92+
// }
93+
94+
// private:
95+
// /** @brief Metadata for one registered region in scratch_buffer. */
96+
// struct BufferRegistration {
97+
// size_t offset{0};
98+
// size_t aligned_bytes{0};
99+
// infinicore::Shape shape;
100+
// infinicore::DataType dtype;
101+
// infinicore::Device device;
102+
// std::vector<BindFn> bind_callbacks;
103+
// };
104+
105+
// void register_buffer_impl(const std::string &name,
106+
// size_t offset,
107+
// const infinicore::Shape &shape,
108+
// const infinicore::DataType &dtype,
109+
// const infinicore::Device &device,
110+
// BindFn bind_fn,
111+
// bool bump_tail) {
112+
// ASSERT(!finalized_);
113+
// ASSERT(device.getType() != infinicore::Device::Type::CPU);
114+
115+
// auto compute_numel = [](const infinicore::Shape &shape) {
116+
// size_t numel = 1;
117+
// for (const auto dim : shape) {
118+
// numel *= dim;
119+
// }
120+
// return numel;
121+
// };
122+
123+
// auto align_up = [](size_t n, size_t alignment = 512) {
124+
// return (n + alignment - 1) & ~(alignment - 1);
125+
// };
126+
127+
// const size_t actual_bytes = compute_numel(shape) * infinicore::dsize(dtype);
128+
// const size_t aligned_bytes = align_up(actual_bytes);
129+
130+
// if (registrations_.find(name) == registrations_.end()) {
131+
// BufferRegistration reg;
132+
// reg.offset = offset;
133+
// reg.aligned_bytes = aligned_bytes;
134+
// reg.shape = shape;
135+
// reg.dtype = dtype;
136+
// reg.device = device;
137+
138+
// if (bump_tail) {
139+
// total_bytes_ += aligned_bytes;
140+
// } else {
141+
// total_bytes_ = std::max(total_bytes_, offset + aligned_bytes);
142+
// }
143+
// registrations_.emplace(name, std::move(reg));
144+
// }
145+
146+
// auto &reg = registrations_.at(name);
147+
// ASSERT(reg.aligned_bytes == aligned_bytes);
148+
// ASSERT(reg.shape == shape);
149+
// ASSERT(reg.dtype == dtype);
150+
// ASSERT(reg.device == device);
151+
// reg.bind_callbacks.push_back(std::move(bind_fn));
152+
// }
153+
154+
// size_t total_bytes_{0};
155+
// bool finalized_{false};
156+
// infinicore::Tensor scratch_buffer_;
157+
// std::unordered_map<std::string, BufferRegistration> registrations_;
158+
// std::unordered_map<std::string, infinicore::Tensor> inference_buffers_;
159+
// };
160+
161+
}; // namespace infinilm::global_state

csrc/layers/attention/attention.cpp

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "../../utils.hpp"
44
#include "../rotary_embedding/rotary_embedding.hpp"
55
#include <string>
6+
#include <tuple>
67

78
namespace infinilm::layers::attention {
89

@@ -48,7 +49,10 @@ Attention::Attention(std::shared_ptr<infinilm::config::ModelConfig> model_config
4849
init_kv_cache_quant_params(register_fn, device_, kv_cache_k_scale_, kv_cache_v_scale_);
4950

5051
rank_qkv_output_size_ = qkv_proj_->out_features() / static_cast<size_t>(tp_size);
51-
this->_initialize_preallocated_workspace();
52+
enable_workspace_manager_ = infinilm::global_state::get_infinilm_config().enable_workspace_manager;
53+
if (enable_workspace_manager_) {
54+
this->_register_inference_buffer();
55+
}
5256
}
5357

5458
infinicore::Tensor Attention::forward(const infinicore::Tensor &positions,
@@ -68,8 +72,13 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position
6872
size_t seq_len = shape[1];
6973

7074
// 1. Project Q, K, V
71-
auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_});
72-
auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable);
75+
infinicore::Tensor q, k, v;
76+
if (enable_workspace_manager_) {
77+
auto qkv_output = max_qkv_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, rank_qkv_output_size_});
78+
std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable);
79+
} else {
80+
std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable);
81+
}
7382

7483
// 2. Reshape for multi-head attention
7584
auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_});
@@ -96,10 +105,13 @@ infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &position
96105
// 5. Attn Backend calculate
97106
auto attn_output = attn_->forward(q_rope, k_reshaped, v_reshaped);
98107

99-
// 7. Project output
100-
auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_});
101-
o_proj_->forward_(o_output, attn_output);
102-
return o_output;
108+
// 6. Project output
109+
if (enable_workspace_manager_) {
110+
auto o_output = max_o_output_->narrow({{0, 0, batch_size * seq_len}})->view({batch_size, seq_len, hidden_size_});
111+
o_proj_->forward_(o_output, attn_output);
112+
return o_output;
113+
}
114+
return o_proj_->forward(attn_output);
103115
}
104116

105117
infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_ids,
@@ -114,8 +126,13 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_
114126
ASSERT_EQ(batch_size, 1);
115127

116128
// 1. Project Q, K, V
117-
auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_});
118-
auto [q, k, v] = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable);
129+
infinicore::Tensor q, k, v;
130+
if (enable_workspace_manager_) {
131+
auto qkv_output = max_qkv_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, rank_qkv_output_size_});
132+
std::tie(q, k, v) = qkv_proj_->forward_split_(qkv_output, hidden_states_mutable);
133+
} else {
134+
std::tie(q, k, v) = qkv_proj_->forward_split(hidden_states_mutable);
135+
}
119136

120137
// 2. Reshape for multi-head attention
121138
auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_});
@@ -142,35 +159,44 @@ infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &position_
142159
auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped);
143160

144161
// 6. Project output
145-
auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_});
146-
o_proj_->forward_(o_output, attn_output);
147-
return o_output;
162+
if (enable_workspace_manager_) {
163+
auto o_output = max_o_output_->narrow({{0, 0, seq_len}})->view({1, seq_len, hidden_size_});
164+
o_proj_->forward_(o_output, attn_output);
165+
return o_output;
166+
}
167+
return o_proj_->forward(attn_output);
148168
}
149169

150-
void Attention::_initialize_preallocated_workspace() {
170+
void Attention::_register_inference_buffer() {
151171
const auto &infinilm_config = infinilm::global_state::get_infinilm_config();
152-
auto &preallocated_workspace = infinilm::global_state::get_forward_context().preallocated_workspace;
172+
auto &workspace_manager = infinilm::global_state::get_forward_context().workspace_manager;
153173
const size_t max_num_batched_tokens = infinilm_config.max_num_batched_tokens;
154174

175+
ASSERT(rank_qkv_output_size_ > 0 && hidden_size_ > 0);
176+
155177
const std::string attention_cache_key = std::string("Attention_max_num_batched_tokens_")
156178
+ std::to_string(max_num_batched_tokens) + "_rank_qkv_output_size_"
157179
+ std::to_string(rank_qkv_output_size_) + "_hidden_size_"
158180
+ std::to_string(hidden_size_) + "_dtype_"
159181
+ infinicore::toString(dtype_) + "_device_"
160182
+ device_.toString();
161183

162-
size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_);
163-
if (preallocated_workspace.find(attention_cache_key) == preallocated_workspace.end()) {
164-
auto attention_buffer = infinicore::Tensor::empty({max_num_batched_tokens * max_output_size}, dtype_, device_);
165-
preallocated_workspace[attention_cache_key] = attention_buffer;
166-
}
167-
168-
auto attention_buffer = preallocated_workspace.at(attention_cache_key);
169-
const auto attention_buffer_shape = attention_buffer->shape();
170-
ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size);
171-
172-
max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}})->view({max_num_batched_tokens, rank_qkv_output_size_});
173-
max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}})->view({max_num_batched_tokens, hidden_size_});
184+
const size_t max_output_size = std::max(rank_qkv_output_size_, hidden_size_);
185+
const infinicore::Shape attention_buffer_shape = {max_num_batched_tokens * max_output_size};
186+
workspace_manager.register_buffer(
187+
attention_cache_key,
188+
attention_buffer_shape,
189+
dtype_,
190+
device_,
191+
[this, max_num_batched_tokens, max_output_size](const infinicore::Tensor &attention_buffer) {
192+
const auto attention_buffer_shape = attention_buffer->shape();
193+
ASSERT(attention_buffer_shape[0] == max_num_batched_tokens * max_output_size);
194+
195+
max_qkv_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * rank_qkv_output_size_}})
196+
->view({max_num_batched_tokens, rank_qkv_output_size_});
197+
max_o_output_ = attention_buffer->narrow({{0, 0, max_num_batched_tokens * hidden_size_}})
198+
->view({max_num_batched_tokens, hidden_size_});
199+
});
174200
}
175201

176202
void init_kv_cache_quant_params(std::function<void(const std::string &, infinicore::nn::Parameter)> register_fn,

csrc/layers/attention/attention.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
#include "../../global_state/global_state.hpp"
66
#include "../linear/linear.hpp"
77
#include "backends/attention_layer.hpp"
8-
#include "infinicore/device.hpp"
9-
#include "infinicore/dtype.hpp"
108
#include "infinicore/nn/module.hpp"
119
#include "infinicore/nn/rope.hpp"
1210
#include "infinicore/tensor.hpp"
@@ -39,7 +37,7 @@ class Attention : public infinicore::nn::Module {
3937
infinicore::Tensor forward_paged_(const infinicore::Tensor &positions,
4038
const infinicore::Tensor &hidden_states) const;
4139

42-
void _initialize_preallocated_workspace();
40+
void _register_inference_buffer();
4341

4442
protected:
4543
std::shared_ptr<infinilm::layers::linear::QKVParallelLinear> qkv_proj_;
@@ -61,11 +59,10 @@ class Attention : public infinicore::nn::Module {
6159
INFINICORE_NN_PARAMETER(kv_cache_v_scale);
6260

6361
private:
64-
size_t rank_qkv_output_size_;
65-
66-
// preallocated workspace for Attention
67-
infinicore::Tensor max_qkv_output_;
68-
infinicore::Tensor max_o_output_;
62+
bool enable_workspace_manager_{false};
63+
size_t rank_qkv_output_size_{0};
64+
infinicore::Tensor max_qkv_output_; // inference buffer for Attention
65+
infinicore::Tensor max_o_output_; // inference buffer for Attention
6966
};
7067
void init_kv_cache_quant_params(std::function<void(const std::string &, infinicore::nn::Parameter)> register_fn,
7168
const infinicore::Device &device,

0 commit comments

Comments
 (0)