|
5 | 5 | #include <cstring> |
6 | 6 | #include <fstream> |
7 | 7 | #include <iostream> |
| 8 | +#include <limits> |
| 9 | +#include <memory> |
8 | 10 | #include <sstream> |
9 | 11 |
|
10 | 12 | namespace infinicore { |
@@ -93,214 +95,135 @@ void print_data_bf16(const uint16_t *data, const Shape &shape, const Strides &st |
93 | 95 | } |
94 | 96 | } |
95 | 97 |
|
| 98 | +// Template function for writing data recursively to binary file (handles non-contiguous tensors) |
| 99 | +template <typename T> |
| 100 | +void write_binary_data(std::ofstream &out, const T *data, const Shape &shape, const Strides &strides, size_t dim) { |
| 101 | + if (dim == shape.size() - 1) { |
| 102 | + // Write the innermost dimension |
| 103 | + for (size_t i = 0; i < shape[dim]; i++) { |
| 104 | + out.write(reinterpret_cast<const char *>(&data[i * strides[dim]]), sizeof(T)); |
| 105 | + } |
| 106 | + } else { |
| 107 | + // Recursively process higher dimensions |
| 108 | + for (size_t i = 0; i < shape[dim]; i++) { |
| 109 | + write_binary_data(out, data + i * strides[dim], shape, strides, dim + 1); |
| 110 | + } |
| 111 | + } |
| 112 | +} |
| 113 | + |
96 | 114 | void TensorImpl::debug(const std::string &filename) const { |
97 | 115 | // Synchronize device if needed |
98 | 116 | context::syncDevice(); |
99 | 117 |
|
100 | 118 | std::cout << info() << std::endl; |
101 | 119 |
|
102 | 120 | const std::byte *cpu_data = nullptr; |
103 | | - std::byte *allocated_memory = nullptr; |
| 121 | + std::unique_ptr<std::byte[]> allocated_memory; // RAII: 自动管理内存 |
104 | 122 |
|
105 | 123 | // Copy data to CPU if not already on CPU |
106 | 124 | if (this->device().getType() != Device::Type::CPU) { |
107 | | - size_t mem_size = this->numel() * dsize(this->dtype()); |
108 | | - allocated_memory = new std::byte[mem_size]; |
109 | | - context::memcpyD2H(allocated_memory, this->data(), mem_size); |
110 | | - cpu_data = allocated_memory; |
| 125 | + size_t numel = this->numel(); |
| 126 | + size_t element_size = dsize(this->dtype()); |
| 127 | + |
| 128 | + // 检查乘法溢出 |
| 129 | + if (numel > 0 && element_size > std::numeric_limits<size_t>::max() / numel) { |
| 130 | + std::cerr << "Error: Memory size calculation overflow for tensor with " |
| 131 | + << numel << " elements of size " << element_size << "\n"; |
| 132 | + return; |
| 133 | + } |
| 134 | + |
| 135 | + size_t mem_size = numel * element_size; |
| 136 | + allocated_memory = std::make_unique<std::byte[]>(mem_size); |
| 137 | + context::memcpyD2H(allocated_memory.get(), this->data(), mem_size); |
| 138 | + cpu_data = allocated_memory.get(); |
111 | 139 | } else { |
112 | 140 | cpu_data = this->data(); |
113 | 141 | } |
114 | 142 |
|
115 | | - // If filename is provided, save to file |
| 143 | + // If filename is provided, save to binary file |
116 | 144 | if (!filename.empty()) { |
117 | | - // Determine file format based on extension |
118 | | - bool is_text_format = false; |
119 | | - size_t dot_pos = filename.find_last_of('.'); |
120 | | - if (dot_pos != std::string::npos) { |
121 | | - std::string ext = filename.substr(dot_pos); |
122 | | - is_text_format = (ext == ".txt"); |
| 145 | + std::ofstream outFile(filename, std::ios::binary); |
| 146 | + if (!outFile) { |
| 147 | + std::cerr << "Error opening file for writing: " << filename << "\n"; |
| 148 | + return; // allocated_memory 会自动释放(RAII) |
123 | 149 | } |
124 | 150 |
|
125 | | - if (is_text_format) { |
126 | | - // Save as text format |
127 | | - std::ofstream outFile(filename); |
128 | | - if (!outFile) { |
129 | | - std::cerr << "Error opening file for writing: " << filename << "\n"; |
130 | | - if (allocated_memory) { |
131 | | - delete[] allocated_memory; |
132 | | - } |
133 | | - return; |
134 | | - } |
135 | | - |
136 | | - // Write header with tensor information |
137 | | - outFile << "# Tensor Debug Output\n"; |
138 | | - outFile << "# Shape: ["; |
139 | | - for (size_t i = 0; i < this->shape().size(); ++i) { |
140 | | - outFile << this->shape()[i]; |
141 | | - if (i < this->shape().size() - 1) { |
142 | | - outFile << ", "; |
143 | | - } |
144 | | - } |
145 | | - outFile << "]\n"; |
146 | | - outFile << "# Strides: ["; |
147 | | - for (size_t i = 0; i < this->strides().size(); ++i) { |
148 | | - outFile << this->strides()[i]; |
149 | | - if (i < this->strides().size() - 1) { |
150 | | - outFile << ", "; |
151 | | - } |
152 | | - } |
153 | | - outFile << "]\n"; |
154 | | - outFile << "# Dtype: " << toString(this->dtype()) << "\n"; |
155 | | - outFile << "# Contiguous: " << (this->is_contiguous() ? "Yes" : "No") << "\n"; |
156 | | - outFile << "# Elements: " << this->numel() << "\n"; |
157 | | - outFile << "#\n"; |
158 | | - |
159 | | - // Helper function to write data recursively |
160 | | - std::function<void(const std::byte *, const Shape &, const Strides &, size_t, std::ofstream &)> write_data; |
161 | | - |
| 151 | + // Check if tensor is contiguous - for optimization |
| 152 | + if (this->is_contiguous()) { |
| 153 | + // Fast path: contiguous tensor, write in one go |
| 154 | + size_t mem_size = this->numel() * dsize(this->dtype()); |
| 155 | + outFile.write(reinterpret_cast<const char *>(cpu_data), mem_size); |
| 156 | + } else { |
| 157 | + // Slow path: non-contiguous tensor, write element by element using strides |
162 | 158 | switch (this->dtype()) { |
163 | 159 | case DataType::F16: |
164 | | - write_data = [&write_data](const std::byte *data, const Shape &shape, const Strides &strides, size_t dim, std::ofstream &out) { |
165 | | - const uint16_t *ptr = reinterpret_cast<const uint16_t *>(data); |
166 | | - if (dim == shape.size() - 1) { |
167 | | - for (size_t i = 0; i < shape[dim]; i++) { |
168 | | - out << f16_to_f32(ptr[i * strides[dim]]); |
169 | | - if (i < shape[dim] - 1) { |
170 | | - out << " "; |
171 | | - } |
172 | | - } |
173 | | - out << "\n"; |
174 | | - } else { |
175 | | - for (size_t i = 0; i < shape[dim]; i++) { |
176 | | - write_data(data + i * strides[dim] * sizeof(uint16_t), shape, strides, dim + 1, out); |
177 | | - } |
178 | | - } |
179 | | - }; |
| 160 | + case DataType::BF16: |
| 161 | + write_binary_data(outFile, reinterpret_cast<const uint16_t *>(cpu_data), |
| 162 | + this->shape(), this->strides(), 0); |
180 | 163 | break; |
181 | 164 | case DataType::F32: |
182 | | - write_data = [&write_data](const std::byte *data, const Shape &shape, const Strides &strides, size_t dim, std::ofstream &out) { |
183 | | - const float *ptr = reinterpret_cast<const float *>(data); |
184 | | - if (dim == shape.size() - 1) { |
185 | | - for (size_t i = 0; i < shape[dim]; i++) { |
186 | | - out << ptr[i * strides[dim]]; |
187 | | - if (i < shape[dim] - 1) { |
188 | | - out << " "; |
189 | | - } |
190 | | - } |
191 | | - out << "\n"; |
192 | | - } else { |
193 | | - for (size_t i = 0; i < shape[dim]; i++) { |
194 | | - write_data(data + i * strides[dim] * sizeof(float), shape, strides, dim + 1, out); |
195 | | - } |
196 | | - } |
197 | | - }; |
| 165 | + write_binary_data(outFile, reinterpret_cast<const float *>(cpu_data), |
| 166 | + this->shape(), this->strides(), 0); |
198 | 167 | break; |
199 | 168 | case DataType::F64: |
200 | | - write_data = [&write_data](const std::byte *data, const Shape &shape, const Strides &strides, size_t dim, std::ofstream &out) { |
201 | | - const double *ptr = reinterpret_cast<const double *>(data); |
202 | | - if (dim == shape.size() - 1) { |
203 | | - for (size_t i = 0; i < shape[dim]; i++) { |
204 | | - out << ptr[i * strides[dim]]; |
205 | | - if (i < shape[dim] - 1) { |
206 | | - out << " "; |
207 | | - } |
208 | | - } |
209 | | - out << "\n"; |
210 | | - } else { |
211 | | - for (size_t i = 0; i < shape[dim]; i++) { |
212 | | - write_data(data + i * strides[dim] * sizeof(double), shape, strides, dim + 1, out); |
213 | | - } |
214 | | - } |
215 | | - }; |
| 169 | + write_binary_data(outFile, reinterpret_cast<const double *>(cpu_data), |
| 170 | + this->shape(), this->strides(), 0); |
216 | 171 | break; |
217 | | - case DataType::I32: |
218 | | - write_data = [&write_data](const std::byte *data, const Shape &shape, const Strides &strides, size_t dim, std::ofstream &out) { |
219 | | - const int32_t *ptr = reinterpret_cast<const int32_t *>(data); |
220 | | - if (dim == shape.size() - 1) { |
221 | | - for (size_t i = 0; i < shape[dim]; i++) { |
222 | | - out << ptr[i * strides[dim]]; |
223 | | - if (i < shape[dim] - 1) { |
224 | | - out << " "; |
225 | | - } |
226 | | - } |
227 | | - out << "\n"; |
228 | | - } else { |
229 | | - for (size_t i = 0; i < shape[dim]; i++) { |
230 | | - write_data(data + i * strides[dim] * sizeof(int32_t), shape, strides, dim + 1, out); |
231 | | - } |
232 | | - } |
233 | | - }; |
| 172 | + case DataType::U64: |
| 173 | + write_binary_data(outFile, reinterpret_cast<const uint64_t *>(cpu_data), |
| 174 | + this->shape(), this->strides(), 0); |
234 | 175 | break; |
235 | 176 | case DataType::I64: |
236 | | - write_data = [&write_data](const std::byte *data, const Shape &shape, const Strides &strides, size_t dim, std::ofstream &out) { |
237 | | - const int64_t *ptr = reinterpret_cast<const int64_t *>(data); |
238 | | - if (dim == shape.size() - 1) { |
239 | | - for (size_t i = 0; i < shape[dim]; i++) { |
240 | | - out << ptr[i * strides[dim]]; |
241 | | - if (i < shape[dim] - 1) { |
242 | | - out << " "; |
243 | | - } |
244 | | - } |
245 | | - out << "\n"; |
246 | | - } else { |
247 | | - for (size_t i = 0; i < shape[dim]; i++) { |
248 | | - write_data(data + i * strides[dim] * sizeof(int64_t), shape, strides, dim + 1, out); |
249 | | - } |
250 | | - } |
251 | | - }; |
| 177 | + write_binary_data(outFile, reinterpret_cast<const int64_t *>(cpu_data), |
| 178 | + this->shape(), this->strides(), 0); |
252 | 179 | break; |
253 | | - case DataType::BF16: |
254 | | - write_data = [&write_data](const std::byte *data, const Shape &shape, const Strides &strides, size_t dim, std::ofstream &out) { |
255 | | - const uint16_t *ptr = reinterpret_cast<const uint16_t *>(data); |
256 | | - if (dim == shape.size() - 1) { |
257 | | - for (size_t i = 0; i < shape[dim]; i++) { |
258 | | - out << bf16_to_f32(ptr[i * strides[dim]]); |
259 | | - if (i < shape[dim] - 1) { |
260 | | - out << " "; |
261 | | - } |
262 | | - } |
263 | | - out << "\n"; |
264 | | - } else { |
265 | | - for (size_t i = 0; i < shape[dim]; i++) { |
266 | | - write_data(data + i * strides[dim] * sizeof(uint16_t), shape, strides, dim + 1, out); |
267 | | - } |
268 | | - } |
269 | | - }; |
| 180 | + case DataType::U32: |
| 181 | + write_binary_data(outFile, reinterpret_cast<const uint32_t *>(cpu_data), |
| 182 | + this->shape(), this->strides(), 0); |
| 183 | + break; |
| 184 | + case DataType::I32: |
| 185 | + write_binary_data(outFile, reinterpret_cast<const int32_t *>(cpu_data), |
| 186 | + this->shape(), this->strides(), 0); |
| 187 | + break; |
| 188 | + case DataType::U16: |
| 189 | + write_binary_data(outFile, reinterpret_cast<const uint16_t *>(cpu_data), |
| 190 | + this->shape(), this->strides(), 0); |
| 191 | + break; |
| 192 | + case DataType::I16: |
| 193 | + write_binary_data(outFile, reinterpret_cast<const int16_t *>(cpu_data), |
| 194 | + this->shape(), this->strides(), 0); |
| 195 | + break; |
| 196 | + case DataType::U8: |
| 197 | + write_binary_data(outFile, reinterpret_cast<const uint8_t *>(cpu_data), |
| 198 | + this->shape(), this->strides(), 0); |
| 199 | + break; |
| 200 | + case DataType::I8: |
| 201 | + write_binary_data(outFile, reinterpret_cast<const int8_t *>(cpu_data), |
| 202 | + this->shape(), this->strides(), 0); |
| 203 | + break; |
| 204 | + case DataType::BOOL: |
| 205 | + // 布尔类型特殊处理:转换为 uint8_t 以保证跨平台一致性 |
| 206 | + write_binary_data(outFile, reinterpret_cast<const uint8_t *>(cpu_data), |
| 207 | + this->shape(), this->strides(), 0); |
270 | 208 | break; |
271 | 209 | default: |
272 | | - outFile << "# Unsupported data type for text output\n"; |
273 | | - outFile.close(); |
274 | | - if (allocated_memory) { |
275 | | - delete[] allocated_memory; |
276 | | - } |
| 210 | + std::cerr << "Unsupported data type for binary output\n"; |
277 | 211 | return; |
278 | 212 | } |
| 213 | + } |
279 | 214 |
|
280 | | - // Write the actual data |
281 | | - write_data(cpu_data, this->shape(), this->strides(), 0, outFile); |
282 | | - |
283 | | - outFile.close(); |
284 | | - std::cout << "Data written to text file: " << filename << "\n"; |
285 | | - } else { |
286 | | - // Save as binary format (default) |
287 | | - std::ofstream outFile(filename, std::ios::binary); |
288 | | - if (!outFile) { |
289 | | - std::cerr << "Error opening file for writing: " << filename << "\n"; |
290 | | - if (allocated_memory) { |
291 | | - delete[] allocated_memory; |
292 | | - } |
293 | | - return; |
294 | | - } |
295 | | - size_t mem_size = this->numel() * dsize(this->dtype()); |
296 | | - outFile.write(reinterpret_cast<const char *>(cpu_data), mem_size); |
297 | | - outFile.close(); |
298 | | - std::cout << "Data written to binary file: " << filename << "\n"; |
| 215 | + // 显式关闭文件并检查是否成功 |
| 216 | + outFile.close(); |
| 217 | + if (!outFile) { |
| 218 | + std::cerr << "Error: Failed to write data to file: " << filename << "\n"; |
| 219 | + return; |
299 | 220 | } |
300 | 221 |
|
301 | | - if (allocated_memory) { |
302 | | - delete[] allocated_memory; |
| 222 | + std::cout << "Data written to binary file: " << filename; |
| 223 | + if (!this->is_contiguous()) { |
| 224 | + std::cout << " (non-contiguous tensor, wrote " << this->numel() << " elements)"; |
303 | 225 | } |
| 226 | + std::cout << "\n"; |
304 | 227 | return; |
305 | 228 | } |
306 | 229 |
|
@@ -362,11 +285,6 @@ void TensorImpl::debug(const std::string &filename) const { |
362 | 285 | std::cout << "Unsupported data type for debug" << std::endl; |
363 | 286 | break; |
364 | 287 | } |
365 | | - |
366 | | - // Clean up allocated memory |
367 | | - if (allocated_memory) { |
368 | | - delete[] allocated_memory; |
369 | | - } |
370 | 288 | } |
371 | 289 |
|
372 | 290 | void TensorImpl::debug() const { |
|
0 commit comments