Skip to content

Commit a630b56

Browse files
authored
Make CUDA/AOTI partitioner composable after another delegate (pytorch#20077) (pytorch#20077)
Summary: `AotiPartitioner.partition` tagged every `call_function` node, including `executorch_call_delegate` calls already lowered by an earlier partitioner. So when `CudaPartitioner` runs as a second partitioner — e.g. after a TensorRT partition in a stacked `.pte` where TensorRT lowers the ops it can and the CUDA backend handles the rest — it tried to re-delegate the foreign delegate node, producing a malformed nested delegate. This is the blocker to composing the two backends in one `.pte`. Tag only the non-lowered nodes, reusing the existing `get_non_lowered_nodes` helper (which already excludes `executorch_call_delegate` calls and their output getitems), so the partitioner claims just the remaining ops and composes cleanly after another backend. In the single-partitioner case there are no delegate nodes, so `get_non_lowered_nodes` returns every `call_function` and behavior is unchanged. The same composition gap existed for constants: the final loop tagged every untagged param/buffer/lifted constant with this partition's tag, including ones consumed only by the foreign delegate. Backend lowering rejected those, since it requires every user of a tagged constant to share that tag while the foreign delegate's call keeps the prior one. Now only genuinely unused constants are tagged here — `tag_constant_data` already claims the ones this partition uses, and a constant feeding only a prior delegate is left untagged. Mirrored in fbcode and xplat. Reviewed By: Gasoonjia Differential Revision: D107690797
1 parent ba5ffab commit a630b56

2 files changed

Lines changed: 126 additions & 10 deletions

File tree

backends/aoti/aoti_partitioner.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
Partitioner,
1515
PartitionResult,
1616
)
17-
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
17+
from executorch.exir.backend.utils import (
18+
get_non_lowered_nodes,
19+
tag_constant_data,
20+
tag_mutated_buffer,
21+
)
1822
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
1923
from torch.export.exported_program import ExportedProgram
2024

@@ -60,8 +64,17 @@ def is_control_flow(node: torch.fx.Node) -> bool:
6064
torch.ops.higher_order.while_loop,
6165
]
6266

67+
# Nodes already lowered by an earlier partitioner (e.g. a preceding
68+
# TensorRT partition) appear as executorch_call_delegate calls and their
69+
# output getitems; re-delegating them would nest a foreign delegate. Tag
70+
# only the remaining non-lowered ops so this partitioner composes after
71+
# others.
72+
non_lowered_nodes = set(get_non_lowered_nodes(exported_program.graph))
73+
6374
for node in exported_program.graph.nodes:
6475
if node.op == "call_function":
76+
if node not in non_lowered_nodes:
77+
continue
6578
node.meta["delegation_tag"] = tag
6679
# Tag get_attr nodes that are used by control flow operations
6780
elif node.op == "get_attr":
@@ -76,17 +89,22 @@ def is_control_flow(node: torch.fx.Node) -> bool:
7689
tag_constant_data(exported_program)
7790
tag_mutated_buffer(exported_program)
7891

79-
# Tag constant placeholders that have no users
80-
# tag_constant_data only tags constants that have users with delegation_tag
81-
# but we need to tag all constants for this partition
92+
# A constant that still has users feeds only a prior delegate; tagging it
93+
# would fail backend lowering's same-tag check (its user keeps the prior
94+
# tag). tag_constant_data already claimed the ones this partition uses, so
95+
# tag only the genuinely unused constants here.
8296
for node in exported_program.graph.nodes:
83-
if node.op == "placeholder" and (
84-
is_param(exported_program, node)
85-
or is_buffer(exported_program, node)
86-
or is_lifted_tensor_constant(exported_program, node)
97+
if (
98+
node.op == "placeholder"
99+
and not node.users
100+
and "delegation_tag" not in node.meta
101+
and (
102+
is_param(exported_program, node)
103+
or is_buffer(exported_program, node)
104+
or is_lifted_tensor_constant(exported_program, node)
105+
)
87106
):
88-
if "delegation_tag" not in node.meta:
89-
node.meta["delegation_tag"] = tag
107+
node.meta["delegation_tag"] = tag
90108

91109
return PartitionResult(
92110
tagged_exported_program=exported_program, partition_tags=partition_tags

backends/cuda/tests/test_cuda_partitioner.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import operator
78
import unittest
89
from typing import Tuple
910

1011
import torch
1112
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
1213
from executorch.exir.backend.partitioner import PartitionResult
14+
from executorch.exir.delegate import executorch_call_delegate
15+
from torch._export.utils import is_buffer
1316
from torch.export import export
1417

1518

@@ -222,3 +225,98 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
222225
expected_tag,
223226
f"Constant placeholder {node.name} has tag '{actual_tag}' but expected '{expected_tag}'",
224227
)
228+
229+
def test_does_not_retag_already_lowered_delegate(self) -> None:
230+
"""
231+
A node already lowered by a previous partitioner appears as an
232+
executorch_call_delegate call plus its output getitem. The CUDA
233+
partitioner must not re-tag those, so it can run after another backend
234+
(e.g. TensorRT) and only claim the remaining ops.
235+
"""
236+
237+
class AddModule(torch.nn.Module):
238+
def forward(self, x: torch.Tensor) -> torch.Tensor:
239+
return x + x
240+
241+
exported_program = export(AddModule(), (torch.randn(3, 4),), strict=True)
242+
graph_module = exported_program.graph_module
243+
graph = graph_module.graph
244+
245+
placeholder = next(n for n in graph.nodes if n.op == "placeholder")
246+
aten_node = next(
247+
n
248+
for n in graph.nodes
249+
if n.op == "call_function" and n.target != operator.getitem
250+
)
251+
252+
# Splice in a fake, already-lowered delegate (call + output getitem), as a
253+
# preceding partitioner (e.g. TensorRT) would have produced.
254+
graph_module.lowered_module_0 = torch.nn.Module()
255+
with graph.inserting_before(aten_node):
256+
lowered = graph.get_attr("lowered_module_0")
257+
delegate = graph.call_function(
258+
executorch_call_delegate, (lowered, placeholder)
259+
)
260+
delegate_output = graph.call_function(operator.getitem, (delegate, 0))
261+
graph.lint()
262+
263+
CudaPartitioner([]).partition(exported_program)
264+
265+
self.assertNotIn("delegation_tag", delegate.meta)
266+
self.assertNotIn("delegation_tag", delegate_output.meta)
267+
self.assertIn("delegation_tag", aten_node.meta)
268+
269+
def test_does_not_tag_constant_used_only_by_prior_delegate(self) -> None:
270+
"""
271+
A constant whose only consumer is a previously lowered delegate must stay
272+
untagged. Tagging it would give it this partition's tag while its user
273+
keeps the prior delegate's, which backend lowering rejects. Only ops this
274+
partitioner claims and genuinely unused constants may be tagged.
275+
"""
276+
277+
class AddModule(torch.nn.Module):
278+
def __init__(self) -> None:
279+
super().__init__()
280+
self.register_buffer("w", torch.randn(3, 4))
281+
282+
def forward(self, x: torch.Tensor) -> torch.Tensor:
283+
return x + self.w
284+
285+
exported_program = export(AddModule(), (torch.randn(3, 4),), strict=True)
286+
graph_module = exported_program.graph_module
287+
graph = graph_module.graph
288+
289+
buffer_placeholder = next(
290+
n
291+
for n in graph.nodes
292+
if n.op == "placeholder" and is_buffer(exported_program, n)
293+
)
294+
input_placeholder = next(
295+
n
296+
for n in graph.nodes
297+
if n.op == "placeholder" and not is_buffer(exported_program, n)
298+
)
299+
aten_node = next(
300+
n
301+
for n in graph.nodes
302+
if n.op == "call_function" and n.target != operator.getitem
303+
)
304+
305+
# Make the buffer feed only a fake, already-lowered delegate (as a
306+
# preceding TensorRT partition would): rewire the aten op off the buffer,
307+
# then splice the delegate consuming it.
308+
aten_node.replace_input_with(buffer_placeholder, input_placeholder)
309+
graph_module.lowered_module_0 = torch.nn.Module()
310+
with graph.inserting_before(aten_node):
311+
lowered = graph.get_attr("lowered_module_0")
312+
delegate = graph.call_function(
313+
executorch_call_delegate, (lowered, buffer_placeholder)
314+
)
315+
graph.call_function(operator.getitem, (delegate, 0))
316+
graph.lint()
317+
318+
CudaPartitioner([]).partition(exported_program)
319+
320+
self.assertNotIn("delegation_tag", buffer_placeholder.meta)
321+
self.assertNotIn("delegation_tag", delegate.meta)
322+
self.assertIn("delegation_tag", aten_node.meta)

0 commit comments

Comments
 (0)