1010
1111#include < numeric>
1212
13+ #include < executorch/runtime/core/device_allocator.h>
1314#include < executorch/runtime/core/exec_aten/util/tensor_util.h>
1415
1516namespace executorch {
@@ -61,7 +62,9 @@ TensorPtr make_tensor_ptr(
6162 std::vector<executorch::aten::StridesType> strides,
6263 executorch::aten::ScalarType type,
6364 executorch::aten::TensorShapeDynamism dynamism,
64- std::function<void (void *)> deleter) {
65+ std::function<void (void *)> deleter,
66+ runtime::etensor::DeviceType device_type,
67+ runtime::etensor::DeviceIndex device_index) {
6568 const auto dim = sizes.size ();
6669 ET_CHECK_MSG (
6770 dim_order.empty () || dim_order.size () == dim,
@@ -101,6 +104,7 @@ TensorPtr make_tensor_ptr(
101104
102105 strides = std::move (computed_strides);
103106
107+ TensorPtr cpu_tensor;
104108#ifndef USE_ATEN_LIB
105109 executorch::aten::TensorImpl tensor_impl (
106110 type,
@@ -116,9 +120,9 @@ TensorPtr make_tensor_ptr(
116120 std::move (dim_order),
117121 std::move (strides),
118122 std::move (deleter));
119- const auto tensor_ptr = &storage->tensor ;
120- return std::shared_ptr<executorch::aten::Tensor>(
121- std::move (storage), tensor_ptr );
123+ const auto raw_tensor_ptr = &storage->tensor ;
124+ cpu_tensor = std::shared_ptr<executorch::aten::Tensor>(
125+ std::move (storage), raw_tensor_ptr );
122126#else
123127 auto options = c10::TensorOptions ()
124128 .dtype (c10::scalarTypeToTypeMeta (type))
@@ -136,8 +140,13 @@ TensorPtr make_tensor_ptr(
136140 c10::DispatchKeySet (c10::DispatchKey::CPU),
137141 options.dtype ());
138142 tensor_impl->set_sizes_and_strides (sizes, strides);
139- return std::make_shared<executorch::aten::Tensor>(std::move (tensor_impl));
143+ cpu_tensor =
144+ std::make_shared<executorch::aten::Tensor>(std::move (tensor_impl));
140145#endif // USE_ATEN_LIB
146+ if (device_type != runtime::etensor::DeviceType::CPU) {
147+ return clone_tensor_ptr_to_device (cpu_tensor, device_type, device_index);
148+ }
149+ return cpu_tensor;
141150}
142151
143152TensorPtr make_tensor_ptr (
@@ -146,7 +155,9 @@ TensorPtr make_tensor_ptr(
146155 std::vector<executorch::aten::DimOrderType> dim_order,
147156 std::vector<executorch::aten::StridesType> strides,
148157 executorch::aten::ScalarType type,
149- executorch::aten::TensorShapeDynamism dynamism) {
158+ executorch::aten::TensorShapeDynamism dynamism,
159+ runtime::etensor::DeviceType device_type,
160+ runtime::etensor::DeviceIndex device_index) {
150161 ET_CHECK_MSG (
151162 data.size () ==
152163 executorch::aten::compute_numel (sizes.data (), sizes.size ()) *
@@ -161,7 +172,9 @@ TensorPtr make_tensor_ptr(
161172 type,
162173 dynamism,
163174 // Data is moved into the deleter and is destroyed together with Storage.
164- [data = std::move (data)](void *) {});
175+ [data = std::move (data)](void *) {},
176+ device_type,
177+ device_index);
165178}
166179
167180TensorPtr clone_tensor_ptr (
@@ -248,5 +261,179 @@ runtime::Error resize_tensor_ptr(
248261 sizes.data (), sizes.size ()));
249262}
250263
264+ // ---- Device tensor helpers ----
265+
266+ namespace {
267+
268+ #ifndef USE_ATEN_LIB
269+ struct DeviceStorage final {
270+ executorch::aten::TensorImpl tensor_impl;
271+ executorch::aten::Tensor tensor;
272+ std::vector<executorch::aten::SizesType> sizes;
273+ std::vector<executorch::aten::DimOrderType> dim_order;
274+ std::vector<executorch::aten::StridesType> strides;
275+ std::function<void (void *)> deleter;
276+
277+ DeviceStorage (
278+ executorch::aten::TensorImpl&& tensor_impl,
279+ std::vector<executorch::aten::SizesType>&& sizes,
280+ std::vector<executorch::aten::DimOrderType>&& dim_order,
281+ std::vector<executorch::aten::StridesType>&& strides,
282+ std::function<void (void *)>&& deleter)
283+ : tensor_impl(std::move(tensor_impl)),
284+ tensor (&this ->tensor_impl),
285+ sizes(std::move(sizes)),
286+ dim_order(std::move(dim_order)),
287+ strides(std::move(strides)),
288+ deleter(std::move(deleter)) {}
289+
290+ ~DeviceStorage () {
291+ if (deleter) {
292+ deleter (tensor_impl.mutable_data ());
293+ }
294+ }
295+ };
296+ #endif // USE_ATEN_LIB
297+
298+ TensorPtr make_tensor_ptr_with_device (
299+ std::vector<executorch::aten::SizesType> sizes,
300+ void * data,
301+ executorch::aten::ScalarType type,
302+ runtime::etensor::DeviceType device_type,
303+ runtime::etensor::DeviceIndex device_index,
304+ std::function<void (void *)> deleter) {
305+ const auto dim = sizes.size ();
306+ std::vector<executorch::aten::DimOrderType> dim_order (dim);
307+ std::iota (dim_order.begin (), dim_order.end (), 0 );
308+
309+ std::vector<executorch::aten::StridesType> strides (dim);
310+ if (dim > 0 ) {
311+ auto error = runtime::dim_order_to_stride (
312+ sizes.data (), dim_order.data (), dim, strides.data ());
313+ ET_CHECK_MSG (error == runtime::Error::Ok, " Failed to compute strides." );
314+ }
315+
316+ #ifndef USE_ATEN_LIB
317+ executorch::aten::TensorImpl tensor_impl (
318+ type,
319+ dim,
320+ sizes.data (),
321+ data,
322+ dim_order.data (),
323+ strides.data (),
324+ dim > 0 ? executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND
325+ : executorch::aten::TensorShapeDynamism::STATIC,
326+ device_type,
327+ device_index);
328+ auto storage = std::make_shared<DeviceStorage>(
329+ std::move (tensor_impl),
330+ std::move (sizes),
331+ std::move (dim_order),
332+ std::move (strides),
333+ std::move (deleter));
334+ const auto tensor_ptr = &storage->tensor ;
335+ return std::shared_ptr<executorch::aten::Tensor>(
336+ std::move (storage), tensor_ptr);
337+ #else
338+ (void )device_type;
339+ (void )device_index;
340+ return make_tensor_ptr (
341+ std::move (sizes),
342+ data,
343+ type,
344+ executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
345+ std::move (deleter));
346+ #endif // USE_ATEN_LIB
347+ }
348+
349+ } // namespace
350+
351+ TensorPtr clone_tensor_ptr_to_device (
352+ const TensorPtr& cpu_tensor,
353+ runtime::etensor::DeviceType device_type,
354+ runtime::etensor::DeviceIndex device_index) {
355+ ET_CHECK_MSG (
356+ device_type != runtime::etensor::DeviceType::CPU,
357+ " Target device must not be CPU; use clone_tensor_ptr for CPU-to-CPU copies." );
358+
359+ auto * allocator = runtime::get_device_allocator (device_type);
360+ ET_CHECK_MSG (
361+ allocator != nullptr ,
362+ " No device allocator registered for device type %d" ,
363+ static_cast <int >(device_type));
364+
365+ const auto nbytes = cpu_tensor->nbytes ();
366+ const auto * cpu_data = cpu_tensor->const_data_ptr ();
367+ ET_CHECK_MSG (cpu_data != nullptr , " Source tensor has no data." );
368+
369+ auto result = allocator->allocate (nbytes, device_index);
370+ ET_CHECK_MSG (result.ok (), " Failed to allocate device memory." );
371+ void * device_data = result.get ();
372+
373+ auto err = allocator->copy_host_to_device (
374+ device_data, cpu_data, nbytes, device_index);
375+ ET_CHECK_MSG (err == runtime::Error::Ok, " Host-to-device copy failed." );
376+
377+ std::vector<executorch::aten::SizesType> sizes (
378+ cpu_tensor->sizes ().begin (), cpu_tensor->sizes ().end ());
379+
380+ return make_tensor_ptr_with_device (
381+ std::move (sizes),
382+ device_data,
383+ cpu_tensor->scalar_type (),
384+ device_type,
385+ device_index,
386+ [allocator, device_index](void * ptr) {
387+ allocator->deallocate (ptr, device_index);
388+ });
389+ }
390+
391+ TensorPtr clone_tensor_ptr_to_cpu (const TensorPtr& device_tensor) {
392+ const auto nbytes = device_tensor->nbytes ();
393+ const auto * device_data = device_tensor->const_data_ptr ();
394+ ET_CHECK_MSG (device_data != nullptr , " Source device tensor has no data." );
395+
396+ #ifndef USE_ATEN_LIB
397+ const auto device_type = device_tensor->unsafeGetTensorImpl ()->device_type ();
398+ const auto device_index =
399+ device_tensor->unsafeGetTensorImpl ()->device_index ();
400+ #else
401+ const auto & aten_device = device_tensor->device ();
402+ ET_CHECK_MSG (!aten_device.is_cpu (), " Source tensor is already on CPU." );
403+ auto device_type = runtime::etensor::DeviceType::CPU;
404+ if (aten_device.is_cuda ()) {
405+ device_type = runtime::etensor::DeviceType::CUDA;
406+ }
407+ const auto device_index =
408+ static_cast <runtime::etensor::DeviceIndex>(aten_device.index ());
409+ #endif
410+
411+ ET_CHECK_MSG (
412+ device_type != runtime::etensor::DeviceType::CPU,
413+ " Source tensor is already on CPU." );
414+
415+ auto * allocator = runtime::get_device_allocator (device_type);
416+ ET_CHECK_MSG (
417+ allocator != nullptr ,
418+ " No device allocator registered for device type %d" ,
419+ static_cast <int >(device_type));
420+
421+ std::vector<uint8_t > cpu_data (nbytes);
422+
423+ auto err = allocator->copy_device_to_host (
424+ cpu_data.data (), device_data, nbytes, device_index);
425+ ET_CHECK_MSG (err == runtime::Error::Ok, " Device-to-host copy failed." );
426+
427+ std::vector<executorch::aten::SizesType> sizes (
428+ device_tensor->sizes ().begin (), device_tensor->sizes ().end ());
429+
430+ return make_tensor_ptr (
431+ std::move (sizes),
432+ std::move (cpu_data),
433+ {},
434+ {},
435+ device_tensor->scalar_type ());
436+ }
437+
251438} // namespace extension
252439} // namespace executorch
0 commit comments