Skip to content

Commit 2aea7df

Browse files
ma-hangwangpengcheng
authored andcommitted
refactor: improve WorkspaceManager buffer registration
1 parent 6bb2040 commit 2aea7df

1 file changed

Lines changed: 148 additions & 148 deletions

File tree

csrc/global_state/workspace_manager.hpp

Lines changed: 148 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -11,151 +11,151 @@
1111

1212
namespace infinilm::global_state {
1313

14-
// /**
15-
// * @brief Unified GPU inference workspace manager.
16-
// *
17-
// * Phase 1: modules register buffer layouts via ``register_buffer``.
18-
// * Phase 2/3: ``finalize_and_bind`` allocates ``scratch_buffer_`` and binds views.
19-
// */
20-
// class WorkspaceManager {
21-
// public:
22-
// using BindFn = std::function<void(const infinicore::Tensor &)>;
23-
24-
// WorkspaceManager() = default;
25-
// ~WorkspaceManager() = default;
26-
27-
// /**
28-
// * @brief Register a buffer appended at the current scratch_buffer tail.
29-
// *
30-
// * @param name Unique cache key; duplicate keys share one slot.
31-
// * @param shape Tensor shape for the bound view.
32-
// * @param dtype Element type of the bound view.
33-
// * @param device Device on which scratch_buffer is allocated.
34-
// * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view.
35-
// */
36-
// void register_buffer(const std::string &name,
37-
// const infinicore::Shape &shape,
38-
// const infinicore::DataType &dtype,
39-
// const infinicore::Device &device,
40-
// BindFn bind_fn) {
41-
// register_buffer_impl(name, total_bytes_, shape, dtype, device, std::move(bind_fn), true);
42-
// }
43-
44-
// /**
45-
// * @brief Register a buffer pinned at a fixed byte offset.
46-
// *
47-
// * @param name Unique cache key; duplicate keys share one slot.
48-
// * @param offset Byte offset in scratch_buffer (currently only 0 is supported).
49-
// * @param shape Tensor shape for the bound view.
50-
// * @param dtype Element type of the bound view.
51-
// * @param device Device on which scratch_buffer is allocated.
52-
// * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view.
53-
// */
54-
// void register_buffer(const std::string &name,
55-
// size_t offset,
56-
// const infinicore::Shape &shape,
57-
// const infinicore::DataType &dtype,
58-
// const infinicore::Device &device,
59-
// BindFn bind_fn) {
60-
// ASSERT(0 == offset);
61-
// register_buffer_impl(name, offset, shape, dtype, device, std::move(bind_fn), false);
62-
// }
63-
64-
// /**
65-
// * @brief Allocate scratch_buffer and run all registered bind callbacks.
66-
// *
67-
// * @param device Device on which scratch_buffer is allocated.
68-
// */
69-
// void finalize_and_bind(const infinicore::Device &device) {
70-
// ASSERT(!finalized_);
71-
// if (total_bytes_ == 0) {
72-
// finalized_ = true;
73-
// return;
74-
// }
75-
76-
// ASSERT(device.getType() != infinicore::Device::Type::CPU);
77-
78-
// scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, device);
79-
80-
// spdlog::info("WorkspaceManager: finalize_and_bind {:.3f} MB", total_bytes_ / 1024.0 / 1024.0);
81-
82-
// for (auto &[name, reg] : registrations_) {
83-
// auto *base_ptr = scratch_buffer_->data() + reg.offset;
84-
// auto view = infinicore::Tensor::from_blob(static_cast<void *>(base_ptr), reg.shape, reg.dtype, device);
85-
// inference_buffers_[name] = view;
86-
// for (auto &bind_fn : reg.bind_callbacks) {
87-
// bind_fn(view);
88-
// }
89-
// }
90-
91-
// finalized_ = true;
92-
// }
93-
94-
// private:
95-
// /** @brief Metadata for one registered region in scratch_buffer. */
96-
// struct BufferRegistration {
97-
// size_t offset{0};
98-
// size_t aligned_bytes{0};
99-
// infinicore::Shape shape;
100-
// infinicore::DataType dtype;
101-
// infinicore::Device device;
102-
// std::vector<BindFn> bind_callbacks;
103-
// };
104-
105-
// void register_buffer_impl(const std::string &name,
106-
// size_t offset,
107-
// const infinicore::Shape &shape,
108-
// const infinicore::DataType &dtype,
109-
// const infinicore::Device &device,
110-
// BindFn bind_fn,
111-
// bool bump_tail) {
112-
// ASSERT(!finalized_);
113-
// ASSERT(device.getType() != infinicore::Device::Type::CPU);
114-
115-
// auto compute_numel = [](const infinicore::Shape &shape) {
116-
// size_t numel = 1;
117-
// for (const auto dim : shape) {
118-
// numel *= dim;
119-
// }
120-
// return numel;
121-
// };
122-
123-
// auto align_up = [](size_t n, size_t alignment = 512) {
124-
// return (n + alignment - 1) & ~(alignment - 1);
125-
// };
126-
127-
// const size_t actual_bytes = compute_numel(shape) * infinicore::dsize(dtype);
128-
// const size_t aligned_bytes = align_up(actual_bytes);
129-
130-
// if (registrations_.find(name) == registrations_.end()) {
131-
// BufferRegistration reg;
132-
// reg.offset = offset;
133-
// reg.aligned_bytes = aligned_bytes;
134-
// reg.shape = shape;
135-
// reg.dtype = dtype;
136-
// reg.device = device;
137-
138-
// if (bump_tail) {
139-
// total_bytes_ += aligned_bytes;
140-
// } else {
141-
// total_bytes_ = std::max(total_bytes_, offset + aligned_bytes);
142-
// }
143-
// registrations_.emplace(name, std::move(reg));
144-
// }
145-
146-
// auto &reg = registrations_.at(name);
147-
// ASSERT(reg.aligned_bytes == aligned_bytes);
148-
// ASSERT(reg.shape == shape);
149-
// ASSERT(reg.dtype == dtype);
150-
// ASSERT(reg.device == device);
151-
// reg.bind_callbacks.push_back(std::move(bind_fn));
152-
// }
153-
154-
// size_t total_bytes_{0};
155-
// bool finalized_{false};
156-
// infinicore::Tensor scratch_buffer_;
157-
// std::unordered_map<std::string, BufferRegistration> registrations_;
158-
// std::unordered_map<std::string, infinicore::Tensor> inference_buffers_;
159-
// };
160-
161-
}; // namespace infinilm::global_state
14+
/**
15+
* @brief Unified GPU inference workspace manager.
16+
*
17+
* Phase 1: modules register buffer layouts via ``register_buffer``.
18+
* Phase 2/3: ``finalize_and_bind`` allocates ``scratch_buffer_`` and binds views.
19+
*/
20+
class WorkspaceManager {
21+
public:
22+
using BindFn = std::function<void(const infinicore::Tensor &)>;
23+
24+
WorkspaceManager() = default;
25+
~WorkspaceManager() = default;
26+
27+
/**
28+
* @brief Register a buffer appended at the current scratch_buffer tail.
29+
*
30+
* @param name Unique cache key; duplicate keys share one slot.
31+
* @param shape Tensor shape for the bound view.
32+
* @param dtype Element type of the bound view.
33+
* @param device Device on which scratch_buffer is allocated.
34+
* @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view.
35+
*/
36+
void register_buffer(const std::string &name,
37+
const infinicore::Shape &shape,
38+
const infinicore::DataType &dtype,
39+
const infinicore::Device &device,
40+
BindFn bind_fn) {
41+
register_buffer_impl(name, total_bytes_, shape, dtype, device, std::move(bind_fn), true);
42+
}
43+
44+
/**
45+
* @brief Register a buffer pinned at a fixed byte offset.
46+
*
47+
* @param name Unique cache key; duplicate keys share one slot.
48+
* @param offset Byte offset in scratch_buffer (currently only 0 is supported).
49+
* @param shape Tensor shape for the bound view.
50+
* @param dtype Element type of the bound view.
51+
* @param device Device on which scratch_buffer is allocated.
52+
* @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view.
53+
*/
54+
void register_buffer(const std::string &name,
55+
size_t offset,
56+
const infinicore::Shape &shape,
57+
const infinicore::DataType &dtype,
58+
const infinicore::Device &device,
59+
BindFn bind_fn) {
60+
ASSERT(0 == offset);
61+
register_buffer_impl(name, offset, shape, dtype, device, std::move(bind_fn), false);
62+
}
63+
64+
/**
65+
* @brief Allocate scratch_buffer and run all registered bind callbacks.
66+
*
67+
* @param device Device on which scratch_buffer is allocated.
68+
*/
69+
void finalize_and_bind(const infinicore::Device &device) {
70+
ASSERT(!finalized_);
71+
if (total_bytes_ == 0) {
72+
finalized_ = true;
73+
return;
74+
}
75+
76+
ASSERT(device.getType() != infinicore::Device::Type::CPU);
77+
78+
scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, device);
79+
80+
spdlog::info("WorkspaceManager: finalize_and_bind {:.3f} MB", total_bytes_ / 1024.0 / 1024.0);
81+
82+
for (auto &[name, reg] : registrations_) {
83+
auto *base_ptr = scratch_buffer_->data() + reg.offset;
84+
auto view = infinicore::Tensor::from_blob(static_cast<void *>(base_ptr), reg.shape, reg.dtype, device);
85+
inference_buffers_[name] = view;
86+
for (auto &bind_fn : reg.bind_callbacks) {
87+
bind_fn(view);
88+
}
89+
}
90+
91+
finalized_ = true;
92+
}
93+
94+
private:
95+
/** @brief Metadata for one registered region in scratch_buffer. */
96+
struct BufferRegistration {
97+
size_t offset{0};
98+
size_t aligned_bytes{0};
99+
infinicore::Shape shape;
100+
infinicore::DataType dtype;
101+
infinicore::Device device;
102+
std::vector<BindFn> bind_callbacks;
103+
};
104+
105+
void register_buffer_impl(const std::string &name,
106+
size_t offset,
107+
const infinicore::Shape &shape,
108+
const infinicore::DataType &dtype,
109+
const infinicore::Device &device,
110+
BindFn bind_fn,
111+
bool bump_tail) {
112+
ASSERT(!finalized_);
113+
ASSERT(device.getType() != infinicore::Device::Type::CPU);
114+
115+
auto compute_numel = [](const infinicore::Shape &shape) {
116+
size_t numel = 1;
117+
for (const auto dim : shape) {
118+
numel *= dim;
119+
}
120+
return numel;
121+
};
122+
123+
auto align_up = [](size_t n, size_t alignment = 512) {
124+
return (n + alignment - 1) & ~(alignment - 1);
125+
};
126+
127+
const size_t actual_bytes = compute_numel(shape) * infinicore::dsize(dtype);
128+
const size_t aligned_bytes = align_up(actual_bytes);
129+
130+
if (registrations_.find(name) == registrations_.end()) {
131+
BufferRegistration reg;
132+
reg.offset = offset;
133+
reg.aligned_bytes = aligned_bytes;
134+
reg.shape = shape;
135+
reg.dtype = dtype;
136+
reg.device = device;
137+
138+
if (bump_tail) {
139+
total_bytes_ += aligned_bytes;
140+
} else {
141+
total_bytes_ = std::max(total_bytes_, offset + aligned_bytes);
142+
}
143+
registrations_.emplace(name, std::move(reg));
144+
}
145+
146+
auto &reg = registrations_.at(name);
147+
ASSERT(reg.aligned_bytes == aligned_bytes);
148+
ASSERT(reg.shape == shape);
149+
ASSERT(reg.dtype == dtype);
150+
ASSERT(reg.device == device);
151+
reg.bind_callbacks.push_back(std::move(bind_fn));
152+
}
153+
154+
size_t total_bytes_{0};
155+
bool finalized_{false};
156+
infinicore::Tensor scratch_buffer_;
157+
std::unordered_map<std::string, BufferRegistration> registrations_;
158+
std::unordered_map<std::string, infinicore::Tensor> inference_buffers_;
159+
};
160+
161+
}; // namespace infinilm::global_state

0 commit comments

Comments
 (0)