Skip to content

Commit 4b25bb4

Browse files
committed
[ET Device Support] Propagate device info from TensorSpec into serialized Tensor
Pull Request resolved: #18079 Propagate device information from `TensorSpec.device` (set by `PropagateDevicePass`) to the serialized `schema.Tensor` in the emitted PTE file, to make runtime further aware of it. ghstack-source-id: 366850771 @exported-using-ghexport Differential Revision: [D95899706](https://our.internmc.facebook.com/intern/diff/D95899706/)
1 parent 3c6c38b commit 4b25bb4

3 files changed

Lines changed: 146 additions & 1 deletion

File tree

exir/emit/test/BUCK

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,15 @@ fbcode_target(_kind = runtime.python_test,
2626
"//executorch/exir:schema",
2727
"//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner",
2828
"//executorch/exir/backend:backend_api",
29+
"//executorch/exir/backend:compile_spec_schema",
30+
"//executorch/exir/backend:partitioner",
31+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
32+
"//executorch/exir/backend/test:backend_with_compiler_demo",
2933
"//executorch/exir/emit:lib",
3034
"//executorch/exir/passes:const_prop_pass",
3135
"//executorch/exir/passes:constant_prop_pass",
3236
"//executorch/exir/passes:init_mutable_pass",
37+
"//executorch/exir/passes:propagate_device_pass",
3338
"//executorch/exir/tests:lib",
3439
"//executorch/exir/tests:models",
3540
"//executorch/extension/pybindings:portable_lib",

exir/emit/test/test_emit.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2518,3 +2518,128 @@ def forward(self):
25182518
for j in range(2):
25192519
expected_storage.append(j * 16 + i)
25202520
self.assertEqual([int(v) for v in storage_values], expected_storage)
2521+
2522+
def test_emit_device_info_propagated_to_serialized_tensor(self) -> None:
2523+
"""Verify that device info from PropagateDevicePass flows through
2524+
the emitter into ExtraTensorInfo.device_type on serialized tensors."""
2525+
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
2526+
generate_pattern_op_partitions,
2527+
)
2528+
from executorch.exir.backend.compile_spec_schema import CompileSpec
2529+
from executorch.exir.backend.partitioner import (
2530+
DelegationSpec,
2531+
Partitioner,
2532+
PartitionResult,
2533+
)
2534+
from executorch.exir.backend.test.backend_with_compiler_demo import (
2535+
BackendWithCompilerDemo,
2536+
)
2537+
from executorch.exir.passes.propagate_device_pass import (
2538+
TARGET_DEVICE_COMPILE_SPEC_KEY,
2539+
)
2540+
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
2541+
2542+
class AddSupport(OperatorSupportBase):
2543+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
2544+
return node.op == "call_function" and node.target in [
2545+
exir_ops.edge.aten.add.Tensor,
2546+
]
2547+
2548+
class DevicePartitioner(Partitioner):
2549+
def __init__(self):
2550+
super().__init__()
2551+
self.delegation_spec = DelegationSpec(
2552+
BackendWithCompilerDemo.__name__,
2553+
[
2554+
CompileSpec("max_value", bytes([4])),
2555+
CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"),
2556+
],
2557+
)
2558+
2559+
def partition(self, exported_program) -> PartitionResult:
2560+
partition_tags = {}
2561+
partition_list = generate_pattern_op_partitions(
2562+
exported_program.graph_module,
2563+
op_support=any_chain(AddSupport()),
2564+
)
2565+
for partition in partition_list:
2566+
for node in partition.nodes:
2567+
tag = f"tag{partition.id}"
2568+
node.meta["delegation_tag"] = tag
2569+
partition_tags[tag] = self.delegation_spec
2570+
return PartitionResult(
2571+
tagged_exported_program=exported_program,
2572+
partition_tags=partition_tags,
2573+
)
2574+
2575+
class Model(torch.nn.Module):
2576+
def forward(self, a, b):
2577+
return torch.add(a, b)
2578+
2579+
model = Model()
2580+
inputs = (torch.randn(2, 2), torch.randn(2, 2))
2581+
2582+
edge = to_edge(
2583+
export(model, inputs),
2584+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
2585+
)
2586+
lowered = edge.to_backend(DevicePartitioner())
2587+
et_prog = lowered.to_executorch()
2588+
program = et_prog._emitter_output.program
2589+
2590+
plan = program.execution_plan[0]
2591+
self.assertGreater(len(plan.delegates), 0)
2592+
2593+
tensor_values = [v.val for v in plan.values if isinstance(v.val, Tensor)]
2594+
cuda_tensors = [
2595+
t
2596+
for t in tensor_values
2597+
if t.extra_tensor_info is not None
2598+
and t.extra_tensor_info.device_type == schema.DeviceType.CUDA
2599+
]
2600+
# add(a, b) has 2 delegate inputs + 1 delegate output = 3 CUDA tensors
2601+
self.assertEqual(
2602+
len(cuda_tensors),
2603+
3,
2604+
f"Expected exactly 3 CUDA tensors (2 inputs + 1 output for delegated add), got {len(cuda_tensors)}",
2605+
)
2606+
# Verify device_index is also correctly serialized (cuda:0 → index 0)
2607+
for t in cuda_tensors:
2608+
self.assertEqual(
2609+
t.extra_tensor_info.device_index,
2610+
0,
2611+
"CUDA tensor device_index should be 0 for cuda:0",
2612+
)
2613+
2614+
def test_emit_cpu_tensors_no_extra_device_info(self) -> None:
2615+
"""When all tensors are on CPU (default), ExtraTensorInfo should NOT be
2616+
created solely for device info — it should remain None for activation tensors.
2617+
"""
2618+
2619+
class Model(torch.nn.Module):
2620+
def forward(self, a, b):
2621+
return torch.add(a, b)
2622+
2623+
model = Model()
2624+
inputs = (torch.randn(2, 2), torch.randn(2, 2))
2625+
2626+
edge = to_edge(
2627+
export(model, inputs),
2628+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
2629+
)
2630+
et_prog = edge.to_executorch()
2631+
program = et_prog._emitter_output.program
2632+
2633+
plan = program.execution_plan[0]
2634+
tensor_values = [v.val for v in plan.values if isinstance(v.val, Tensor)]
2635+
non_cpu_tensors = [
2636+
t
2637+
for t in tensor_values
2638+
if t.extra_tensor_info is not None
2639+
and t.extra_tensor_info.device_type is not None
2640+
]
2641+
self.assertEqual(
2642+
len(non_cpu_tensors),
2643+
0,
2644+
"No tensor should have extra device info when model runs entirely on CPU",
2645+
)

exir/tensor.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,21 @@ def to_list(
366366
tensor_size = to_list(spec.shape)
367367
tensor_dim_order = to_list(spec.dim_order)
368368

369+
extra_tensor_info = spec.extra_tensor_info
370+
# Propagate device from TensorSpec into ExtraTensorInfo for serialization.
371+
# Note: we don't propagate Device on CPU; if no device info will be noticed,
372+
# tensor_parser will automatic treat it as CPU:0, to prevent pte size
373+
# regression as much as possible.
374+
if spec.device != schema.DeviceType.CPU:
375+
if extra_tensor_info is None:
376+
extra_tensor_info = schema.ExtraTensorInfo(
377+
device_type=spec.device,
378+
device_index=spec.device_index,
379+
)
380+
else:
381+
extra_tensor_info.device_type = spec.device
382+
extra_tensor_info.device_index = spec.device_index
383+
369384
flatbuffer_tensor = schema.Tensor(
370385
scalar_type=scalar_type_enum(spec.scalar_type),
371386
# The runtime currently only supports tensors with offsets of zero.
@@ -377,7 +392,7 @@ def to_list(
377392
allocation_info=allocation_info,
378393
layout=layout_enum(spec.layout),
379394
shape_dynamism=spec.shape_dynamism,
380-
extra_tensor_info=spec.extra_tensor_info,
395+
extra_tensor_info=extra_tensor_info,
381396
)
382397
return flatbuffer_tensor
383398

0 commit comments

Comments
 (0)