Skip to content

Commit 4812339

Browse files
committed
[ET Device Support] Parse device info from serialized tensor in tensor_parser
Pull Request resolved: #18328 Parse device info (device_type, device_index) from the serialized ExtraTensorInfo in .pte files into TensorImpl at runtime. When a tensor's extra_tensor_info contains device annotations (e.g., CUDA), the tensor parser now reads and propagates them to the TensorImpl constructor. Tensors without extra_tensor_info default to CPU/0 for backward compatibility with older PTE files.、 ghstack-source-id: 354740819 @exported-using-ghexport Differential Revision: [D97199497](https://our.internmc.facebook.com/intern/diff/D97199497/)
1 parent 1f44185 commit 4812339

5 files changed

Lines changed: 377 additions & 1 deletion

File tree

runtime/executor/tensor_parser_portable.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,18 @@ Result<Tensor> parseTensor(
147147
Internal,
148148
"dim_order_to_stride returned invalid status");
149149

150+
// Extract device info from serialized tensor metadata.
151+
// Defaults to CPU/0 for backward compatibility when extra_tensor_info is
152+
// absent (e.g., older PTE files without device annotations).
153+
auto device_type = executorch::runtime::etensor::DeviceType::CPU;
154+
executorch::runtime::etensor::DeviceIndex device_index = 0;
155+
if (s_tensor->extra_tensor_info() != nullptr) {
156+
device_type = static_cast<executorch::runtime::etensor::DeviceType>(
157+
s_tensor->extra_tensor_info()->device_type());
158+
device_index = static_cast<executorch::runtime::etensor::DeviceIndex>(
159+
s_tensor->extra_tensor_info()->device_index());
160+
}
161+
150162
auto* tensor_impl = method_allocator->allocateInstance<TensorImpl>();
151163
if (tensor_impl == nullptr) {
152164
return Error::MemoryAllocationFailed;
@@ -161,7 +173,9 @@ Result<Tensor> parseTensor(
161173
/*data=*/nullptr,
162174
dim_order,
163175
strides,
164-
dynamism);
176+
dynamism,
177+
device_type,
178+
device_index);
165179

166180
// Now that we know how big the tensor is, find and assign its memory.
167181
Result<void*> data_ptr = getTensorDataPtr(

runtime/executor/test/targets.bzl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,19 @@ def define_common_targets(is_fbcode = False):
312312
],
313313
env = modules_env,
314314
)
315+
316+
runtime.cxx_test(
317+
name = "tensor_parser_device_test",
318+
srcs = [
319+
"tensor_parser_device_test.cpp",
320+
],
321+
deps = [
322+
":managed_memory_manager",
323+
"//executorch/runtime/executor:program",
324+
"//executorch/extension/data_loader:file_data_loader",
325+
"//executorch/schema:program",
326+
],
327+
env = {
328+
"ET_MODULE_ADD_WITH_DEVICE_PATH": "$(location fbcode//executorch/test/models:exported_program_with_device_info[ModuleAddWithDevice.pte])",
329+
},
330+
)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/**
10+
* Tests that device info (device_type) is correctly parsed from serialized
11+
* tensors in .pte files into TensorImpl at runtime.
12+
*
13+
* Uses a .pte exported with DeviceAwarePartitioner (CUDA device annotation)
14+
* so that delegate output tensors carry device_type=CUDA in ExtraTensorInfo.
15+
*/
16+
17+
#include <executorch/runtime/executor/tensor_parser.h>
18+
19+
#include <executorch/extension/data_loader/file_data_loader.h>
20+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
21+
#include <executorch/runtime/executor/test/managed_memory_manager.h>
22+
#include <executorch/schema/program_generated.h>
23+
24+
#include <gtest/gtest.h>
25+
26+
using executorch::aten::Tensor;
27+
using executorch::runtime::Error;
28+
using executorch::runtime::Program;
29+
using executorch::runtime::Result;
30+
using executorch::runtime::deserialization::parseTensor;
31+
using executorch::runtime::testing::ManagedMemoryManager;
32+
using torch::executor::util::FileDataLoader;
33+
34+
constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
35+
constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;
36+
37+
namespace executorch {
38+
namespace runtime {
39+
namespace testing {
40+
class ProgramTestFriend final {
41+
public:
42+
const static executorch_flatbuffer::Program* GetInternalProgram(
43+
const Program* program) {
44+
return program->internal_program_;
45+
}
46+
};
47+
} // namespace testing
48+
} // namespace runtime
49+
} // namespace executorch
50+
51+
using executorch::runtime::testing::ProgramTestFriend;
52+
53+
class TensorParserDeviceTest : public ::testing::Test {
54+
protected:
55+
void SetUp() override {
56+
const char* path = std::getenv("ET_MODULE_ADD_WITH_DEVICE_PATH");
57+
ASSERT_NE(path, nullptr)
58+
<< "ET_MODULE_ADD_WITH_DEVICE_PATH env var not set";
59+
Result<FileDataLoader> loader = FileDataLoader::from(path);
60+
ASSERT_EQ(loader.error(), Error::Ok);
61+
loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
62+
}
63+
64+
std::unique_ptr<FileDataLoader> loader_;
65+
};
66+
67+
TEST_F(TensorParserDeviceTest, CUDADeviceParsedFromPteFile) {
68+
Result<Program> program =
69+
Program::load(loader_.get(), Program::Verification::Minimal);
70+
ASSERT_EQ(program.error(), Error::Ok);
71+
72+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
73+
74+
const executorch_flatbuffer::Program* internal_program =
75+
ProgramTestFriend::GetInternalProgram(&program.get());
76+
auto* execution_plan =
77+
internal_program->execution_plan()->GetMutableObject(0);
78+
auto* flatbuffer_values = execution_plan->values();
79+
80+
int cuda_tensor_count = 0;
81+
int cpu_tensor_count = 0;
82+
int total_tensor_count = 0;
83+
84+
for (size_t i = 0; i < flatbuffer_values->size(); ++i) {
85+
auto* serialization_value = flatbuffer_values->Get(i);
86+
if (serialization_value->val_type() !=
87+
executorch_flatbuffer::KernelTypes::Tensor) {
88+
continue;
89+
}
90+
total_tensor_count++;
91+
92+
auto* s_tensor = serialization_value->val_as_Tensor();
93+
94+
Result<Tensor> tensor = parseTensor(&program.get(), &mmm.get(), s_tensor);
95+
if (!tensor.ok()) {
96+
bool has_cuda = s_tensor->extra_tensor_info() != nullptr &&
97+
s_tensor->extra_tensor_info()->device_type() ==
98+
executorch_flatbuffer::DeviceType::CUDA;
99+
if (has_cuda) {
100+
cuda_tensor_count++;
101+
}
102+
continue;
103+
}
104+
105+
Tensor t = tensor.get();
106+
auto device_type = t.unsafeGetTensorImpl()->device_type();
107+
108+
if (device_type == executorch::runtime::etensor::DeviceType::CUDA) {
109+
cuda_tensor_count++;
110+
EXPECT_EQ(t.unsafeGetTensorImpl()->device_index(), 0)
111+
<< "CUDA tensor should have device_index=0";
112+
} else {
113+
EXPECT_EQ(device_type, executorch::runtime::etensor::DeviceType::CPU);
114+
EXPECT_EQ(t.unsafeGetTensorImpl()->device_index(), 0)
115+
<< "CPU tensor should have device_index=0";
116+
cpu_tensor_count++;
117+
}
118+
}
119+
120+
EXPECT_GT(total_tensor_count, 0) << "Should have at least one tensor";
121+
// The model has add(a, b) delegated to CUDA — 2 inputs + 1 output = 3 CUDA
122+
EXPECT_EQ(cuda_tensor_count, 3)
123+
<< "Expected 3 CUDA tensors (2 delegate inputs + 1 delegate output)";
124+
}
125+
126+
TEST_F(TensorParserDeviceTest, NonDelegatedTensorsDefaultToCPU) {
127+
Result<Program> program =
128+
Program::load(loader_.get(), Program::Verification::Minimal);
129+
ASSERT_EQ(program.error(), Error::Ok);
130+
131+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
132+
133+
const executorch_flatbuffer::Program* internal_program =
134+
ProgramTestFriend::GetInternalProgram(&program.get());
135+
auto* execution_plan =
136+
internal_program->execution_plan()->GetMutableObject(0);
137+
auto* flatbuffer_values = execution_plan->values();
138+
139+
for (size_t i = 0; i < flatbuffer_values->size(); ++i) {
140+
auto* serialization_value = flatbuffer_values->Get(i);
141+
if (serialization_value->val_type() !=
142+
executorch_flatbuffer::KernelTypes::Tensor) {
143+
continue;
144+
}
145+
146+
auto* s_tensor = serialization_value->val_as_Tensor();
147+
bool has_cuda_device = s_tensor->extra_tensor_info() != nullptr &&
148+
s_tensor->extra_tensor_info()->device_type() ==
149+
executorch_flatbuffer::DeviceType::CUDA;
150+
151+
// Only check tensors that are NOT annotated as CUDA
152+
if (has_cuda_device) {
153+
continue;
154+
}
155+
156+
Result<Tensor> tensor = parseTensor(&program.get(), &mmm.get(), s_tensor);
157+
if (!tensor.ok()) {
158+
continue;
159+
}
160+
161+
Tensor t = tensor.get();
162+
EXPECT_EQ(
163+
t.unsafeGetTensorImpl()->device_type(),
164+
executorch::runtime::etensor::DeviceType::CPU)
165+
<< "Tensor at index " << i
166+
<< " without CUDA annotation should default to CPU";
167+
EXPECT_EQ(t.unsafeGetTensorImpl()->device_index(), 0)
168+
<< "Tensor at index " << i
169+
<< " without device annotation should have device_index=0";
170+
}
171+
}
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
"""Exports a simple model with device-annotated tensors for C++ testing.
10+
11+
Uses DeviceAwarePartitioner (BackendWithCompilerDemo + target_device=cuda:0)
12+
so that delegate output tensors are annotated with CUDA device in the .pte.
13+
"""
14+
15+
import argparse
16+
import os
17+
from typing import Dict, final
18+
19+
import torch
20+
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
21+
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
22+
generate_pattern_op_partitions,
23+
)
24+
from executorch.exir.backend.compile_spec_schema import CompileSpec
25+
from executorch.exir.backend.partitioner import (
26+
DelegationSpec,
27+
Partitioner,
28+
PartitionResult,
29+
)
30+
from executorch.exir.backend.test.backend_with_compiler_demo import (
31+
BackendWithCompilerDemo,
32+
)
33+
from executorch.exir.dialects._ops import ops as exir_ops
34+
from executorch.exir.passes.propagate_device_pass import TARGET_DEVICE_COMPILE_SPEC_KEY
35+
from torch import nn
36+
from torch.export import export
37+
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
38+
39+
40+
class _AddOperatorSupport(OperatorSupportBase):
41+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
42+
return node.op == "call_function" and node.target in [
43+
exir_ops.edge.aten.add.Tensor,
44+
]
45+
46+
47+
@final
48+
class _DeviceAwarePartitioner(Partitioner):
49+
"""Partitioner that tags add ops for delegation with target_device=cuda:0."""
50+
51+
def __init__(self) -> None:
52+
super().__init__()
53+
self.delegation_spec = DelegationSpec(
54+
BackendWithCompilerDemo.__name__,
55+
[
56+
CompileSpec("max_value", bytes([4])),
57+
CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"),
58+
],
59+
)
60+
61+
def partition(self, exported_program) -> PartitionResult:
62+
partition_tags: Dict[str, DelegationSpec] = {}
63+
partition_list = generate_pattern_op_partitions(
64+
exported_program.graph_module,
65+
op_support=any_chain(_AddOperatorSupport()),
66+
)
67+
for partition in partition_list:
68+
for node in partition.nodes:
69+
tag = f"tag{partition.id}"
70+
node.meta["delegation_tag"] = tag
71+
partition_tags[tag] = self.delegation_spec
72+
return PartitionResult(
73+
tagged_exported_program=exported_program,
74+
partition_tags=partition_tags,
75+
)
76+
77+
78+
class ModuleAddWithDevice(nn.Module):
79+
"""Simple add model — the add op will be delegated with CUDA device annotation."""
80+
81+
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
82+
return torch.add(a, b)
83+
84+
def get_random_inputs(self):
85+
return (torch.randn(2, 2), torch.randn(2, 2))
86+
87+
88+
def main() -> None:
89+
parser = argparse.ArgumentParser()
90+
parser.add_argument("--outdir", type=str, required=True)
91+
args = parser.parse_args()
92+
93+
torch.manual_seed(0)
94+
model = ModuleAddWithDevice()
95+
inputs = model.get_random_inputs()
96+
97+
edge = to_edge(
98+
export(model, inputs),
99+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
100+
)
101+
lowered = edge.to_backend(_DeviceAwarePartitioner())
102+
et_prog = lowered.to_executorch(ExecutorchBackendConfig(emit_stacktrace=False))
103+
104+
os.makedirs(args.outdir, exist_ok=True)
105+
outfile = os.path.join(args.outdir, "ModuleAddWithDevice.pte")
106+
107+
# Verify device annotations are present in the serialized program
108+
from executorch.exir.schema import DeviceType, Tensor as SchemaTensor
109+
110+
program = et_prog._emitter_output.program
111+
plan = program.execution_plan[0]
112+
print(f"Delegates: {len(plan.delegates)}")
113+
cuda_count = 0
114+
for i, v in enumerate(plan.values):
115+
if isinstance(v.val, SchemaTensor):
116+
t = v.val
117+
eti = t.extra_tensor_info
118+
dev = eti.device_type if eti else "no_eti"
119+
print(f" Tensor[{i}]: sizes={list(t.sizes)}, device={dev}")
120+
if eti and eti.device_type == DeviceType.CUDA:
121+
cuda_count += 1
122+
print(f"CUDA tensors: {cuda_count}")
123+
124+
# Also check graph module specs
125+
from executorch.exir.delegate import executorch_call_delegate
126+
from executorch.exir.tensor import TensorSpec
127+
128+
gm = et_prog.exported_program().graph_module
129+
for node in gm.graph.nodes:
130+
if node.op == "call_function" and node.target == executorch_call_delegate:
131+
specs = node.meta.get("spec")
132+
print(
133+
f" Delegate node '{node.name}' spec.device = {specs.device if isinstance(specs, TensorSpec) else [s.device for s in specs if isinstance(s, TensorSpec)]}"
134+
)
135+
136+
with open(outfile, "wb") as fp:
137+
fp.write(et_prog.buffer)
138+
print(f"Exported ModuleAddWithDevice to {outfile}")
139+
140+
141+
if __name__ == "__main__":
142+
main()

0 commit comments

Comments
 (0)