Skip to content

Commit 05b0cab

Browse files
committed
[ET Device Support] PropagateDevicePass inserts H2D/D2H copy ops at delegate boundaries
Pull Request resolved: #18730 Extend PropagateDevicePass to insert explicit et_copy._h2d_copy and et_copy._d2h_copy ops at delegate boundaries, making the graph functional by explicitly transferring data between CPU and device memory. Key changes: - Inserts _h2d_copy before each delegate input, _d2h_copy after each output - Original input nodes stay CPU; h2d_copy output tagged as device - Getitem nodes inherit device; d2h_copy output tagged as CPU - Skip-copy optimizations via skip_h2d_for_method_inputs/skip_d2h_for_method_outputs - _parse_device_spec_value: lowercases string, raises ValueError for unknown types - _program.py passes config flags to PropagateDevicePass constructor ghstack-source-id: 364093613 @exported-using-ghexport Differential Revision: [D99636777](https://our.internmc.facebook.com/intern/diff/D99636777/)
1 parent 4d743fc commit 05b0cab

4 files changed

Lines changed: 346 additions & 52 deletions

File tree

exir/passes/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ fbcode_target(_kind = runtime.python_library,
466466
"propagate_device_pass.py",
467467
],
468468
deps = [
469+
":device_copy_ops_registry",
469470
"//caffe2:torch",
470471
"//executorch/exir:delegate",
471472
"//executorch/exir:lowered_backend_module",

exir/passes/propagate_device_pass.py

Lines changed: 156 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66

77
# pyre-strict
88

9+
import copy
910
import logging
11+
import operator
1012
from typing import Optional
1113

14+
# Import to register the et_copy ops so torch.ops.et_copy is available.
15+
import executorch.exir.passes._device_copy_ops_registry # noqa: F401
16+
1217
import executorch.exir.schema as schema
1318

1419
import torch
@@ -124,23 +129,150 @@ def _tag_specs_with_device(
124129
return False
125130

126131

132+
def _clone_spec_with_device(
133+
spec: TensorSpec,
134+
device_type: schema.DeviceType,
135+
device_index: int = 0,
136+
) -> TensorSpec:
137+
"""Create a copy of a TensorSpec with a different device."""
138+
new_spec = copy.copy(spec)
139+
new_spec.init_mem_planning_fields()
140+
_set_device_on_spec(new_spec, device_type, device_index)
141+
return new_spec
142+
143+
127144
class PropagateDevicePass(PassBase):
128145
"""
129-
After to_backend, walk the graph and set device metadata on TensorSpecs
130-
based on partitioner-assigned delegation info.
131-
132-
Rules:
133-
1. Delegated nodes: Input and output tensors of a delegate call are marked
134-
with the target device derived from the delegate's CompileSpec
135-
(key="target_device").
136-
2. Non-delegated nodes: Remain on CPU (default).
137-
3. Getitem nodes that extract from a delegate call inherit the device from
138-
the delegate call's output spec at the corresponding index.
146+
After to_backend, walk the graph and insert H2D/D2H copy ops at delegate
147+
boundaries based on partitioner-assigned device info.
148+
149+
When a delegate has a target_device CompileSpec (e.g., "cuda:0"):
150+
- For each delegate input: insert et_copy._h2d_copy before the delegate call.
151+
The original input node stays CPU; the h2d_copy output is tagged as device.
152+
- For each delegate output: insert et_copy._d2h_copy after each getitem.
153+
The getitem stays device; the d2h_copy output is tagged as CPU.
154+
- Getitem nodes that extract from a delegate call inherit the device.
155+
156+
Skip-copy optimizations:
157+
- skip_h2d_for_method_inputs: If the input is a graph-level placeholder
158+
feeding directly to a delegate, don't insert H2D — tag the placeholder
159+
as device instead (user provides GPU tensor at runtime).
160+
- skip_d2h_for_method_outputs: If the getitem feeds directly to graph
161+
output, don't insert D2H — the output stays on device.
139162
"""
140163

164+
def __init__(
165+
self,
166+
) -> None:
167+
super().__init__()
168+
169+
def _is_placeholder(self, node: torch.fx.Node) -> bool:
170+
"""Check if a node is a graph-level input (placeholder)."""
171+
return node.op == "placeholder"
172+
173+
def _feeds_directly_to_output(self, node: torch.fx.Node) -> bool:
174+
"""Check if all users of a node are output nodes."""
175+
return all(user.op == "output" for user in node.users)
176+
177+
def _insert_h2d_copies(
178+
self,
179+
graph_module: torch.fx.GraphModule,
180+
node: torch.fx.Node,
181+
target_device_type: schema.DeviceType,
182+
device_index: int,
183+
) -> bool:
184+
"""Insert H2D copy nodes for each tensor input to a delegate call."""
185+
changed = False
186+
new_args = list(node.args)
187+
for i, arg in enumerate(node.args[1:], start=1):
188+
if not isinstance(arg, torch.fx.Node):
189+
continue
190+
arg_spec = arg.meta.get("spec")
191+
if not isinstance(arg_spec, TensorSpec):
192+
continue
193+
194+
with graph_module.graph.inserting_before(node):
195+
h2d_node = graph_module.graph.call_function(
196+
torch.ops.et_copy._h2d_copy.default,
197+
(arg,),
198+
)
199+
h2d_spec = _clone_spec_with_device(
200+
arg_spec, target_device_type, device_index
201+
)
202+
h2d_node.meta["spec"] = h2d_spec
203+
h2d_node.meta["val"] = arg.meta.get("val")
204+
if "tensor_meta" in arg.meta:
205+
h2d_node.meta["tensor_meta"] = arg.meta["tensor_meta"]
206+
new_args[i] = h2d_node
207+
changed = True
208+
209+
node.args = tuple(new_args)
210+
return changed
211+
212+
def _insert_d2h_for_getitem(
213+
self,
214+
graph_module: torch.fx.GraphModule,
215+
node: torch.fx.Node,
216+
) -> bool:
217+
"""If *node* is a getitem extracting from a delegate call, tag its spec
218+
with the delegate device and insert a D2H copy after it."""
219+
source_node = node.args[0]
220+
if not (
221+
isinstance(source_node, torch.fx.Node)
222+
and source_node.op == "call_function"
223+
and source_node.target == executorch_call_delegate
224+
):
225+
return False
226+
227+
spec = node.meta.get("spec")
228+
source_specs = source_node.meta.get("spec")
229+
idx = node.args[1]
230+
if not (
231+
isinstance(spec, TensorSpec)
232+
and isinstance(source_specs, (tuple, list))
233+
and isinstance(idx, int)
234+
and idx < len(source_specs)
235+
):
236+
return False
237+
238+
source_spec = source_specs[idx]
239+
if not isinstance(source_spec, TensorSpec):
240+
return False
241+
242+
_set_device_on_spec(spec, source_spec.device, source_spec.device_index)
243+
244+
with graph_module.graph.inserting_after(node):
245+
d2h_node = graph_module.graph.call_function(
246+
torch.ops.et_copy._d2h_copy.default,
247+
(node,),
248+
)
249+
d2h_spec = _clone_spec_with_device(spec, schema.DeviceType.CPU, 0)
250+
d2h_node.meta["spec"] = d2h_spec
251+
d2h_node.meta["val"] = node.meta.get("val")
252+
if "tensor_meta" in node.meta:
253+
d2h_node.meta["tensor_meta"] = node.meta["tensor_meta"]
254+
255+
node.replace_all_uses_with(
256+
d2h_node,
257+
delete_user_cb=lambda user, _d2h=d2h_node: user != _d2h,
258+
)
259+
return True
260+
141261
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
262+
# Two-pass approach:
263+
# Pass 1 – For each delegate with a target_device CompileSpec, insert
264+
# H2D copy nodes before delegate inputs and tag the delegate
265+
# output specs with the target device. Delegates without a
266+
# target_device are left untouched (no copies, specs stay CPU).
267+
# Pass 2 – For each getitem that extracts from a device-tagged delegate
268+
# (tracked in device_delegates), propagate the device onto the
269+
# getitem spec and insert a D2H copy after it so downstream
270+
# non-delegated ops receive CPU tensors.
142271
changed = False
143-
for node in graph_module.graph.nodes:
272+
device_delegates: set[torch.fx.Node] = set()
273+
274+
# Pass 1: insert H2D copies and tag delegate output specs.
275+
for node in list(graph_module.graph.nodes):
144276
if node.op == "call_function" and node.target == executorch_call_delegate:
145277
lowered_module = _get_lowered_module(graph_module, node)
146278
if lowered_module is None:
@@ -151,18 +283,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
151283
continue
152284

153285
target_device_type, device_index = result
286+
device_delegates.add(node)
287+
288+
changed |= self._insert_h2d_copies(
289+
graph_module, node, target_device_type, device_index
290+
)
154291

155-
# Tag delegate input tensors.
156-
# args[0] is the get_attr node for the lowered module; skip it.
157-
for arg in node.args[1:]:
158-
if isinstance(arg, torch.fx.Node):
159-
changed |= _tag_specs_with_device(
160-
arg.meta.get("spec"),
161-
target_device_type,
162-
device_index,
163-
)
164-
165-
# Tag delegate output tensors.
166292
changed |= _tag_specs_with_device(
167293
node.meta.get("spec"),
168294
target_device_type,
@@ -177,34 +303,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
177303
lowered_module.backend_id,
178304
)
179305

180-
# Second pass: propagate device through getitem nodes that extract
181-
# individual outputs from a delegate call.
182-
for node in graph_module.graph.nodes:
183-
if node.op == "call_function" and node.target.__name__ == "getitem":
184-
source_node = node.args[0]
185-
if (
186-
isinstance(source_node, torch.fx.Node)
187-
and source_node.op == "call_function"
188-
and source_node.target == executorch_call_delegate
189-
):
190-
spec = node.meta.get("spec")
191-
source_specs = source_node.meta.get("spec")
192-
idx = node.args[1]
193-
if (
194-
spec is not None
195-
and isinstance(spec, TensorSpec)
196-
and source_specs is not None
197-
and isinstance(source_specs, (tuple, list))
198-
and isinstance(idx, int)
199-
and idx < len(source_specs)
200-
):
201-
source_spec = source_specs[idx]
202-
if isinstance(source_spec, TensorSpec):
203-
_set_device_on_spec(
204-
spec,
205-
source_spec.device,
206-
source_spec.device_index,
207-
)
208-
changed = True
306+
# Second pass: propagate device through getitem nodes and insert D2H
307+
# only for delegates that have a target_device.
308+
for node in list(graph_module.graph.nodes):
309+
if node.op == "call_function" and node.target == operator.getitem:
310+
source = node.args[0]
311+
if isinstance(source, torch.fx.Node) and source in device_delegates:
312+
changed |= self._insert_d2h_for_getitem(graph_module, node)
209313

314+
graph_module.recompile()
210315
return PassResult(graph_module, changed)

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ python_unittest(
502502
"//executorch/exir/backend/test:backend_with_compiler_demo",
503503
"//executorch/exir/dialects:lib",
504504
"//executorch/exir/passes:propagate_device_pass",
505+
"//executorch/exir/passes:device_copy_ops_registry",
505506
],
506507
)
507508

0 commit comments

Comments
 (0)