|
| 1 | +#include "workspace_manager.hpp" |
| 2 | + |
| 3 | +#include "../utils.hpp" |
| 4 | +#include "parallel_state.hpp" |
| 5 | + |
| 6 | +#include <algorithm> |
| 7 | +#include <iomanip> |
| 8 | +#include <sstream> |
| 9 | + |
| 10 | +namespace infinilm::global_state { |
| 11 | + |
| 12 | +namespace { |
| 13 | + |
| 14 | +constexpr size_t k_scratch_align_bytes = 512; |
| 15 | + |
| 16 | +size_t compute_numel(const infinicore::Shape &shape) { |
| 17 | + size_t numel = 1; |
| 18 | + for (const auto dim : shape) { |
| 19 | + numel *= dim; |
| 20 | + } |
| 21 | + return numel; |
| 22 | +} |
| 23 | + |
| 24 | +size_t align_up(size_t n, size_t alignment = k_scratch_align_bytes) { |
| 25 | + return (n + alignment - 1) & ~(alignment - 1); |
| 26 | +} |
| 27 | + |
| 28 | +size_t compute_aligned_bytes(const infinicore::Shape &shape, const infinicore::DataType &dtype) { |
| 29 | + return align_up(compute_numel(shape) * infinicore::dsize(dtype)); |
| 30 | +} |
| 31 | + |
| 32 | +} // namespace |
| 33 | + |
| 34 | +void WorkspaceManager::register_buffer(const std::string &name, |
| 35 | + const infinicore::Shape &shape, |
| 36 | + const infinicore::DataType &dtype, |
| 37 | + const infinicore::Device &device) { |
| 38 | + _register_buffer_impl(name, total_bytes_, shape, dtype, device, true); |
| 39 | +} |
| 40 | + |
| 41 | +void WorkspaceManager::register_buffer(const std::string &name, |
| 42 | + size_t offset, |
| 43 | + const infinicore::Shape &shape, |
| 44 | + const infinicore::DataType &dtype, |
| 45 | + const infinicore::Device &device) { |
| 46 | + ASSERT(0 == offset); |
| 47 | + _register_buffer_impl(name, offset, shape, dtype, device, false); |
| 48 | +} |
| 49 | + |
| 50 | +infinicore::Tensor WorkspaceManager::_make_runtime_view(size_t offset, |
| 51 | + const infinicore::Shape &shape, |
| 52 | + const infinicore::DataType &dtype, |
| 53 | + const infinicore::Device &device) { |
| 54 | + auto *base_ptr = scratch_buffer_->data() + offset; |
| 55 | + return infinicore::Tensor::from_blob(static_cast<void *>(base_ptr), shape, dtype, device); |
| 56 | +} |
| 57 | + |
| 58 | +infinicore::Tensor WorkspaceManager::get_buffer(const std::string &buffer_name, |
| 59 | + const infinicore::Shape &shape, |
| 60 | + const infinicore::DataType &dtype, |
| 61 | + const infinicore::Device &device) { |
| 62 | + ASSERT(finalized_); |
| 63 | + ASSERT(!scratch_buffer_.empty()); |
| 64 | + |
| 65 | + auto cached = runtime_buffers_.find(buffer_name); |
| 66 | + if (cached != runtime_buffers_.end()) { |
| 67 | + return cached->second; |
| 68 | + } |
| 69 | + |
| 70 | + auto &rank_device = get_tensor_model_parallel_rank_info().device; |
| 71 | + const size_t aligned_bytes = compute_aligned_bytes(shape, dtype); |
| 72 | + |
| 73 | + auto registered = registrations_.find(buffer_name); |
| 74 | + if (registered != registrations_.end()) { |
| 75 | + const auto ® = registered->second; |
| 76 | + auto tensor = _make_runtime_view(reg.offset, shape, dtype, rank_device); |
| 77 | + runtime_buffers_.emplace(buffer_name, tensor); |
| 78 | + return tensor; |
| 79 | + } |
| 80 | + |
| 81 | + const size_t offset = scratch_buffer_offset_; |
| 82 | + ASSERT(offset + aligned_bytes <= total_bytes_); |
| 83 | + |
| 84 | + auto tensor = _make_runtime_view(offset, shape, dtype, rank_device); |
| 85 | + runtime_buffers_.emplace(buffer_name, tensor); |
| 86 | + scratch_buffer_offset_ += aligned_bytes; |
| 87 | + return tensor; |
| 88 | +} |
| 89 | + |
| 90 | +infinicore::Tensor WorkspaceManager::get_buffer(const std::string &buffer_name, |
| 91 | + size_t offset, |
| 92 | + const infinicore::Shape &shape, |
| 93 | + const infinicore::DataType &dtype, |
| 94 | + const infinicore::Device &device) { |
| 95 | + ASSERT(finalized_); |
| 96 | + ASSERT(!scratch_buffer_.empty()); |
| 97 | + |
| 98 | + auto cached = runtime_buffers_.find(buffer_name); |
| 99 | + if (cached != runtime_buffers_.end()) { |
| 100 | + return cached->second; |
| 101 | + } |
| 102 | + |
| 103 | + auto &rank_device = get_tensor_model_parallel_rank_info().device; |
| 104 | + const size_t aligned_bytes = compute_aligned_bytes(shape, dtype); |
| 105 | + |
| 106 | + auto registered = registrations_.find(buffer_name); |
| 107 | + if (registered != registrations_.end()) { |
| 108 | + const auto ® = registered->second; |
| 109 | + auto tensor = _make_runtime_view(reg.offset, shape, dtype, rank_device); |
| 110 | + runtime_buffers_.emplace(buffer_name, tensor); |
| 111 | + return tensor; |
| 112 | + } |
| 113 | + |
| 114 | + ASSERT(offset + aligned_bytes <= total_bytes_); |
| 115 | + |
| 116 | + auto tensor = _make_runtime_view(offset, shape, dtype, rank_device); |
| 117 | + runtime_buffers_.emplace(buffer_name, tensor); |
| 118 | + return tensor; |
| 119 | +} |
| 120 | + |
| 121 | +void WorkspaceManager::reset_runtime_buffers() { |
| 122 | + ASSERT(finalized_); |
| 123 | + scratch_buffer_offset_ = 0; |
| 124 | + runtime_buffers_.clear(); |
| 125 | +} |
| 126 | + |
| 127 | +void WorkspaceManager::finalize_and_bind() { |
| 128 | + ASSERT(!finalized_); |
| 129 | + runtime_buffers_.clear(); |
| 130 | + scratch_buffer_offset_ = 0; |
| 131 | + |
| 132 | + if (total_bytes_ == 0) { |
| 133 | + finalized_ = true; |
| 134 | + return; |
| 135 | + } |
| 136 | + |
| 137 | + auto &rank_device = get_tensor_model_parallel_rank_info().device; |
| 138 | + |
| 139 | + scratch_buffer_ = infinicore::Tensor::empty({total_bytes_}, infinicore::DataType::U8, rank_device); |
| 140 | + |
| 141 | + spdlog::info("WorkspaceManager: finalize_and_bind {:.3f} MB", total_bytes_ / 1024.0 / 1024.0); |
| 142 | + |
| 143 | + for (auto &entry : registrations_) { |
| 144 | + auto ® = entry.second; |
| 145 | + auto *base_ptr = scratch_buffer_->data() + reg.offset; |
| 146 | + ASSERT(rank_device == reg.device); |
| 147 | + reg.bound_view = infinicore::Tensor::from_blob(static_cast<void *>(base_ptr), reg.shape, reg.dtype, rank_device); |
| 148 | + } |
| 149 | + |
| 150 | + scratch_buffer_offset_ = 0; |
| 151 | + finalized_ = true; |
| 152 | +} |
| 153 | + |
| 154 | +void WorkspaceManager::log_registrations() const { |
| 155 | + std::vector<std::string> names; |
| 156 | + names.reserve(registrations_.size()); |
| 157 | + for (const auto &entry : registrations_) { |
| 158 | + names.push_back(entry.first); |
| 159 | + } |
| 160 | + std::sort(names.begin(), names.end(), [this](const std::string &a, const std::string &b) { |
| 161 | + return registrations_.at(a).offset < registrations_.at(b).offset; |
| 162 | + }); |
| 163 | + |
| 164 | + std::ostringstream oss; |
| 165 | + oss << std::fixed << std::setprecision(3); |
| 166 | + oss << "\n========== WorkspaceManager registrations ==========\n"; |
| 167 | + oss << " " << std::setw(16) << std::left << "finalized:" << finalized_ << "\n"; |
| 168 | + oss << " " << std::setw(16) << std::left << "slots:" << registrations_.size() << "\n"; |
| 169 | + oss << " " << std::setw(16) << std::left << "runtime_buffers:" << runtime_buffers_.size() << "\n"; |
| 170 | + oss << " " << std::setw(16) << std::left << "scratch_bytes:" |
| 171 | + << total_bytes_ << " (" << (total_bytes_ / 1024.0 / 1024.0) << " MB)\n"; |
| 172 | + oss << " " << std::setw(16) << std::left << "scratch_buffer_offset_:" |
| 173 | + << scratch_buffer_offset_ << " (" << (scratch_buffer_offset_ / 1024.0 / 1024.0) << " MB)\n"; |
| 174 | + oss << " note: scratch_bytes=max span; registered slots may overlap.\n"; |
| 175 | + oss << "----------------------------------------------------\n"; |
| 176 | + |
| 177 | + auto memory_end = [](const BufferRegistration ®) { |
| 178 | + return reg.offset + reg.aligned_bytes; |
| 179 | + }; |
| 180 | + auto ranges_overlap = [](size_t a_start, size_t a_end, size_t b_start, size_t b_end) { |
| 181 | + return a_start < b_end && b_start < a_end; |
| 182 | + }; |
| 183 | + |
| 184 | + for (size_t slot_idx = 0; slot_idx < names.size(); ++slot_idx) { |
| 185 | + const auto &name = names[slot_idx]; |
| 186 | + const auto ® = registrations_.at(name); |
| 187 | + const size_t mem_start = reg.offset; |
| 188 | + const size_t mem_end = memory_end(reg); |
| 189 | + |
| 190 | + std::string shape_str = "["; |
| 191 | + for (size_t i = 0; i < reg.shape.size(); ++i) { |
| 192 | + if (i > 0) { |
| 193 | + shape_str += ", "; |
| 194 | + } |
| 195 | + shape_str += std::to_string(reg.shape[i]); |
| 196 | + } |
| 197 | + shape_str += "]"; |
| 198 | + |
| 199 | + std::string overlap_str = "none"; |
| 200 | + { |
| 201 | + std::ostringstream overlap_oss; |
| 202 | + bool first = true; |
| 203 | + for (size_t other_idx = 0; other_idx < names.size(); ++other_idx) { |
| 204 | + if (other_idx == slot_idx) { |
| 205 | + continue; |
| 206 | + } |
| 207 | + const auto &other = registrations_.at(names[other_idx]); |
| 208 | + if (ranges_overlap(mem_start, mem_end, other.offset, memory_end(other))) { |
| 209 | + if (!first) { |
| 210 | + overlap_oss << ", "; |
| 211 | + } |
| 212 | + overlap_oss << "slot " << other_idx; |
| 213 | + first = false; |
| 214 | + } |
| 215 | + } |
| 216 | + if (!first) { |
| 217 | + overlap_str = overlap_oss.str(); |
| 218 | + } |
| 219 | + } |
| 220 | + |
| 221 | + oss << " [slot " << slot_idx << "]\n"; |
| 222 | + oss << " " << std::setw(16) << std::left << "layout:" |
| 223 | + << (reg.is_bump_tail ? "bump" : "pinned@0") << "\n"; |
| 224 | + oss << " " << std::setw(16) << std::left << "memory:" |
| 225 | + << "[" << mem_start << ", " << mem_end << ") " |
| 226 | + << "(" << (reg.aligned_bytes / 1024.0 / 1024.0) << " MB)\n"; |
| 227 | + oss << " " << std::setw(16) << std::left << "overlaps:" << overlap_str << "\n"; |
| 228 | + oss << " " << std::setw(16) << std::left << "name:" << name << "\n"; |
| 229 | + oss << " " << std::setw(16) << std::left << "shape:" << shape_str << "\n"; |
| 230 | + oss << " " << std::setw(16) << std::left << "dtype:" << infinicore::toString(reg.dtype) << "\n"; |
| 231 | + oss << " " << std::setw(16) << std::left << "device:" << reg.device.toString() << "\n"; |
| 232 | + oss << " " << std::setw(16) << std::left << "bound:" << finalized_ << "\n"; |
| 233 | + if (slot_idx + 1 < names.size()) { |
| 234 | + oss << "\n"; |
| 235 | + } |
| 236 | + } |
| 237 | + oss << "====================================================\n"; |
| 238 | + |
| 239 | + spdlog::info("{}", oss.str()); |
| 240 | +} |
| 241 | + |
| 242 | +void WorkspaceManager::_register_buffer_impl(const std::string &name, |
| 243 | + size_t offset, |
| 244 | + const infinicore::Shape &shape, |
| 245 | + const infinicore::DataType &dtype, |
| 246 | + const infinicore::Device &device, |
| 247 | + bool bump_tail) { |
| 248 | + ASSERT(!finalized_); |
| 249 | + ASSERT(device == get_tensor_model_parallel_rank_info().device); |
| 250 | + |
| 251 | + const size_t aligned_bytes = compute_aligned_bytes(shape, dtype); |
| 252 | + |
| 253 | + if (registrations_.find(name) == registrations_.end()) { |
| 254 | + BufferRegistration reg; |
| 255 | + reg.offset = offset; |
| 256 | + reg.aligned_bytes = aligned_bytes; |
| 257 | + reg.is_bump_tail = bump_tail; |
| 258 | + reg.shape = shape; |
| 259 | + reg.dtype = dtype; |
| 260 | + reg.device = device; |
| 261 | + |
| 262 | + if (bump_tail) { |
| 263 | + total_bytes_ += aligned_bytes; |
| 264 | + } else { |
| 265 | + total_bytes_ = std::max(total_bytes_, offset + aligned_bytes); |
| 266 | + } |
| 267 | + registrations_.emplace(name, std::move(reg)); |
| 268 | + } |
| 269 | + |
| 270 | + auto ® = registrations_.at(name); |
| 271 | + ASSERT(reg.is_bump_tail == bump_tail); |
| 272 | + ASSERT(reg.aligned_bytes == aligned_bytes); |
| 273 | + ASSERT(reg.shape == shape); |
| 274 | + ASSERT(reg.dtype == dtype); |
| 275 | + ASSERT(reg.device == device); |
| 276 | +} |
| 277 | + |
| 278 | +} // namespace infinilm::global_state |
0 commit comments