Skip to content

Commit 91f2c3a

Browse files
committed
[ET Device Support] TensorImpl carries device info
This diff extends `TensorImpl` to carry device information, enabling the runtime tensor to track which device its data resides on (CPU, CUDA, etc.). This is a prerequisite for parsing device info from the schema and allocating device memory. Differential Revision: [D93635655](https://our.internmc.facebook.com/intern/diff/D93635655/) ghstack-source-id: 342367953 Pull Request resolved: #17534
1 parent fce7663 commit 91f2c3a

3 files changed

Lines changed: 141 additions & 3 deletions

File tree

runtime/core/portable_type/tensor_impl.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ TensorImpl::TensorImpl(
5050
void* data,
5151
DimOrderType* dim_order,
5252
StridesType* strides,
53-
TensorShapeDynamism dynamism)
53+
TensorShapeDynamism dynamism,
54+
DeviceType device_type,
55+
DeviceIndex device_index)
5456
: sizes_(sizes),
5557
dim_order_(dim_order),
5658
strides_(strides),
@@ -59,7 +61,8 @@ TensorImpl::TensorImpl(
5961
numel_(compute_numel(sizes, dim)),
6062
numel_bound_(numel_),
6163
type_(type),
62-
shape_dynamism_(dynamism) {
64+
shape_dynamism_(dynamism),
65+
device_(device_type, device_index) {
6366
ET_CHECK_MSG(
6467
isValid(type_), "Invalid type %" PRId8, static_cast<int8_t>(type_));
6568
ET_CHECK_MSG(dim_ >= 0, "Dimension must be non-negative, got %zd", dim_);

runtime/core/portable_type/tensor_impl.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/runtime/core/array_ref.h>
1212
#include <executorch/runtime/core/error.h>
13+
#include <executorch/runtime/core/portable_type/device.h>
1314
#include <executorch/runtime/core/portable_type/scalar_type.h>
1415
#include <executorch/runtime/core/tensor_shape_dynamism.h>
1516

@@ -99,6 +100,8 @@ class TensorImpl {
99100
* @param strides Strides of the tensor at each dimension. Must contain `dim`
100101
* entries.
101102
* @param dynamism The mutability of the shape of the tensor.
103+
* @param device_type The type of device where tensor data resides.
104+
* @param device_index The device index for multi-device scenarios.
102105
*/
103106
TensorImpl(
104107
ScalarType type,
@@ -107,7 +110,9 @@ class TensorImpl {
107110
void* data = nullptr,
108111
DimOrderType* dim_order = nullptr,
109112
StridesType* strides = nullptr,
110-
TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC);
113+
TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC,
114+
DeviceType device_type = DeviceType::CPU,
115+
DeviceIndex device_index = -1);
111116

112117
/**
113118
* Returns the size of the tensor in bytes.
@@ -176,6 +181,21 @@ class TensorImpl {
176181
return shape_dynamism_;
177182
}
178183

184+
/// Returns the device where tensor data resides.
185+
Device device() const {
186+
return device_;
187+
}
188+
189+
/// Returns the type of device where tensor data resides.
190+
DeviceType device_type() const {
191+
return device_.type();
192+
}
193+
194+
/// Returns the device index, or -1 if default/unspecified.
195+
DeviceIndex device_index() const {
196+
return device_.index();
197+
}
198+
179199
/// Returns a pointer of type T to the constant underlying data blob.
180200
template <typename T>
181201
inline const T* data() const {
@@ -261,6 +281,9 @@ class TensorImpl {
261281

262282
/// Specifies the mutability of the shape of the tensor.
263283
const TensorShapeDynamism shape_dynamism_;
284+
285+
/// Device where tensor data resides (CPU, CUDA, etc.)
286+
Device device_;
264287
};
265288

266289
/**

runtime/core/portable_type/test/tensor_impl_test.cpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ using namespace ::testing;
2121
using executorch::runtime::ArrayRef;
2222
using executorch::runtime::Error;
2323
using executorch::runtime::TensorShapeDynamism;
24+
using executorch::runtime::etensor::Device;
25+
using executorch::runtime::etensor::DeviceIndex;
26+
using executorch::runtime::etensor::DeviceType;
2427
using executorch::runtime::etensor::ScalarType;
2528
using executorch::runtime::etensor::TensorImpl;
2629
using SizesType = TensorImpl::SizesType;
@@ -449,3 +452,112 @@ TEST_F(TensorImplTest, TestResizingTensorToZeroAndBack) {
449452
EXPECT_GT(t.numel(), 0);
450453
EXPECT_EQ(t.data(), data);
451454
}
455+
456+
// ============== Device Tests ==============
457+
458+
TEST_F(TensorImplTest, TestDefaultDeviceIsCPU) {
459+
// TensorImpl constructed without device parameters should default to CPU
460+
SizesType sizes[2] = {3, 2};
461+
float data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
462+
TensorImpl t(ScalarType::Float, 2, sizes, data);
463+
464+
EXPECT_EQ(t.device_type(), DeviceType::CPU);
465+
EXPECT_EQ(t.device_index(), -1);
466+
EXPECT_EQ(t.device(), Device(DeviceType::CPU, -1));
467+
}
468+
469+
TEST_F(TensorImplTest, TestExplicitCPUDevice) {
470+
// TensorImpl constructed with explicit CPU device
471+
SizesType sizes[2] = {3, 2};
472+
DimOrderType dim_order[2] = {0, 1};
473+
StridesType strides[2] = {2, 1};
474+
float data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
475+
TensorImpl t(
476+
ScalarType::Float,
477+
2,
478+
sizes,
479+
data,
480+
dim_order,
481+
strides,
482+
TensorShapeDynamism::STATIC,
483+
DeviceType::CPU,
484+
0);
485+
486+
EXPECT_EQ(t.device_type(), DeviceType::CPU);
487+
EXPECT_EQ(t.device_index(), 0);
488+
EXPECT_EQ(t.device(), Device(DeviceType::CPU, 0));
489+
}
490+
491+
TEST_F(TensorImplTest, TestCUDADevice) {
492+
// TensorImpl constructed with CUDA device
493+
SizesType sizes[2] = {3, 2};
494+
DimOrderType dim_order[2] = {0, 1};
495+
StridesType strides[2] = {2, 1};
496+
float data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
497+
TensorImpl t(
498+
ScalarType::Float,
499+
2,
500+
sizes,
501+
data,
502+
dim_order,
503+
strides,
504+
TensorShapeDynamism::STATIC,
505+
DeviceType::CUDA,
506+
0);
507+
508+
EXPECT_EQ(t.device_type(), DeviceType::CUDA);
509+
EXPECT_EQ(t.device_index(), 0);
510+
EXPECT_EQ(t.device(), Device(DeviceType::CUDA, 0));
511+
}
512+
513+
TEST_F(TensorImplTest, TestCUDADeviceMultiGPU) {
514+
// TensorImpl with CUDA device index 1 (second GPU)
515+
SizesType sizes[2] = {3, 2};
516+
DimOrderType dim_order[2] = {0, 1};
517+
StridesType strides[2] = {2, 1};
518+
float data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
519+
TensorImpl t(
520+
ScalarType::Float,
521+
2,
522+
sizes,
523+
data,
524+
dim_order,
525+
strides,
526+
TensorShapeDynamism::STATIC,
527+
DeviceType::CUDA,
528+
1);
529+
530+
EXPECT_EQ(t.device_type(), DeviceType::CUDA);
531+
EXPECT_EQ(t.device_index(), 1);
532+
EXPECT_EQ(t.device(), Device(DeviceType::CUDA, 1));
533+
}
534+
535+
TEST_F(TensorImplTest, TestDeviceWithDynamicTensor) {
536+
// Device info should work correctly with dynamic tensors
537+
SizesType sizes[2] = {3, 2};
538+
DimOrderType dim_order[2] = {0, 1};
539+
StridesType strides[2] = {2, 1};
540+
float data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
541+
TensorImpl t(
542+
ScalarType::Float,
543+
2,
544+
sizes,
545+
data,
546+
dim_order,
547+
strides,
548+
TensorShapeDynamism::DYNAMIC_BOUND,
549+
DeviceType::CUDA,
550+
0);
551+
552+
EXPECT_EQ(t.device_type(), DeviceType::CUDA);
553+
EXPECT_EQ(t.device_index(), 0);
554+
555+
// Resize should not affect device
556+
SizesType new_sizes[2] = {2, 2};
557+
Error err = resize_tensor_impl(&t, {new_sizes, 2});
558+
EXPECT_EQ(err, Error::Ok);
559+
560+
// Device should remain unchanged after resize
561+
EXPECT_EQ(t.device_type(), DeviceType::CUDA);
562+
EXPECT_EQ(t.device_index(), 0);
563+
}

0 commit comments

Comments
 (0)