1- # Copyright 2025 Arm Limited and/or its affiliates.
1+ # Copyright 2025-2026 Arm Limited and/or its affiliates.
22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
@@ -20,9 +20,7 @@ def get_inputs(self) -> input_t: ...
2020
2121
2222class AvgPool2dWithStride (torch .nn .Module ):
23- """
24- avg_pool2d model with explicit stride parameter
25- """
23+ """avg_pool2d model with explicit stride parameter."""
2624
2725 def get_inputs (self ) -> input_t :
2826 return (torch .rand (1 , 3 , 8 , 8 ),)
@@ -32,8 +30,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3230
3331
3432class AvgPool2dWithoutStride (torch .nn .Module ):
35- """
36- avg_pool2d model without stride parameter (should default to kernel_size)
33+ """avg_pool2d model without stride parameter (should default to
34+ kernel_size)
3735 """
3836
3937 def get_inputs (self ) -> input_t :
@@ -44,9 +42,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4442
4543
4644class AvgPool2dListKernel (torch .nn .Module ):
47- """
48- avg_pool2d model with list kernel_size and no stride
49- """
45+ """avg_pool2d model with list kernel_size and no stride."""
5046
5147 def get_inputs (self ) -> input_t :
5248 return (torch .rand (1 , 3 , 8 , 8 ),)
@@ -64,7 +60,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6460
6561@common .parametrize ("module" , modules )
6662def test_decompose_avg_pool2d_tosa_FP (module : ModuleWithInputs ) -> None :
67- """Test that DecomposeAvgPool2d pass works correctly with and without stride parameters."""
63+ """Test that DecomposeAvgPool2d pass works correctly with and without stride
64+ parameters.
65+ """
6866 nn_module = cast (torch .nn .Module , module )
6967 pipeline = PassPipeline [input_t ](
7068 nn_module ,
0 commit comments