Skip to content

Commit c7054b7

Browse files
author
zhuyue
committed
refactor test scripts and remove txt write
add large-scale and non-contiguous tensor I/O tests
1 parent f72193d commit c7054b7

2 files changed

Lines changed: 513 additions & 529 deletions

File tree

src/infinicore/tensor/debug.cc

Lines changed: 96 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <cstring>
66
#include <fstream>
77
#include <iostream>
8+
#include <limits>
9+
#include <memory>
810
#include <sstream>
911

1012
namespace infinicore {
@@ -93,214 +95,135 @@ void print_data_bf16(const uint16_t *data, const Shape &shape, const Strides &st
9395
}
9496
}
9597

98+
// Template function for writing data recursively to binary file (handles non-contiguous tensors)
99+
template <typename T>
100+
void write_binary_data(std::ofstream &out, const T *data, const Shape &shape, const Strides &strides, size_t dim) {
101+
if (dim == shape.size() - 1) {
102+
// Write the innermost dimension
103+
for (size_t i = 0; i < shape[dim]; i++) {
104+
out.write(reinterpret_cast<const char *>(&data[i * strides[dim]]), sizeof(T));
105+
}
106+
} else {
107+
// Recursively process higher dimensions
108+
for (size_t i = 0; i < shape[dim]; i++) {
109+
write_binary_data(out, data + i * strides[dim], shape, strides, dim + 1);
110+
}
111+
}
112+
}
113+
96114
void TensorImpl::debug(const std::string &filename) const {
97115
// Synchronize device if needed
98116
context::syncDevice();
99117

100118
std::cout << info() << std::endl;
101119

102120
const std::byte *cpu_data = nullptr;
103-
std::byte *allocated_memory = nullptr;
121+
std::unique_ptr<std::byte[]> allocated_memory; // RAII: 自动管理内存
104122

105123
// Copy data to CPU if not already on CPU
106124
if (this->device().getType() != Device::Type::CPU) {
107-
size_t mem_size = this->numel() * dsize(this->dtype());
108-
allocated_memory = new std::byte[mem_size];
109-
context::memcpyD2H(allocated_memory, this->data(), mem_size);
110-
cpu_data = allocated_memory;
125+
size_t numel = this->numel();
126+
size_t element_size = dsize(this->dtype());
127+
128+
// 检查乘法溢出
129+
if (numel > 0 && element_size > std::numeric_limits<size_t>::max() / numel) {
130+
std::cerr << "Error: Memory size calculation overflow for tensor with "
131+
<< numel << " elements of size " << element_size << "\n";
132+
return;
133+
}
134+
135+
size_t mem_size = numel * element_size;
136+
allocated_memory = std::make_unique<std::byte[]>(mem_size);
137+
context::memcpyD2H(allocated_memory.get(), this->data(), mem_size);
138+
cpu_data = allocated_memory.get();
111139
} else {
112140
cpu_data = this->data();
113141
}
114142

115-
// If filename is provided, save to file
143+
// If filename is provided, save to binary file
116144
if (!filename.empty()) {
117-
// Determine file format based on extension
118-
bool is_text_format = false;
119-
size_t dot_pos = filename.find_last_of('.');
120-
if (dot_pos != std::string::npos) {
121-
std::string ext = filename.substr(dot_pos);
122-
is_text_format = (ext == ".txt");
145+
std::ofstream outFile(filename, std::ios::binary);
146+
if (!outFile) {
147+
std::cerr << "Error opening file for writing: " << filename << "\n";
148+
return; // allocated_memory 会自动释放(RAII)
123149
}
124150

125-
if (is_text_format) {
126-
// Save as text format
127-
std::ofstream outFile(filename);
128-
if (!outFile) {
129-
std::cerr << "Error opening file for writing: " << filename << "\n";
130-
if (allocated_memory) {
131-
delete[] allocated_memory;
132-
}
133-
return;
134-
}
135-
136-
// Write header with tensor information
137-
outFile << "# Tensor Debug Output\n";
138-
outFile << "# Shape: [";
139-
for (size_t i = 0; i < this->shape().size(); ++i) {
140-
outFile << this->shape()[i];
141-
if (i < this->shape().size() - 1) {
142-
outFile << ", ";
143-
}
144-
}
145-
outFile << "]\n";
146-
outFile << "# Strides: [";
147-
for (size_t i = 0; i < this->strides().size(); ++i) {
148-
outFile << this->strides()[i];
149-
if (i < this->strides().size() - 1) {
150-
outFile << ", ";
151-
}
152-
}
153-
outFile << "]\n";
154-
outFile << "# Dtype: " << toString(this->dtype()) << "\n";
155-
outFile << "# Contiguous: " << (this->is_contiguous() ? "Yes" : "No") << "\n";
156-
outFile << "# Elements: " << this->numel() << "\n";
157-
outFile << "#\n";
158-
159-
// Helper function to write data recursively
160-
std::function<void(const std::byte *, const Shape &, const Strides &, size_t, std::ofstream &)> write_data;
161-
151+
// Check if tensor is contiguous - for optimization
152+
if (this->is_contiguous()) {
153+
// Fast path: contiguous tensor, write in one go
154+
size_t mem_size = this->numel() * dsize(this->dtype());
155+
outFile.write(reinterpret_cast<const char *>(cpu_data), mem_size);
156+
} else {
157+
// Slow path: non-contiguous tensor, write element by element using strides
162158
switch (this->dtype()) {
163159
case DataType::F16:
164-
write_data = [&write_data](const std::byte *data, const Shape &shape, const Strides &strides, size_t dim, std::ofstream &out) {
165-
const uint16_t *ptr = reinterpret_cast<const uint16_t *>(data);
166-
if (dim == shape.size() - 1) {
167-
for (size_t i = 0; i < shape[dim]; i++) {
168-
out << f16_to_f32(ptr[i * strides[dim]]);
169-
if (i < shape[dim] - 1) {
170-
out << " ";
171-
}
172-
}
173-
out << "\n";
174-
} else {
175-
for (size_t i = 0; i < shape[dim]; i++) {
176-
write_data(data + i * strides[dim] * sizeof(uint16_t), shape, strides, dim + 1, out);
177-
}
178-
}
179-
};
160+
case DataType::BF16:
161+
write_binary_data(outFile, reinterpret_cast<const uint16_t *>(cpu_data),
162+
this->shape(), this->strides(), 0);
180163
break;
181164
case DataType::F32:
182-
write_data = [&write_data](const std::byte *data, const Shape &shape, const Strides &strides, size_t dim, std::ofstream &out) {
183-
const float *ptr = reinterpret_cast<const float *>(data);
184-
if (dim == shape.size() - 1) {
185-
for (size_t i = 0; i < shape[dim]; i++) {
186-
out << ptr[i * strides[dim]];
187-
if (i < shape[dim] - 1) {
188-
out << " ";
189-
}
190-
}
191-
out << "\n";
192-
} else {
193-
for (size_t i = 0; i < shape[dim]; i++) {
194-
write_data(data + i * strides[dim] * sizeof(float), shape, strides, dim + 1, out);
195-
}
196-
}
197-
};
165+
write_binary_data(outFile, reinterpret_cast<const float *>(cpu_data),
166+
this->shape(), this->strides(), 0);
198167
break;
199168
case DataType::F64:
200-
write_data = [&write_data](const std::byte *data, const Shape &shape, const Strides &strides, size_t dim, std::ofstream &out) {
201-
const double *ptr = reinterpret_cast<const double *>(data);
202-
if (dim == shape.size() - 1) {
203-
for (size_t i = 0; i < shape[dim]; i++) {
204-
out << ptr[i * strides[dim]];
205-
if (i < shape[dim] - 1) {
206-
out << " ";
207-
}
208-
}
209-
out << "\n";
210-
} else {
211-
for (size_t i = 0; i < shape[dim]; i++) {
212-
write_data(data + i * strides[dim] * sizeof(double), shape, strides, dim + 1, out);
213-
}
214-
}
215-
};
169+
write_binary_data(outFile, reinterpret_cast<const double *>(cpu_data),
170+
this->shape(), this->strides(), 0);
216171
break;
217-
case DataType::I32:
218-
write_data = [&write_data](const std::byte *data, const Shape &shape, const Strides &strides, size_t dim, std::ofstream &out) {
219-
const int32_t *ptr = reinterpret_cast<const int32_t *>(data);
220-
if (dim == shape.size() - 1) {
221-
for (size_t i = 0; i < shape[dim]; i++) {
222-
out << ptr[i * strides[dim]];
223-
if (i < shape[dim] - 1) {
224-
out << " ";
225-
}
226-
}
227-
out << "\n";
228-
} else {
229-
for (size_t i = 0; i < shape[dim]; i++) {
230-
write_data(data + i * strides[dim] * sizeof(int32_t), shape, strides, dim + 1, out);
231-
}
232-
}
233-
};
172+
case DataType::U64:
173+
write_binary_data(outFile, reinterpret_cast<const uint64_t *>(cpu_data),
174+
this->shape(), this->strides(), 0);
234175
break;
235176
case DataType::I64:
236-
write_data = [&write_data](const std::byte *data, const Shape &shape, const Strides &strides, size_t dim, std::ofstream &out) {
237-
const int64_t *ptr = reinterpret_cast<const int64_t *>(data);
238-
if (dim == shape.size() - 1) {
239-
for (size_t i = 0; i < shape[dim]; i++) {
240-
out << ptr[i * strides[dim]];
241-
if (i < shape[dim] - 1) {
242-
out << " ";
243-
}
244-
}
245-
out << "\n";
246-
} else {
247-
for (size_t i = 0; i < shape[dim]; i++) {
248-
write_data(data + i * strides[dim] * sizeof(int64_t), shape, strides, dim + 1, out);
249-
}
250-
}
251-
};
177+
write_binary_data(outFile, reinterpret_cast<const int64_t *>(cpu_data),
178+
this->shape(), this->strides(), 0);
252179
break;
253-
case DataType::BF16:
254-
write_data = [&write_data](const std::byte *data, const Shape &shape, const Strides &strides, size_t dim, std::ofstream &out) {
255-
const uint16_t *ptr = reinterpret_cast<const uint16_t *>(data);
256-
if (dim == shape.size() - 1) {
257-
for (size_t i = 0; i < shape[dim]; i++) {
258-
out << bf16_to_f32(ptr[i * strides[dim]]);
259-
if (i < shape[dim] - 1) {
260-
out << " ";
261-
}
262-
}
263-
out << "\n";
264-
} else {
265-
for (size_t i = 0; i < shape[dim]; i++) {
266-
write_data(data + i * strides[dim] * sizeof(uint16_t), shape, strides, dim + 1, out);
267-
}
268-
}
269-
};
180+
case DataType::U32:
181+
write_binary_data(outFile, reinterpret_cast<const uint32_t *>(cpu_data),
182+
this->shape(), this->strides(), 0);
183+
break;
184+
case DataType::I32:
185+
write_binary_data(outFile, reinterpret_cast<const int32_t *>(cpu_data),
186+
this->shape(), this->strides(), 0);
187+
break;
188+
case DataType::U16:
189+
write_binary_data(outFile, reinterpret_cast<const uint16_t *>(cpu_data),
190+
this->shape(), this->strides(), 0);
191+
break;
192+
case DataType::I16:
193+
write_binary_data(outFile, reinterpret_cast<const int16_t *>(cpu_data),
194+
this->shape(), this->strides(), 0);
195+
break;
196+
case DataType::U8:
197+
write_binary_data(outFile, reinterpret_cast<const uint8_t *>(cpu_data),
198+
this->shape(), this->strides(), 0);
199+
break;
200+
case DataType::I8:
201+
write_binary_data(outFile, reinterpret_cast<const int8_t *>(cpu_data),
202+
this->shape(), this->strides(), 0);
203+
break;
204+
case DataType::BOOL:
205+
// 布尔类型特殊处理:转换为 uint8_t 以保证跨平台一致性
206+
write_binary_data(outFile, reinterpret_cast<const uint8_t *>(cpu_data),
207+
this->shape(), this->strides(), 0);
270208
break;
271209
default:
272-
outFile << "# Unsupported data type for text output\n";
273-
outFile.close();
274-
if (allocated_memory) {
275-
delete[] allocated_memory;
276-
}
210+
std::cerr << "Unsupported data type for binary output\n";
277211
return;
278212
}
213+
}
279214

280-
// Write the actual data
281-
write_data(cpu_data, this->shape(), this->strides(), 0, outFile);
282-
283-
outFile.close();
284-
std::cout << "Data written to text file: " << filename << "\n";
285-
} else {
286-
// Save as binary format (default)
287-
std::ofstream outFile(filename, std::ios::binary);
288-
if (!outFile) {
289-
std::cerr << "Error opening file for writing: " << filename << "\n";
290-
if (allocated_memory) {
291-
delete[] allocated_memory;
292-
}
293-
return;
294-
}
295-
size_t mem_size = this->numel() * dsize(this->dtype());
296-
outFile.write(reinterpret_cast<const char *>(cpu_data), mem_size);
297-
outFile.close();
298-
std::cout << "Data written to binary file: " << filename << "\n";
215+
// 显式关闭文件并检查是否成功
216+
outFile.close();
217+
if (!outFile) {
218+
std::cerr << "Error: Failed to write data to file: " << filename << "\n";
219+
return;
299220
}
300221

301-
if (allocated_memory) {
302-
delete[] allocated_memory;
222+
std::cout << "Data written to binary file: " << filename;
223+
if (!this->is_contiguous()) {
224+
std::cout << " (non-contiguous tensor, wrote " << this->numel() << " elements)";
303225
}
226+
std::cout << "\n";
304227
return;
305228
}
306229

@@ -362,11 +285,6 @@ void TensorImpl::debug(const std::string &filename) const {
362285
std::cout << "Unsupported data type for debug" << std::endl;
363286
break;
364287
}
365-
366-
// Clean up allocated memory
367-
if (allocated_memory) {
368-
delete[] allocated_memory;
369-
}
370288
}
371289

372290
void TensorImpl::debug() const {

0 commit comments

Comments
 (0)