Skip to content

Commit 4330ac2

Browse files
committed
Fix torch.split fails in to_edge with alias annotations
Fixes #11723 _remove_invalid_ops_for_not_decompose relied on torchgen's aliased_return_names() to detect ops with aliased returns, but it returns [None] for ops returning lists of aliased tensors (e.g., split.Tensor returns Tensor(a)[]). This let split.Tensor through into the EDGE_DO_NOT_DECOMP namespace where functionalization failed. Add a fallback check using op._schema.returns directly, which correctly reports alias_info on list return types. This also fixes the same latent issue for chunk and tensor_split. Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
1 parent 6020c29 commit 4330ac2

2 files changed

Lines changed: 46 additions & 0 deletions

File tree

exir/program/_program.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,16 @@ def keep(op):
11221122
)
11231123
return False
11241124

1125+
# Fallback: torchgen may fail to detect alias annotations on ops
1126+
# returning lists of tensors (e.g. split.Tensor returns Tensor(a)[]).
1127+
# Check op._schema.returns directly as a more reliable source.
1128+
for ret in schema.returns:
1129+
if ret.alias_info is not None:
1130+
log_warning(
1131+
f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output."
1132+
)
1133+
return False
1134+
11251135
# Explicit block list of ops that don't work if asked for
11261136
# preservation
11271137
if op in [

exir/tests/test_passes.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,42 @@ def body(i, h, h_accum):
940940
torch.allclose(prog.exported_program().module()(inp), model(inp))
941941
)
942942

943+
def test_remove_invalid_ops_filters_aliased_list_returns(self) -> None:
944+
"""Verify _remove_invalid_ops_for_not_decompose filters ops that return
945+
aliased tensor lists (e.g. split, chunk) even when torchgen's
946+
aliased_return_names() fails to detect them. Regression test for
947+
https://github.com/pytorch/executorch/issues/11723
948+
"""
949+
from executorch.exir.program._program import (
950+
_remove_invalid_ops_for_not_decompose,
951+
)
952+
953+
# These ops return Tensor(a)[] — a list of aliased views.
954+
# torchgen's aliased_return_names() misses the alias annotation on
955+
# list returns, so the fallback check on op._schema.returns is needed.
956+
aliased_list_ops = [
957+
torch.ops.aten.split.Tensor,
958+
torch.ops.aten.chunk.default,
959+
torch.ops.aten.tensor_split.sections,
960+
]
961+
for op in aliased_list_ops:
962+
result = _remove_invalid_ops_for_not_decompose([op])
963+
self.assertNotIn(
964+
op,
965+
result,
966+
f"{op} should be filtered out because it returns aliased tensors",
967+
)
968+
969+
# Non-aliased ops should be preserved.
970+
preserved_ops = [torch.ops.aten.linear.default]
971+
for op in preserved_ops:
972+
result = _remove_invalid_ops_for_not_decompose([op])
973+
self.assertIn(
974+
op,
975+
result,
976+
f"{op} should be preserved because it has no aliased returns",
977+
)
978+
943979
def test_convert_symb_ops(self) -> None:
944980
class Foo(torch.nn.Module):
945981
def forward(self, x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)