Skip to content

Commit e5be6d5

Browse files
Martin Lindströmmartinlsm
authored andcommitted
Arm backend: Declare op groups as sets
In quantization_annotator.py, declare groups of ops as sets instead of lists. The reason is that they are only used for lookup, so this change improves the time complexity for such lookups. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: I93428108b863c3dd2d6f80feea08859c059c187a
1 parent 8241d86 commit e5be6d5

1 file changed

Lines changed: 23 additions & 23 deletions

File tree

backends/arm/quantizer/quantization_annotator.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414
import operator
1515
from dataclasses import dataclass, replace
16-
from typing import Callable, cast, List, Optional, Sequence
16+
from typing import Callable, cast, Iterable, List, Optional, Sequence
1717

1818
import torch
1919
import torch.fx
@@ -391,14 +391,16 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
391391

392392

393393
def _match_pattern(
394-
node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None
394+
node: Node,
395+
pattern: Sequence[Iterable[object]],
396+
filter_fn: Optional[Callable[[Node], bool]] = None,
395397
) -> bool:
396398
"""Check whether a node chain matches a pattern.
397399
398400
Verify a chain of ancestors -> node -> descendants matches the provided
399401
``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
400-
to pass the filter. Each pattern element is a list of disjunctive node
401-
targets.
402+
to pass the filter. Each pattern element is an iterable of disjunctive
403+
node targets.
402404
403405
"""
404406
if len(pattern) < 1:
@@ -432,16 +434,16 @@ def _match_pattern(
432434
return left_condition and right_condition
433435

434436

435-
_conv_ops = [
437+
_conv_ops = {
436438
torch.ops.aten.conv1d.default,
437439
torch.ops.aten.conv2d.default,
438440
torch.ops.aten.conv2d.padding,
439441
torch.ops.aten.conv_transpose2d.input,
440442
torch.ops.aten.conv3d.default,
441443
torch.ops.aten.conv3d.padding,
442-
]
444+
}
443445

444-
_one_to_one = [
446+
_one_to_one = {
445447
torch.ops.aten.abs.default,
446448
torch.ops.aten.ceil.default,
447449
torch.ops.aten.erf.default,
@@ -479,9 +481,9 @@ def _match_pattern(
479481
torch.ops.aten.acos.default,
480482
torch.ops.aten.cumsum.default,
481483
torch.ops.aten.tan.default,
482-
]
484+
}
483485

484-
_one_to_one_shared_input_qspec = [
486+
_one_to_one_shared_input_qspec = {
485487
torch.ops.aten.squeeze.default,
486488
torch.ops.aten.squeeze_copy.default,
487489
torch.ops.aten.squeeze_copy.dim,
@@ -539,9 +541,9 @@ def _match_pattern(
539541
# dequant -> neg -> requant chain.
540542
torch.ops.aten.neg.default,
541543
torch.ops.aten.detach_copy.default,
542-
]
544+
}
543545

544-
_one_to_one_shared_input_or_input_act_qspec = [
546+
_one_to_one_shared_input_or_input_act_qspec = {
545547
torch.ops.aten.alias.default,
546548
torch.ops.aten.clone.default,
547549
torch.ops.aten.hardtanh.default,
@@ -562,7 +564,7 @@ def _match_pattern(
562564
torch.ops.aten.alias_copy.default,
563565
torch.ops.aten.pixel_shuffle.default,
564566
torch.ops.aten.pixel_unshuffle.default,
565-
]
567+
}
566568

567569

568570
def get_quant_properties( # noqa: C901
@@ -615,13 +617,13 @@ def any_or_hardtanh_min_zero(n: Node):
615617
node,
616618
[
617619
_conv_ops,
618-
[torch.ops.aten.batch_norm.default],
619-
[
620+
{torch.ops.aten.batch_norm.default},
621+
{
620622
torch.ops.aten.relu.default,
621623
torch.ops.aten.relu_.default,
622624
torch.ops.aten.hardtanh.default,
623625
torch.ops.aten.hardtanh_.default,
624-
],
626+
},
625627
],
626628
filter_fn=any_or_hardtanh_min_zero,
627629
):
@@ -644,7 +646,7 @@ def any_or_hardtanh_min_zero(n: Node):
644646
node,
645647
[
646648
_conv_ops,
647-
[torch.ops.aten.batch_norm.default],
649+
{torch.ops.aten.batch_norm.default},
648650
],
649651
):
650652
if node.target in _conv_ops:
@@ -654,23 +656,21 @@ def any_or_hardtanh_min_zero(n: Node):
654656
_QuantProperty(1, conv_weight_qspec, mark_annotated=True),
655657
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
656658
]
657-
elif node.target in [
658-
torch.ops.aten.batch_norm.default,
659-
]:
659+
elif node.target in {torch.ops.aten.batch_norm.default}:
660660
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
661661
elif not is_symmetric and _match_pattern(
662662
node,
663663
[
664-
[
664+
{
665665
*_conv_ops,
666666
torch.ops.aten.linear.default,
667-
],
668-
[
667+
},
668+
{
669669
torch.ops.aten.relu.default,
670670
torch.ops.aten.relu_.default,
671671
torch.ops.aten.hardtanh.default,
672672
torch.ops.aten.hardtanh_.default,
673-
],
673+
},
674674
],
675675
any_or_hardtanh_min_zero,
676676
):

0 commit comments

Comments
 (0)