Skip to content

Commit 6043775

Browse files
authored
Add ExecutorchBackendConfig flags for skipping H2D/D2H copies
Differential Revision: D99636778 Pull Request resolved: pytorch#19929
1 parent f512d7e commit 6043775

4 files changed

Lines changed: 408 additions & 16 deletions

File tree

exir/capture/_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,15 @@ class ExecutorchBackendConfig:
123123
# vs. accelerator memory. Default False preserves the legacy behavior
124124
# where all tensors are planned into CPU memory regardless of device.
125125
enable_non_cpu_memory_planning: bool = False
126+
127+
# When True, method-level input tensors that feed directly into a device
128+
# delegate are NOT wrapped with _h2d_copy. The user must provide tensors
129+
# already on the target device. Useful for pipelines where inputs are
130+
# pre-staged on GPU.
131+
skip_h2d_for_method_inputs: bool = False
132+
133+
# When True, device delegate outputs that are directly method outputs
134+
# are NOT wrapped with _d2h_copy. The method outputs stay on device.
135+
# Useful for cross-method GPU pipelines where the next method consumes
136+
# GPU tensors directly.
137+
skip_d2h_for_method_outputs: bool = False

exir/passes/propagate_device_pass.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,12 @@ class PropagateDevicePass(PassBase):
163163

164164
def __init__(
165165
self,
166+
skip_h2d_for_method_inputs: bool = False,
167+
skip_d2h_for_method_outputs: bool = False,
166168
) -> None:
167169
super().__init__()
170+
self.skip_h2d_for_method_inputs = skip_h2d_for_method_inputs
171+
self.skip_d2h_for_method_outputs = skip_d2h_for_method_outputs
168172

169173
def _is_placeholder(self, node: torch.fx.Node) -> bool:
170174
"""Check if a node is a graph-level input (placeholder)."""
@@ -191,6 +195,23 @@ def _insert_h2d_copies(
191195
if not isinstance(arg_spec, TensorSpec):
192196
continue
193197

198+
if self.skip_h2d_for_method_inputs and self._is_placeholder(arg):
199+
# TODO(gasoonjia): support skip_h2d_for_method_inputs for
200+
# multiple-user placeholder inputs.
201+
if len(arg.users) != 1:
202+
raise RuntimeError(
203+
f"skip_h2d_for_method_inputs=True requires placeholder "
204+
f"'{arg.name}' to have exactly one user, but it has "
205+
f"{len(arg.users)} users. The placeholder is shared by "
206+
f"multiple consumers, so its TensorSpec cannot be safely "
207+
f"mutated in-place to the delegate's device. Either disable "
208+
f"skip_h2d_for_method_inputs, or ensure the placeholder is "
209+
f"used exclusively by this delegate."
210+
)
211+
_set_device_on_spec(arg_spec, target_device_type, device_index)
212+
changed = True
213+
continue
214+
194215
with graph_module.graph.inserting_before(node):
195216
h2d_node = graph_module.graph.call_function(
196217
torch.ops.et_copy._h2d_copy.default,
@@ -241,6 +262,9 @@ def _insert_d2h_for_getitem(
241262

242263
_set_device_on_spec(spec, source_spec.device, source_spec.device_index)
243264

265+
if self.skip_d2h_for_method_outputs and self._feeds_directly_to_output(node):
266+
return True
267+
244268
with graph_module.graph.inserting_after(node):
245269
d2h_node = graph_module.graph.call_function(
246270
torch.ops.et_copy._d2h_copy.default,

exir/program/_program.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,10 @@ def edge_to_executorch_passes(
764764
# there exists an unbacked symint operation.
765765
*config.passes,
766766
SpecPropPass(),
767-
PropagateDevicePass(),
767+
PropagateDevicePass(
768+
skip_h2d_for_method_inputs=config.skip_h2d_for_method_inputs,
769+
skip_d2h_for_method_outputs=config.skip_d2h_for_method_outputs,
770+
),
768771
EdgeToBackendOpsPass(),
769772
RemoveGraphAssertsPass(),
770773
] + pre_memory_planning_passes(config, name)

0 commit comments

Comments
 (0)