44# LICENSE file in the root directory of this source tree.
55
66import math
7+ import operator
78from copy import copy
89from typing import cast , Dict , Optional , Set , Tuple , Type
910
@@ -34,22 +35,60 @@ class InsertRescalePass(ArmPass):
3435
3536 _passes_required_after : Set [Type [ExportPass ]] = set ()
3637
38+ def _ensure_uint8_io_only (self , graph_module : GraphModule ) -> None :
39+ """Ensure uint8 tensors only appear at IO boundaries.
40+
41+ TOSA has no true uint8 tensor type; unsigned semantics are carried via
42+ RESCALE input/output flags. If uint8 appears for other nodes, it means
43+ unsigned data leaked past IO.
44+
45+ """
46+ for node in graph_module .graph .nodes :
47+ meta_val = node .meta .get ("val" )
48+ if not isinstance (meta_val , torch .Tensor ):
49+ continue
50+ if meta_val .dtype != torch .uint8 :
51+ continue
52+ if node .op in ("placeholder" , "output" ):
53+ continue
54+ if node .op == "call_function" and node .target == operator .getitem :
55+ if all (user .op == "output" for user in node .users ):
56+ continue
57+ if (
58+ node .op == "call_function"
59+ and node .target == exir_ops .backend .tosa .RESCALE .default
60+ ):
61+ continue
62+ raise ValueError (
63+ f"Found internal uint8 tensor at node { node .name } "
64+ f"({ node .target } ). Uint8 is only allowed at IO boundaries."
65+ )
66+
3767 def fold_dq_q_to_rescale (self , node : Node , user : Node , graph_module : GraphModule ):
3868 dq_args = QuantArgs .from_operator (node .target , node .args )
3969 q_args = QuantArgs .from_operator (user .target , user .args )
4070 new_scale = dq_args .scale / q_args .scale
71+ input_unsigned = dq_args .dtype == torch .uint8
72+ output_unsigned = q_args .dtype == torch .uint8
73+ # TOSA has no true uint8 tensors; unsigned semantics are handled via
74+ # the RESCALE flags, so uint8 does not propagate as a tensor dtype.
75+ output_dtype = torch .int8 if output_unsigned else q_args .dtype
4176
4277 with graph_module .graph .inserting_before (node ):
4378 rescale_node = create_node (
4479 graph_module .graph ,
4580 exir_ops .backend .tosa .RESCALE .default ,
4681 (
4782 node .all_input_nodes [0 ],
48- q_args . dtype ,
83+ output_dtype ,
4984 [new_scale ],
5085 dq_args .zp ,
5186 q_args .zp ,
5287 ),
88+ kwargs = {
89+ "input_unsigned" : input_unsigned ,
90+ "output_unsigned" : output_unsigned ,
91+ },
5392 )
5493 rescale_node .meta = copy (user .meta )
5594 user .replace_all_uses_with (rescale_node )
@@ -74,6 +113,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
74113 graph_module .recompile ()
75114 return PassResult (graph_module , modified )
76115
116+ def ensures (self , graph_module : GraphModule ) -> None :
117+ self ._ensure_uint8_io_only (graph_module )
118+
77119
78120class InsertRescaleInt32Pass (ArmPass ):
79121 """Numerous TOSA ops require inputs and outputs to be 32-bit integers in
0 commit comments