Skip to content

Commit 8fdf040

Browse files
committed
Arm backend: Support uint8 IO quantization for backends
Add support for IO tensors only to be uint8. In conjuction with the QuantizeInput and QuantizeOutput pass this adds the possibility to give inputs of uint8 dtype to the model directly. Change-Id: Icc08ac242e5c980f2abd484eb0e7661418873ab7 Signed-off-by: Per Åstrand <per.astrand@arm.com>
1 parent 81bc830 commit 8fdf040

11 files changed

Lines changed: 876 additions & 26 deletions

File tree

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import math
7+
import operator
78
from copy import copy
89
from typing import cast, Dict, Optional, Set, Tuple, Type
910

@@ -34,22 +35,60 @@ class InsertRescalePass(ArmPass):
3435

3536
_passes_required_after: Set[Type[ExportPass]] = set()
3637

38+
def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None:
39+
"""Ensure uint8 tensors only appear at IO boundaries.
40+
41+
TOSA has no true uint8 tensor type; unsigned semantics are carried via
42+
RESCALE input/output flags. If uint8 appears for other nodes, it means
43+
unsigned data leaked past IO.
44+
45+
"""
46+
for node in graph_module.graph.nodes:
47+
meta_val = node.meta.get("val")
48+
if not isinstance(meta_val, torch.Tensor):
49+
continue
50+
if meta_val.dtype != torch.uint8:
51+
continue
52+
if node.op in ("placeholder", "output"):
53+
continue
54+
if node.op == "call_function" and node.target == operator.getitem:
55+
if all(user.op == "output" for user in node.users):
56+
continue
57+
if (
58+
node.op == "call_function"
59+
and node.target == exir_ops.backend.tosa.RESCALE.default
60+
):
61+
continue
62+
raise ValueError(
63+
f"Found internal uint8 tensor at node {node.name} "
64+
f"({node.target}). Uint8 is only allowed at IO boundaries."
65+
)
66+
3767
def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule):
3868
dq_args = QuantArgs.from_operator(node.target, node.args)
3969
q_args = QuantArgs.from_operator(user.target, user.args)
4070
new_scale = dq_args.scale / q_args.scale
71+
input_unsigned = dq_args.dtype == torch.uint8
72+
output_unsigned = q_args.dtype == torch.uint8
73+
# TOSA has no true uint8 tensors; unsigned semantics are handled via
74+
# the RESCALE flags, so uint8 does not propagate as a tensor dtype.
75+
output_dtype = torch.int8 if output_unsigned else q_args.dtype
4176

4277
with graph_module.graph.inserting_before(node):
4378
rescale_node = create_node(
4479
graph_module.graph,
4580
exir_ops.backend.tosa.RESCALE.default,
4681
(
4782
node.all_input_nodes[0],
48-
q_args.dtype,
83+
output_dtype,
4984
[new_scale],
5085
dq_args.zp,
5186
q_args.zp,
5287
),
88+
kwargs={
89+
"input_unsigned": input_unsigned,
90+
"output_unsigned": output_unsigned,
91+
},
5392
)
5493
rescale_node.meta = copy(user.meta)
5594
user.replace_all_uses_with(rescale_node)
@@ -74,6 +113,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
74113
graph_module.recompile()
75114
return PassResult(graph_module, modified)
76115

116+
def ensures(self, graph_module: GraphModule) -> None:
117+
self._ensure_uint8_io_only(graph_module)
118+
77119

78120
class InsertRescaleInt32Pass(ArmPass):
79121
"""Numerous TOSA ops require inputs and outputs to be 32-bit integers in

backends/arm/operators/op_tosa_rescale.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def _build_rescale(
161161
rounding_mode: ts.RoundingMode,
162162
per_channel: bool = False,
163163
is_scale32: bool = True,
164+
input_unsigned: bool = False,
165+
output_unsigned: bool = False,
164166
):
165167
"""Insert a TOSA RESCALE operator configured for the quantized path.
166168
@@ -198,8 +200,8 @@ def _build_rescale(
198200
scale32=is_scale32,
199201
rounding_mode=rounding_mode,
200202
per_channel=per_channel,
201-
input_unsigned=False,
202-
output_unsigned=False,
203+
input_unsigned=input_unsigned,
204+
output_unsigned=output_unsigned,
203205
)
204206

205207
tosa_fb.addOperator(
@@ -228,6 +230,14 @@ def define_node(
228230
scales = cast(list[float], node.args[2])
229231
input_zp = cast(int, node.args[3])
230232
output_zp = cast(int, node.args[4])
233+
if "input_unsigned" in node.kwargs:
234+
input_unsigned = cast(bool, node.kwargs.get("input_unsigned", False))
235+
else:
236+
input_unsigned = cast(bool, node.args[5]) if len(node.args) > 5 else False
237+
if "output_unsigned" in node.kwargs:
238+
output_unsigned = cast(bool, node.kwargs.get("output_unsigned", False))
239+
else:
240+
output_unsigned = cast(bool, node.args[6]) if len(node.args) > 6 else False
231241

232242
if (
233243
input_dtype
@@ -244,7 +254,6 @@ def define_node(
244254
raise ValueError(
245255
f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
246256
)
247-
248257
_build_rescale(
249258
tosa_graph,
250259
scale=scales,
@@ -255,4 +264,6 @@ def define_node(
255264
output_zp=[output_zp],
256265
rounding_mode=ts.RoundingMode.SINGLE_ROUND,
257266
per_channel=len(scales) > 1,
267+
input_unsigned=input_unsigned,
268+
output_unsigned=output_unsigned,
258269
)

backends/arm/quantizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
EthosUQuantizer,
1616
get_symmetric_a16w8_quantization_config,
1717
get_symmetric_quantization_config,
18+
get_uint8_io_quantization_config,
1819
TOSAQuantizer,
1920
VgfQuantizer,
2021
)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
"VgfQuantizer",
106106
"get_symmetric_a16w8_quantization_config",
107107
"get_symmetric_quantization_config",
108+
"get_uint8_io_quantization_config",
108109
]
109110

110111
logger = logging.getLogger(__name__)
@@ -234,6 +235,53 @@ def get_symmetric_quantization_config(
234235
return quantization_config
235236

236237

238+
@functools.lru_cache
239+
def get_uint8_io_quantization_config(
240+
is_qat: bool = False,
241+
is_dynamic: bool = False,
242+
eps: float = 2**-16,
243+
) -> QuantizationConfig:
244+
"""Create a uint8 IO quantization config for TOSA backends.
245+
246+
This config is intended for model inputs/outputs only. Internal tensors
247+
should remain int8 for TOSA INT lowering.
248+
249+
"""
250+
extra_args: Dict[str, Any] = {"eps": eps}
251+
if is_qat:
252+
if is_dynamic:
253+
act_observer_or_fake_quant_ctr = FakeQuantize
254+
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
255+
averaging_constant=1
256+
)
257+
extra_args["observer"] = dynamic_quant_observer
258+
else:
259+
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
260+
else:
261+
if is_dynamic:
262+
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
263+
else:
264+
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
265+
266+
act_quantization_spec = QuantizationSpec(
267+
dtype=torch.uint8,
268+
quant_min=torch.iinfo(torch.uint8).min,
269+
quant_max=torch.iinfo(torch.uint8).max,
270+
qscheme=torch.per_tensor_affine,
271+
is_dynamic=is_dynamic,
272+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
273+
**extra_args,
274+
),
275+
)
276+
277+
return TOSAQuantizationConfig(
278+
act_quantization_spec,
279+
act_quantization_spec,
280+
None,
281+
None,
282+
)
283+
284+
237285
def get_symmetric_a8w4_quantization_config(
238286
is_per_channel: bool = True, is_qat: bool = True, is_dynamic: bool = False
239287
):

backends/arm/test/misc/test_rescale_range.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66
from typing import Tuple
77

8+
import executorch.backends.arm.tosa.dialect # noqa: F401
9+
810
import pytest
911
import torch
1012

13+
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
1114
from executorch.backends.arm.tosa.specification import (
1215
TosaLoweringContext,
1316
TosaSpecification,
@@ -128,3 +131,76 @@ def test_zp_outside_range_tosa_INT():
128131
]
129132
)
130133
)
134+
135+
136+
def test_unsigned_zp_range_tosa_INT_valid():
137+
# Validate unsigned zero-point ranges via explicit unsigned semantics.
138+
# First case: uint8 input (input_unsigned=True) uses in_zp in [0,255].
139+
# Second case: signed int8 input but unsigned output semantics (output_unsigned=True)
140+
# allow out_zp in [0,255].
141+
sample_inputs = [
142+
# (data, out_dtype, scale, in_zp, out_zp, input_unsigned, output_unsigned)
143+
(
144+
torch.randint(low=0, high=255, size=(4, 4, 4), dtype=torch.uint8),
145+
torch.int8,
146+
[0.5],
147+
255,
148+
0,
149+
True,
150+
False,
151+
),
152+
(
153+
torch.randint(low=-128, high=127, size=(4, 4, 4), dtype=torch.int8),
154+
torch.int8,
155+
[0.5],
156+
0,
157+
255,
158+
False,
159+
True,
160+
),
161+
]
162+
163+
with TosaLoweringContext(
164+
TosaSpecification.create_from_string("TOSA-1.0+INT")
165+
), FakeTensorMode() as mode:
166+
for sample_input in sample_inputs:
167+
exir_ops.backend.tosa.RESCALE.default(
168+
*tuple(
169+
[
170+
mode.from_tensor(i) if isinstance(i, torch.Tensor) else i
171+
for i in sample_input[:5]
172+
]
173+
),
174+
input_unsigned=sample_input[5],
175+
output_unsigned=sample_input[6],
176+
)
177+
178+
179+
def test_unsigned_zp_range_tosa_INT_invalid():
180+
with TosaLoweringContext(
181+
TosaSpecification.create_from_string("TOSA-1.0+INT")
182+
), FakeTensorMode() as mode:
183+
with pytest.raises(TosaValueError, match="(in_zp|input_zp).*range"):
184+
exir_ops.backend.tosa.RESCALE.default(
185+
mode.from_tensor(
186+
torch.randint(low=0, high=255, size=(4, 4, 4), dtype=torch.uint8)
187+
),
188+
torch.int8,
189+
[0.5],
190+
256,
191+
0,
192+
input_unsigned=True,
193+
output_unsigned=False,
194+
)
195+
with pytest.raises(TosaValueError, match="(out_zp|output_zp).*range"):
196+
exir_ops.backend.tosa.RESCALE.default(
197+
mode.from_tensor(
198+
torch.randint(low=0, high=255, size=(4, 4, 4), dtype=torch.uint8)
199+
),
200+
torch.int8,
201+
[0.5],
202+
0,
203+
256,
204+
input_unsigned=False,
205+
output_unsigned=True,
206+
)

0 commit comments

Comments
 (0)