55
66
77import logging
8- from typing import Set , Type
8+ from typing import cast , Literal , Set , Type
99
1010import torch
1111from executorch .backends .arm ._passes import ArmPass
@@ -25,26 +25,54 @@ class ConvertInt64OutputOpsToInt32Pass(ArmPass):
2525 """Rewrites or removes operations that produce int64 outputs, converting
2626 them to int32 where possible.
2727
28- Currently, this pass handles casting and argmax operators:
28+ Currently, this pass handles casting, argmax and argmin operators:
2929 1. int32 -> int64:
3030 removes the cast and redirects all uses to the original int32 value.
3131 2. other types -> int64:
3232 rewrites the cast to produce int32 instead of int64.
33- 3. torch.argmax()
34- insert an int64->int32 cast after the argmax node
33+ 3. torch.argmax() / torch.argmin()
34+ insert an int64->int32 cast after the argmax/argmin node
3535
36- Future extensions may include operators that return int64 outputs by default
37- (e.g., `argmin`), rewriting them or inserting an int64 -> int32 cast to yield
38- int32 results.
36+ Future extensions may include other operators that return int64 outputs by
37+ default, rewriting them or inserting an int64 -> int32 cast to yield int32
38+ results.
3939
40- Note: Overflow checks are applied selectively in this pass. For operators without
41- such checks, it is the user's responsibility to ensure that values fit within
42- the int32 range.
40+ Args:
41+ on_overflow: Action when an argmax/argmin index cannot safely fit in
42+ int32 (i.e. the reduced dimension has more than INT32_MAX elements).
43+ ``"raise"`` (default) raises a ``RuntimeError`` at compile time.
44+ ``"warn"`` logs a warning and skips the cast for that node.
45+ ``"skip"`` silently skips the cast for that node.
4346
4447 """
4548
4649 _passes_required_after : Set [Type [ExportPass ]] = set ()
4750
51+ _INT32_MAX = torch .iinfo (torch .int32 ).max
52+
53+ def __init__ (
54+ self ,
55+ * args ,
56+ on_overflow : Literal ["raise" , "warn" , "skip" ] = "raise" ,
57+ ** kwargs ,
58+ ) -> None :
59+ super ().__init__ (* args , ** kwargs )
60+ if on_overflow not in ("raise" , "warn" , "skip" ):
61+ raise ValueError (
62+ f"on_overflow must be 'raise', 'warn', or 'skip', got { on_overflow !r} "
63+ )
64+ self .on_overflow = on_overflow
65+
66+ def _is_int32_range_safe (self , node : torch .fx .Node ) -> bool :
67+ """Return True if the argmax/argmin index output fits in int32."""
68+ input_tensor = get_first_fake_tensor (cast (torch .fx .Node , node .args [0 ]))
69+ dim = node .args [1 ] if len (node .args ) > 1 and node .args [1 ] is not None else None
70+ if dim is None :
71+ size = input_tensor .numel ()
72+ else :
73+ size = input_tensor .shape [cast (int , dim )]
74+ return size <= self ._INT32_MAX
75+
4876 aten_cast_ops = (
4977 torch .ops .aten .to .dtype ,
5078 torch .ops .aten .to .dtype_layout ,
@@ -54,8 +82,11 @@ class ConvertInt64OutputOpsToInt32Pass(ArmPass):
5482 aten_argmax_ops = (torch .ops .aten .argmax .default ,)
5583 edge_argmax_ops = (exir_ops .edge .aten .argmax .default ,)
5684
57- aten_ops = aten_cast_ops + aten_argmax_ops
58- edge_ops = edge_cast_ops + edge_argmax_ops
85+ aten_argmin_ops = (torch .ops .aten .argmin .default ,)
86+ edge_argmin_ops = (exir_ops .edge .aten .argmin .default ,)
87+
88+ aten_ops = aten_cast_ops + aten_argmax_ops + aten_argmin_ops
89+ edge_ops = edge_cast_ops + edge_argmax_ops + edge_argmin_ops
5990
6091 # dtype is specified in args
6192 cast_ops_args = (
@@ -104,7 +135,7 @@ def _convert_casting_operators(self, node: torch.fx.Node):
104135 f" { input_dtype } ->torch.int32 defined in { node .meta .get ('stack_trace' ,'[no stack trace found]' )} "
105136 )
106137
107- def _convert_argmax_operators (self , node : torch .fx .Node , graph : torch .fx .Graph ):
138+ def _cast_int64_output_to_int32 (self , node : torch .fx .Node , graph : torch .fx .Graph ):
108139 output_tensor = node
109140 to_copy_op = self ._get_decomposition (node .target )
110141 with graph .inserting_after (node ):
@@ -138,9 +169,23 @@ def call(self, graph_module: torch.fx.GraphModule):
138169
139170 if node .target in self .aten_cast_ops + self .edge_cast_ops :
140171 self ._convert_casting_operators (node )
141- elif node .target in self .aten_argmax_ops + self .edge_argmax_ops :
142- # TODO: Add range check based on the input tensor shape before casting the output
143- self ._convert_argmax_operators (node , graph )
172+ elif node .target in (
173+ self .aten_argmax_ops
174+ + self .edge_argmax_ops
175+ + self .aten_argmin_ops
176+ + self .edge_argmin_ops
177+ ):
178+ if not self ._is_int32_range_safe (node ):
179+ msg = (
180+ f"{ node .target } reduces over more than { self ._INT32_MAX } elements; "
181+ f"the int64 index cannot be safely cast to int32."
182+ )
183+ if self .on_overflow == "raise" :
184+ raise RuntimeError (msg )
185+ if self .on_overflow == "warn" :
186+ logger .warning (msg )
187+ continue
188+ self ._cast_int64_output_to_int32 (node , graph )
144189 else :
145190 raise RuntimeError (f"Unexpected target { node .target } in { node .name } " )
146191
0 commit comments