Commit e969a98
authored
Fix torch.split fails in to_edge with alias annotations (#18700)
Fixes #11723
## Summary
`torch.split` fails with `RuntimeError: Found a custom (non-ATen)
operator whose output has alias annotations` when used with
`to_edge_transform_and_lower` and a partitioner that requests op
preservation.
**Root cause**: `_remove_invalid_ops_for_not_decompose` relies on
`torchgen`'s `aliased_return_names()` to detect ops with aliased
returns. However, for ops returning lists of aliased tensors (e.g.,
`split.Tensor` returns `Tensor(a)[]`), `aliased_return_names()` returns
`[None]`, failing to detect the alias annotation. This lets
`split.Tensor` pass through into the `EDGE_DO_NOT_DECOMP` namespace,
where functionalization fails.
**Fix**: Add a fallback check using `op._schema.returns` directly, which
correctly reports `alias_info` on list return types. This also fixes the
same latent issue for `chunk.default` and `tensor_split.sections`.
## Test plan
- Added `test_remove_invalid_ops_filters_aliased_list_returns`
regression test
- Run: `pytest
exir/tests/test_passes.py::TestPasses::test_remove_invalid_ops_filters_aliased_list_returns
-xvs`
- Verified existing split-related test still passes:
`test_to_out_variant_singleon_tensor_list`
- Verified existing broken ops test still passes:
`test_compile_fix_broken_ops`
<details>
<summary>Before fix</summary>
```
==================== BEFORE FIX ====================
RESULT: FAILED
RuntimeError: Found a custom (non-ATen) operator whose output has alias annotations: EDGE_DO_NOT_DECOMP::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]. We only support functionalizing operators whose outputs do not have alias annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas 'Tensor' is a Tensor without. The '(a)' is the alias annotation). The alias annotation specifies that the output Tensor shares storage with an input that has the same annotation. Please check if (1) the output needs to be an output (if not, don't return it), (2) if the output doesn't share storage with any inputs, then delete the alias annotation. (3) if the output indeed shares storage with an input, then add a .clone() before returning it to prevent storage sharing and then delete the alias annotation. Otherwise, please file an issue on GitHub.
While executing %split : [num_users=3] = call_function[target=torch.ops.EDGE_DO_NOT_DECOMP.split.Tensor](args = (%x, 2), kwargs = {})
Original traceback:
None
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
```
</details>
<details>
<summary>After fix</summary>
```
==================== AFTER FIX ====================
WARNING:root:Op aten.split.Tensor was requested for preservation by partitioner. This request is ignored because it aliases output.
Test 1: to_edge (no partitioner)
RESULT: SUCCESS - outputs match
Test 2: to_edge_transform_and_lower with split.Tensor preservation
RESULT: SUCCESS - split.Tensor correctly filtered from EDGE_DO_NOT_DECOMP
(AttributeError from dummy partitioner partition(), not from split bug)
Test 3: _remove_invalid_ops_for_not_decompose filter check
aten::split.Tensor -> FILTERED (correct)
aten::chunk -> FILTERED (correct)
aten::tensor_split.sections -> FILTERED (correct)
```
</details>
<details>
<summary>Unit test output</summary>
```
$ pytest exir/tests/test_passes.py::TestPasses::test_remove_invalid_ops_filters_aliased_list_returns -xvs
============================= test session starts ==============================
platform linux -- Python 3.12.12, pytest-8.4.2
collected 1 item
exir/tests/test_passes.py::TestPasses::test_remove_invalid_ops_filters_aliased_list_returns PASSED
============================== 1 passed in 6.83s ===============================
$ pytest exir/tests/test_passes.py::TestPasses::test_to_out_variant_singleon_tensor_list -xvs
PASSED
$ pytest exir/tests/test_passes.py::TestPasses::test_compile_fix_broken_ops -xvs
PASSED
```
</details>
This PR was authored with the assistance of Claude.
---------
Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>1 parent b57ac03 commit e969a98
2 files changed
Lines changed: 48 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1081 | 1081 | | |
1082 | 1082 | | |
1083 | 1083 | | |
1084 | | - | |
| 1084 | + | |
1085 | 1085 | | |
1086 | 1086 | | |
1087 | 1087 | | |
| |||
1124 | 1124 | | |
1125 | 1125 | | |
1126 | 1126 | | |
| 1127 | + | |
| 1128 | + | |
| 1129 | + | |
| 1130 | + | |
| 1131 | + | |
| 1132 | + | |
| 1133 | + | |
| 1134 | + | |
| 1135 | + | |
| 1136 | + | |
1127 | 1137 | | |
1128 | 1138 | | |
1129 | 1139 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
940 | 940 | | |
941 | 941 | | |
942 | 942 | | |
| 943 | + | |
| 944 | + | |
| 945 | + | |
| 946 | + | |
| 947 | + | |
| 948 | + | |
| 949 | + | |
| 950 | + | |
| 951 | + | |
| 952 | + | |
| 953 | + | |
| 954 | + | |
| 955 | + | |
| 956 | + | |
| 957 | + | |
| 958 | + | |
| 959 | + | |
| 960 | + | |
| 961 | + | |
| 962 | + | |
| 963 | + | |
| 964 | + | |
| 965 | + | |
| 966 | + | |
| 967 | + | |
| 968 | + | |
| 969 | + | |
| 970 | + | |
| 971 | + | |
| 972 | + | |
| 973 | + | |
| 974 | + | |
| 975 | + | |
| 976 | + | |
| 977 | + | |
| 978 | + | |
| 979 | + | |
943 | 980 | | |
944 | 981 | | |
945 | 982 | | |
| |||
0 commit comments