33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6+
67import torch
78from executorch .backends .qualcomm .builders .node_visitor import dq_ops , q_ops
89from executorch .backends .qualcomm .utils .constants import QCOM_QUANT_ATTRS
1314from .utils import get_quant_attrs
1415
1516
16- class AnnotateAdaptiveAvgPool1D (ExportPass ):
17+ class AnnotateAvgPool1D (ExportPass ):
1718 """
1819 Add "quant_attrs" to graph nodes' meta from the QDQ information
1920 generated after quantization process.
20- adaptive_avg_pool1d got decomposed to unsqueeze -> adaptive_avg_pool2d -> squeeze
21+ avg_pool1d and adaptive_avg_pool1d get decomposed to:
22+ unsqueeze -> avg_pool2d/adaptive_avg_pool2d -> squeeze
2123 """
2224
25+ _SOURCE_OPS = [
26+ torch .ops .aten .avg_pool1d .default ,
27+ torch .avg_pool1d ,
28+ torch .ops .aten .adaptive_avg_pool1d .default ,
29+ torch .adaptive_avg_pool1d ,
30+ ]
31+
2332 def __init__ (self , edge_program : torch .export .ExportedProgram ):
24- super (AnnotateAdaptiveAvgPool1D , self ).__init__ ()
33+ super (AnnotateAvgPool1D , self ).__init__ ()
2534 self .edge_program = edge_program
2635
27- def _annotate_adaptive_avg_pool1d (self , graph_module : torch .fx .GraphModule ):
36+ def _annotate (self , graph_module : torch .fx .GraphModule ):
2837 partitions = get_source_partitions (
2938 graph_module .graph ,
30- [ torch . ops . aten . adaptive_avg_pool1d . default , torch . adaptive_avg_pool1d ] ,
39+ self . _SOURCE_OPS ,
3140 )
3241 for src_partitions in partitions .values ():
3342 for src_partition in src_partitions :
@@ -44,11 +53,11 @@ def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule):
4453 self .edge_program , list (output .users )[0 ]
4554 )
4655 for n in src_partition .nodes :
47- # For adaptive_avg_pool2d and squeeze
56+ # For avg_pool2d/ adaptive_avg_pool2d and squeeze
4857 if n .target != exir_ops .edge .aten .unsqueeze_copy .default :
4958 n .meta [QCOM_QUANT_ATTRS ] = quant_attrs .copy ()
5059
5160 def call (self , graph_module : torch .fx .GraphModule ):
52- self ._annotate_adaptive_avg_pool1d (graph_module )
61+ self ._annotate (graph_module )
5362 graph_module .recompile ()
5463 return PassResult (graph_module , True )
0 commit comments