Skip to content

Commit 6d1fa23

Browse files
author
wangpengcheng
committed
issue/407 - refine the code
1 parent 2aea7df commit 6d1fa23

14 files changed

Lines changed: 372 additions & 210 deletions

File tree

csrc/engine/rank_worker.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,8 @@ void RankWorker::thread_loop() {
282282
infinicore::context::syncStream();
283283

284284
if (infinilm_config_->enable_workspace_manager) {
285-
forward_context_.workspace_manager.finalize_and_bind(rank_info_.device);
285+
forward_context_.workspace_manager.finalize_and_bind();
286+
// forward_context_.workspace_manager.log_registrations();
286287
}
287288
infinicore::context::syncStream();
288289

@@ -402,6 +403,7 @@ void RankWorker::thread_loop() {
402403
try {
403404
{
404405
std::lock_guard<std::mutex> lk(mutex_);
406+
infinilm::global_state::get_forward_context().workspace_manager.reset_runtime_buffers();
405407

406408
infinicore::Tensor logits;
407409
// Try to get compiled graph
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
#include "workspace_manager.hpp"
2+
3+
#include "../utils.hpp"
4+
#include "parallel_state.hpp"
5+
6+
#include <algorithm>
7+
#include <iomanip>
8+
#include <sstream>
9+
10+
namespace infinilm::global_state {
11+
12+
namespace {
13+
14+
constexpr size_t k_scratch_align_bytes = 512;
15+
16+
size_t compute_numel(const infinicore::Shape &shape) {
17+
size_t numel = 1;
18+
for (const auto dim : shape) {
19+
numel *= dim;
20+
}
21+
return numel;
22+
}
23+
24+
size_t align_up(size_t n, size_t alignment = k_scratch_align_bytes) {
25+
return (n + alignment - 1) & ~(alignment - 1);
26+
}
27+
28+
size_t compute_aligned_bytes(const infinicore::Shape &shape, const infinicore::DataType &dtype) {
29+
return align_up(compute_numel(shape) * infinicore::dsize(dtype));
30+
}
31+
32+
} // namespace
33+
34+
void WorkspaceManager::register_buffer(const std::string &name,
35+
const infinicore::Shape &shape,
36+
const infinicore::DataType &dtype,
37+
const infinicore::Device &device) {
38+
_register_buffer_impl(name, total_bytes_, shape, dtype, device, true);
39+
}
40+
41+
void WorkspaceManager::register_buffer(const std::string &name,
42+
size_t offset,
43+
const infinicore::Shape &shape,
44+
const infinicore::DataType &dtype,
45+
const infinicore::Device &device) {
46+
ASSERT(0 == offset);
47+
_register_buffer_impl(name, offset, shape, dtype, device, false);
48+
}
49+
50+
infinicore::Tensor WorkspaceManager::_make_runtime_view(size_t offset,
51+
const infinicore::Shape &shape,
52+
const infinicore::DataType &dtype,
53+
const infinicore::Device &device) {
54+
auto *base_ptr = scratch_buffer_->data() + offset;
55+
return infinicore::Tensor::from_blob(static_cast<void *>(base_ptr), shape, dtype, device);
56+
}
57+
58+
infinicore::Tensor WorkspaceManager::get_buffer(const std::string &buffer_name,
59+
const infinicore::Shape &shape,
60+
const infinicore::DataType &dtype,
61+
const infinicore::Device &device) {
62+
ASSERT(finalized_);
63+
ASSERT(!scratch_buffer_.empty());
64+
65+
auto cached = runtime_buffers_.find(buffer_name);
66+
if (cached != runtime_buffers_.end()) {
67+
return cached->second;
68+
}
69+
70+
auto &rank_device = get_tensor_model_parallel_rank_info().device;
71+
const size_t aligned_bytes = compute_aligned_bytes(shape, dtype);
72+
73+
auto registered = registrations_.find(buffer_name);
74+
if (registered != registrations_.end()) {
75+
const auto &reg = registered->second;
76+
auto tensor = _make_runtime_view(reg.offset, shape, dtype, rank_device);
77+
runtime_buffers_.emplace(buffer_name, tensor);
78+
return tensor;
79+
}
80+
81+
const size_t offset = scratch_buffer_offset_;
82+
ASSERT(offset + aligned_bytes <= total_bytes_);
83+
84+
auto tensor = _make_runtime_view(offset, shape, dtype, rank_device);
85+
runtime_buffers_.emplace(buffer_name, tensor);
86+
scratch_buffer_offset_ += aligned_bytes;
87+
return tensor;
88+
}
89+
90+
infinicore::Tensor WorkspaceManager::get_buffer(const std::string &buffer_name,
91+
size_t offset,
92+
const infinicore::Shape &shape,
93+
const infinicore::DataType &dtype,
94+
const infinicore::Device &device) {
95+
ASSERT(finalized_);
96+
ASSERT(!scratch_buffer_.empty());
97+
98+
auto cached = runtime_buffers_.find(buffer_name);
99+
if (cached != runtime_buffers_.end()) {
100+
return cached->second;
101+
}
102+
103+
auto &rank_device = get_tensor_model_parallel_rank_info().device;
104+
const size_t aligned_bytes = compute_aligned_bytes(shape, dtype);
105+
106+
auto registered = registrations_.find(buffer_name);
107+
if (registered != registrations_.end()) {
108+
const auto &reg = registered->second;
109+
auto tensor = _make_runtime_view(reg.offset, shape, dtype, rank_device);
110+
runtime_buffers_.emplace(buffer_name, tensor);
111+
return tensor;
112+
}
113+
114+
ASSERT(offset + aligned_bytes <= total_bytes_);
115+
116+
auto tensor = _make_runtime_view(offset, shape, dtype, rank_device);
117+
runtime_buffers_.emplace(buffer_name, tensor);
118+
return tensor;
119+
}
120+
121+
void WorkspaceManager::reset_runtime_buffers() {
122+
ASSERT(finalized_);
123+
scratch_buffer_offset_ = 0;
124+
runtime_buffers_.clear();
125+
}
126+
127+
void WorkspaceManager::finalize_and_bind() {
128+
ASSERT(!finalized_);
129+
runtime_buffers_.clear();
130+
scratch_buffer_offset_ = 0;
131+
132+
if (total_bytes_ == 0) {
133+
finalized_ = true;
134+
return;
135+
}
136+
137+
auto &rank_device = get_tensor_model_parallel_rank_info().device;
138+
139+
scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, rank_device);
140+
141+
spdlog::info("WorkspaceManager: finalize_and_bind {:.3f} MB", total_bytes_ / 1024.0 / 1024.0);
142+
143+
for (auto &entry : registrations_) {
144+
auto &reg = entry.second;
145+
auto *base_ptr = scratch_buffer_->data() + reg.offset;
146+
ASSERT(rank_device == reg.device);
147+
reg.bound_view = infinicore::Tensor::from_blob(static_cast<void *>(base_ptr), reg.shape, reg.dtype, rank_device);
148+
}
149+
150+
scratch_buffer_offset_ = 0;
151+
finalized_ = true;
152+
}
153+
154+
void WorkspaceManager::log_registrations() const {
155+
std::vector<std::string> names;
156+
names.reserve(registrations_.size());
157+
for (const auto &entry : registrations_) {
158+
names.push_back(entry.first);
159+
}
160+
std::sort(names.begin(), names.end(), [this](const std::string &a, const std::string &b) {
161+
return registrations_.at(a).offset < registrations_.at(b).offset;
162+
});
163+
164+
std::ostringstream oss;
165+
oss << std::fixed << std::setprecision(3);
166+
oss << "\n========== WorkspaceManager registrations ==========\n";
167+
oss << " " << std::setw(16) << std::left << "finalized:" << finalized_ << "\n";
168+
oss << " " << std::setw(16) << std::left << "slots:" << registrations_.size() << "\n";
169+
oss << " " << std::setw(16) << std::left << "runtime_buffers:" << runtime_buffers_.size() << "\n";
170+
oss << " " << std::setw(16) << std::left << "scratch_bytes:"
171+
<< total_bytes_ << " (" << (total_bytes_ / 1024.0 / 1024.0) << " MB)\n";
172+
oss << " " << std::setw(16) << std::left << "scratch_buffer_offset_:"
173+
<< scratch_buffer_offset_ << " (" << (scratch_buffer_offset_ / 1024.0 / 1024.0) << " MB)\n";
174+
oss << " note: scratch_bytes=max span; registered slots may overlap.\n";
175+
oss << "----------------------------------------------------\n";
176+
177+
auto memory_end = [](const BufferRegistration &reg) {
178+
return reg.offset + reg.aligned_bytes;
179+
};
180+
auto ranges_overlap = [](size_t a_start, size_t a_end, size_t b_start, size_t b_end) {
181+
return a_start < b_end && b_start < a_end;
182+
};
183+
184+
for (size_t slot_idx = 0; slot_idx < names.size(); ++slot_idx) {
185+
const auto &name = names[slot_idx];
186+
const auto &reg = registrations_.at(name);
187+
const size_t mem_start = reg.offset;
188+
const size_t mem_end = memory_end(reg);
189+
190+
std::string shape_str = "[";
191+
for (size_t i = 0; i < reg.shape.size(); ++i) {
192+
if (i > 0) {
193+
shape_str += ", ";
194+
}
195+
shape_str += std::to_string(reg.shape[i]);
196+
}
197+
shape_str += "]";
198+
199+
std::string overlap_str = "none";
200+
{
201+
std::ostringstream overlap_oss;
202+
bool first = true;
203+
for (size_t other_idx = 0; other_idx < names.size(); ++other_idx) {
204+
if (other_idx == slot_idx) {
205+
continue;
206+
}
207+
const auto &other = registrations_.at(names[other_idx]);
208+
if (ranges_overlap(mem_start, mem_end, other.offset, memory_end(other))) {
209+
if (!first) {
210+
overlap_oss << ", ";
211+
}
212+
overlap_oss << "slot " << other_idx;
213+
first = false;
214+
}
215+
}
216+
if (!first) {
217+
overlap_str = overlap_oss.str();
218+
}
219+
}
220+
221+
oss << " [slot " << slot_idx << "]\n";
222+
oss << " " << std::setw(16) << std::left << "layout:"
223+
<< (reg.is_bump_tail ? "bump" : "pinned@0") << "\n";
224+
oss << " " << std::setw(16) << std::left << "memory:"
225+
<< "[" << mem_start << ", " << mem_end << ") "
226+
<< "(" << (reg.aligned_bytes / 1024.0 / 1024.0) << " MB)\n";
227+
oss << " " << std::setw(16) << std::left << "overlaps:" << overlap_str << "\n";
228+
oss << " " << std::setw(16) << std::left << "name:" << name << "\n";
229+
oss << " " << std::setw(16) << std::left << "shape:" << shape_str << "\n";
230+
oss << " " << std::setw(16) << std::left << "dtype:" << infinicore::toString(reg.dtype) << "\n";
231+
oss << " " << std::setw(16) << std::left << "device:" << reg.device.toString() << "\n";
232+
oss << " " << std::setw(16) << std::left << "bound:" << finalized_ << "\n";
233+
if (slot_idx + 1 < names.size()) {
234+
oss << "\n";
235+
}
236+
}
237+
oss << "====================================================\n";
238+
239+
spdlog::info("{}", oss.str());
240+
}
241+
242+
void WorkspaceManager::_register_buffer_impl(const std::string &name,
243+
size_t offset,
244+
const infinicore::Shape &shape,
245+
const infinicore::DataType &dtype,
246+
const infinicore::Device &device,
247+
bool bump_tail) {
248+
ASSERT(!finalized_);
249+
ASSERT(device == get_tensor_model_parallel_rank_info().device);
250+
251+
const size_t aligned_bytes = compute_aligned_bytes(shape, dtype);
252+
253+
if (registrations_.find(name) == registrations_.end()) {
254+
BufferRegistration reg;
255+
reg.offset = offset;
256+
reg.aligned_bytes = aligned_bytes;
257+
reg.is_bump_tail = bump_tail;
258+
reg.shape = shape;
259+
reg.dtype = dtype;
260+
reg.device = device;
261+
262+
if (bump_tail) {
263+
total_bytes_ += aligned_bytes;
264+
} else {
265+
total_bytes_ = std::max(total_bytes_, offset + aligned_bytes);
266+
}
267+
registrations_.emplace(name, std::move(reg));
268+
}
269+
270+
auto &reg = registrations_.at(name);
271+
ASSERT(reg.is_bump_tail == bump_tail);
272+
ASSERT(reg.aligned_bytes == aligned_bytes);
273+
ASSERT(reg.shape == shape);
274+
ASSERT(reg.dtype == dtype);
275+
ASSERT(reg.device == device);
276+
}
277+
278+
} // namespace infinilm::global_state

0 commit comments

Comments
 (0)