Skip to content

Commit 7509a83

Browse files
committed
[ET Device Support] Schema changes: device info on Tensor
Pull Request resolved: #17533 This diff adds device placement information to the ExecuTorch schema to support representing tensor-level device type information, which will be the basic requirement for the following tensor_parser updates. This is part of the Phase 1 implementation to make ET device type work E2E without user-specified device placement. Design doc: https://docs.google.com/document/d/1lwd9BlohmwkN5EEvRulO_b-XnZBwv1nMb5l2K3jfuwA/edit?tab=t.0#heading=h.o6anuvkix4bu ghstack-source-id: 353202792 @exported-using-ghexport Differential Revision: [D93635657](https://our.internmc.facebook.com/intern/diff/D93635657/)
1 parent 22174fa commit 7509a83

2 files changed

Lines changed: 22 additions & 0 deletions

File tree

exir/schema.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ class TensorDataLocation(IntEnum):
4848
EXTERNAL = 1
4949

5050

51+
class DeviceType(IntEnum):
52+
CPU = 0
53+
CUDA = 1
54+
55+
5156
@dataclass
5257
class ExtraTensorInfo:
5358
"""
@@ -57,6 +62,8 @@ class ExtraTensorInfo:
5762
mutable_data_segments_idx: int = 0
5863
fully_qualified_name: Optional[str] = None
5964
location: TensorDataLocation = TensorDataLocation.SEGMENT
65+
device_type: DeviceType = DeviceType.CPU
66+
device_index: int = -1
6067

6168

6269
@dataclass

schema/program.fbs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ enum TensorDataLocation : byte {
6262
EXTERNAL = 1,
6363
}
6464

65+
// Device type enum indicating where a tensor resides or should be allocated.
66+
enum DeviceType : byte {
67+
CPU = 0,
68+
CUDA = 1,
69+
}
70+
6571
// Table to put additional information about tensors in that is not applicable
6672
// to the vast majority of tensors in the vast majority of programs.
6773
table ExtraTensorInfo {
@@ -80,6 +86,15 @@ table ExtraTensorInfo {
8086
// must be non-empty, and is used as a key to find the tensor's external
8187
// data. Tensor.data_buffer_idx is ignored.
8288
location: TensorDataLocation;
89+
90+
// [Optional] The device type where this tensor resides or should be allocated.
91+
// Defaults to CPU for backward compatibility with existing PTE files.
92+
device_type: DeviceType = CPU;
93+
94+
// [Optional] The device index for multi-device scenarios (e.g., cuda:0, cuda:1).
95+
// A value of -1 indicates the default device. Defaults to -1 for backward
96+
// compatibility.
97+
device_index: byte = -1;
8398
}
8499

85100
table Tensor {

0 commit comments

Comments
 (0)