|
21 | 21 | #include "arrow/c/dlpack.h" |
22 | 22 | #include "arrow/c/dlpack_abi.h" |
23 | 23 | #include "arrow/memory_pool.h" |
| 24 | +#include "arrow/tensor.h" |
24 | 25 | #include "arrow/testing/gtest_util.h" |
25 | 26 |
|
26 | 27 | namespace arrow::dlpack { |
@@ -48,7 +49,6 @@ void CheckDLTensor(const std::shared_ptr<Array>& arr, |
48 | 49 | ASSERT_EQ(1, dltensor.ndim); |
49 | 50 |
|
50 | 51 | ASSERT_EQ(dlpack_type, dltensor.dtype.code); |
51 | | - |
52 | 52 | ASSERT_EQ(arrow_type->bit_width(), dltensor.dtype.bits); |
53 | 53 | ASSERT_EQ(1, dltensor.dtype.lanes); |
54 | 54 | ASSERT_EQ(DLDeviceType::kDLCPU, dltensor.device.device_type); |
@@ -126,4 +126,93 @@ TEST_F(TestExportArray, TestErrors) { |
126 | 126 | arrow::dlpack::ExportDevice(array_boolean)); |
127 | 127 | } |
128 | 128 |
|
| 129 | +class TestExportTensor : public ::testing::Test { |
| 130 | + public: |
| 131 | + void SetUp() {} |
| 132 | +}; |
| 133 | + |
| 134 | +void CheckDLTensor(const std::shared_ptr<Tensor>& t, |
| 135 | + const std::shared_ptr<DataType>& tensor_type, |
| 136 | + DLDataTypeCode dlpack_type, std::vector<int64_t> shape, |
| 137 | + std::vector<int64_t> strides) { |
| 138 | + ASSERT_OK_AND_ASSIGN(auto dlmtensor, arrow::dlpack::ExportTensor(t)); |
| 139 | + auto dltensor = dlmtensor->dl_tensor; |
| 140 | + |
| 141 | + ASSERT_EQ(t->data()->data(), dltensor.data); |
| 142 | + ASSERT_EQ(t->ndim(), dltensor.ndim); |
| 143 | + ASSERT_EQ(0, dltensor.byte_offset); |
| 144 | + for (int i = 0; i < t->ndim(); i++) { |
| 145 | + ASSERT_EQ(shape.data()[i], dltensor.shape[i]); |
| 146 | + ASSERT_EQ(strides.data()[i], dltensor.strides[i]); |
| 147 | + } |
| 148 | + |
| 149 | + ASSERT_EQ(dlpack_type, dltensor.dtype.code); |
| 150 | + ASSERT_EQ(tensor_type->bit_width(), dltensor.dtype.bits); |
| 151 | + ASSERT_EQ(1, dltensor.dtype.lanes); |
| 152 | + ASSERT_EQ(DLDeviceType::kDLCPU, dltensor.device.device_type); |
| 153 | + ASSERT_EQ(0, dltensor.device.device_id); |
| 154 | + |
| 155 | + ASSERT_OK_AND_ASSIGN(auto device, arrow::dlpack::ExportDevice(t)); |
| 156 | + ASSERT_EQ(DLDeviceType::kDLCPU, device.device_type); |
| 157 | + ASSERT_EQ(0, device.device_id); |
| 158 | + |
| 159 | + dlmtensor->deleter(dlmtensor); |
| 160 | +} |
| 161 | + |
| 162 | +TEST_F(TestExportTensor, TestTensor) { |
| 163 | + const std::vector<std::pair<std::shared_ptr<DataType>, DLDataTypeCode>> cases = { |
| 164 | + {int8(), DLDataTypeCode::kDLInt}, |
| 165 | + {uint8(), DLDataTypeCode::kDLUInt}, |
| 166 | + { |
| 167 | + int16(), |
| 168 | + DLDataTypeCode::kDLInt, |
| 169 | + }, |
| 170 | + {uint16(), DLDataTypeCode::kDLUInt}, |
| 171 | + { |
| 172 | + int32(), |
| 173 | + DLDataTypeCode::kDLInt, |
| 174 | + }, |
| 175 | + {uint32(), DLDataTypeCode::kDLUInt}, |
| 176 | + { |
| 177 | + int64(), |
| 178 | + DLDataTypeCode::kDLInt, |
| 179 | + }, |
| 180 | + {uint64(), DLDataTypeCode::kDLUInt}, |
| 181 | + {float16(), DLDataTypeCode::kDLFloat}, |
| 182 | + {float32(), DLDataTypeCode::kDLFloat}, |
| 183 | + {float64(), DLDataTypeCode::kDLFloat}}; |
| 184 | + |
| 185 | + const auto allocated_bytes = arrow::default_memory_pool()->bytes_allocated(); |
| 186 | + |
| 187 | + for (auto [arrow_type, dlpack_type] : cases) { |
| 188 | + std::vector<int64_t> shape = {3, 6}; |
| 189 | + std::vector<int64_t> dlpack_strides = {6, 1}; |
| 190 | + std::shared_ptr<Tensor> tensor = TensorFromJSON( |
| 191 | + arrow_type, "[1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9]", shape); |
| 192 | + |
| 193 | + CheckDLTensor(tensor, arrow_type, dlpack_type, shape, dlpack_strides); |
| 194 | + } |
| 195 | + |
| 196 | + ASSERT_EQ(allocated_bytes, arrow::default_memory_pool()->bytes_allocated()); |
| 197 | +} |
| 198 | + |
| 199 | +TEST_F(TestExportTensor, TestTensorStrided) { |
| 200 | + std::vector<int64_t> shape = {2, 2, 2}; |
| 201 | + std::vector<int64_t> strides = {sizeof(float) * 4, sizeof(float) * 2, |
| 202 | + sizeof(float) * 1}; |
| 203 | + std::vector<int64_t> dlpack_strides = {4, 2, 1}; |
| 204 | + std::shared_ptr<Tensor> tensor = |
| 205 | + TensorFromJSON(float32(), "[1, 2, 3, 4, 5, 6, 1, 1]", shape, strides); |
| 206 | + |
| 207 | + CheckDLTensor(tensor, float32(), DLDataTypeCode::kDLFloat, shape, dlpack_strides); |
| 208 | + |
| 209 | + std::vector<int64_t> f_strides = {sizeof(float) * 1, sizeof(float) * 2, |
| 210 | + sizeof(float) * 4}; |
| 211 | + std::vector<int64_t> f_dlpack_strides = {1, 2, 4}; |
| 212 | + std::shared_ptr<Tensor> f_tensor = |
| 213 | + TensorFromJSON(float32(), "[1, 2, 3, 4, 5, 6, 1, 1]", shape, f_strides); |
| 214 | + |
| 215 | + CheckDLTensor(f_tensor, float32(), DLDataTypeCode::kDLFloat, shape, f_dlpack_strides); |
| 216 | +} |
| 217 | + |
129 | 218 | } // namespace arrow::dlpack |
0 commit comments