99from executorch .backends .arm ._passes .rewrite_avg_pool2d_pass import RewriteAvgPool2dPass
1010from executorch .backends .arm .test import common
1111from executorch .backends .arm .test .tester .test_pipeline import PassPipeline
12+ from executorch .backends .test .harness .stages import StageType
13+ from executorch .exir .dialects ._ops import ops as exir_ops
1214
1315input_t = Tuple [torch .Tensor ]
1416
@@ -41,6 +43,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4143 return torch .nn .functional .avg_pool2d (x , kernel_size = [2 , 3 ])
4244
4345
46+ class AvgPool2dScalarPadding (torch .nn .Module ):
47+ def get_inputs (self ) -> input_t :
48+ return (torch .rand (1 , 3 , 8 , 8 ),)
49+
50+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
51+ return torch .nn .functional .avg_pool2d (x , kernel_size = 3 , stride = 2 , padding = 1 )
52+
53+
54+ class AvgPool2dWithEmptyStride (torch .nn .Module ):
55+ def get_inputs (self ) -> input_t :
56+ return (torch .rand (1 , 3 , 8 , 8 ),)
57+
58+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
59+ return torch .nn .functional .avg_pool2d (x , kernel_size = [2 , 3 ], stride = [])
60+
61+
4462modules : Dict [str , ModuleWithInputs ] = {
4563 "avg_pool2d_with_stride" : AvgPool2dWithStride (),
4664 "avg_pool2d_without_stride" : AvgPool2dWithoutStride (),
@@ -67,3 +85,42 @@ def test_rewrite_avg_pool2d_tosa(module: ModuleWithInputs) -> None:
6785 "run_method_and_compare_outputs"
6886 ) # Cannot run aten graph with tosa dialect ops
6987 pipeline .run ()
88+
89+
90+ def _get_tosa_avg_pool2d_node (
91+ pipeline : PassPipeline [input_t ],
92+ ) -> torch .fx .Node :
93+ exported_program = pipeline .tester .get_artifact (
94+ StageType .RUN_PASSES
95+ ).exported_program ()
96+ graph_module = exported_program .graph_module
97+
98+ tosa_nodes = [
99+ node
100+ for node in graph_module .graph .nodes
101+ if node .op == "call_function"
102+ and node .target == exir_ops .backend .tosa .AVG_POOL2D .default
103+ ]
104+ assert len (tosa_nodes ) == 1
105+ return tosa_nodes [0 ]
106+
107+
108+ def test_rewrite_avg_pool2d_tosa_empty_stride_uses_kernel_size () -> None :
109+ module = AvgPool2dWithEmptyStride ()
110+ pipeline = PassPipeline [input_t ](
111+ module ,
112+ module .get_inputs (),
113+ ops_before_pass = {
114+ "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default" : 1 ,
115+ },
116+ ops_after_pass = {
117+ "executorch_exir_dialects_backend__ops_tosa_AVG_POOL2D_default" : 1 ,
118+ "executorch_exir_dialects_edge__ops_aten_permute_copy_default" : 2 ,
119+ },
120+ pass_list = [RewriteAvgPool2dPass ],
121+ )
122+ pipeline .pop_stage ("run_method_and_compare_outputs" )
123+ pipeline .run ()
124+
125+ tosa_node = _get_tosa_avg_pool2d_node (pipeline )
126+ assert tosa_node .args [4 ] == [2 , 3 ]
0 commit comments