Skip to content

Commit 5563ee9

Browse files
Arm backend: TOSAQuantizerV2 fixes (pytorch#20031)
Break out fixes from pytorch#19758 as discussed in pytorch#19966 --------- Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Co-authored-by: RJ Ascani <rja@meta.com>
1 parent 01b3568 commit 5563ee9

9 files changed

Lines changed: 173 additions & 70 deletions

File tree

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 85 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,18 @@ class PatternQuantizer(Quantizer, QuantizerReporterUser):
243243
244244
"""
245245

246+
PARAMETER_TARGETS = {
247+
torch.ops.aten.linear.default,
248+
torch.ops.aten.convolution.default,
249+
torch.ops.aten.conv1d.default,
250+
torch.ops.aten.conv1d.padding,
251+
torch.ops.aten.conv2d.default,
252+
torch.ops.aten.conv2d.padding,
253+
torch.ops.aten.conv3d.default,
254+
torch.ops.aten.conv3d.padding,
255+
torch.ops.aten.conv_transpose2d.input,
256+
}
257+
246258
def __init__(
247259
self,
248260
quantization_config: QuantizationConfig | None,
@@ -275,75 +287,59 @@ def get_quantizer_info(self):
275287
support_config_path,
276288
)
277289

278-
def is_parameter(self, node: Node, model: torch.fx.GraphModule) -> bool:
279-
"""Returns True if the given node is a parameter of the model."""
280-
try:
281-
_ = model.get_parameter(node.target) # type: ignore[arg-type]
282-
return True
283-
except Exception:
290+
def is_weight(self, node: Node) -> bool:
291+
"""Returns True if node is used as a weight by all users."""
292+
if node.op != "get_attr":
284293
return False
285294

286-
def is_weight(
287-
self, node: Node, params: list[Node], model: torch.fx.GraphModule
288-
) -> bool:
289-
"""Returns True if node is the first parameter of the given
290-
parameters.
291-
"""
292-
return len(params) > 0 and node == params[0]
295+
# Ensure that the node is used as a weight by all users
296+
for user_node in node.users:
297+
if user_node.target not in self.PARAMETER_TARGETS:
298+
return False
293299

294-
def is_bias(
295-
self, node: Node, params: list[Node], model: torch.fx.GraphModule
296-
) -> bool:
297-
"""Returns True if node is the second parameter of the given
298-
parameters.
299-
"""
300-
return len(params) == 2 and node == params[1]
300+
args = list(user_node.args)
301+
if not (len(args) > 1 and node == args[1]):
302+
return False
303+
304+
return True
305+
306+
def is_bias(self, node: Node) -> bool:
307+
"""Returns True if node is used as a bias by all users."""
308+
if node.op != "get_attr":
309+
return False
310+
311+
# Ensure that the node is used as a bias by all users
312+
for user_node in node.users:
313+
if user_node.target not in self.PARAMETER_TARGETS:
314+
return False
315+
316+
args = list(user_node.args)
317+
if not (len(args) > 2 and node == args[2]):
318+
return False
319+
320+
return True
301321

302322
def annotate_match(
303323
self,
304324
match: list[Node],
305325
config: QuantizationConfig | None,
306-
model: torch.fx.GraphModule,
307326
) -> None:
308327
"""Annotates a matched pattern according to the given quantization
309328
config.
310329
"""
311-
parameter_targets = {
312-
torch.ops.aten.linear.default,
313-
torch.ops.aten.convolution.default,
314-
torch.ops.aten.conv1d.default,
315-
torch.ops.aten.conv1d.padding,
316-
torch.ops.aten.conv2d.default,
317-
torch.ops.aten.conv2d.padding,
318-
torch.ops.aten.conv3d.default,
319-
torch.ops.aten.conv3d.padding,
320-
torch.ops.aten.conv_transpose2d.input,
321-
}
322330

323331
for node in match:
324332
input_qspec_map = {}
325333
output_qspec = None
326334

327-
params = [n for n in node.all_input_nodes if self.is_parameter(n, model)]
328-
if node.target in parameter_targets:
329-
if len(params) == 0 or len(params) > 2:
330-
logger.warning(
331-
f"{node.name} is expected to have parameter tensors for weight/bias but no such inputs found, which may cause unexpected quantization annotations. This is likely caused by incorrect tensor instantiations or non-constant weight/biases."
332-
)
333-
else:
334-
if len(params) > 0:
335-
logger.warning(
336-
f"{node.name} is not expected to not have parameter tensors but found {[n.name for n in params]}, which may cause unexpected quantization annotations."
337-
)
338-
339335
for input_node in node.all_input_nodes:
340336
if not has_float_output(input_node):
341337
continue
342-
if self.is_weight(input_node, params, model):
338+
if self.is_weight(input_node):
343339
input_qspec_map[input_node] = (
344340
config.get_weight_qspec(node) if config else None
345341
)
346-
elif self.is_bias(input_node, params, model):
342+
elif self.is_bias(input_node):
347343
input_qspec_map[input_node] = (
348344
config.get_bias_qspec(node) if config else None # type: ignore[assignment]
349345
)
@@ -370,7 +366,7 @@ def annotate(self, model: torch.fx.GraphModule) -> None: # type: ignore[overrid
370366
)
371367
for result in matches:
372368
if result.accepted:
373-
self.annotate_match(result.pattern, self.quantization_config, model)
369+
self.annotate_match(result.pattern, self.quantization_config)
374370
self.report_accept(result.pattern)
375371
else:
376372
self.report_reject(
@@ -424,6 +420,9 @@ class SharedQspecQuantizer(Quantizer, QuantizerReporterUser):
424420
torch.ops.aten.flip.default,
425421
torch.ops.aten.index_select.default,
426422
torch.ops.aten.index_put.default,
423+
torch.ops.aten.index_put_.default,
424+
torch.ops.aten.index_copy.default,
425+
torch.ops.aten.index_copy_.default,
427426
torch.ops.aten.contiguous.default,
428427
torch.ops.aten.as_strided_copy.default,
429428
torch.ops.aten.pixel_shuffle.default,
@@ -571,6 +570,42 @@ def _get_shared_clique(self, root_node: Node) -> tuple[set[Node], list[Any]]:
571570

572571
return shared_nodes, adjacent_qspecs
573572

573+
def _should_skip_while_shared_qspec(self, node: Node) -> bool:
574+
return node.target == torch.ops.higher_order.while_loop and bool(
575+
node.meta.get("additional_inputs")
576+
)
577+
578+
def _annotate_while_with_additional_inputs(
579+
self,
580+
root_node: Node,
581+
adjacent_qspecs: list[Any],
582+
) -> bool:
583+
if not self._should_skip_while_shared_qspec(root_node):
584+
return False
585+
if len(adjacent_qspecs) == 0:
586+
self.report_reject(
587+
[root_node],
588+
"Couldn't find any adjacent quantization spec to annotate while_loop.",
589+
)
590+
return True
591+
592+
input_qspec = adjacent_qspecs[0]
593+
input_qspec_map: dict[Node, Optional[QuantizationSpec]] = {
594+
n: input_qspec for n in self._get_input_nodes_with_float_output(root_node)
595+
}
596+
output_qspec: Optional[QuantizationSpec] = None
597+
if len(self._get_user_nodes_with_float_input(root_node)) > 0:
598+
output_qspec = input_qspec
599+
600+
_mark_node_as_quantized(
601+
root_node,
602+
input_qspec_map,
603+
output_qspec,
604+
is_quantized=True,
605+
)
606+
self.report_accept([root_node])
607+
return True
608+
574609
def _annotate_shared_cluster(self, root_node: Node) -> None:
575610
if (
576611
len(self._get_input_nodes_with_float_output(root_node)) == 0
@@ -592,9 +627,11 @@ def _annotate_shared_cluster(self, root_node: Node) -> None:
592627
node_order = {node: index for index, node in enumerate(root_node.graph.nodes)}
593628
ordered_nodes = sorted(shared_nodes, key=lambda node: node_order.get(node, 0))
594629

630+
if self._annotate_while_with_additional_inputs(root_node, adjacent_qspecs):
631+
return
632+
595633
# Ensure the root node is the first one in the graph.
596634
root_node = ordered_nodes[0]
597-
598635
if len(adjacent_qspecs) > 0:
599636
root_node_float_inputs = self._get_input_nodes_with_float_output(root_node)
600637
if len(root_node_float_inputs) > 0:

backends/arm/quantizer/quantization_annotator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from executorch.backends.arm.common.type import ensure_type
2222
from executorch.backends.arm.quantizer import QuantizationConfig
2323

24+
from torch._ops import OpOverload
2425
from torch._subclasses import FakeTensor
2526
from torch.fx import Node
2627
from torchao.quantization.pt2e import (
@@ -441,7 +442,7 @@ def _match_pattern(
441442
return left_condition and right_condition
442443

443444

444-
_conv_ops = {
445+
_conv_ops: set[OpOverload] = {
445446
torch.ops.aten.conv1d.default,
446447
torch.ops.aten.conv2d.default,
447448
torch.ops.aten.conv2d.padding,
@@ -473,7 +474,7 @@ def _match_pattern(
473474
},
474475
}
475476

476-
_one_to_one = {
477+
_one_to_one: set[OpOverload] = {
477478
torch.ops.aten.abs.default,
478479
torch.ops.aten.ceil.default,
479480
torch.ops.aten.erf.default,
@@ -514,7 +515,7 @@ def _match_pattern(
514515
torch.ops.aten.tan.default,
515516
}
516517

517-
_one_to_one_shared_input_qspec = {
518+
_one_to_one_shared_input_qspec: set[OpOverload] = {
518519
torch.ops.aten.squeeze.default,
519520
torch.ops.aten.squeeze_copy.default,
520521
torch.ops.aten.squeeze_copy.dim,
@@ -574,7 +575,7 @@ def _match_pattern(
574575
torch.ops.aten.detach_copy.default,
575576
}
576577

577-
_one_to_one_shared_input_or_input_act_qspec = {
578+
_one_to_one_shared_input_or_input_act_qspec: set[OpOverload] = {
578579
torch.ops.aten.alias.default,
579580
torch.ops.aten.clone.default,
580581
torch.ops.aten.hardtanh.default,

backends/arm/quantizer/quantization_config.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from torchao.quantization.pt2e.quantizer import (
2323
DerivedQuantizationSpec,
24+
FixedQParamsQuantizationSpec,
2425
QuantizationSpec,
2526
QuantizationSpecBase,
2627
SharedQuantizationSpec,
@@ -284,10 +285,18 @@ def get_input_act_qspec(self, node=None, input_node=None):
284285
285286
For comparison operators, make sure that both inputs share the same
286287
quantization spec, by returning a SharedQuantizationSpec that ties the
287-
quantization of both inputs together. For other operators, return the
288-
default input activation spec.
288+
quantization of both inputs together.
289+
290+
For trigonometric ops, ensure that input spec has fixed qparams.
291+
292+
For other operators, return the default input activation spec.
289293
290294
"""
295+
# MLETORCH-1853: Fix lazy import when moving files around
296+
from executorch.backends.arm.quantizer.quantization_annotator import (
297+
_fixed_input_qspec_ops,
298+
)
299+
291300
if node is None or input_node is None:
292301
return super().get_input_act_qspec(node, input_node)
293302

@@ -296,6 +305,29 @@ def get_input_act_qspec(self, node=None, input_node=None):
296305
return super().get_input_act_qspec(node, input_node)
297306
else:
298307
return SharedQuantizationSpec((node.args[0], node))
308+
elif node.target in _fixed_input_qspec_ops:
309+
310+
input_act_qspec = super().get_input_act_qspec(node, input_node)
311+
if not hasattr(input_act_qspec, "dtype") or not isinstance(
312+
input_act_qspec.dtype, torch.dtype
313+
):
314+
raise ValueError(
315+
f"{node.target} requires an input activation quantization "
316+
"spec to use fixed input qparams."
317+
)
318+
dtype = getattr(input_act_qspec, "dtype", None)
319+
num_bits = torch.iinfo(dtype).bits
320+
321+
qparams = _fixed_input_qspec_ops[node.target][num_bits]
322+
return FixedQParamsQuantizationSpec(
323+
dtype=dtype,
324+
scale=qparams.scale,
325+
zero_point=qparams.zero_point,
326+
quant_min=input_act_qspec.quant_min,
327+
quant_max=input_act_qspec.quant_max,
328+
qscheme=input_act_qspec.qscheme,
329+
is_dynamic=input_act_qspec.is_dynamic,
330+
)
299331

300332
return super().get_input_act_qspec(node, input_node)
301333

backends/arm/quantizer/quantizer_support.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,6 @@ def check_pattern(cls, pattern):
7777
torch.ops.aten.relu_.default,
7878
torch.ops.aten.hardtanh.default,
7979
torch.ops.aten.hardtanh_.default,
80-
torch.ops.aten.hardsigmoid.default,
81-
torch.ops.aten.hardsigmoid_.default,
8280
torch.ops.aten.clamp.default,
8381
torch.ops.aten.clamp_.default,
8482
]
@@ -168,6 +166,14 @@ def check_pattern(cls, pattern):
168166
(torch.ops.aten.ge.Scalar,),
169167
(torch.ops.aten.eq.Scalar,),
170168
(torch.ops.aten.ne.Scalar,),
169+
(torch.ops.aten.lstm.input,),
170+
(torch.ops.aten.rnn_tanh.input,),
171+
(torch.ops.aten.rnn_relu.input,),
172+
(torch.ops.aten.gru.input,),
173+
(torch.ops.aten.asin.default,),
174+
(torch.ops.aten.acos.default,),
175+
(torch.ops.aten.atanh.default,),
176+
(torch.ops.aten.einsum.default,),
171177
]
172178
)
173179
TOSA_QUANTIZER_SUPPORT_DICT: dict[tuple[OpOverload, ...], type[PatternCheck] | None] = {

backends/arm/scripts/docgen/docgen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def get_docstring(obj) -> str:
4646

4747
lines = docstring.split("\n")
4848
for line in lines:
49-
if ":" in line and line.startswith(" "):
49+
# Only first-level arg lines should become bullets.
50+
is_arg_line = line.startswith(" ") and not line.startswith(" ")
51+
if ":" in line and is_arg_line:
5052
new_line = line.strip()
5153
pos = new_line.index(":")
5254
new_line = f"- **{new_line[:pos]}**" + new_line[pos:]

backends/cortex_m/test/misc/test_portable_int8.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,36 @@ def _quantize_and_export(
301301
(torch.randn(6), torch.randn(6)),
302302
torch.int64,
303303
),
304+
"index_put_": OpCase(
305+
torch.ops.aten.index_put_.default,
306+
_build_module(
307+
lambda x, y: torch.ops.aten.index_put_.default(
308+
x, (torch.tensor([1, 3]),), torch.tensor([1.0, 2.0]), False
309+
)
310+
),
311+
(torch.randn(6), torch.randn(6)),
312+
torch.int64,
313+
),
314+
"index_copy": OpCase(
315+
torch.ops.aten.index_copy.default,
316+
_build_module(
317+
lambda x, y: torch.ops.aten.index_copy.default(
318+
x, 0, torch.tensor([0, 2]), y
319+
)
320+
),
321+
(torch.randn(4, 5), torch.randn(2, 5)),
322+
torch.int64,
323+
),
324+
"index_copy_": OpCase(
325+
torch.ops.aten.index_copy_.default,
326+
_build_module(
327+
lambda x, y: torch.ops.aten.index_copy_.default(
328+
x, 0, torch.tensor([0, 2]), y
329+
)
330+
),
331+
(torch.randn(4, 5), torch.randn(2, 5)),
332+
torch.int64,
333+
),
304334
"contiguous": OpCase(
305335
torch.ops.aten.contiguous.default,
306336
_build_module(lambda x, y: torch.ops.aten.contiguous.default(x)),

docs/source/backends/arm-ethos-u/tutorials/ethos-u-getting-started.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ In this tutorial you will learn how to export a simple PyTorch model for the Exe
2020
```{tip}
2121
If you are already familiar with this delegate, you may want to jump directly to the examples:
2222
* [Examples in the ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm)
23-
* [A commandline compiler for example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py)
23+
* [A commandline compiler for quick tests and example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py)
2424
```
2525

2626
This tutorial serves as an introduction to using ExecuTorch to deploy PyTorch models on Arm&reg; Ethos&trade;-U targets. It is based on `ethos_u_minimal_example.ipynb`, provided in Arm’s examples folder.
@@ -142,9 +142,10 @@ save_pte_program(executorch_program_manager, "ethos_u_minimal_example.pte")
142142

143143

144144
```{tip}
145-
For a quick start, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte.
145+
For a quick test, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte.
146146
To produce a pte file equivalent to the one above, run
147-
`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=ethos_u_minimal_example.pte`
147+
`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=ethos_u_minimal_example.pte`.
148+
For production use, you should instead use the stable Python API shown above.
148149
```
149150

150151
### Runtime:

0 commit comments

Comments
 (0)