22
33#include " ../models/infinilm_model.hpp"
44#include " ../utils.hpp"
5+ #include " parallel_state.hpp"
56#include < algorithm>
6- #include < cstdio>
77#include < functional>
8+ #include < iomanip>
9+ #include < sstream>
810#include < string>
911#include < unordered_map>
1012#include < vector>
1113
1214namespace infinilm ::global_state {
1315
1416/* *
15- * @brief Unified GPU inference workspace manager .
17+ * @brief Unified GPU inference scratch buffer .
1618 *
17- * Phase 1: modules register buffer layouts via ``register_buffer``.
18- * Phase 2/3: ``finalize_and_bind`` allocates ``scratch_buffer_`` and binds views.
19+ * Flow: register_buffer -> finalize_and_bind -> log_registrations (optional).
20+ * Layout: bump (tail append) or pinned@0 (offset fixed at 0).
21+ * Slots may overlap; scratch_bytes is max span, not sum of slots. Safe use requires
22+ * temporal reuse across forward phases.
1923 */
2024class WorkspaceManager {
2125public:
@@ -24,93 +28,168 @@ class WorkspaceManager {
2428 WorkspaceManager () = default ;
2529 ~WorkspaceManager () = default ;
2630
27- /* *
28- * @brief Register a buffer appended at the current scratch_buffer tail.
29- *
30- * @param name Unique cache key; duplicate keys share one slot.
31- * @param shape Tensor shape for the bound view.
32- * @param dtype Element type of the bound view.
33- * @param device Device on which scratch_buffer is allocated.
34- * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view.
35- */
31+ /* * @brief Register a bump slot at current total_bytes_. Same name reuses one slot. */
3632 void register_buffer (const std::string &name,
3733 const infinicore::Shape &shape,
3834 const infinicore::DataType &dtype,
3935 const infinicore::Device &device,
4036 BindFn bind_fn) {
41- register_buffer_impl (name, total_bytes_, shape, dtype, device, std::move (bind_fn), true );
37+ _register_buffer_impl (name, total_bytes_, shape, dtype, device, std::move (bind_fn), true );
4238 }
4339
44- /* *
45- * @brief Register a buffer pinned at a fixed byte offset.
46- *
47- * @param name Unique cache key; duplicate keys share one slot.
48- * @param offset Byte offset in scratch_buffer (currently only 0 is supported).
49- * @param shape Tensor shape for the bound view.
50- * @param dtype Element type of the bound view.
51- * @param device Device on which scratch_buffer is allocated.
52- * @param bind_fn Callback invoked in ``finalize_and_bind`` with the bound view.
53- */
40+ /* * @brief Register a pinned@0 slot (only offset==0). May overlap bump slots. */
5441 void register_buffer (const std::string &name,
5542 size_t offset,
5643 const infinicore::Shape &shape,
5744 const infinicore::DataType &dtype,
5845 const infinicore::Device &device,
5946 BindFn bind_fn) {
6047 ASSERT (0 == offset);
61- register_buffer_impl (name, offset, shape, dtype, device, std::move (bind_fn), false );
48+ _register_buffer_impl (name, offset, shape, dtype, device, std::move (bind_fn), false );
6249 }
6350
64- /* *
65- * @brief Allocate scratch_buffer and run all registered bind callbacks.
66- *
67- * @param device Device on which scratch_buffer is allocated.
68- */
69- void finalize_and_bind (const infinicore::Device &device) {
51+ /* * @brief Allocate scratch_buffer_ and run bind callbacks. */
52+ void finalize_and_bind () {
7053 ASSERT (!finalized_);
7154 if (total_bytes_ == 0 ) {
7255 finalized_ = true ;
7356 return ;
7457 }
7558
76- ASSERT (device. getType () != infinicore::Device::Type:: CPU ) ;
59+ auto &rank_device = get_tensor_model_parallel_rank_info (). device ;
7760
78- scratch_buffer_ = infinicore::Tensor::empty ({total_bytes_}, infinicore::DataType::U8 , device );
61+ scratch_buffer_ = infinicore::Tensor::empty ({total_bytes_}, infinicore::DataType::U8 , rank_device );
7962
8063 spdlog::info (" WorkspaceManager: finalize_and_bind {:.3f} MB" , total_bytes_ / 1024.0 / 1024.0 );
8164
82- for (auto &[name, reg] : registrations_) {
65+ for (auto &entry : registrations_) {
66+ auto ® = entry.second ;
8367 auto *base_ptr = scratch_buffer_->data () + reg.offset ;
84- auto view = infinicore::Tensor::from_blob ( static_cast < void *>(base_ptr), reg.shape , reg. dtype , device);
85- inference_buffers_[name] = view ;
68+ ASSERT (rank_device == reg.device );
69+ reg. bound_view = infinicore::Tensor::from_blob ( static_cast < void *>(base_ptr), reg. shape , reg. dtype , rank_device) ;
8670 for (auto &bind_fn : reg.bind_callbacks ) {
87- bind_fn (view );
71+ bind_fn (reg. bound_view );
8872 }
8973 }
9074
9175 finalized_ = true ;
9276 }
9377
78+ /* * @brief Log slot layout with memory ranges and overlap info. */
79+ void log_registrations () const {
80+ size_t total_callbacks = 0 ;
81+ for (const auto &entry : registrations_) {
82+ total_callbacks += entry.second .bind_callbacks .size ();
83+ }
84+
85+ std::vector<std::string> names;
86+ names.reserve (registrations_.size ());
87+ for (const auto &entry : registrations_) {
88+ names.push_back (entry.first );
89+ }
90+ std::sort (names.begin (), names.end (), [this ](const std::string &a, const std::string &b) {
91+ return registrations_.at (a).offset < registrations_.at (b).offset ;
92+ });
93+
94+ std::ostringstream oss;
95+ oss << std::fixed << std::setprecision (3 );
96+ oss << " \n ========== WorkspaceManager registrations ==========\n " ;
97+ oss << " " << std::setw (16 ) << std::left << " finalized:" << finalized_ << " \n " ;
98+ oss << " " << std::setw (16 ) << std::left << " slots:" << registrations_.size () << " \n " ;
99+ oss << " " << std::setw (16 ) << std::left << " bind_callbacks:" << total_callbacks << " \n " ;
100+ oss << " " << std::setw (16 ) << std::left << " scratch_bytes:"
101+ << total_bytes_ << " (" << (total_bytes_ / 1024.0 / 1024.0 ) << " MB)\n " ;
102+ oss << " note: scratch_bytes=max span; slots may overlap (temporal reuse).\n " ;
103+ oss << " ----------------------------------------------------\n " ;
104+
105+ auto memory_end = [](const BufferRegistration ®) {
106+ return reg.offset + reg.aligned_bytes ;
107+ };
108+ auto ranges_overlap = [](size_t a_start, size_t a_end, size_t b_start, size_t b_end) {
109+ return a_start < b_end && b_start < a_end;
110+ };
111+
112+ for (size_t slot_idx = 0 ; slot_idx < names.size (); ++slot_idx) {
113+ const auto &name = names[slot_idx];
114+ const auto ® = registrations_.at (name);
115+ const size_t mem_start = reg.offset ;
116+ const size_t mem_end = memory_end (reg);
117+
118+ std::string shape_str = " [" ;
119+ for (size_t i = 0 ; i < reg.shape .size (); ++i) {
120+ if (i > 0 ) {
121+ shape_str += " , " ;
122+ }
123+ shape_str += std::to_string (reg.shape [i]);
124+ }
125+ shape_str += " ]" ;
126+
127+ std::string overlap_str = " none" ;
128+ {
129+ std::ostringstream overlap_oss;
130+ bool first = true ;
131+ for (size_t other_idx = 0 ; other_idx < names.size (); ++other_idx) {
132+ if (other_idx == slot_idx) {
133+ continue ;
134+ }
135+ const auto &other = registrations_.at (names[other_idx]);
136+ if (ranges_overlap (mem_start, mem_end, other.offset , memory_end (other))) {
137+ if (!first) {
138+ overlap_oss << " , " ;
139+ }
140+ overlap_oss << " slot " << other_idx;
141+ first = false ;
142+ }
143+ }
144+ if (!first) {
145+ overlap_str = overlap_oss.str ();
146+ }
147+ }
148+
149+ oss << " [slot " << slot_idx << " ]\n " ;
150+ oss << " " << std::setw (16 ) << std::left << " layout:"
151+ << (reg.is_bump_tail ? " bump" : " pinned@0" ) << " \n " ;
152+ oss << " " << std::setw (16 ) << std::left << " memory:"
153+ << " [" << mem_start << " , " << mem_end << " ) "
154+ << " (" << (reg.aligned_bytes / 1024.0 / 1024.0 ) << " MB)\n " ;
155+ oss << " " << std::setw (16 ) << std::left << " overlaps:" << overlap_str << " \n " ;
156+ oss << " " << std::setw (16 ) << std::left << " name:" << name << " \n " ;
157+ oss << " " << std::setw (16 ) << std::left << " shape:" << shape_str << " \n " ;
158+ oss << " " << std::setw (16 ) << std::left << " dtype:" << infinicore::toString (reg.dtype ) << " \n " ;
159+ oss << " " << std::setw (16 ) << std::left << " device:" << reg.device .toString () << " \n " ;
160+ oss << " " << std::setw (16 ) << std::left << " bind_callbacks:" << reg.bind_callbacks .size () << " \n " ;
161+ oss << " " << std::setw (16 ) << std::left << " bound:" << finalized_ << " \n " ;
162+ if (slot_idx + 1 < names.size ()) {
163+ oss << " \n " ;
164+ }
165+ }
166+ oss << " ====================================================\n " ;
167+
168+ spdlog::info (" {}" , oss.str ());
169+ }
170+
94171private:
95- /* * @brief Metadata for one registered region in scratch_buffer . */
172+ /* * @brief Metadata for one registered view into scratch_buffer_ . */
96173 struct BufferRegistration {
97- size_t offset{0 };
98- size_t aligned_bytes{0 };
99- infinicore::Shape shape;
100- infinicore::DataType dtype;
101- infinicore::Device device;
102- std::vector<BindFn> bind_callbacks;
174+ size_t offset{0 }; // view start in scratch_buffer_ (not a unique partition id)
175+ size_t aligned_bytes{0 }; // view span after alignment; used for scratch size accounting
176+ bool is_bump_tail{true }; // true=bump tail slot; false=pinned@0 slot
177+ infinicore::Shape shape; // shape of the bound inference view
178+ infinicore::DataType dtype; // element type of the bound inference view
179+ infinicore::Device device; // device passed at registration (must match rank device)
180+ infinicore::Tensor bound_view; // view into scratch_buffer_; valid after finalize_and_bind
181+ std::vector<BindFn> bind_callbacks; // callbacks that bind module tensors to bound_view
103182 };
104183
105- void register_buffer_impl (const std::string &name,
106- size_t offset,
107- const infinicore::Shape &shape,
108- const infinicore::DataType &dtype,
109- const infinicore::Device &device,
110- BindFn bind_fn,
111- bool bump_tail) {
184+ void _register_buffer_impl (const std::string &name,
185+ size_t offset,
186+ const infinicore::Shape &shape,
187+ const infinicore::DataType &dtype,
188+ const infinicore::Device &device,
189+ BindFn bind_fn,
190+ bool bump_tail) {
112191 ASSERT (!finalized_);
113- ASSERT (device. getType () != infinicore::Device::Type:: CPU );
192+ ASSERT (device == get_tensor_model_parallel_rank_info (). device );
114193
115194 auto compute_numel = [](const infinicore::Shape &shape) {
116195 size_t numel = 1 ;
@@ -131,6 +210,7 @@ class WorkspaceManager {
131210 BufferRegistration reg;
132211 reg.offset = offset;
133212 reg.aligned_bytes = aligned_bytes;
213+ reg.is_bump_tail = bump_tail;
134214 reg.shape = shape;
135215 reg.dtype = dtype;
136216 reg.device = device;
@@ -144,6 +224,7 @@ class WorkspaceManager {
144224 }
145225
146226 auto ® = registrations_.at (name);
227+ ASSERT (reg.is_bump_tail == bump_tail);
147228 ASSERT (reg.aligned_bytes == aligned_bytes);
148229 ASSERT (reg.shape == shape);
149230 ASSERT (reg.dtype == dtype);
@@ -155,7 +236,6 @@ class WorkspaceManager {
155236 bool finalized_{false };
156237 infinicore::Tensor scratch_buffer_;
157238 std::unordered_map<std::string, BufferRegistration> registrations_;
158- std::unordered_map<std::string, infinicore::Tensor> inference_buffers_;
159239};
160240
161241}; // namespace infinilm::global_state
0 commit comments