Skip to content

Commit 034fa54

Browse files
committed
[ET Device Support] Propagate device metadata from partitioner result onto TensorSpecs
Pull Request resolved: #18078 Annotate the delegate's input and output tensors as specific device type The overall pipeline is: a. Partitioner use `compile_spec` to determine which device the partitoned blob is runing on b. after lowered partitioned graph to backend, the new-introed propagate_device_pass will annotate the input and output tensors of delegate blob as target device. ghstack-source-id: 354478927 @exported-using-ghexport Differential Revision: [D95842511](https://our.internmc.facebook.com/intern/diff/D95842511/)
1 parent ef66322 commit 034fa54

8 files changed

Lines changed: 797 additions & 0 deletions

File tree

exir/passes/BUCK

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,17 @@ fbcode_target(_kind = runtime.python_library,
439439
"//caffe2:torch",
440440
],
441441
)
442+
443+
fbcode_target(_kind = runtime.python_library,
444+
name = "propagate_device_pass",
445+
srcs = [
446+
"propagate_device_pass.py",
447+
],
448+
deps = [
449+
"//caffe2:torch",
450+
"//executorch/exir:delegate",
451+
"//executorch/exir:lowered_backend_module",
452+
"//executorch/exir:schema",
453+
"//executorch/exir:tensor",
454+
],
455+
)
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import logging
10+
from typing import Optional
11+
12+
import executorch.exir.schema as schema
13+
14+
import torch
15+
from executorch.exir.delegate import executorch_call_delegate
16+
from executorch.exir.lowered_backend_module import LoweredBackendModule
17+
from executorch.exir.tensor import TensorSpec
18+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
19+
20+
logger: logging.Logger = logging.getLogger(__name__)
21+
22+
# CompileSpec key convention for specifying the target device.
23+
# Partitioners that target a specific device should include a CompileSpec entry
24+
# with this key and a value encoding the device string (e.g., b"cuda:0").
25+
TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device"
26+
27+
# Mapping from torch.device type strings to schema.DeviceType.
28+
_DEVICE_STR_TO_ET_DEVICE: dict[str, schema.DeviceType] = {
29+
"cpu": schema.DeviceType.CPU,
30+
"cuda": schema.DeviceType.CUDA,
31+
}
32+
33+
34+
def _parse_device_spec_value(value: bytes) -> tuple[schema.DeviceType, int]:
35+
"""
36+
Parse a target_device CompileSpec value (e.g., b"cuda:0") into
37+
(DeviceType, device_index).
38+
"""
39+
device_str = value.decode("utf-8")
40+
torch_device = torch.device(device_str)
41+
device_type = _DEVICE_STR_TO_ET_DEVICE.get(torch_device.type, schema.DeviceType.CPU)
42+
device_index = torch_device.index if torch_device.index is not None else 0
43+
return device_type, device_index
44+
45+
46+
def _get_lowered_module(
47+
graph_module: torch.fx.GraphModule,
48+
delegate_call_node: torch.fx.Node,
49+
) -> Optional[LoweredBackendModule]:
50+
"""
51+
Given an executorch_call_delegate node, retrieve the associated
52+
LoweredBackendModule from the graph module.
53+
The first argument to executorch_call_delegate is a get_attr node
54+
whose target names the LoweredBackendModule attribute.
55+
"""
56+
if len(delegate_call_node.args) < 1:
57+
return None
58+
lowered_node = delegate_call_node.args[0]
59+
if not isinstance(lowered_node, torch.fx.Node) or lowered_node.op != "get_attr":
60+
return None
61+
lowered_module = getattr(graph_module, lowered_node.target, None)
62+
if isinstance(lowered_module, LoweredBackendModule):
63+
return lowered_module
64+
return None
65+
66+
67+
def _get_target_device_from_compile_specs(
68+
lowered_module: LoweredBackendModule,
69+
) -> Optional[tuple[schema.DeviceType, int]]:
70+
"""
71+
Look for a CompileSpec with key TARGET_DEVICE_COMPILE_SPEC_KEY and return
72+
the corresponding (DeviceType, device_index), or None if not found.
73+
"""
74+
for spec in lowered_module.compile_specs:
75+
if spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY:
76+
return _parse_device_spec_value(spec.value)
77+
return None
78+
79+
80+
def _set_device_on_spec(
81+
spec: TensorSpec,
82+
device_type: schema.DeviceType,
83+
) -> None:
84+
"""Set the device attribute on a TensorSpec."""
85+
spec.device = device_type
86+
87+
88+
def _tag_specs_with_device(
89+
specs: object,
90+
device_type: schema.DeviceType,
91+
) -> bool:
92+
"""Apply device annotation to a TensorSpec or a collection of TensorSpecs.
93+
94+
Args:
95+
specs: A TensorSpec, a tuple/list of TensorSpecs, or None.
96+
device_type: The target device type to set.
97+
98+
Returns:
99+
True if any spec was modified, False otherwise.
100+
"""
101+
if specs is None:
102+
return False
103+
if isinstance(specs, TensorSpec):
104+
_set_device_on_spec(specs, device_type)
105+
return True
106+
if isinstance(specs, (tuple, list)):
107+
changed = False
108+
for s in specs:
109+
if isinstance(s, TensorSpec):
110+
_set_device_on_spec(s, device_type)
111+
changed = True
112+
return changed
113+
return False
114+
115+
116+
class PropagateDevicePass(PassBase):
117+
"""
118+
After to_backend, walk the graph and set device metadata on TensorSpecs
119+
based on partitioner-assigned delegation info.
120+
121+
Rules:
122+
1. Delegated nodes: Input and output tensors of a delegate call are marked
123+
with the target device derived from the delegate's CompileSpec
124+
(key="target_device").
125+
2. Non-delegated nodes: Remain on CPU (default).
126+
3. Getitem nodes that extract from a delegate call inherit the device from
127+
the delegate call's output spec at the corresponding index.
128+
"""
129+
130+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
131+
changed = False
132+
for node in graph_module.graph.nodes:
133+
if node.op == "call_function" and node.target == executorch_call_delegate:
134+
lowered_module = _get_lowered_module(graph_module, node)
135+
if lowered_module is None:
136+
continue
137+
138+
result = _get_target_device_from_compile_specs(lowered_module)
139+
if result is None:
140+
continue
141+
142+
target_device_type, _device_index = result
143+
144+
# Tag delegate input tensors.
145+
# args[0] is the get_attr node for the lowered module; skip it.
146+
for arg in node.args[1:]:
147+
if isinstance(arg, torch.fx.Node):
148+
changed |= _tag_specs_with_device(
149+
arg.meta.get("spec"), target_device_type
150+
)
151+
152+
# Tag delegate output tensors.
153+
changed |= _tag_specs_with_device(
154+
node.meta.get("spec"), target_device_type
155+
)
156+
157+
logger.debug(
158+
"PropagateDevicePass: set device=%s on delegate node %s "
159+
"(backend=%s)",
160+
target_device_type,
161+
node.name,
162+
lowered_module.backend_id,
163+
)
164+
165+
# Second pass: propagate device through getitem nodes that extract
166+
# individual outputs from a delegate call.
167+
for node in graph_module.graph.nodes:
168+
if node.op == "call_function" and node.target.__name__ == "getitem":
169+
source_node = node.args[0]
170+
if (
171+
isinstance(source_node, torch.fx.Node)
172+
and source_node.op == "call_function"
173+
and source_node.target == executorch_call_delegate
174+
):
175+
spec = node.meta.get("spec")
176+
source_specs = source_node.meta.get("spec")
177+
idx = node.args[1]
178+
if (
179+
spec is not None
180+
and isinstance(spec, TensorSpec)
181+
and source_specs is not None
182+
and isinstance(source_specs, (tuple, list))
183+
and isinstance(idx, int)
184+
and idx < len(source_specs)
185+
):
186+
source_spec = source_specs[idx]
187+
if isinstance(source_spec, TensorSpec):
188+
_set_device_on_spec(spec, source_spec.device)
189+
changed = True
190+
191+
return PassResult(graph_module, changed)

exir/passes/replace_view_copy_with_view_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
110110
"mem_offset",
111111
"dtype", # property
112112
"extra_tensor_info", # property
113+
"device",
113114
]
114115

115116
# Make sure _self_fields and _base_fields are disjoint

exir/program/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ fbcode_target(_kind = runtime.python_library,
4040
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
4141
"//executorch/exir/passes:lib",
4242
"//executorch/exir/passes:normalize_view_copy_base_pass",
43+
"//executorch/exir/passes:propagate_device_pass",
4344
"//executorch/exir/passes:remove_graph_asserts_pass",
4445
"//executorch/exir/passes:remove_mixed_type_operators",
4546
"//executorch/exir/passes:replace_aten_with_edge_pass",

exir/program/_program.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from executorch.exir.passes.normalize_view_copy_base_pass import (
6060
NormalizeViewCopyBasePass,
6161
)
62+
from executorch.exir.passes.propagate_device_pass import PropagateDevicePass
6263
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
6364
from executorch.exir.passes.reinplace import reinplace_pass
6465
from executorch.exir.passes.remove_graph_asserts_pass import (
@@ -848,6 +849,10 @@ def edge_to_executorch_passes(
848849
# there exists an unbacked symint operation.
849850
*config.passes,
850851
SpecPropPass(),
852+
# Propagate device metadata (e.g., CUDA) from delegate CompileSpecs onto
853+
# TensorSpecs. Must run after SpecPropPass so specs are freshly created
854+
# with correct shapes.
855+
PropagateDevicePass(),
851856
EdgeToBackendOpsPass(),
852857
RemoveGraphAssertsPass(),
853858
] + pre_memory_planning_passes(config, name)

exir/tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def __init__(
172172
self.init_mem_planning_fields()
173173
self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(self.shape)
174174
self.extra_tensor_info = extra_tensor_info
175+
# device type will be only updated during PropagateDevicePass.
176+
self.device: schema.DeviceType = schema.DeviceType.CPU
175177

176178
@property
177179
def allocated_memory(self) -> int:
@@ -254,6 +256,7 @@ def __repr__(self) -> str:
254256
+ f", is_sparse={self.is_sparse}"
255257
+ f", shape_dynamism={self.shape_dynamism}"
256258
+ f", const={self.const}, requires_grad={self.requires_grad}"
259+
+ f", device={self.device.name}"
257260
+ ")"
258261
)
259262

exir/tests/TARGETS

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,3 +516,23 @@ python_unittest(
516516
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
517517
],
518518
)
519+
520+
python_unittest(
521+
name = "propagate_device_pass",
522+
srcs = [
523+
"test_propagate_device_pass.py",
524+
],
525+
deps = [
526+
"//caffe2:torch",
527+
"//executorch/exir:lib",
528+
"//executorch/exir:schema",
529+
"//executorch/exir:tensor",
530+
"//executorch/exir/backend:backend_api",
531+
"//executorch/exir/backend:compile_spec_schema",
532+
"//executorch/exir/backend:partitioner",
533+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
534+
"//executorch/exir/backend/test:backend_with_compiler_demo",
535+
"//executorch/exir/dialects:lib",
536+
"//executorch/exir/passes:propagate_device_pass",
537+
],
538+
)

0 commit comments

Comments
 (0)