1313import logging
1414import operator
1515from dataclasses import dataclass , replace
16- from typing import Callable , cast , List , Optional , Sequence
16+ from typing import Callable , cast , Iterable , List , Optional , Sequence
1717
1818import torch
1919import torch .fx
@@ -391,14 +391,16 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
391391
392392
393393def _match_pattern (
394- node : Node , pattern : List [List ], filter_fn : Optional [Callable [[Node ], bool ]] = None
394+ node : Node ,
395+ pattern : Sequence [Iterable [object ]],
396+ filter_fn : Optional [Callable [[Node ], bool ]] = None ,
395397) -> bool :
396398 """Check whether a node chain matches a pattern.
397399
398400 Verify a chain of ancestors -> node -> descendants matches the provided
399401 ``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
400- to pass the filter. Each pattern element is a list of disjunctive node
401- targets.
402+ to pass the filter. Each pattern element is an iterable of disjunctive
403+ node targets.
402404
403405 """
404406 if len (pattern ) < 1 :
@@ -432,16 +434,16 @@ def _match_pattern(
432434 return left_condition and right_condition
433435
434436
435- _conv_ops = [
437+ _conv_ops = {
436438 torch .ops .aten .conv1d .default ,
437439 torch .ops .aten .conv2d .default ,
438440 torch .ops .aten .conv2d .padding ,
439441 torch .ops .aten .conv_transpose2d .input ,
440442 torch .ops .aten .conv3d .default ,
441443 torch .ops .aten .conv3d .padding ,
442- ]
444+ }
443445
444- _one_to_one = [
446+ _one_to_one = {
445447 torch .ops .aten .abs .default ,
446448 torch .ops .aten .ceil .default ,
447449 torch .ops .aten .erf .default ,
@@ -479,9 +481,9 @@ def _match_pattern(
479481 torch .ops .aten .acos .default ,
480482 torch .ops .aten .cumsum .default ,
481483 torch .ops .aten .tan .default ,
482- ]
484+ }
483485
484- _one_to_one_shared_input_qspec = [
486+ _one_to_one_shared_input_qspec = {
485487 torch .ops .aten .squeeze .default ,
486488 torch .ops .aten .squeeze_copy .default ,
487489 torch .ops .aten .squeeze_copy .dim ,
@@ -539,9 +541,9 @@ def _match_pattern(
539541 # dequant -> neg -> requant chain.
540542 torch .ops .aten .neg .default ,
541543 torch .ops .aten .detach_copy .default ,
542- ]
544+ }
543545
544- _one_to_one_shared_input_or_input_act_qspec = [
546+ _one_to_one_shared_input_or_input_act_qspec = {
545547 torch .ops .aten .alias .default ,
546548 torch .ops .aten .clone .default ,
547549 torch .ops .aten .hardtanh .default ,
@@ -562,7 +564,7 @@ def _match_pattern(
562564 torch .ops .aten .alias_copy .default ,
563565 torch .ops .aten .pixel_shuffle .default ,
564566 torch .ops .aten .pixel_unshuffle .default ,
565- ]
567+ }
566568
567569
568570def get_quant_properties ( # noqa: C901
@@ -615,13 +617,13 @@ def any_or_hardtanh_min_zero(n: Node):
615617 node ,
616618 [
617619 _conv_ops ,
618- [ torch .ops .aten .batch_norm .default ] ,
619- [
620+ { torch .ops .aten .batch_norm .default } ,
621+ {
620622 torch .ops .aten .relu .default ,
621623 torch .ops .aten .relu_ .default ,
622624 torch .ops .aten .hardtanh .default ,
623625 torch .ops .aten .hardtanh_ .default ,
624- ] ,
626+ } ,
625627 ],
626628 filter_fn = any_or_hardtanh_min_zero ,
627629 ):
@@ -644,7 +646,7 @@ def any_or_hardtanh_min_zero(n: Node):
644646 node ,
645647 [
646648 _conv_ops ,
647- [ torch .ops .aten .batch_norm .default ] ,
649+ { torch .ops .aten .batch_norm .default } ,
648650 ],
649651 ):
650652 if node .target in _conv_ops :
@@ -654,23 +656,21 @@ def any_or_hardtanh_min_zero(n: Node):
654656 _QuantProperty (1 , conv_weight_qspec , mark_annotated = True ),
655657 _QuantProperty (2 , bias_qspec , optional = True , mark_annotated = True ),
656658 ]
657- elif node .target in [
658- torch .ops .aten .batch_norm .default ,
659- ]:
659+ elif node .target in {torch .ops .aten .batch_norm .default }:
660660 quant_properties .quant_output = _QuantProperty (0 , output_act_qspec )
661661 elif not is_symmetric and _match_pattern (
662662 node ,
663663 [
664- [
664+ {
665665 * _conv_ops ,
666666 torch .ops .aten .linear .default ,
667- ] ,
668- [
667+ } ,
668+ {
669669 torch .ops .aten .relu .default ,
670670 torch .ops .aten .relu_ .default ,
671671 torch .ops .aten .hardtanh .default ,
672672 torch .ops .aten .hardtanh_ .default ,
673- ] ,
673+ } ,
674674 ],
675675 any_or_hardtanh_min_zero ,
676676 ):
0 commit comments