Skip to content

Commit 8e700ab

Browse files
committed
[ET Device Support] Add ExecutorchBackendConfig flags for skipping H2D/D2H copies
Add skip_h2d_for_method_inputs and skip_d2h_for_method_outputs config flags to ExecutorchBackendConfig. These control whether PropagateDevicePass skips inserting H2D/D2H copy ops at method I/O boundaries: - skip_h2d_for_method_inputs: user provides GPU tensor directly - skip_d2h_for_method_outputs: output stays on device for cross-method pipelines Differential Revision: [D99636778](https://our.internmc.facebook.com/intern/diff/D99636778/) ghstack-source-id: 364093611 Pull Request resolved: #18760
1 parent 05b0cab commit 8e700ab

4 files changed

Lines changed: 413 additions & 35 deletions

File tree

exir/capture/_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,15 @@ class ExecutorchBackendConfig:
123123
# vs. accelerator memory. Default False preserves the legacy behavior
124124
# where all tensors are planned into CPU memory regardless of device.
125125
enable_non_cpu_memory_planning: bool = False
126+
127+
# When True, method-level input tensors that feed directly into a device
128+
# delegate are NOT wrapped with _h2d_copy. The user must provide tensors
129+
# already on the target device. Useful for pipelines where inputs are
130+
# pre-staged on GPU.
131+
skip_h2d_for_method_inputs: bool = False
132+
133+
# When True, device delegate outputs that are directly method outputs
134+
# are NOT wrapped with _d2h_copy. The method outputs stay on device.
135+
# Useful for cross-method GPU pipelines where the next method consumes
136+
# GPU tensors directly.
137+
skip_d2h_for_method_outputs: bool = False

exir/passes/propagate_device_pass.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,12 @@ class PropagateDevicePass(PassBase):
163163

164164
def __init__(
165165
self,
166+
skip_h2d_for_method_inputs: bool = False,
167+
skip_d2h_for_method_outputs: bool = False,
166168
) -> None:
167169
super().__init__()
170+
self.skip_h2d_for_method_inputs = skip_h2d_for_method_inputs
171+
self.skip_d2h_for_method_outputs = skip_d2h_for_method_outputs
168172

169173
def _is_placeholder(self, node: torch.fx.Node) -> bool:
170174
"""Check if a node is a graph-level input (placeholder)."""
@@ -191,6 +195,11 @@ def _insert_h2d_copies(
191195
if not isinstance(arg_spec, TensorSpec):
192196
continue
193197

198+
if self.skip_h2d_for_method_inputs and self._is_placeholder(arg):
199+
_set_device_on_spec(arg_spec, target_device_type, device_index)
200+
changed = True
201+
continue
202+
194203
with graph_module.graph.inserting_before(node):
195204
h2d_node = graph_module.graph.call_function(
196205
torch.ops.et_copy._h2d_copy.default,
@@ -241,6 +250,9 @@ def _insert_d2h_for_getitem(
241250

242251
_set_device_on_spec(spec, source_spec.device, source_spec.device_index)
243252

253+
if self.skip_d2h_for_method_outputs and self._feeds_directly_to_output(node):
254+
return True
255+
244256
with graph_module.graph.inserting_after(node):
245257
d2h_node = graph_module.graph.call_function(
246258
torch.ops.et_copy._d2h_copy.default,

exir/program/_program.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,10 @@ def edge_to_executorch_passes(
849849
# there exists an unbacked symint operation.
850850
*config.passes,
851851
SpecPropPass(),
852-
PropagateDevicePass(),
852+
PropagateDevicePass(
853+
skip_h2d_for_method_inputs=config.skip_h2d_for_method_inputs,
854+
skip_d2h_for_method_outputs=config.skip_d2h_for_method_outputs,
855+
),
853856
EdgeToBackendOpsPass(),
854857
RemoveGraphAssertsPass(),
855858
] + pre_memory_planning_passes(config, name)

0 commit comments

Comments
 (0)