Skip to content

Commit bd8be1b

Browse files
committed
[ET Device Support] Add device tensor helper functions to TensorPtr API
Pull Request resolved: #18761 Add clone_tensor_ptr_to_device and clone_tensor_ptr_to_cpu to tensor_ptr.h for cloning tensors between host and device memory via DeviceAllocatorRegistry. Extend the existing make_tensor_ptr(const TensorPtr&, ...) overload with optional device_type/device_index parameters (default CPU/0) for seamless device placement. ghstack-source-id: 364764398 @exported-using-ghexport Differential Revision: [D99913077](https://our.internmc.facebook.com/intern/diff/D99913077/)
1 parent 8e700ab commit bd8be1b

5 files changed

Lines changed: 791 additions & 25 deletions

File tree

extension/tensor/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def define_common_targets():
2424
],
2525
visibility = ["PUBLIC"],
2626
deps = [
27+
"//executorch/runtime/core:device_allocator",
2728
"//executorch/runtime/core/exec_aten/util:dim_order_util" + aten_suffix,
2829
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
2930
],

extension/tensor/tensor_ptr.cpp

Lines changed: 194 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <numeric>
1212

13+
#include <executorch/runtime/core/device_allocator.h>
1314
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1415

1516
namespace executorch {
@@ -61,7 +62,9 @@ TensorPtr make_tensor_ptr(
6162
std::vector<executorch::aten::StridesType> strides,
6263
executorch::aten::ScalarType type,
6364
executorch::aten::TensorShapeDynamism dynamism,
64-
std::function<void(void*)> deleter) {
65+
std::function<void(void*)> deleter,
66+
runtime::etensor::DeviceType device_type,
67+
runtime::etensor::DeviceIndex device_index) {
6568
const auto dim = sizes.size();
6669
ET_CHECK_MSG(
6770
dim_order.empty() || dim_order.size() == dim,
@@ -101,6 +104,7 @@ TensorPtr make_tensor_ptr(
101104

102105
strides = std::move(computed_strides);
103106

107+
TensorPtr cpu_tensor;
104108
#ifndef USE_ATEN_LIB
105109
executorch::aten::TensorImpl tensor_impl(
106110
type,
@@ -116,9 +120,9 @@ TensorPtr make_tensor_ptr(
116120
std::move(dim_order),
117121
std::move(strides),
118122
std::move(deleter));
119-
const auto tensor_ptr = &storage->tensor;
120-
return std::shared_ptr<executorch::aten::Tensor>(
121-
std::move(storage), tensor_ptr);
123+
const auto raw_tensor_ptr = &storage->tensor;
124+
cpu_tensor = std::shared_ptr<executorch::aten::Tensor>(
125+
std::move(storage), raw_tensor_ptr);
122126
#else
123127
auto options = c10::TensorOptions()
124128
.dtype(c10::scalarTypeToTypeMeta(type))
@@ -136,8 +140,13 @@ TensorPtr make_tensor_ptr(
136140
c10::DispatchKeySet(c10::DispatchKey::CPU),
137141
options.dtype());
138142
tensor_impl->set_sizes_and_strides(sizes, strides);
139-
return std::make_shared<executorch::aten::Tensor>(std::move(tensor_impl));
143+
cpu_tensor =
144+
std::make_shared<executorch::aten::Tensor>(std::move(tensor_impl));
140145
#endif // USE_ATEN_LIB
146+
if (device_type != runtime::etensor::DeviceType::CPU) {
147+
return clone_tensor_ptr_to_device(cpu_tensor, device_type, device_index);
148+
}
149+
return cpu_tensor;
141150
}
142151

143152
TensorPtr make_tensor_ptr(
@@ -146,7 +155,9 @@ TensorPtr make_tensor_ptr(
146155
std::vector<executorch::aten::DimOrderType> dim_order,
147156
std::vector<executorch::aten::StridesType> strides,
148157
executorch::aten::ScalarType type,
149-
executorch::aten::TensorShapeDynamism dynamism) {
158+
executorch::aten::TensorShapeDynamism dynamism,
159+
runtime::etensor::DeviceType device_type,
160+
runtime::etensor::DeviceIndex device_index) {
150161
ET_CHECK_MSG(
151162
data.size() ==
152163
executorch::aten::compute_numel(sizes.data(), sizes.size()) *
@@ -161,7 +172,9 @@ TensorPtr make_tensor_ptr(
161172
type,
162173
dynamism,
163174
// Data is moved into the deleter and is destroyed together with Storage.
164-
[data = std::move(data)](void*) {});
175+
[data = std::move(data)](void*) {},
176+
device_type,
177+
device_index);
165178
}
166179

167180
TensorPtr clone_tensor_ptr(
@@ -248,5 +261,179 @@ runtime::Error resize_tensor_ptr(
248261
sizes.data(), sizes.size()));
249262
}
250263

264+
// ---- Device tensor helpers ----
265+
266+
namespace {
267+
268+
#ifndef USE_ATEN_LIB
269+
struct DeviceStorage final {
270+
executorch::aten::TensorImpl tensor_impl;
271+
executorch::aten::Tensor tensor;
272+
std::vector<executorch::aten::SizesType> sizes;
273+
std::vector<executorch::aten::DimOrderType> dim_order;
274+
std::vector<executorch::aten::StridesType> strides;
275+
std::function<void(void*)> deleter;
276+
277+
DeviceStorage(
278+
executorch::aten::TensorImpl&& tensor_impl,
279+
std::vector<executorch::aten::SizesType>&& sizes,
280+
std::vector<executorch::aten::DimOrderType>&& dim_order,
281+
std::vector<executorch::aten::StridesType>&& strides,
282+
std::function<void(void*)>&& deleter)
283+
: tensor_impl(std::move(tensor_impl)),
284+
tensor(&this->tensor_impl),
285+
sizes(std::move(sizes)),
286+
dim_order(std::move(dim_order)),
287+
strides(std::move(strides)),
288+
deleter(std::move(deleter)) {}
289+
290+
~DeviceStorage() {
291+
if (deleter) {
292+
deleter(tensor_impl.mutable_data());
293+
}
294+
}
295+
};
296+
#endif // USE_ATEN_LIB
297+
298+
TensorPtr make_tensor_ptr_with_device(
299+
std::vector<executorch::aten::SizesType> sizes,
300+
void* data,
301+
executorch::aten::ScalarType type,
302+
runtime::etensor::DeviceType device_type,
303+
runtime::etensor::DeviceIndex device_index,
304+
std::function<void(void*)> deleter) {
305+
const auto dim = sizes.size();
306+
std::vector<executorch::aten::DimOrderType> dim_order(dim);
307+
std::iota(dim_order.begin(), dim_order.end(), 0);
308+
309+
std::vector<executorch::aten::StridesType> strides(dim);
310+
if (dim > 0) {
311+
auto error = runtime::dim_order_to_stride(
312+
sizes.data(), dim_order.data(), dim, strides.data());
313+
ET_CHECK_MSG(error == runtime::Error::Ok, "Failed to compute strides.");
314+
}
315+
316+
#ifndef USE_ATEN_LIB
317+
executorch::aten::TensorImpl tensor_impl(
318+
type,
319+
dim,
320+
sizes.data(),
321+
data,
322+
dim_order.data(),
323+
strides.data(),
324+
dim > 0 ? executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND
325+
: executorch::aten::TensorShapeDynamism::STATIC,
326+
device_type,
327+
device_index);
328+
auto storage = std::make_shared<DeviceStorage>(
329+
std::move(tensor_impl),
330+
std::move(sizes),
331+
std::move(dim_order),
332+
std::move(strides),
333+
std::move(deleter));
334+
const auto tensor_ptr = &storage->tensor;
335+
return std::shared_ptr<executorch::aten::Tensor>(
336+
std::move(storage), tensor_ptr);
337+
#else
338+
(void)device_type;
339+
(void)device_index;
340+
return make_tensor_ptr(
341+
std::move(sizes),
342+
data,
343+
type,
344+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
345+
std::move(deleter));
346+
#endif // USE_ATEN_LIB
347+
}
348+
349+
} // namespace
350+
351+
TensorPtr clone_tensor_ptr_to_device(
352+
const TensorPtr& cpu_tensor,
353+
runtime::etensor::DeviceType device_type,
354+
runtime::etensor::DeviceIndex device_index) {
355+
ET_CHECK_MSG(
356+
device_type != runtime::etensor::DeviceType::CPU,
357+
"Target device must not be CPU; use clone_tensor_ptr for CPU-to-CPU copies.");
358+
359+
auto* allocator = runtime::get_device_allocator(device_type);
360+
ET_CHECK_MSG(
361+
allocator != nullptr,
362+
"No device allocator registered for device type %d",
363+
static_cast<int>(device_type));
364+
365+
const auto nbytes = cpu_tensor->nbytes();
366+
const auto* cpu_data = cpu_tensor->const_data_ptr();
367+
ET_CHECK_MSG(cpu_data != nullptr, "Source tensor has no data.");
368+
369+
auto result = allocator->allocate(nbytes, device_index);
370+
ET_CHECK_MSG(result.ok(), "Failed to allocate device memory.");
371+
void* device_data = result.get();
372+
373+
auto err = allocator->copy_host_to_device(
374+
device_data, cpu_data, nbytes, device_index);
375+
ET_CHECK_MSG(err == runtime::Error::Ok, "Host-to-device copy failed.");
376+
377+
std::vector<executorch::aten::SizesType> sizes(
378+
cpu_tensor->sizes().begin(), cpu_tensor->sizes().end());
379+
380+
return make_tensor_ptr_with_device(
381+
std::move(sizes),
382+
device_data,
383+
cpu_tensor->scalar_type(),
384+
device_type,
385+
device_index,
386+
[allocator, device_index](void* ptr) {
387+
allocator->deallocate(ptr, device_index);
388+
});
389+
}
390+
391+
TensorPtr clone_tensor_ptr_to_cpu(const TensorPtr& device_tensor) {
392+
const auto nbytes = device_tensor->nbytes();
393+
const auto* device_data = device_tensor->const_data_ptr();
394+
ET_CHECK_MSG(device_data != nullptr, "Source device tensor has no data.");
395+
396+
#ifndef USE_ATEN_LIB
397+
const auto device_type = device_tensor->unsafeGetTensorImpl()->device_type();
398+
const auto device_index =
399+
device_tensor->unsafeGetTensorImpl()->device_index();
400+
#else
401+
const auto& aten_device = device_tensor->device();
402+
ET_CHECK_MSG(!aten_device.is_cpu(), "Source tensor is already on CPU.");
403+
auto device_type = runtime::etensor::DeviceType::CPU;
404+
if (aten_device.is_cuda()) {
405+
device_type = runtime::etensor::DeviceType::CUDA;
406+
}
407+
const auto device_index =
408+
static_cast<runtime::etensor::DeviceIndex>(aten_device.index());
409+
#endif
410+
411+
ET_CHECK_MSG(
412+
device_type != runtime::etensor::DeviceType::CPU,
413+
"Source tensor is already on CPU.");
414+
415+
auto* allocator = runtime::get_device_allocator(device_type);
416+
ET_CHECK_MSG(
417+
allocator != nullptr,
418+
"No device allocator registered for device type %d",
419+
static_cast<int>(device_type));
420+
421+
std::vector<uint8_t> cpu_data(nbytes);
422+
423+
auto err = allocator->copy_device_to_host(
424+
cpu_data.data(), device_data, nbytes, device_index);
425+
ET_CHECK_MSG(err == runtime::Error::Ok, "Device-to-host copy failed.");
426+
427+
std::vector<executorch::aten::SizesType> sizes(
428+
device_tensor->sizes().begin(), device_tensor->sizes().end());
429+
430+
return make_tensor_ptr(
431+
std::move(sizes),
432+
std::move(cpu_data),
433+
{},
434+
{},
435+
device_tensor->scalar_type());
436+
}
437+
251438
} // namespace extension
252439
} // namespace executorch

0 commit comments

Comments
 (0)