Skip to content

Commit 7871a9b

Browse files
authored
Gate device copy insertion on device memory planning (pytorch#19961)
Differential Revision: D107310726 Pull Request resolved: pytorch#19961
1 parent 658dcd4 commit 7871a9b

3 files changed

Lines changed: 67 additions & 8 deletions

File tree

exir/passes/propagate_device_pass.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,12 @@ def __init__(
165165
self,
166166
skip_h2d_for_method_inputs: bool = False,
167167
skip_d2h_for_method_outputs: bool = False,
168+
enable_non_cpu_memory_planning: bool = False,
168169
) -> None:
169170
super().__init__()
170171
self.skip_h2d_for_method_inputs = skip_h2d_for_method_inputs
171172
self.skip_d2h_for_method_outputs = skip_d2h_for_method_outputs
173+
self.enable_non_cpu_memory_planning = enable_non_cpu_memory_planning
172174

173175
def _is_placeholder(self, node: torch.fx.Node) -> bool:
174176
"""Check if a node is a graph-level input (placeholder)."""
@@ -282,7 +284,7 @@ def _insert_d2h_for_getitem(
282284
)
283285
return True
284286

285-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
287+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
286288
# Two-pass approach:
287289
# Pass 1 – For each delegate with a target_device CompileSpec, insert
288290
# H2D copy nodes before delegate inputs and tag the delegate
@@ -313,9 +315,18 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
313315
target_device_type, device_index = result
314316
device_delegates.add(node)
315317

316-
changed |= self._insert_h2d_copies(
317-
graph_module, node, target_device_type, device_index
318-
)
318+
if self.enable_non_cpu_memory_planning:
319+
changed |= self._insert_h2d_copies(
320+
graph_module, node, target_device_type, device_index
321+
)
322+
else:
323+
for arg in node.args[1:]:
324+
if isinstance(arg, torch.fx.Node):
325+
changed |= _tag_specs_with_device(
326+
arg.meta.get("spec"),
327+
target_device_type,
328+
device_index,
329+
)
319330

320331
changed |= _tag_specs_with_device(
321332
node.meta.get("spec"),
@@ -337,7 +348,26 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
337348
if node.op == "call_function" and node.target == operator.getitem:
338349
source = node.args[0]
339350
if isinstance(source, torch.fx.Node) and source in device_delegates:
340-
changed |= self._insert_d2h_for_getitem(graph_module, node)
351+
if self.enable_non_cpu_memory_planning:
352+
changed |= self._insert_d2h_for_getitem(graph_module, node)
353+
else:
354+
spec = node.meta.get("spec")
355+
source_specs = source.meta.get("spec")
356+
idx = node.args[1]
357+
if (
358+
isinstance(spec, TensorSpec)
359+
and isinstance(source_specs, (tuple, list))
360+
and isinstance(idx, int)
361+
and idx < len(source_specs)
362+
):
363+
source_spec = source_specs[idx]
364+
if isinstance(source_spec, TensorSpec):
365+
_set_device_on_spec(
366+
spec,
367+
source_spec.device,
368+
source_spec.device_index,
369+
)
370+
changed = True
341371

342372
graph_module.recompile()
343373
return PassResult(graph_module, changed)

exir/program/_program.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,7 @@ def edge_to_executorch_passes(
767767
PropagateDevicePass(
768768
skip_h2d_for_method_inputs=config.skip_h2d_for_method_inputs,
769769
skip_d2h_for_method_outputs=config.skip_d2h_for_method_outputs,
770+
enable_non_cpu_memory_planning=config.enable_non_cpu_memory_planning,
770771
),
771772
EdgeToBackendOpsPass(),
772773
RemoveGraphAssertsPass(),

exir/tests/test_propagate_device_pass.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def _lower_model_to_executorch(
121121
"""Lower model all the way through to_executorch for E2E tests."""
122122
if et_config is None:
123123
et_config = ExecutorchBackendConfig(emit_stacktrace=False)
124+
124125
ep = export(model, inputs)
125126
ep_copied = deepcopy(ep)
126127

@@ -314,7 +315,10 @@ def forward(self, a, b):
314315
inputs = (torch.randn(2, 2), torch.randn(2, 2))
315316

316317
for pipeline, gm in _lower_model_to_executorch(
317-
model, inputs, DeviceAwarePartitioner("cuda:0")
318+
model,
319+
inputs,
320+
DeviceAwarePartitioner("cuda:0"),
321+
ExecutorchBackendConfig(enable_non_cpu_memory_planning=True),
318322
):
319323
with self.subTest(pipeline=pipeline):
320324
nodes = _collect_device_copy_nodes(gm)
@@ -371,7 +375,10 @@ def forward(self, a, b):
371375
inputs = (torch.randn(2, 2), torch.randn(2, 2))
372376

373377
for pipeline, gm in _lower_model_to_executorch(
374-
model, inputs, DeviceAwarePartitioner("cuda:0")
378+
model,
379+
inputs,
380+
DeviceAwarePartitioner("cuda:0"),
381+
ExecutorchBackendConfig(enable_non_cpu_memory_planning=True),
375382
):
376383
with self.subTest(pipeline=pipeline):
377384
nodes = _collect_device_copy_nodes(gm)
@@ -445,6 +452,24 @@ def forward(self, a, b):
445452
f"[{pipeline}] Unexpected D2H copy nodes when no target_device is set",
446453
)
447454

455+
def test_copy_nodes_require_non_cpu_memory_planning(self):
456+
"""Default lowering keeps legacy device tags without runtime copy ops."""
457+
458+
class Model(torch.nn.Module):
459+
def forward(self, a, b):
460+
return torch.add(a, b)
461+
462+
model = Model()
463+
inputs = (torch.randn(2, 2), torch.randn(2, 2))
464+
465+
for pipeline, gm in _lower_model_to_executorch(
466+
model, inputs, DeviceAwarePartitioner("cuda:0")
467+
):
468+
with self.subTest(pipeline=pipeline):
469+
device_copy_nodes = _collect_device_copy_nodes(gm)
470+
self.assertEqual(len(device_copy_nodes.h2d_nodes), 0)
471+
self.assertEqual(len(device_copy_nodes.d2h_nodes), 0)
472+
448473
# ---- Integration tests: device consistency after to_executorch ----
449474

450475
def test_device_consistency_cuda_1(self):
@@ -523,7 +548,10 @@ def forward(self, a, b):
523548
inputs = (torch.randn(2, 2), torch.randn(2, 2))
524549

525550
for pipeline, gm in _lower_model_to_executorch(
526-
model, inputs, DeviceAwarePartitioner("cuda:0")
551+
model,
552+
inputs,
553+
DeviceAwarePartitioner("cuda:0"),
554+
ExecutorchBackendConfig(enable_non_cpu_memory_planning=True),
527555
):
528556
with self.subTest(pipeline=pipeline):
529557
for node in gm.graph.nodes:

0 commit comments

Comments
 (0)