Skip to content

Commit 7134334

Browse files
committed
Update base for Update on "[ET Device Support] Extract shared device test utilities to reduce redundancy"
Extract DeviceAwarePartitioner, CpuOnlyPartitioner, and MockCudaAllocator into shared test utility modules to eliminate duplicated definitions across test files. Python: Create executorch/exir/backend/test/device_util.py with DeviceAwarePartitioner (configurable target_device, default "cuda:0"), CpuOnlyPartitioner, and AddOperatorSupport. Update 3 consumer test files. C++: Create executorch/runtime/core/test/mock_cuda_allocator.h with a canonical MockCudaAllocator (malloc/free/memcpy-backed, with call tracking). Update 4 consumer test files. Differential Revision: [D99925172](https://our.internmc.facebook.com/intern/diff/D99925172/) [ghstack-poisoned]
1 parent c6e67e6 commit 7134334

3 files changed

Lines changed: 262 additions & 22 deletions

File tree

extension/tensor/tensor_ptr.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ TensorPtr make_tensor_ptr(
6262
std::vector<executorch::aten::StridesType> strides,
6363
executorch::aten::ScalarType type,
6464
executorch::aten::TensorShapeDynamism dynamism,
65-
std::function<void(void*)> deleter) {
65+
std::function<void(void*)> deleter,
66+
runtime::etensor::DeviceType device_type,
67+
runtime::etensor::DeviceIndex device_index) {
6668
const auto dim = sizes.size();
6769
ET_CHECK_MSG(
6870
dim_order.empty() || dim_order.size() == dim,
@@ -102,6 +104,7 @@ TensorPtr make_tensor_ptr(
102104

103105
strides = std::move(computed_strides);
104106

107+
TensorPtr cpu_tensor;
105108
#ifndef USE_ATEN_LIB
106109
executorch::aten::TensorImpl tensor_impl(
107110
type,
@@ -117,9 +120,9 @@ TensorPtr make_tensor_ptr(
117120
std::move(dim_order),
118121
std::move(strides),
119122
std::move(deleter));
120-
const auto tensor_ptr = &storage->tensor;
121-
return std::shared_ptr<executorch::aten::Tensor>(
122-
std::move(storage), tensor_ptr);
123+
const auto raw_tensor_ptr = &storage->tensor;
124+
cpu_tensor = std::shared_ptr<executorch::aten::Tensor>(
125+
std::move(storage), raw_tensor_ptr);
123126
#else
124127
auto options = c10::TensorOptions()
125128
.dtype(c10::scalarTypeToTypeMeta(type))
@@ -137,8 +140,13 @@ TensorPtr make_tensor_ptr(
137140
c10::DispatchKeySet(c10::DispatchKey::CPU),
138141
options.dtype());
139142
tensor_impl->set_sizes_and_strides(sizes, strides);
140-
return std::make_shared<executorch::aten::Tensor>(std::move(tensor_impl));
143+
cpu_tensor =
144+
std::make_shared<executorch::aten::Tensor>(std::move(tensor_impl));
141145
#endif // USE_ATEN_LIB
146+
if (device_type != runtime::etensor::DeviceType::CPU) {
147+
return clone_tensor_ptr_to_device(cpu_tensor, device_type, device_index);
148+
}
149+
return cpu_tensor;
142150
}
143151

144152
TensorPtr make_tensor_ptr(
@@ -147,7 +155,9 @@ TensorPtr make_tensor_ptr(
147155
std::vector<executorch::aten::DimOrderType> dim_order,
148156
std::vector<executorch::aten::StridesType> strides,
149157
executorch::aten::ScalarType type,
150-
executorch::aten::TensorShapeDynamism dynamism) {
158+
executorch::aten::TensorShapeDynamism dynamism,
159+
runtime::etensor::DeviceType device_type,
160+
runtime::etensor::DeviceIndex device_index) {
151161
ET_CHECK_MSG(
152162
data.size() ==
153163
executorch::aten::compute_numel(sizes.data(), sizes.size()) *
@@ -162,7 +172,9 @@ TensorPtr make_tensor_ptr(
162172
type,
163173
dynamism,
164174
// Data is moved into the deleter and is destroyed together with Storage.
165-
[data = std::move(data)](void*) {});
175+
[data = std::move(data)](void*) {},
176+
device_type,
177+
device_index);
166178
}
167179

168180
TensorPtr clone_tensor_ptr(

extension/tensor/tensor_ptr.h

Lines changed: 90 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ using TensorPtr = std::shared_ptr<executorch::aten::Tensor>;
4141
* @param deleter A custom deleter function for managing the lifetime of the
4242
* data buffer. If provided, this deleter will be called when the managed Tensor
4343
* object is destroyed.
44+
* @param device_type The target device type (default CPU, meaning no copy).
45+
* @param device_index The target device index (default 0).
4446
* @return A TensorPtr that manages the newly created Tensor.
4547
*/
4648
TensorPtr make_tensor_ptr(
@@ -52,7 +54,10 @@ TensorPtr make_tensor_ptr(
5254
executorch::aten::ScalarType::Float,
5355
const executorch::aten::TensorShapeDynamism dynamism =
5456
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
55-
std::function<void(void*)> deleter = nullptr);
57+
std::function<void(void*)> deleter = nullptr,
58+
runtime::etensor::DeviceType device_type =
59+
runtime::etensor::DeviceType::CPU,
60+
runtime::etensor::DeviceIndex device_index = 0);
5661

5762
/**
5863
* Creates a TensorPtr that manages a Tensor with the specified properties.
@@ -64,6 +69,8 @@ TensorPtr make_tensor_ptr(
6469
* @param deleter A custom deleter function for managing the lifetime of the
6570
* data buffer. If provided, this deleter will be called when the managed Tensor
6671
* object is destroyed.
72+
* @param device_type The target device type (default CPU, meaning no copy).
73+
* @param device_index The target device index (default 0).
6774
* @return A TensorPtr that manages the newly created Tensor.
6875
*/
6976
inline TensorPtr make_tensor_ptr(
@@ -73,9 +80,20 @@ inline TensorPtr make_tensor_ptr(
7380
executorch::aten::ScalarType::Float,
7481
const executorch::aten::TensorShapeDynamism dynamism =
7582
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
76-
std::function<void(void*)> deleter = nullptr) {
83+
std::function<void(void*)> deleter = nullptr,
84+
runtime::etensor::DeviceType device_type =
85+
runtime::etensor::DeviceType::CPU,
86+
runtime::etensor::DeviceIndex device_index = 0) {
7787
return make_tensor_ptr(
78-
std::move(sizes), data, {}, {}, type, dynamism, std::move(deleter));
88+
std::move(sizes),
89+
data,
90+
{},
91+
{},
92+
type,
93+
dynamism,
94+
std::move(deleter),
95+
device_type,
96+
device_index);
7997
}
8098

8199
/**
@@ -96,6 +114,8 @@ inline TensorPtr make_tensor_ptr(
96114
* @param type The scalar type of the tensor elements. If it differs from the
97115
* deduced type, the data will be cast to this type if allowed.
98116
* @param dynamism Specifies the mutability of the tensor's shape.
117+
* @param device_type The target device type (default CPU, meaning no copy).
118+
* @param device_index The target device index (default 0).
99119
* @return A TensorPtr that manages the newly created TensorImpl.
100120
*/
101121
template <
@@ -109,7 +129,10 @@ inline TensorPtr make_tensor_ptr(
109129
std::vector<executorch::aten::StridesType> strides = {},
110130
executorch::aten::ScalarType type = deduced_type,
111131
executorch::aten::TensorShapeDynamism dynamism =
112-
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
132+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
133+
runtime::etensor::DeviceType device_type =
134+
runtime::etensor::DeviceType::CPU,
135+
runtime::etensor::DeviceIndex device_index = 0) {
113136
ET_CHECK_MSG(
114137
data.size() ==
115138
executorch::aten::compute_numel(sizes.data(), sizes.size()),
@@ -145,7 +168,9 @@ inline TensorPtr make_tensor_ptr(
145168
std::move(strides),
146169
type,
147170
dynamism,
148-
[data_ptr = std::move(data_ptr)](void*) {});
171+
[data_ptr = std::move(data_ptr)](void*) {},
172+
device_type,
173+
device_index);
149174
}
150175
const auto raw_data_ptr = data.data();
151176
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
@@ -156,7 +181,9 @@ inline TensorPtr make_tensor_ptr(
156181
std::move(strides),
157182
type,
158183
dynamism,
159-
[data_ptr = std::move(data_ptr)](void*) {});
184+
[data_ptr = std::move(data_ptr)](void*) {},
185+
device_type,
186+
device_index);
160187
}
161188

162189
/**
@@ -174,6 +201,8 @@ inline TensorPtr make_tensor_ptr(
174201
* @param type The scalar type of the tensor elements. If it differs from the
175202
* deduced type, the data will be cast to this type if allowed.
176203
* @param dynamism Specifies the mutability of the tensor's shape.
204+
* @param device_type The target device type (default CPU, meaning no copy).
205+
* @param device_index The target device index (default 0).
177206
* @return A TensorPtr that manages the newly created TensorImpl.
178207
*/
179208
template <
@@ -184,11 +213,21 @@ inline TensorPtr make_tensor_ptr(
184213
std::vector<T> data,
185214
executorch::aten::ScalarType type = deduced_type,
186215
executorch::aten::TensorShapeDynamism dynamism =
187-
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
216+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
217+
runtime::etensor::DeviceType device_type =
218+
runtime::etensor::DeviceType::CPU,
219+
runtime::etensor::DeviceIndex device_index = 0) {
188220
std::vector<executorch::aten::SizesType> sizes{
189221
executorch::aten::SizesType(data.size())};
190222
return make_tensor_ptr(
191-
std::move(sizes), std::move(data), {0}, {1}, type, dynamism);
223+
std::move(sizes),
224+
std::move(data),
225+
{0},
226+
{1},
227+
type,
228+
dynamism,
229+
device_type,
230+
device_index);
192231
}
193232

194233
/**
@@ -211,6 +250,8 @@ inline TensorPtr make_tensor_ptr(
211250
* @param type The scalar type of the tensor elements. If it differs from the
212251
* deduced type, the data will be cast to this type if allowed.
213252
* @param dynamism Specifies the mutability of the tensor's shape.
253+
* @param device_type The target device type (default CPU, meaning no copy).
254+
* @param device_index The target device index (default 0).
214255
* @return A TensorPtr that manages the newly created TensorImpl.
215256
*/
216257
template <
@@ -224,14 +265,19 @@ inline TensorPtr make_tensor_ptr(
224265
std::vector<executorch::aten::StridesType> strides = {},
225266
executorch::aten::ScalarType type = deduced_type,
226267
executorch::aten::TensorShapeDynamism dynamism =
227-
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
268+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
269+
runtime::etensor::DeviceType device_type =
270+
runtime::etensor::DeviceType::CPU,
271+
runtime::etensor::DeviceIndex device_index = 0) {
228272
return make_tensor_ptr(
229273
std::move(sizes),
230274
std::vector<T>(std::move(list)),
231275
std::move(dim_order),
232276
std::move(strides),
233277
type,
234-
dynamism);
278+
dynamism,
279+
device_type,
280+
device_index);
235281
}
236282

237283
/**
@@ -251,6 +297,8 @@ inline TensorPtr make_tensor_ptr(
251297
* @param type The scalar type of the tensor elements. If it differs from the
252298
* deduced type, the data will be cast to this type if allowed.
253299
* @param dynamism Specifies the mutability of the tensor's shape.
300+
* @param device_type The target device type (default CPU, meaning no copy).
301+
* @param device_index The target device index (default 0).
254302
* @return A TensorPtr that manages the newly created TensorImpl.
255303
*/
256304
template <
@@ -261,11 +309,21 @@ inline TensorPtr make_tensor_ptr(
261309
std::initializer_list<T> list,
262310
executorch::aten::ScalarType type = deduced_type,
263311
executorch::aten::TensorShapeDynamism dynamism =
264-
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
312+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
313+
runtime::etensor::DeviceType device_type =
314+
runtime::etensor::DeviceType::CPU,
315+
runtime::etensor::DeviceIndex device_index = 0) {
265316
std::vector<executorch::aten::SizesType> sizes{
266317
executorch::aten::SizesType(list.size())};
267318
return make_tensor_ptr(
268-
std::move(sizes), std::move(list), {0}, {1}, type, dynamism);
319+
std::move(sizes),
320+
std::move(list),
321+
{0},
322+
{1},
323+
type,
324+
dynamism,
325+
device_type,
326+
device_index);
269327
}
270328

271329
/**
@@ -294,6 +352,8 @@ inline TensorPtr make_tensor_ptr(T value) {
294352
* @param strides A vector specifying the strides of each dimension.
295353
* @param type The scalar type of the tensor elements.
296354
* @param dynamism Specifies the mutability of the tensor's shape.
355+
* @param device_type The target device type (default CPU, meaning no copy).
356+
* @param device_index The target device index (default 0).
297357
* @return A TensorPtr managing the newly created Tensor.
298358
*/
299359
TensorPtr make_tensor_ptr(
@@ -303,7 +363,10 @@ TensorPtr make_tensor_ptr(
303363
std::vector<executorch::aten::StridesType> strides,
304364
executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
305365
executorch::aten::TensorShapeDynamism dynamism =
306-
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND);
366+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
367+
runtime::etensor::DeviceType device_type =
368+
runtime::etensor::DeviceType::CPU,
369+
runtime::etensor::DeviceIndex device_index = 0);
307370

308371
/**
309372
* Creates a TensorPtr that manages a Tensor with the specified properties.
@@ -316,16 +379,28 @@ TensorPtr make_tensor_ptr(
316379
* @param data A vector containing the raw memory for the tensor's data.
317380
* @param type The scalar type of the tensor elements.
318381
* @param dynamism Specifies the mutability of the tensor's shape.
382+
* @param device_type The target device type (default CPU, meaning no copy).
383+
* @param device_index The target device index (default 0).
319384
* @return A TensorPtr managing the newly created Tensor.
320385
*/
321386
inline TensorPtr make_tensor_ptr(
322387
std::vector<executorch::aten::SizesType> sizes,
323388
std::vector<uint8_t> data,
324389
executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
325390
executorch::aten::TensorShapeDynamism dynamism =
326-
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
391+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
392+
runtime::etensor::DeviceType device_type =
393+
runtime::etensor::DeviceType::CPU,
394+
runtime::etensor::DeviceIndex device_index = 0) {
327395
return make_tensor_ptr(
328-
std::move(sizes), std::move(data), {}, {}, type, dynamism);
396+
std::move(sizes),
397+
std::move(data),
398+
{},
399+
{},
400+
type,
401+
dynamism,
402+
device_type,
403+
device_index);
329404
}
330405

331406
/**

0 commit comments

Comments
 (0)