Skip to content

Commit 53259a5

Browse files
committed
[ET Device Support] Propagate device metadata from partitioner result onto TensorSpecs
Pull Request resolved: #18078 Annotate the delegate's input and output tensors as specific device type The overall pipeline is: a. Partitioner use `compile_spec` to determine which device the partitoned blob is runing on b. after lowered partitioned graph to backend, the new-introed propagate_device_pass will annotate the input and output tensors of delegate blob as target device with correct device index. ghstack-source-id: 355133241 @exported-using-ghexport Differential Revision: [D95842511](https://our.internmc.facebook.com/intern/diff/D95842511/)
1 parent 1d70eb6 commit 53259a5

8 files changed

Lines changed: 879 additions & 0 deletions

File tree

exir/passes/BUCK

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,17 @@ fbcode_target(_kind = runtime.python_library,
439439
"//caffe2:torch",
440440
],
441441
)
442+
443+
fbcode_target(_kind = runtime.python_library,
444+
name = "propagate_device_pass",
445+
srcs = [
446+
"propagate_device_pass.py",
447+
],
448+
deps = [
449+
"//caffe2:torch",
450+
"//executorch/exir:delegate",
451+
"//executorch/exir:lowered_backend_module",
452+
"//executorch/exir:schema",
453+
"//executorch/exir:tensor",
454+
],
455+
)
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
from typing import Optional
11+
12+
import executorch.exir.schema as schema
13+
14+
import torch
15+
from executorch.exir.delegate import executorch_call_delegate
16+
from executorch.exir.lowered_backend_module import LoweredBackendModule
17+
from executorch.exir.tensor import TensorSpec
18+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
19+
20+
logger: logging.Logger = logging.getLogger(__name__)
21+
22+
# CompileSpec key convention for specifying the target device.
23+
# Partitioners that target a specific device should include a CompileSpec entry
24+
# with this key and a value encoding the device string (e.g., b"cuda:0").
25+
TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device"
26+
27+
# Mapping from torch.device type strings to schema.DeviceType.
28+
_DEVICE_STR_TO_ET_DEVICE: dict[str, schema.DeviceType] = {
29+
"cpu": schema.DeviceType.CPU,
30+
"cuda": schema.DeviceType.CUDA,
31+
}
32+
33+
34+
def _parse_device_spec_value(value: bytes) -> tuple[schema.DeviceType, int]:
35+
"""
36+
Parse a target_device CompileSpec value (e.g., b"cuda:0") into
37+
(DeviceType, device_index).
38+
"""
39+
device_str = value.decode("utf-8")
40+
torch_device = torch.device(device_str)
41+
device_type = _DEVICE_STR_TO_ET_DEVICE.get(torch_device.type, schema.DeviceType.CPU)
42+
device_index = torch_device.index if torch_device.index is not None else 0
43+
return device_type, device_index
44+
45+
46+
def _get_lowered_module(
47+
graph_module: torch.fx.GraphModule,
48+
delegate_call_node: torch.fx.Node,
49+
) -> Optional[LoweredBackendModule]:
50+
"""
51+
Given an executorch_call_delegate node, retrieve the associated
52+
LoweredBackendModule from the graph module.
53+
The first argument to executorch_call_delegate is a get_attr node
54+
whose target names the LoweredBackendModule attribute.
55+
"""
56+
if len(delegate_call_node.args) < 1:
57+
return None
58+
lowered_node = delegate_call_node.args[0]
59+
if not isinstance(lowered_node, torch.fx.Node) or lowered_node.op != "get_attr":
60+
return None
61+
lowered_module = getattr(graph_module, lowered_node.target, None)
62+
if isinstance(lowered_module, LoweredBackendModule):
63+
return lowered_module
64+
return None
65+
66+
67+
def _get_target_device_from_compile_specs(
68+
lowered_module: LoweredBackendModule,
69+
) -> Optional[tuple[schema.DeviceType, int]]:
70+
"""
71+
Look for a CompileSpec with key TARGET_DEVICE_COMPILE_SPEC_KEY and return
72+
the corresponding (DeviceType, device_index), or None if not found.
73+
"""
74+
for spec in lowered_module.compile_specs:
75+
if spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY:
76+
return _parse_device_spec_value(spec.value)
77+
return None
78+
79+
80+
def _set_device_on_spec(
81+
spec: TensorSpec,
82+
device_type: schema.DeviceType,
83+
device_index: int = 0,
84+
) -> None:
85+
"""Set the device attribute on a TensorSpec."""
86+
spec.device = device_type
87+
spec.device_index = device_index
88+
89+
90+
def _tag_specs_with_device(
91+
specs: object,
92+
device_type: schema.DeviceType,
93+
device_index: int = 0,
94+
) -> bool:
95+
"""Apply device annotation to a TensorSpec or a collection of TensorSpecs.
96+
97+
Args:
98+
specs: A TensorSpec, a tuple/list of TensorSpecs, or None.
99+
device_type: The target device type to set.
100+
device_index: The device index (e.g., 0 for cuda:0, 1 for cuda:1).
101+
102+
Returns:
103+
True if any spec was modified, False otherwise.
104+
"""
105+
if specs is None:
106+
return False
107+
if isinstance(specs, TensorSpec):
108+
_set_device_on_spec(specs, device_type, device_index)
109+
return True
110+
if isinstance(specs, (tuple, list)):
111+
changed = False
112+
for s in specs:
113+
if isinstance(s, TensorSpec):
114+
_set_device_on_spec(s, device_type, device_index)
115+
changed = True
116+
return changed
117+
return False
118+
119+
120+
class PropagateDevicePass(PassBase):
121+
"""
122+
After to_backend, walk the graph and set device metadata on TensorSpecs
123+
based on partitioner-assigned delegation info.
124+
125+
Rules:
126+
1. Delegated nodes: Input and output tensors of a delegate call are marked
127+
with the target device derived from the delegate's CompileSpec
128+
(key="target_device").
129+
2. Non-delegated nodes: Remain on CPU (default).
130+
3. Getitem nodes that extract from a delegate call inherit the device from
131+
the delegate call's output spec at the corresponding index.
132+
"""
133+
134+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
135+
changed = False
136+
for node in graph_module.graph.nodes:
137+
if node.op == "call_function" and node.target == executorch_call_delegate:
138+
lowered_module = _get_lowered_module(graph_module, node)
139+
if lowered_module is None:
140+
continue
141+
142+
result = _get_target_device_from_compile_specs(lowered_module)
143+
if result is None:
144+
continue
145+
146+
target_device_type, device_index = result
147+
148+
# Tag delegate input tensors.
149+
# args[0] is the get_attr node for the lowered module; skip it.
150+
for arg in node.args[1:]:
151+
if isinstance(arg, torch.fx.Node):
152+
changed |= _tag_specs_with_device(
153+
arg.meta.get("spec"),
154+
target_device_type,
155+
device_index,
156+
)
157+
158+
# Tag delegate output tensors.
159+
changed |= _tag_specs_with_device(
160+
node.meta.get("spec"),
161+
target_device_type,
162+
device_index,
163+
)
164+
165+
logger.debug(
166+
"PropagateDevicePass: set device=%s on delegate node %s "
167+
"(backend=%s)",
168+
target_device_type,
169+
node.name,
170+
lowered_module.backend_id,
171+
)
172+
173+
# Second pass: propagate device through getitem nodes that extract
174+
# individual outputs from a delegate call.
175+
for node in graph_module.graph.nodes:
176+
if node.op == "call_function" and node.target.__name__ == "getitem":
177+
source_node = node.args[0]
178+
if (
179+
isinstance(source_node, torch.fx.Node)
180+
and source_node.op == "call_function"
181+
and source_node.target == executorch_call_delegate
182+
):
183+
spec = node.meta.get("spec")
184+
source_specs = source_node.meta.get("spec")
185+
idx = node.args[1]
186+
if (
187+
spec is not None
188+
and isinstance(spec, TensorSpec)
189+
and source_specs is not None
190+
and isinstance(source_specs, (tuple, list))
191+
and isinstance(idx, int)
192+
and idx < len(source_specs)
193+
):
194+
source_spec = source_specs[idx]
195+
if isinstance(source_spec, TensorSpec):
196+
_set_device_on_spec(
197+
spec,
198+
source_spec.device,
199+
source_spec.device_index,
200+
)
201+
changed = True
202+
203+
return PassResult(graph_module, changed)

exir/passes/replace_view_copy_with_view_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
110110
"mem_offset",
111111
"dtype", # property
112112
"extra_tensor_info", # property
113+
"device",
113114
]
114115

115116
# Make sure _self_fields and _base_fields are disjoint

exir/program/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ fbcode_target(_kind = runtime.python_library,
4040
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
4141
"//executorch/exir/passes:lib",
4242
"//executorch/exir/passes:normalize_view_copy_base_pass",
43+
"//executorch/exir/passes:propagate_device_pass",
4344
"//executorch/exir/passes:remove_graph_asserts_pass",
4445
"//executorch/exir/passes:remove_mixed_type_operators",
4546
"//executorch/exir/passes:replace_aten_with_edge_pass",

exir/program/_program.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from executorch.exir.passes.normalize_view_copy_base_pass import (
6060
NormalizeViewCopyBasePass,
6161
)
62+
from executorch.exir.passes.propagate_device_pass import PropagateDevicePass
6263
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
6364
from executorch.exir.passes.reinplace import reinplace_pass
6465
from executorch.exir.passes.remove_graph_asserts_pass import (
@@ -848,6 +849,10 @@ def edge_to_executorch_passes(
848849
# there exists an unbacked symint operation.
849850
*config.passes,
850851
SpecPropPass(),
852+
# Propagate device metadata (e.g., CUDA) from delegate CompileSpecs onto
853+
# TensorSpecs. Must run after SpecPropPass so specs are freshly created
854+
# with correct shapes.
855+
PropagateDevicePass(),
851856
EdgeToBackendOpsPass(),
852857
RemoveGraphAssertsPass(),
853858
] + pre_memory_planning_passes(config, name)

exir/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ def __init__(
172172
self.init_mem_planning_fields()
173173
self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(self.shape)
174174
self.extra_tensor_info = extra_tensor_info
175+
# device type will be only updated during PropagateDevicePass.
176+
self.device: schema.DeviceType = schema.DeviceType.CPU
177+
self.device_index: int = 0
175178

176179
@property
177180
def allocated_memory(self) -> int:
@@ -254,6 +257,7 @@ def __repr__(self) -> str:
254257
+ f", is_sparse={self.is_sparse}"
255258
+ f", shape_dynamism={self.shape_dynamism}"
256259
+ f", const={self.const}, requires_grad={self.requires_grad}"
260+
+ f", device={self.device.name}:{self.device_index}"
257261
+ ")"
258262
)
259263

exir/tests/TARGETS

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,3 +516,23 @@ python_unittest(
516516
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
517517
],
518518
)
519+
520+
python_unittest(
521+
name = "propagate_device_pass",
522+
srcs = [
523+
"test_propagate_device_pass.py",
524+
],
525+
deps = [
526+
"//caffe2:torch",
527+
"//executorch/exir:lib",
528+
"//executorch/exir:schema",
529+
"//executorch/exir:tensor",
530+
"//executorch/exir/backend:backend_api",
531+
"//executorch/exir/backend:compile_spec_schema",
532+
"//executorch/exir/backend:partitioner",
533+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
534+
"//executorch/exir/backend/test:backend_with_compiler_demo",
535+
"//executorch/exir/dialects:lib",
536+
"//executorch/exir/passes:propagate_device_pass",
537+
],
538+
)

0 commit comments

Comments
 (0)