Skip to content

Commit e56c7c3

Browse files
authored
Arm backend: add argmin support and int32 overflow guard to ConvertIn… (pytorch#19918)
## Summary Follow-up to pytorch#13803. Two changes to `ConvertInt64OutputOpsToInt32Pass`. ## 1. argmin support `ConvertInt64OutputOpsToInt32Pass` inserts an `int64 → int32` cast after `aten.argmax` nodes so that the index output (TOSA has no int64) becomes int32 and downstream consumers can be delegated. `aten.argmin` returns int64 identically but was not handled — the committer explicitly deferred it as a future extension: > *"Future extensions may include operators that return int64 outputs by default (e.g., `argmin`) …"* ```mermaid flowchart LR subgraph before["Before"] direction LR A1["argmin\nint64"]:::cpu --> B1["mul\nint64"]:::blocked --> C1["add\nint64"]:::blocked end subgraph after["After"] direction LR A2["argmin\nint64"]:::cpu --> T["to_int32"]:::cpu T --> B2["mul\nint32"]:::delegated --> C2["add\nint32"]:::delegated end before ~~~ after classDef cpu fill:#f5c542,stroke:#b8962e,color:#000 classDef blocked fill:#e05c5c,stroke:#a33,color:#fff classDef delegated fill:#4caf7d,stroke:#2d7a54,color:#fff ``` **Changes:** Mirror the existing argmax registration to cover argmin. Rename the cast helper — it operates on the node's output dtype, not the op name, so the old name was misleading once argmin was added. --- ## 2. int32 overflow guard The pass previously had an open TODO: ```python # TODO: Add range check based on the input tensor shape before casting the output ``` `argmax`/`argmin` return an index in `[0, size)` where `size` is the number of elements searched. If `size > INT32_MAX`, casting to int32 silently truncates, producing a wrong index with no error. **Changes:** Add a compile-time shape check (`shape[dim]` or `numel()` for the no-dim form) and an `on_overflow` constructor param (`"raise"` / `"warn"` / `"skip"`, default `"raise"`). A compile-time error is preferable to a silent wrong result at runtime. --- ## Tests ```bash $ python -m pytest backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py -v 9 passed # 5 existing + 2 parametrized [argmax]/[argmin] delegation + 4 overflow (raise/warn/skip/invalid) $ lintrunner backends/arm/_passes/convert_int64_output_ops_to_int32.py \ backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py ok No lint issues. ``` The argmax and argmin delegation cases are unified into a single `@pytest.mark.parametrize` test. Signed-off-by: Youngsik Yang <vacu9708@gmail.com>
1 parent f0d9991 commit e56c7c3

2 files changed

Lines changed: 139 additions & 50 deletions

File tree

backends/arm/_passes/convert_int64_output_ops_to_int32.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
import logging
8-
from typing import Set, Type
8+
from typing import cast, Literal, Set, Type
99

1010
import torch
1111
from executorch.backends.arm._passes import ArmPass
@@ -25,26 +25,54 @@ class ConvertInt64OutputOpsToInt32Pass(ArmPass):
2525
"""Rewrites or removes operations that produce int64 outputs, converting
2626
them to int32 where possible.
2727
28-
Currently, this pass handles casting and argmax operators:
28+
Currently, this pass handles casting, argmax and argmin operators:
2929
1. int32 -> int64:
3030
removes the cast and redirects all uses to the original int32 value.
3131
2. other types -> int64:
3232
rewrites the cast to produce int32 instead of int64.
33-
3. torch.argmax()
34-
insert an int64->int32 cast after the argmax node
33+
3. torch.argmax() / torch.argmin()
34+
insert an int64->int32 cast after the argmax/argmin node
3535
36-
Future extensions may include operators that return int64 outputs by default
37-
(e.g., `argmin`), rewriting them or inserting an int64 -> int32 cast to yield
38-
int32 results.
36+
Future extensions may include other operators that return int64 outputs by
37+
default, rewriting them or inserting an int64 -> int32 cast to yield int32
38+
results.
3939
40-
Note: Overflow checks are applied selectively in this pass. For operators without
41-
such checks, it is the user's responsibility to ensure that values fit within
42-
the int32 range.
40+
Args:
41+
on_overflow: Action when an argmax/argmin index cannot safely fit in
42+
int32 (i.e. the reduced dimension has more than INT32_MAX elements).
43+
``"raise"`` (default) raises a ``RuntimeError`` at compile time.
44+
``"warn"`` logs a warning and skips the cast for that node.
45+
``"skip"`` silently skips the cast for that node.
4346
4447
"""
4548

4649
_passes_required_after: Set[Type[ExportPass]] = set()
4750

51+
_INT32_MAX = torch.iinfo(torch.int32).max
52+
53+
def __init__(
54+
self,
55+
*args,
56+
on_overflow: Literal["raise", "warn", "skip"] = "raise",
57+
**kwargs,
58+
) -> None:
59+
super().__init__(*args, **kwargs)
60+
if on_overflow not in ("raise", "warn", "skip"):
61+
raise ValueError(
62+
f"on_overflow must be 'raise', 'warn', or 'skip', got {on_overflow!r}"
63+
)
64+
self.on_overflow = on_overflow
65+
66+
def _is_int32_range_safe(self, node: torch.fx.Node) -> bool:
67+
"""Return True if the argmax/argmin index output fits in int32."""
68+
input_tensor = get_first_fake_tensor(cast(torch.fx.Node, node.args[0]))
69+
dim = node.args[1] if len(node.args) > 1 and node.args[1] is not None else None
70+
if dim is None:
71+
size = input_tensor.numel()
72+
else:
73+
size = input_tensor.shape[cast(int, dim)]
74+
return size <= self._INT32_MAX
75+
4876
aten_cast_ops = (
4977
torch.ops.aten.to.dtype,
5078
torch.ops.aten.to.dtype_layout,
@@ -54,8 +82,11 @@ class ConvertInt64OutputOpsToInt32Pass(ArmPass):
5482
aten_argmax_ops = (torch.ops.aten.argmax.default,)
5583
edge_argmax_ops = (exir_ops.edge.aten.argmax.default,)
5684

57-
aten_ops = aten_cast_ops + aten_argmax_ops
58-
edge_ops = edge_cast_ops + edge_argmax_ops
85+
aten_argmin_ops = (torch.ops.aten.argmin.default,)
86+
edge_argmin_ops = (exir_ops.edge.aten.argmin.default,)
87+
88+
aten_ops = aten_cast_ops + aten_argmax_ops + aten_argmin_ops
89+
edge_ops = edge_cast_ops + edge_argmax_ops + edge_argmin_ops
5990

6091
# dtype is specified in args
6192
cast_ops_args = (
@@ -104,7 +135,7 @@ def _convert_casting_operators(self, node: torch.fx.Node):
104135
f" {input_dtype}->torch.int32 defined in {node.meta.get('stack_trace','[no stack trace found]')}"
105136
)
106137

107-
def _convert_argmax_operators(self, node: torch.fx.Node, graph: torch.fx.Graph):
138+
def _cast_int64_output_to_int32(self, node: torch.fx.Node, graph: torch.fx.Graph):
108139
output_tensor = node
109140
to_copy_op = self._get_decomposition(node.target)
110141
with graph.inserting_after(node):
@@ -138,9 +169,23 @@ def call(self, graph_module: torch.fx.GraphModule):
138169

139170
if node.target in self.aten_cast_ops + self.edge_cast_ops:
140171
self._convert_casting_operators(node)
141-
elif node.target in self.aten_argmax_ops + self.edge_argmax_ops:
142-
# TODO: Add range check based on the input tensor shape before casting the output
143-
self._convert_argmax_operators(node, graph)
172+
elif node.target in (
173+
self.aten_argmax_ops
174+
+ self.edge_argmax_ops
175+
+ self.aten_argmin_ops
176+
+ self.edge_argmin_ops
177+
):
178+
if not self._is_int32_range_safe(node):
179+
msg = (
180+
f"{node.target} reduces over more than {self._INT32_MAX} elements; "
181+
f"the int64 index cannot be safely cast to int32."
182+
)
183+
if self.on_overflow == "raise":
184+
raise RuntimeError(msg)
185+
if self.on_overflow == "warn":
186+
logger.warning(msg)
187+
continue
188+
self._cast_int64_output_to_int32(node, graph)
144189
else:
145190
raise RuntimeError(f"Unexpected target {node.target} in {node.name}")
146191

backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55

66
from typing import Callable, Dict, Tuple
77

8+
import pytest
89
import torch
910
from executorch.backends.arm._passes import ConvertInt64OutputOpsToInt32Pass
1011

1112
from executorch.backends.arm.test import common
1213

1314
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineFP
15+
from torch.fx import Graph, GraphModule
1416

1517
input_t1 = Tuple[torch.Tensor] # Input x
1618

@@ -86,44 +88,86 @@ def test_convert_int64_output_ops_to_int32_tosa_FP_remove_casting(
8688
pipeline.run()
8789

8890

89-
#####################################################
90-
## Test arange(dtype=int64) -> arange(dtype=int32) ##
91-
#####################################################
91+
##########################################################
92+
## Test argmax/argmin int64 output -> int32 cast ##
93+
##########################################################
9294

9395

94-
class Int64OutputModel(torch.nn.Module):
96+
@pytest.mark.parametrize(
97+
"arg_op, aten_op_str",
98+
[
99+
(torch.argmax, "torch.ops.aten.argmax.default"),
100+
(torch.argmin, "torch.ops.aten.argmin.default"),
101+
],
102+
ids=["argmax", "argmin"],
103+
)
104+
def test_convert_int64_output_ops_to_int32_tosa_FP_insert_cast(arg_op, aten_op_str):
105+
class ArgOpModel(torch.nn.Module):
106+
def forward(self, x: torch.Tensor) -> torch.Tensor:
107+
return (10 * arg_op(x, dim=-1) + 10) + 1.5
95108

96-
def forward(self, x: torch.Tensor) -> torch.Tensor:
97-
# return torch.argmax(x) # RuntimeError: Int did not match Long; But this is expected as we expect _argmax_i32 to generate int32 output
98-
# return (10 * torch.argmax(x) + 10).to(dtype=torch.int32) # [1]. This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (function _resize_output_check)
99-
return (10 * torch.argmax(x, dim=-1) + 10) + 1.5
100-
101-
def get_inputs(self) -> input_t1:
102-
return (
103-
torch.randint(
104-
0,
105-
10,
106-
(2, 4, 6, 8),
107-
),
108-
)
109-
110-
111-
def test_convert_int64_output_ops_to_int32_tosa_FP_insert_cast():
112-
module = Int64OutputModel()
113-
aten_ops_checks = [
114-
"torch.ops.aten.argmax.default",
115-
"torch.ops.aten.mul.Tensor",
116-
"torch.ops.aten.add.Tensor",
117-
]
118-
exir_ops_checks = [
119-
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
120-
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
121-
]
122109
pipeline = TosaPipelineFP[input_t1](
123-
module,
124-
module.get_inputs(),
125-
aten_op=aten_ops_checks,
126-
exir_op=exir_ops_checks,
110+
ArgOpModel(),
111+
(torch.randint(0, 10, (2, 4, 6, 8)),),
112+
aten_op=[aten_op_str, "torch.ops.aten.mul.Tensor", "torch.ops.aten.add.Tensor"],
113+
exir_op=[
114+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
115+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
116+
],
127117
transform_passes=[ConvertInt64OutputOpsToInt32Pass()],
128118
)
129119
pipeline.run()
120+
121+
122+
##############################################################
123+
## Test on_overflow range check for argmax/argmin ##
124+
##############################################################
125+
126+
_OVERFLOW_DIM = torch.iinfo(torch.int32).max + 1
127+
128+
129+
def _make_argmax_graph_large_dim() -> GraphModule:
130+
"""Construct a minimal graph with an argmax over a dimension > INT32_MAX.
131+
132+
Uses FakeTensorMode so no memory is allocated for the large dimension.
133+
134+
"""
135+
from torch._subclasses import FakeTensorMode
136+
137+
graph = Graph()
138+
with FakeTensorMode():
139+
fake_input = torch.empty(_OVERFLOW_DIM, dtype=torch.float32)
140+
fake_output = torch.empty((), dtype=torch.int64)
141+
x = graph.placeholder("x")
142+
x.meta["val"] = fake_input
143+
out = graph.call_function(torch.ops.aten.argmax.default, (x, 0))
144+
out.meta["val"] = fake_output
145+
graph.output(out)
146+
return GraphModule(torch.nn.Module(), graph)
147+
148+
149+
def test_on_overflow_raise():
150+
gm = _make_argmax_graph_large_dim()
151+
with pytest.raises(RuntimeError, match="cannot be safely cast to int32"):
152+
ConvertInt64OutputOpsToInt32Pass(on_overflow="raise").call(gm)
153+
154+
155+
def test_on_overflow_warn(caplog):
156+
import logging
157+
158+
gm = _make_argmax_graph_large_dim()
159+
with caplog.at_level(logging.WARNING):
160+
result = ConvertInt64OutputOpsToInt32Pass(on_overflow="warn").call(gm)
161+
assert not result.modified
162+
assert "cannot be safely cast to int32" in caplog.text
163+
164+
165+
def test_on_overflow_skip():
166+
gm = _make_argmax_graph_large_dim()
167+
result = ConvertInt64OutputOpsToInt32Pass(on_overflow="skip").call(gm)
168+
assert not result.modified
169+
170+
171+
def test_on_overflow_invalid():
172+
with pytest.raises(ValueError, match="on_overflow must be"):
173+
ConvertInt64OutputOpsToInt32Pass(on_overflow="blah")

0 commit comments

Comments
 (0)