Skip to content

Commit ccb13dd

Browse files
Arm backend: Format docs in backends/arm/test/passes (#17376)
Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent 2478798 commit ccb13dd

20 files changed

Lines changed: 100 additions & 120 deletions

backends/arm/test/passes/test_convert_expand_copy_to_repeat.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -16,9 +16,7 @@
1616

1717

1818
class Expand(torch.nn.Module):
19-
"""
20-
Basic expand model using torch.Tensor.expand function
21-
"""
19+
"""Basic expand model using torch.Tensor.expand function."""
2220

2321
def __init__(self) -> None:
2422
super().__init__()

backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -361,11 +361,13 @@ def test_convert_int64_const_ops_to_int32_tosa_FP_full(
361361
def test_convert_int64_const_ops_to_int32_tosa_INT_full(
362362
test_data: input_t2,
363363
) -> None:
364-
"""
365-
For INT profile, _lifted_tensor_constant0 is still int64 after applying ConvertInt64ConstOpsToInt32Pass().
364+
"""For INT profile, _lifted_tensor_constant0 is still int64 after applying
365+
ConvertInt64ConstOpsToInt32Pass().
366+
366367
And an int64->int32 cast is inserted at the beginning of the graph.
367368
TODO: Explore why _lifted_tensor_constant0 is handled in different ways in FP and INT profile.
368369
Find a way to optimize out the int64->int32 cast.
370+
369371
"""
370372
module = FullIncrementViewMulXLessThanY()
371373
aten_ops_checks = [

backends/arm/test/passes/test_convert_split_to_slice.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -22,9 +22,7 @@ def get_inputs(self) -> input_t: ...
2222

2323

2424
class Split(torch.nn.Module):
25-
"""
26-
Basic split model using torch.split function
27-
"""
25+
"""Basic split model using torch.split function."""
2826

2927
def get_inputs(self) -> input_t:
3028
return (torch.rand(10),)
@@ -34,9 +32,7 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
3432

3533

3634
class SplitTensor(torch.nn.Module):
37-
"""
38-
Basic split model using torch.Tensor.split function
39-
"""
35+
"""Basic split model using torch.Tensor.split function."""
4036

4137
def get_inputs(self) -> input_t:
4238
return (torch.rand(10),)

backends/arm/test/passes/test_decompose_avg_pool2d_pass.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -20,9 +20,7 @@ def get_inputs(self) -> input_t: ...
2020

2121

2222
class AvgPool2dWithStride(torch.nn.Module):
23-
"""
24-
avg_pool2d model with explicit stride parameter
25-
"""
23+
"""avg_pool2d model with explicit stride parameter."""
2624

2725
def get_inputs(self) -> input_t:
2826
return (torch.rand(1, 3, 8, 8),)
@@ -32,8 +30,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3230

3331

3432
class AvgPool2dWithoutStride(torch.nn.Module):
35-
"""
36-
avg_pool2d model without stride parameter (should default to kernel_size)
33+
"""avg_pool2d model without stride parameter (should default to
34+
kernel_size)
3735
"""
3836

3937
def get_inputs(self) -> input_t:
@@ -44,9 +42,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4442

4543

4644
class AvgPool2dListKernel(torch.nn.Module):
47-
"""
48-
avg_pool2d model with list kernel_size and no stride
49-
"""
45+
"""avg_pool2d model with list kernel_size and no stride."""
5046

5147
def get_inputs(self) -> input_t:
5248
return (torch.rand(1, 3, 8, 8),)
@@ -64,7 +60,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6460

6561
@common.parametrize("module", modules)
6662
def test_decompose_avg_pool2d_tosa_FP(module: ModuleWithInputs) -> None:
67-
"""Test that DecomposeAvgPool2d pass works correctly with and without stride parameters."""
63+
"""Test that DecomposeAvgPool2d pass works correctly with and without stride
64+
parameters.
65+
"""
6866
nn_module = cast(torch.nn.Module, module)
6967
pipeline = PassPipeline[input_t](
7068
nn_module,

backends/arm/test/passes/test_decompose_div_pass.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -20,9 +20,7 @@ def get_inputs(self) -> input_t: ...
2020

2121

2222
class Div(torch.nn.Module):
23-
"""
24-
Basic div model using torch.div
25-
"""
23+
"""Basic div model using torch.div."""
2624

2725
def get_inputs(self) -> input_t:
2826
return (torch.rand(10),)
@@ -32,9 +30,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3230

3331

3432
class DivTensor(torch.nn.Module):
35-
"""
36-
Basic div model using torch.Tensor.div
37-
"""
33+
"""Basic div model using torch.Tensor.div."""
3834

3935
def get_inputs(self) -> input_t:
4036
return (torch.rand(10),)

backends/arm/test/passes/test_decompose_int_pow_pass.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -23,9 +23,7 @@ def get_inputs(self) -> input_t: ...
2323

2424

2525
class Square(torch.nn.Module):
26-
"""
27-
Basic squaring
28-
"""
26+
"""Basic squaring."""
2927

3028
def forward(self, x: torch.Tensor) -> torch.Tensor:
3129
return x.square()
@@ -35,9 +33,7 @@ def get_inputs(self) -> input_t:
3533

3634

3735
class Pow(torch.nn.Module):
38-
"""
39-
Basic squaring
40-
"""
36+
"""Basic squaring."""
4137

4238
def __init__(self, exponent: int) -> None:
4339
super().__init__()

backends/arm/test/passes/test_decompose_layernorm_pass.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -16,9 +16,7 @@
1616

1717

1818
class LayerNorm(torch.nn.Module):
19-
"""
20-
Basic layer_norm model using torch.nn.layer_norm layer
21-
"""
19+
"""Basic layer_norm model using torch.nn.layer_norm layer."""
2220

2321
def __init__(self):
2422
super(LayerNorm, self).__init__()

backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -64,12 +64,14 @@ def get_inputs(self) -> input_t:
6464

6565
@common.parametrize("module", modules)
6666
def test_decompose_linalg_vector_norm_tosa_INT(module: ModuleWithInputs) -> None:
67-
"""
68-
This test creates a PassPipeline that applies the DecomposeLinalgVectorNormPass.
67+
"""This test creates a PassPipeline that applies the
68+
DecomposeLinalgVectorNormPass.
69+
6970
The expected primitive ops vary depending on the norm order:
7071
- p == 1: should decompose to ABS and SUM.
7172
- p == 2 (default): should decompose to MUL, SUM, and SQRT.
7273
- Other p: should decompose to ABS, two instances of POW, and SUM.
74+
7375
"""
7476
ord_val = module.ord if module.ord is not None else 2.0
7577

backends/arm/test/passes/test_decompose_meandim_pass.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -27,9 +27,7 @@ def get_inputs(self) -> input_t: ...
2727

2828

2929
class MeanDim(torch.nn.Module):
30-
"""
31-
Basic mean model using torch.mean with keepdim = True
32-
"""
30+
"""Basic mean model using torch.mean with keepdim = True."""
3331

3432
ops_after_pass = u55_ops_after_pass = {
3533
"torch.ops.aten.sum.dim_IntList": 2,
@@ -53,9 +51,7 @@ def get_inputs(self) -> input_t:
5351

5452

5553
class MeanDimTensor(torch.nn.Module):
56-
"""
57-
Basic mean model using torch.Tensor.mean with keepdim = False
58-
"""
54+
"""Basic mean model using torch.Tensor.mean with keepdim = False."""
5955

6056
ops_after_pass = {
6157
"torch.ops.aten.sum.dim_IntList": 2,

backends/arm/test/passes/test_decompose_softmax_pass.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -14,9 +14,7 @@
1414

1515

1616
class Softmax(torch.nn.Module):
17-
"""
18-
Basic torch.nn.softmax layer model
19-
"""
17+
"""Basic torch.nn.softmax layer model."""
2018

2119
def __init__(self):
2220
super(Softmax, self).__init__()
@@ -31,9 +29,7 @@ def get_inputs(self) -> input_t:
3129

3230

3331
class SoftmaxLog(torch.nn.Module):
34-
"""
35-
Basic torch.nn.log_softmax layer model
36-
"""
32+
"""Basic torch.nn.log_softmax layer model."""
3733

3834
def __init__(self):
3935
super(SoftmaxLog, self).__init__()

0 commit comments

Comments
 (0)