Skip to content

Commit e969a98

Browse files
authored
Fix torch.split fails in to_edge with alias annotations (#18700)
Fixes #11723 ## Summary `torch.split` fails with `RuntimeError: Found a custom (non-ATen) operator whose output has alias annotations` when used with `to_edge_transform_and_lower` and a partitioner that requests op preservation. **Root cause**: `_remove_invalid_ops_for_not_decompose` relies on `torchgen`'s `aliased_return_names()` to detect ops with aliased returns. However, for ops returning lists of aliased tensors (e.g., `split.Tensor` returns `Tensor(a)[]`), `aliased_return_names()` returns `[None]`, failing to detect the alias annotation. This lets `split.Tensor` pass through into the `EDGE_DO_NOT_DECOMP` namespace, where functionalization fails. **Fix**: Add a fallback check using `op._schema.returns` directly, which correctly reports `alias_info` on list return types. This also fixes the same latent issue for `chunk.default` and `tensor_split.sections`. ## Test plan - Added `test_remove_invalid_ops_filters_aliased_list_returns` regression test - Run: `pytest exir/tests/test_passes.py::TestPasses::test_remove_invalid_ops_filters_aliased_list_returns -xvs` - Verified existing split-related test still passes: `test_to_out_variant_singleon_tensor_list` - Verified existing broken ops test still passes: `test_compile_fix_broken_ops` <details> <summary>Before fix</summary> ``` ==================== BEFORE FIX ==================== RESULT: FAILED RuntimeError: Found a custom (non-ATen) operator whose output has alias annotations: EDGE_DO_NOT_DECOMP::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]. We only support functionalizing operators whose outputs do not have alias annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas 'Tensor' is a Tensor without. The '(a)' is the alias annotation). The alias annotation specifies that the output Tensor shares storage with an input that has the same annotation. Please check if (1) the output needs to be an output (if not, don't return it), (2) if the output doesn't share storage with any inputs, then delete the alias annotation. (3) if the output indeed shares storage with an input, then add a .clone() before returning it to prevent storage sharing and then delete the alias annotation. Otherwise, please file an issue on GitHub. While executing %split : [num_users=3] = call_function[target=torch.ops.EDGE_DO_NOT_DECOMP.split.Tensor](args = (%x, 2), kwargs = {}) Original traceback: None Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs) ``` </details> <details> <summary>After fix</summary> ``` ==================== AFTER FIX ==================== WARNING:root:Op aten.split.Tensor was requested for preservation by partitioner. This request is ignored because it aliases output. Test 1: to_edge (no partitioner) RESULT: SUCCESS - outputs match Test 2: to_edge_transform_and_lower with split.Tensor preservation RESULT: SUCCESS - split.Tensor correctly filtered from EDGE_DO_NOT_DECOMP (AttributeError from dummy partitioner partition(), not from split bug) Test 3: _remove_invalid_ops_for_not_decompose filter check aten::split.Tensor -> FILTERED (correct) aten::chunk -> FILTERED (correct) aten::tensor_split.sections -> FILTERED (correct) ``` </details> <details> <summary>Unit test output</summary> ``` $ pytest exir/tests/test_passes.py::TestPasses::test_remove_invalid_ops_filters_aliased_list_returns -xvs ============================= test session starts ============================== platform linux -- Python 3.12.12, pytest-8.4.2 collected 1 item exir/tests/test_passes.py::TestPasses::test_remove_invalid_ops_filters_aliased_list_returns PASSED ============================== 1 passed in 6.83s =============================== $ pytest exir/tests/test_passes.py::TestPasses::test_to_out_variant_singleon_tensor_list -xvs PASSED $ pytest exir/tests/test_passes.py::TestPasses::test_compile_fix_broken_ops -xvs PASSED ``` </details> This PR was authored with the assistance of Claude. --------- Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
1 parent b57ac03 commit e969a98

2 files changed

Lines changed: 48 additions & 1 deletion

File tree

exir/program/_program.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,7 @@ def _sanity_check_graph_for_non_decomp_ops(
10811081
logging.warning(warning_str)
10821082

10831083

1084-
def _remove_invalid_ops_for_not_decompose(
1084+
def _remove_invalid_ops_for_not_decompose( # noqa: C901
10851085
preserve_ops: List[torch._ops.OpOverload],
10861086
) -> List[torch._ops.OpOverload]:
10871087
_logged_warnings = set()
@@ -1124,6 +1124,16 @@ def keep(op):
11241124
)
11251125
return False
11261126

1127+
# Fallback: torchgen does not detect alias annotations on ops
1128+
# returning lists of aliased tensors (e.g. split.Tensor returns
1129+
# Tensor(a)[]). Check op._schema.returns directly.
1130+
for ret in schema.returns:
1131+
if ret.alias_info is not None:
1132+
log_warning(
1133+
f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output."
1134+
)
1135+
return False
1136+
11271137
# Explicit block list of ops that don't work if asked for
11281138
# preservation
11291139
if op in [

exir/tests/test_passes.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,43 @@ def body(i, h, h_accum):
940940
torch.allclose(prog.exported_program().module()(inp), model(inp))
941941
)
942942

943+
def test_remove_invalid_ops_filters_aliased_list_returns(self) -> None:
944+
"""Verify _remove_invalid_ops_for_not_decompose filters ops that return
945+
aliased tensor lists (e.g. split, chunk) even when torchgen's
946+
aliased_return_names() fails to detect them. Regression test for
947+
https://github.com/pytorch/executorch/issues/11723
948+
"""
949+
from executorch.exir.program._program import (
950+
_remove_invalid_ops_for_not_decompose,
951+
)
952+
953+
# These ops return Tensor(a)[] — a list of aliased views.
954+
# torchgen's aliased_return_names() misses the alias annotation on
955+
# list returns, so the fallback check on op._schema.returns is needed.
956+
aliased_list_ops = [
957+
torch.ops.aten.split.Tensor,
958+
torch.ops.aten.chunk.default,
959+
torch.ops.aten.tensor_split.sections,
960+
torch.ops.aten.split_with_sizes.default,
961+
]
962+
for op in aliased_list_ops:
963+
result = _remove_invalid_ops_for_not_decompose([op])
964+
self.assertNotIn(
965+
op,
966+
result,
967+
f"{op} should be filtered out because it returns aliased tensors",
968+
)
969+
970+
# Non-aliased ops should be preserved.
971+
preserved_ops = [torch.ops.aten.linear.default]
972+
for op in preserved_ops:
973+
result = _remove_invalid_ops_for_not_decompose([op])
974+
self.assertIn(
975+
op,
976+
result,
977+
f"{op} should be preserved because it has no aliased returns",
978+
)
979+
943980
def test_convert_symb_ops(self) -> None:
944981
class Foo(torch.nn.Module):
945982
def forward(self, x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)