Skip to content

Commit 35db06d

Browse files
committed
[ET Device Support] Propagate device metadata from partitioner result onto TensorSpecs
Pull Request resolved: #18078 Annotate the delegate's input and output tensors as specific device type The overall pipeline is: a. Partitioner use `compile_spec` to determine which device the partitoned blob is runing on b. after lowered partitioned graph to backend, the new-introed propagate_device_pass will annotate the input and output tensors of delegate blob as target device. ghstack-source-id: 353202790 @exported-using-ghexport Differential Revision: [D95842511](https://our.internmc.facebook.com/intern/diff/D95842511/)
1 parent 15fcc67 commit 35db06d

8 files changed

Lines changed: 734 additions & 0 deletions

File tree

exir/passes/BUCK

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,17 @@ fbcode_target(_kind = runtime.python_library,
439439
"//caffe2:torch",
440440
],
441441
)
442+
443+
fbcode_target(_kind = runtime.python_library,
444+
name = "propagate_device_pass",
445+
srcs = [
446+
"propagate_device_pass.py",
447+
],
448+
deps = [
449+
"//caffe2:torch",
450+
"//executorch/exir:delegate",
451+
"//executorch/exir:lowered_backend_module",
452+
"//executorch/exir:schema",
453+
"//executorch/exir:tensor",
454+
],
455+
)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
from typing import Optional
11+
12+
import executorch.exir.schema as schema
13+
14+
import torch
15+
from executorch.exir.delegate import executorch_call_delegate
16+
from executorch.exir.lowered_backend_module import LoweredBackendModule
17+
from executorch.exir.tensor import TensorSpec
18+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
19+
20+
logger: logging.Logger = logging.getLogger(__name__)
21+
22+
# CompileSpec key convention for specifying the target device.
23+
# Partitioners that target a specific device should include a CompileSpec entry
24+
# with this key and a value encoding the device string (e.g., b"cuda:0").
25+
TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device"
26+
27+
# Mapping from torch.device type strings to schema.DeviceType.
28+
_DEVICE_STR_TO_ET_DEVICE: dict[str, schema.DeviceType] = {
29+
"cpu": schema.DeviceType.CPU,
30+
"cuda": schema.DeviceType.CUDA,
31+
}
32+
33+
34+
def _parse_device_spec_value(value: bytes) -> tuple[schema.DeviceType, int]:
35+
"""
36+
Parse a target_device CompileSpec value (e.g., b"cuda:0") into
37+
(DeviceType, device_index).
38+
"""
39+
device_str = value.decode("utf-8")
40+
torch_device = torch.device(device_str)
41+
device_type = _DEVICE_STR_TO_ET_DEVICE.get(torch_device.type, schema.DeviceType.CPU)
42+
device_index = torch_device.index if torch_device.index is not None else 0
43+
return device_type, device_index
44+
45+
46+
def _get_lowered_module(
47+
graph_module: torch.fx.GraphModule,
48+
delegate_call_node: torch.fx.Node,
49+
) -> Optional[LoweredBackendModule]:
50+
"""
51+
Given an executorch_call_delegate node, retrieve the associated
52+
LoweredBackendModule from the graph module.
53+
The first argument to executorch_call_delegate is a get_attr node
54+
whose target names the LoweredBackendModule attribute.
55+
"""
56+
if len(delegate_call_node.args) < 1:
57+
return None
58+
lowered_node = delegate_call_node.args[0]
59+
if not isinstance(lowered_node, torch.fx.Node) or lowered_node.op != "get_attr":
60+
return None
61+
lowered_module = getattr(graph_module, lowered_node.target, None)
62+
if isinstance(lowered_module, LoweredBackendModule):
63+
return lowered_module
64+
return None
65+
66+
67+
def _get_target_device_from_compile_specs(
68+
lowered_module: LoweredBackendModule,
69+
) -> Optional[tuple[schema.DeviceType, int]]:
70+
"""
71+
Look for a CompileSpec with key TARGET_DEVICE_COMPILE_SPEC_KEY and return
72+
the corresponding (DeviceType, device_index), or None if not found.
73+
"""
74+
for spec in lowered_module.compile_specs:
75+
if spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY:
76+
return _parse_device_spec_value(spec.value)
77+
return None
78+
79+
80+
def _set_device_on_spec(
81+
spec: TensorSpec,
82+
device_type: schema.DeviceType,
83+
) -> None:
84+
"""Set the device attribute on a TensorSpec."""
85+
spec.device = device_type
86+
87+
88+
class PropagateDevicePass(PassBase):
89+
"""
90+
After to_backend, walk the graph and set device metadata on TensorSpecs
91+
based on partitioner-assigned delegation info.
92+
93+
Rules:
94+
1. Delegated nodes: Output tensors of a delegate call are marked with the
95+
target device derived from the delegate's CompileSpec (key="target_device").
96+
2. Non-delegated nodes: Remain on CPU (default).
97+
3. Getitem nodes that extract from a delegate call inherit the device from
98+
the delegate call's output spec at the corresponding index.
99+
"""
100+
101+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
102+
changed = False
103+
for node in graph_module.graph.nodes:
104+
if node.op == "call_function" and node.target == executorch_call_delegate:
105+
lowered_module = _get_lowered_module(graph_module, node)
106+
if lowered_module is None:
107+
continue
108+
109+
result = _get_target_device_from_compile_specs(lowered_module)
110+
if result is None:
111+
continue
112+
113+
target_device_type, _device_index = result
114+
115+
# Mark all output TensorSpecs of this delegate call node
116+
specs = node.meta.get("spec")
117+
if specs is None:
118+
continue
119+
120+
if isinstance(specs, TensorSpec):
121+
_set_device_on_spec(specs, target_device_type)
122+
changed = True
123+
elif isinstance(specs, (tuple, list)):
124+
for s in specs:
125+
if isinstance(s, TensorSpec):
126+
_set_device_on_spec(s, target_device_type)
127+
changed = True
128+
129+
logger.debug(
130+
"PropagateDevicePass: set device=%s on delegate node %s "
131+
"(backend=%s)",
132+
target_device_type,
133+
node.name,
134+
lowered_module.backend_id,
135+
)
136+
137+
# Second pass: propagate device through getitem nodes that extract
138+
# individual outputs from a delegate call.
139+
for node in graph_module.graph.nodes:
140+
if node.op == "call_function" and node.target.__name__ == "getitem":
141+
source_node = node.args[0]
142+
if (
143+
isinstance(source_node, torch.fx.Node)
144+
and source_node.op == "call_function"
145+
and source_node.target == executorch_call_delegate
146+
):
147+
spec = node.meta.get("spec")
148+
source_specs = source_node.meta.get("spec")
149+
idx = node.args[1]
150+
if (
151+
spec is not None
152+
and isinstance(spec, TensorSpec)
153+
and source_specs is not None
154+
and isinstance(source_specs, (tuple, list))
155+
and isinstance(idx, int)
156+
and idx < len(source_specs)
157+
):
158+
source_spec = source_specs[idx]
159+
if isinstance(source_spec, TensorSpec):
160+
_set_device_on_spec(spec, source_spec.device)
161+
changed = True
162+
163+
return PassResult(graph_module, changed)

exir/passes/replace_view_copy_with_view_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
110110
"mem_offset",
111111
"dtype", # property
112112
"extra_tensor_info", # property
113+
"device",
113114
]
114115

115116
# Make sure _self_fields and _base_fields are disjoint

exir/program/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ fbcode_target(_kind = runtime.python_library,
4040
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
4141
"//executorch/exir/passes:lib",
4242
"//executorch/exir/passes:normalize_view_copy_base_pass",
43+
"//executorch/exir/passes:propagate_device_pass",
4344
"//executorch/exir/passes:remove_graph_asserts_pass",
4445
"//executorch/exir/passes:remove_mixed_type_operators",
4546
"//executorch/exir/passes:replace_aten_with_edge_pass",

exir/program/_program.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from executorch.exir.passes.normalize_view_copy_base_pass import (
6060
NormalizeViewCopyBasePass,
6161
)
62+
from executorch.exir.passes.propagate_device_pass import PropagateDevicePass
6263
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
6364
from executorch.exir.passes.reinplace import reinplace_pass
6465
from executorch.exir.passes.remove_graph_asserts_pass import (
@@ -848,6 +849,10 @@ def edge_to_executorch_passes(
848849
# there exists an unbacked symint operation.
849850
*config.passes,
850851
SpecPropPass(),
852+
# Propagate device metadata (e.g., CUDA) from delegate CompileSpecs onto
853+
# TensorSpecs. Must run after SpecPropPass so specs are freshly created
854+
# with correct shapes.
855+
PropagateDevicePass(),
851856
EdgeToBackendOpsPass(),
852857
RemoveGraphAssertsPass(),
853858
] + pre_memory_planning_passes(config, name)

exir/tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def __init__(
172172
self.init_mem_planning_fields()
173173
self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(self.shape)
174174
self.extra_tensor_info = extra_tensor_info
175+
# device type will be only updated during PropagateDevicePass.
176+
self.device: schema.DeviceType = schema.DeviceType.CPU
175177

176178
@property
177179
def allocated_memory(self) -> int:
@@ -254,6 +256,7 @@ def __repr__(self) -> str:
254256
+ f", is_sparse={self.is_sparse}"
255257
+ f", shape_dynamism={self.shape_dynamism}"
256258
+ f", const={self.const}, requires_grad={self.requires_grad}"
259+
+ f", device={self.device.name}"
257260
+ ")"
258261
)
259262

exir/tests/TARGETS

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,3 +516,23 @@ python_unittest(
516516
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
517517
],
518518
)
519+
520+
python_unittest(
521+
name = "propagate_device_pass",
522+
srcs = [
523+
"test_propagate_device_pass.py",
524+
],
525+
deps = [
526+
"//caffe2:torch",
527+
"//executorch/exir:lib",
528+
"//executorch/exir:schema",
529+
"//executorch/exir:tensor",
530+
"//executorch/exir/backend:backend_api",
531+
"//executorch/exir/backend:compile_spec_schema",
532+
"//executorch/exir/backend:partitioner",
533+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
534+
"//executorch/exir/backend/test:backend_with_compiler_demo",
535+
"//executorch/exir/dialects:lib",
536+
"//executorch/exir/passes:propagate_device_pass",
537+
],
538+
)

0 commit comments

Comments
 (0)