Skip to content

Commit b4d262b

Browse files
author
wangpengcheng
committed
issue/407 - Refine the code
1 parent b38dafd commit b4d262b

3 files changed

Lines changed: 136 additions & 55 deletions

File tree

csrc/engine/rank_worker.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,8 @@ void RankWorker::thread_loop() {
282282
infinicore::context::syncStream();
283283

284284
if (infinilm_config_->enable_workspace_manager) {
285-
forward_context_.workspace_manager.finalize_and_bind(rank_info_.device);
285+
forward_context_.workspace_manager.finalize_and_bind();
286+
// forward_context_.workspace_manager.log_registrations();
286287
}
287288
infinicore::context::syncStream();
288289

csrc/global_state/workspace_manager.hpp

Lines changed: 133 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,24 @@
22

33
#include "../models/infinilm_model.hpp"
44
#include "../utils.hpp"
5+
#include "parallel_state.hpp"
56
#include <algorithm>
6-
#include <cstdio>
77
#include <functional>
8+
#include <iomanip>
9+
#include <sstream>
810
#include <string>
911
#include <unordered_map>
1012
#include <vector>
1113

1214
namespace infinilm::global_state {
1315

1416
/**
15-
* @brief Unified GPU inference workspace manager.
17+
* @brief Unified GPU inference scratch buffer.
1618
*
17-
* Phase 1: modules register buffer layouts via ``register_buffer``.
18-
* Phase 2/3: ``finalize_and_bind`` allocates ``scratch_buffer_`` and binds views.
19+
* Flow: register_buffer -> finalize_and_bind -> log_registrations (optional).
20+
* Layout: bump (tail append) or pinned@0 (offset fixed at 0).
21+
* Slots may overlap; scratch_bytes is max span, not sum of slots. Safe use requires
22+
* temporal reuse across forward phases.
1923
*/
2024
class WorkspaceManager {
2125
public:
@@ -24,93 +28,168 @@ class WorkspaceManager {
2428
WorkspaceManager() = default;
2529
~WorkspaceManager() = default;
2630

27-
/**
28-
* @brief Register a buffer appended at the current scratch_buffer tail.
29-
*
30-
* @param name Unique cache key; duplicate keys share one slot.
31-
* @param shape Tensor shape for the bound view.
32-
* @param dtype Element type of the bound view.
33-
* @param device Device on which scratch_buffer is allocated.
34-
* @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view.
35-
*/
31+
/** @brief Register a bump slot at current total_bytes_. Same name reuses one slot. */
3632
void register_buffer(const std::string &name,
3733
const infinicore::Shape &shape,
3834
const infinicore::DataType &dtype,
3935
const infinicore::Device &device,
4036
BindFn bind_fn) {
41-
register_buffer_impl(name, total_bytes_, shape, dtype, device, std::move(bind_fn), true);
37+
_register_buffer_impl(name, total_bytes_, shape, dtype, device, std::move(bind_fn), true);
4238
}
4339

44-
/**
45-
* @brief Register a buffer pinned at a fixed byte offset.
46-
*
47-
* @param name Unique cache key; duplicate keys share one slot.
48-
* @param offset Byte offset in scratch_buffer (currently only 0 is supported).
49-
* @param shape Tensor shape for the bound view.
50-
* @param dtype Element type of the bound view.
51-
* @param device Device on which scratch_buffer is allocated.
52-
* @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view.
53-
*/
40+
/** @brief Register a pinned@0 slot (only offset==0). May overlap bump slots. */
5441
void register_buffer(const std::string &name,
5542
size_t offset,
5643
const infinicore::Shape &shape,
5744
const infinicore::DataType &dtype,
5845
const infinicore::Device &device,
5946
BindFn bind_fn) {
6047
ASSERT(0 == offset);
61-
register_buffer_impl(name, offset, shape, dtype, device, std::move(bind_fn), false);
48+
_register_buffer_impl(name, offset, shape, dtype, device, std::move(bind_fn), false);
6249
}
6350

64-
/**
65-
* @brief Allocate scratch_buffer and run all registered bind callbacks.
66-
*
67-
* @param device Device on which scratch_buffer is allocated.
68-
*/
69-
void finalize_and_bind(const infinicore::Device &device) {
51+
/** @brief Allocate scratch_buffer_ and run bind callbacks. */
52+
void finalize_and_bind() {
7053
ASSERT(!finalized_);
7154
if (total_bytes_ == 0) {
7255
finalized_ = true;
7356
return;
7457
}
7558

76-
ASSERT(device.getType() != infinicore::Device::Type::CPU);
59+
auto &rank_device = get_tensor_model_parallel_rank_info().device;
7760

78-
scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, device);
61+
scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, rank_device);
7962

8063
spdlog::info("WorkspaceManager: finalize_and_bind {:.3f} MB", total_bytes_ / 1024.0 / 1024.0);
8164

82-
for (auto &[name, reg] : registrations_) {
65+
for (auto &entry : registrations_) {
66+
auto &reg = entry.second;
8367
auto *base_ptr = scratch_buffer_->data() + reg.offset;
84-
auto view = infinicore::Tensor::from_blob(static_cast<void *>(base_ptr), reg.shape, reg.dtype, device);
85-
inference_buffers_[name] = view;
68+
ASSERT(rank_device == reg.device);
69+
reg.bound_view = infinicore::Tensor::from_blob(static_cast<void *>(base_ptr), reg.shape, reg.dtype, rank_device);
8670
for (auto &bind_fn : reg.bind_callbacks) {
87-
bind_fn(view);
71+
bind_fn(reg.bound_view);
8872
}
8973
}
9074

9175
finalized_ = true;
9276
}
9377

78+
/** @brief Log slot layout with memory ranges and overlap info. */
79+
void log_registrations() const {
80+
size_t total_callbacks = 0;
81+
for (const auto &entry : registrations_) {
82+
total_callbacks += entry.second.bind_callbacks.size();
83+
}
84+
85+
std::vector<std::string> names;
86+
names.reserve(registrations_.size());
87+
for (const auto &entry : registrations_) {
88+
names.push_back(entry.first);
89+
}
90+
std::sort(names.begin(), names.end(), [this](const std::string &a, const std::string &b) {
91+
return registrations_.at(a).offset < registrations_.at(b).offset;
92+
});
93+
94+
std::ostringstream oss;
95+
oss << std::fixed << std::setprecision(3);
96+
oss << "\n========== WorkspaceManager registrations ==========\n";
97+
oss << " " << std::setw(16) << std::left << "finalized:" << finalized_ << "\n";
98+
oss << " " << std::setw(16) << std::left << "slots:" << registrations_.size() << "\n";
99+
oss << " " << std::setw(16) << std::left << "bind_callbacks:" << total_callbacks << "\n";
100+
oss << " " << std::setw(16) << std::left << "scratch_bytes:"
101+
<< total_bytes_ << " (" << (total_bytes_ / 1024.0 / 1024.0) << " MB)\n";
102+
oss << " note: scratch_bytes=max span; slots may overlap (temporal reuse).\n";
103+
oss << "----------------------------------------------------\n";
104+
105+
auto memory_end = [](const BufferRegistration &reg) {
106+
return reg.offset + reg.aligned_bytes;
107+
};
108+
auto ranges_overlap = [](size_t a_start, size_t a_end, size_t b_start, size_t b_end) {
109+
return a_start < b_end && b_start < a_end;
110+
};
111+
112+
for (size_t slot_idx = 0; slot_idx < names.size(); ++slot_idx) {
113+
const auto &name = names[slot_idx];
114+
const auto &reg = registrations_.at(name);
115+
const size_t mem_start = reg.offset;
116+
const size_t mem_end = memory_end(reg);
117+
118+
std::string shape_str = "[";
119+
for (size_t i = 0; i < reg.shape.size(); ++i) {
120+
if (i > 0) {
121+
shape_str += ", ";
122+
}
123+
shape_str += std::to_string(reg.shape[i]);
124+
}
125+
shape_str += "]";
126+
127+
std::string overlap_str = "none";
128+
{
129+
std::ostringstream overlap_oss;
130+
bool first = true;
131+
for (size_t other_idx = 0; other_idx < names.size(); ++other_idx) {
132+
if (other_idx == slot_idx) {
133+
continue;
134+
}
135+
const auto &other = registrations_.at(names[other_idx]);
136+
if (ranges_overlap(mem_start, mem_end, other.offset, memory_end(other))) {
137+
if (!first) {
138+
overlap_oss << ", ";
139+
}
140+
overlap_oss << "slot " << other_idx;
141+
first = false;
142+
}
143+
}
144+
if (!first) {
145+
overlap_str = overlap_oss.str();
146+
}
147+
}
148+
149+
oss << " [slot " << slot_idx << "]\n";
150+
oss << " " << std::setw(16) << std::left << "layout:"
151+
<< (reg.is_bump_tail ? "bump" : "pinned@0") << "\n";
152+
oss << " " << std::setw(16) << std::left << "memory:"
153+
<< "[" << mem_start << ", " << mem_end << ") "
154+
<< "(" << (reg.aligned_bytes / 1024.0 / 1024.0) << " MB)\n";
155+
oss << " " << std::setw(16) << std::left << "overlaps:" << overlap_str << "\n";
156+
oss << " " << std::setw(16) << std::left << "name:" << name << "\n";
157+
oss << " " << std::setw(16) << std::left << "shape:" << shape_str << "\n";
158+
oss << " " << std::setw(16) << std::left << "dtype:" << infinicore::toString(reg.dtype) << "\n";
159+
oss << " " << std::setw(16) << std::left << "device:" << reg.device.toString() << "\n";
160+
oss << " " << std::setw(16) << std::left << "bind_callbacks:" << reg.bind_callbacks.size() << "\n";
161+
oss << " " << std::setw(16) << std::left << "bound:" << finalized_ << "\n";
162+
if (slot_idx + 1 < names.size()) {
163+
oss << "\n";
164+
}
165+
}
166+
oss << "====================================================\n";
167+
168+
spdlog::info("{}", oss.str());
169+
}
170+
94171
private:
95-
/** @brief Metadata for one registered region in scratch_buffer. */
172+
/** @brief Metadata for one registered view into scratch_buffer_. */
96173
struct BufferRegistration {
97-
size_t offset{0};
98-
size_t aligned_bytes{0};
99-
infinicore::Shape shape;
100-
infinicore::DataType dtype;
101-
infinicore::Device device;
102-
std::vector<BindFn> bind_callbacks;
174+
size_t offset{0}; // view start in scratch_buffer_ (not a unique partition id)
175+
size_t aligned_bytes{0}; // view span after alignment; used for scratch size accounting
176+
bool is_bump_tail{true}; // true=bump tail slot; false=pinned@0 slot
177+
infinicore::Shape shape; // shape of the bound inference view
178+
infinicore::DataType dtype; // element type of the bound inference view
179+
infinicore::Device device; // device passed at registration (must match rank device)
180+
infinicore::Tensor bound_view; // view into scratch_buffer_; valid after finalize_and_bind
181+
std::vector<BindFn> bind_callbacks; // callbacks that bind module tensors to bound_view
103182
};
104183

105-
void register_buffer_impl(const std::string &name,
106-
size_t offset,
107-
const infinicore::Shape &shape,
108-
const infinicore::DataType &dtype,
109-
const infinicore::Device &device,
110-
BindFn bind_fn,
111-
bool bump_tail) {
184+
void _register_buffer_impl(const std::string &name,
185+
size_t offset,
186+
const infinicore::Shape &shape,
187+
const infinicore::DataType &dtype,
188+
const infinicore::Device &device,
189+
BindFn bind_fn,
190+
bool bump_tail) {
112191
ASSERT(!finalized_);
113-
ASSERT(device.getType() != infinicore::Device::Type::CPU);
192+
ASSERT(device == get_tensor_model_parallel_rank_info().device);
114193

115194
auto compute_numel = [](const infinicore::Shape &shape) {
116195
size_t numel = 1;
@@ -131,6 +210,7 @@ class WorkspaceManager {
131210
BufferRegistration reg;
132211
reg.offset = offset;
133212
reg.aligned_bytes = aligned_bytes;
213+
reg.is_bump_tail = bump_tail;
134214
reg.shape = shape;
135215
reg.dtype = dtype;
136216
reg.device = device;
@@ -144,6 +224,7 @@ class WorkspaceManager {
144224
}
145225

146226
auto &reg = registrations_.at(name);
227+
ASSERT(reg.is_bump_tail == bump_tail);
147228
ASSERT(reg.aligned_bytes == aligned_bytes);
148229
ASSERT(reg.shape == shape);
149230
ASSERT(reg.dtype == dtype);
@@ -155,7 +236,6 @@ class WorkspaceManager {
155236
bool finalized_{false};
156237
infinicore::Tensor scratch_buffer_;
157238
std::unordered_map<std::string, BufferRegistration> registrations_;
158-
std::unordered_map<std::string, infinicore::Tensor> inference_buffers_;
159239
};
160240

161241
}; // namespace infinilm::global_state

csrc/layers/mlp/mlp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ void MLP::_register_inference_buffer() {
8787
+ infinicore::toString(dtype_) + "_device_"
8888
+ device_.toString();
8989

90-
auto align_up = [](size_t n, size_t alignment = 256) {
90+
auto align_up = [](size_t n, size_t alignment = 512) {
9191
return (n + alignment - 1) & ~(alignment - 1);
9292
};
9393

0 commit comments

Comments
 (0)