1313import logging
1414import operator
1515from dataclasses import dataclass , replace
16- from typing import Callable , cast , List , Optional , Sequence
16+ from typing import Any , Callable , cast , Iterable , List , NamedTuple , Optional , Sequence
1717
1818import torch
1919import torch .fx
2020from executorch .backends .arm .common .debug import get_node_debug_info
2121from executorch .backends .arm .common .type import ensure_type
2222from executorch .backends .arm .quantizer import QuantizationConfig
23- from torch ._subclasses import FakeTensor
2423
24+ from torch ._subclasses import FakeTensor
2525from torch .fx import Node
2626from torchao .quantization .pt2e import (
2727 FakeQuantize ,
2828 FusedMovingAvgObsFakeQuantize ,
2929 MovingAveragePerChannelMinMaxObserver ,
3030 PartialWrapper ,
3131)
32+
3233from torchao .quantization .pt2e .quantizer import (
3334 annotate_input_qspec_map ,
3435 annotate_output_qspec ,
36+ FixedQParamsQuantizationSpec ,
3537 QuantizationSpec ,
3638 QuantizationSpecBase ,
3739 SharedQuantizationSpec ,
@@ -78,6 +80,11 @@ def __init__(self):
7880 self .quant_output : Optional [_QuantProperty ] = None
7981
8082
83+ class _QParams (NamedTuple ):
84+ scale : float
85+ zero_point : int
86+
87+
8188def _as_list (x ):
8289 """Return ``x`` wrapped as a list if needed.
8390
@@ -391,14 +398,16 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
391398
392399
393400def _match_pattern (
394- node : Node , pattern : List [List ], filter_fn : Optional [Callable [[Node ], bool ]] = None
401+ node : Node ,
402+ pattern : Sequence [Iterable [object ]],
403+ filter_fn : Optional [Callable [[Node ], bool ]] = None ,
395404) -> bool :
396405 """Check whether a node chain matches a pattern.
397406
398407 Verify a chain of ancestors -> node -> descendants matches the provided
399408 ``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
400- to pass the filter. Each pattern element is a list of disjunctive node
401- targets.
409+ to pass the filter. Each pattern element is an iterable of disjunctive
410+ node targets.
402411
403412 """
404413 if len (pattern ) < 1 :
@@ -432,16 +441,39 @@ def _match_pattern(
432441 return left_condition and right_condition
433442
434443
435- _conv_ops = [
444+ _conv_ops = {
436445 torch .ops .aten .conv1d .default ,
437446 torch .ops .aten .conv2d .default ,
438447 torch .ops .aten .conv2d .padding ,
439448 torch .ops .aten .conv_transpose2d .input ,
440449 torch .ops .aten .conv3d .default ,
441450 torch .ops .aten .conv3d .padding ,
442- ]
451+ }
443452
444- _one_to_one = [
453+ # For these ops, we use fixed qspecs, meaning that quantization params for
454+ # these are statically defined. This is to prevent issues with out-of-range
455+ # values when using dynamic quantization.
456+ #
457+ # Dict of operator to a dict of num_bits to qparams for that operator.
458+ _fixed_input_qspec_ops : dict [Any , dict [int , _QParams ]] = {
459+ # acos has a valid range of [-1, 1]
460+ torch .ops .aten .acos .default : {
461+ 8 : _QParams ((1.0 - (- 1.0 )) / (1 << 8 ), 0 ),
462+ 16 : _QParams ((1.0 - (- 1.0 )) / (1 << 16 ), 0 ),
463+ },
464+ # asin has a valid range of [-1, 1]
465+ torch .ops .aten .asin .default : {
466+ 8 : _QParams ((1.0 - (- 1.0 )) / (1 << 8 ), 0 ),
467+ 16 : _QParams ((1.0 - (- 1.0 )) / (1 << 16 ), 0 ),
468+ },
469+ # atanh has a valid range of (-1, 1) (excluding -1 and 1).
470+ torch .ops .aten .atanh .default : {
471+ 8 : _QParams ((0.999 - (- 0.999 )) / (1 << 8 ), 0 ),
472+ 16 : _QParams ((0.99999 - (- 0.99999 )) / (1 << 16 ), 0 ),
473+ },
474+ }
475+
476+ _one_to_one = {
445477 torch .ops .aten .abs .default ,
446478 torch .ops .aten .ceil .default ,
447479 torch .ops .aten .erf .default ,
@@ -472,16 +504,13 @@ def _match_pattern(
472504 torch .ops .aten .log1p .default ,
473505 torch .ops .aten .acosh .default ,
474506 torch .ops .aten .sign .default ,
475- torch .ops .aten .asin .default ,
476- torch .ops .aten .atanh .default ,
477507 torch .ops .aten .asinh .default ,
478508 torch .ops .aten .cosh .default ,
479- torch .ops .aten .acos .default ,
480509 torch .ops .aten .cumsum .default ,
481510 torch .ops .aten .tan .default ,
482- ]
511+ }
483512
484- _one_to_one_shared_input_qspec = [
513+ _one_to_one_shared_input_qspec = {
485514 torch .ops .aten .squeeze .default ,
486515 torch .ops .aten .squeeze_copy .default ,
487516 torch .ops .aten .squeeze_copy .dim ,
@@ -539,9 +568,9 @@ def _match_pattern(
539568 # dequant -> neg -> requant chain.
540569 torch .ops .aten .neg .default ,
541570 torch .ops .aten .detach_copy .default ,
542- ]
571+ }
543572
544- _one_to_one_shared_input_or_input_act_qspec = [
573+ _one_to_one_shared_input_or_input_act_qspec = {
545574 torch .ops .aten .alias .default ,
546575 torch .ops .aten .clone .default ,
547576 torch .ops .aten .hardtanh .default ,
@@ -562,7 +591,7 @@ def _match_pattern(
562591 torch .ops .aten .alias_copy .default ,
563592 torch .ops .aten .pixel_shuffle .default ,
564593 torch .ops .aten .pixel_unshuffle .default ,
565- ]
594+ }
566595
567596
568597def get_quant_properties ( # noqa: C901
@@ -615,13 +644,13 @@ def any_or_hardtanh_min_zero(n: Node):
615644 node ,
616645 [
617646 _conv_ops ,
618- [ torch .ops .aten .batch_norm .default ] ,
619- [
647+ { torch .ops .aten .batch_norm .default } ,
648+ {
620649 torch .ops .aten .relu .default ,
621650 torch .ops .aten .relu_ .default ,
622651 torch .ops .aten .hardtanh .default ,
623652 torch .ops .aten .hardtanh_ .default ,
624- ] ,
653+ } ,
625654 ],
626655 filter_fn = any_or_hardtanh_min_zero ,
627656 ):
@@ -644,7 +673,7 @@ def any_or_hardtanh_min_zero(n: Node):
644673 node ,
645674 [
646675 _conv_ops ,
647- [ torch .ops .aten .batch_norm .default ] ,
676+ { torch .ops .aten .batch_norm .default } ,
648677 ],
649678 ):
650679 if node .target in _conv_ops :
@@ -654,23 +683,21 @@ def any_or_hardtanh_min_zero(n: Node):
654683 _QuantProperty (1 , conv_weight_qspec , mark_annotated = True ),
655684 _QuantProperty (2 , bias_qspec , optional = True , mark_annotated = True ),
656685 ]
657- elif node .target in [
658- torch .ops .aten .batch_norm .default ,
659- ]:
686+ elif node .target in {torch .ops .aten .batch_norm .default }:
660687 quant_properties .quant_output = _QuantProperty (0 , output_act_qspec )
661688 elif not is_symmetric and _match_pattern (
662689 node ,
663690 [
664- [
691+ {
665692 * _conv_ops ,
666693 torch .ops .aten .linear .default ,
667- ] ,
668- [
694+ } ,
695+ {
669696 torch .ops .aten .relu .default ,
670697 torch .ops .aten .relu_ .default ,
671698 torch .ops .aten .hardtanh .default ,
672699 torch .ops .aten .hardtanh_ .default ,
673- ] ,
700+ } ,
674701 ],
675702 any_or_hardtanh_min_zero ,
676703 ):
@@ -784,6 +811,25 @@ def any_or_hardtanh_min_zero(n: Node):
784811 elif node .target in _one_to_one :
785812 quant_properties .quant_inputs = [_QuantProperty (0 , input_act_qspec )]
786813 quant_properties .quant_output = _QuantProperty (0 , output_act_qspec )
814+ elif node .target in _fixed_input_qspec_ops :
815+ num_bits = torch .iinfo (input_act_qspec .dtype ).bits
816+ qparams = _fixed_input_qspec_ops [node .target ][num_bits ]
817+
818+ quant_properties .quant_inputs = [
819+ _QuantProperty (
820+ 0 ,
821+ FixedQParamsQuantizationSpec (
822+ dtype = input_act_qspec .dtype ,
823+ scale = qparams .scale ,
824+ zero_point = qparams .zero_point ,
825+ quant_min = input_act_qspec .quant_min ,
826+ quant_max = input_act_qspec .quant_max ,
827+ qscheme = input_act_qspec .qscheme ,
828+ is_dynamic = input_act_qspec .is_dynamic ,
829+ ),
830+ )
831+ ]
832+ quant_properties .quant_output = _QuantProperty (0 , output_act_qspec )
787833 elif node .target in _one_to_one_shared_input_qspec :
788834 input_node = ensure_type (Node , node .args [0 ])
789835 quant_properties .quant_inputs = [_QuantProperty (0 , input_act_qspec )]
0 commit comments