Skip to content

Commit 7a2e230

Browse files
psiddhclaude
andauthored
Cortex-M Backend: Add tiny model tests for nn.Modules, nn.functional,… (#18297)
… and torch patterns Add 21 new test cases across 3 files that exercise the Cortex-M quantizer and pass manager on small composite models. These mirror the Arm backend's test_nn_modules/test_nn_functional/test_torch_functions pattern but target the Cortex-M pipeline. Tests cover: ConvBnReLU, LinearReLU, ConvTranspose2d, AdaptiveAvgPool2d, MaxPool2d, AvgPool2d, Softmax, Hardswish, Hardsigmoid, depthwise separable conv, inverted residual blocks (MobileNet-style), and multi-op functional compositions. All tests use test_dialect() which runs quantize→export→to_edge→ run_passes→compare_outputs entirely on the host (no FVP needed). Co-authored-by: Claude <noreply@anthropic.com>
1 parent ca2a616 commit 7a2e230

3 files changed

Lines changed: 472 additions & 0 deletions

File tree

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""Tests popular torch.nn and torch.nn.functional operations through the Cortex-M pipeline.
9+
10+
Mirrors the Arm backend test_nn_functional.py (PR #18225) but runs through
11+
the Cortex-M quantizer and pass manager. Tests operations that commonly
12+
appear in real models but aren't individually covered in ops/.
13+
14+
Functions tested:
15+
- relu
16+
- relu6
17+
- avg_pool2d
18+
- max_pool2d
19+
- pad (constant)
20+
- hardtanh
21+
- BatchNorm1d (module)
22+
- linear
23+
"""
24+
25+
import torch
26+
from executorch.backends.arm.test.common import parametrize
27+
from executorch.backends.cortex_m.test.tester import (
28+
CortexMTester,
29+
McuTestCase,
30+
ramp_tensor,
31+
)
32+
33+
torch.manual_seed(0)
34+
35+
36+
class PadModule(torch.nn.Module):
37+
def forward(self, x):
38+
return torch.nn.functional.pad(x, (1, 1, 1, 1), mode="constant", value=0)
39+
40+
41+
class LinearBnModule(torch.nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
self.weight = torch.nn.Parameter(torch.randn(8, 16))
45+
self.bn = torch.nn.BatchNorm1d(8)
46+
47+
def forward(self, x):
48+
return self.bn(torch.nn.functional.linear(x, self.weight))
49+
50+
51+
class ConvReluMaxpool(torch.nn.Module):
52+
def __init__(self):
53+
super().__init__()
54+
self.conv = torch.nn.Conv2d(3, 8, 3, padding=1, bias=False)
55+
56+
def forward(self, x):
57+
x = self.conv(x)
58+
x = torch.nn.functional.relu(x)
59+
x = torch.nn.functional.max_pool2d(x, kernel_size=2, stride=2)
60+
return x
61+
62+
63+
class ConvRelu6Avgpool(torch.nn.Module):
64+
def __init__(self):
65+
super().__init__()
66+
self.conv = torch.nn.Conv2d(3, 8, 3, padding=1, bias=False)
67+
68+
def forward(self, x):
69+
x = self.conv(x)
70+
x = torch.nn.functional.relu6(x)
71+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
72+
return x
73+
74+
75+
class ConvHardtanhModule(torch.nn.Module):
76+
def __init__(self):
77+
super().__init__()
78+
self.conv = torch.nn.Conv2d(3, 4, 3, padding=1, bias=False)
79+
80+
def forward(self, x):
81+
return torch.nn.functional.hardtanh(self.conv(x), min_val=-1.0, max_val=1.0)
82+
83+
84+
test_cases = {
85+
"conv_relu_maxpool": McuTestCase(
86+
model=ConvReluMaxpool(),
87+
example_inputs=(
88+
ramp_tensor(-1, 1, (1, 3, 8, 8)).to(memory_format=torch.channels_last),
89+
),
90+
),
91+
"conv_relu6_avgpool": McuTestCase(
92+
model=ConvRelu6Avgpool(),
93+
example_inputs=(
94+
ramp_tensor(-1, 1, (1, 3, 8, 8)).to(memory_format=torch.channels_last),
95+
),
96+
),
97+
"conv_hardtanh": McuTestCase(
98+
model=ConvHardtanhModule(),
99+
example_inputs=(
100+
ramp_tensor(-1, 1, (1, 3, 8, 8)).to(memory_format=torch.channels_last),
101+
),
102+
),
103+
"pad_constant": McuTestCase(
104+
model=PadModule(),
105+
example_inputs=(
106+
ramp_tensor(-1, 1, (1, 4, 6, 6)).to(memory_format=torch.channels_last),
107+
),
108+
),
109+
"linear_bn": McuTestCase(
110+
model=LinearBnModule(),
111+
example_inputs=(ramp_tensor(-1, 1, (2, 16)),),
112+
),
113+
}
114+
115+
116+
@parametrize("test_case", test_cases)
117+
def test_dialect_nn_functional(test_case):
118+
tester = CortexMTester(test_case.model, test_case.example_inputs)
119+
tester.test_dialect({}, {}, qtol=1)
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""Tests popular nn.Module classes through the Cortex-M pipeline.
9+
10+
Mirrors the Arm backend test_nn_modules.py (PR #18225) but runs through
11+
the Cortex-M quantizer and pass manager instead of TOSA/Ethos-U delegation.
12+
13+
Modules tested:
14+
- Conv2d + BatchNorm2d + ReLU
15+
- Linear + ReLU
16+
- Conv2d + Add + ReLU
17+
- AdaptiveAvgPool2d
18+
- ConvTranspose2d
19+
- Hardswish
20+
- Hardsigmoid
21+
- MaxPool2d
22+
- AvgPool2d
23+
- Softmax
24+
"""
25+
26+
import torch
27+
from executorch.backends.arm.test.common import parametrize
28+
from executorch.backends.cortex_m.test.tester import (
29+
CortexMTester,
30+
McuTestCase,
31+
ramp_tensor,
32+
)
33+
34+
torch.manual_seed(0)
35+
36+
37+
class ConvBnReLU(torch.nn.Module):
38+
def __init__(self):
39+
super().__init__()
40+
self.conv = torch.nn.Conv2d(3, 8, 3, padding=1, bias=False)
41+
self.bn = torch.nn.BatchNorm2d(8)
42+
self.relu = torch.nn.ReLU()
43+
44+
def forward(self, x):
45+
return self.relu(self.bn(self.conv(x)))
46+
47+
48+
class LinearReLU(torch.nn.Module):
49+
def __init__(self):
50+
super().__init__()
51+
self.linear = torch.nn.Linear(16, 8, bias=True)
52+
self.relu = torch.nn.ReLU()
53+
54+
def forward(self, x):
55+
return self.relu(self.linear(x))
56+
57+
58+
class ConvTranspose2dModule(torch.nn.Module):
59+
def __init__(self):
60+
super().__init__()
61+
self.conv_t = torch.nn.ConvTranspose2d(4, 2, 3, stride=2, padding=1)
62+
63+
def forward(self, x):
64+
return self.conv_t(x)
65+
66+
67+
class AdaptiveAvgPool2dModule(torch.nn.Module):
68+
def __init__(self):
69+
super().__init__()
70+
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
71+
72+
def forward(self, x):
73+
return self.pool(x)
74+
75+
76+
class MaxPool2dModule(torch.nn.Module):
77+
def __init__(self):
78+
super().__init__()
79+
self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
80+
81+
def forward(self, x):
82+
return self.pool(x)
83+
84+
85+
class AvgPool2dModule(torch.nn.Module):
86+
def __init__(self):
87+
super().__init__()
88+
self.pool = torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
89+
90+
def forward(self, x):
91+
return self.pool(x)
92+
93+
94+
class SoftmaxModule(torch.nn.Module):
95+
def __init__(self):
96+
super().__init__()
97+
self.linear = torch.nn.Linear(8, 4, bias=False)
98+
99+
def forward(self, x):
100+
return torch.softmax(self.linear(x), dim=-1)
101+
102+
103+
class HardswishModule(torch.nn.Module):
104+
def __init__(self):
105+
super().__init__()
106+
self.conv = torch.nn.Conv2d(2, 4, 1, bias=False)
107+
self.act = torch.nn.Hardswish()
108+
109+
def forward(self, x):
110+
return self.act(self.conv(x))
111+
112+
113+
class HardsigmoidModule(torch.nn.Module):
114+
def __init__(self):
115+
super().__init__()
116+
self.linear = torch.nn.Linear(8, 4, bias=False)
117+
self.act = torch.nn.Hardsigmoid()
118+
119+
def forward(self, x):
120+
return self.act(self.linear(x))
121+
122+
123+
class ConvAddReLU(torch.nn.Module):
124+
def __init__(self):
125+
super().__init__()
126+
self.conv1 = torch.nn.Conv2d(3, 8, 3, padding=1, bias=False)
127+
self.conv2 = torch.nn.Conv2d(3, 8, 3, padding=1, bias=False)
128+
self.relu = torch.nn.ReLU()
129+
130+
def forward(self, x):
131+
return self.relu(self.conv1(x) + self.conv2(x))
132+
133+
134+
test_cases = {
135+
"conv_bn_relu": McuTestCase(
136+
model=ConvBnReLU(),
137+
example_inputs=(
138+
ramp_tensor(-1, 1, (1, 3, 8, 8)).to(memory_format=torch.channels_last),
139+
),
140+
),
141+
"linear_relu": McuTestCase(
142+
model=LinearReLU(),
143+
example_inputs=(ramp_tensor(-1, 1, (1, 16)),),
144+
),
145+
"conv_transpose2d": McuTestCase(
146+
model=ConvTranspose2dModule(),
147+
example_inputs=(
148+
ramp_tensor(-1, 1, (1, 4, 4, 4)).to(memory_format=torch.channels_last),
149+
),
150+
),
151+
"adaptive_avg_pool2d": McuTestCase(
152+
model=AdaptiveAvgPool2dModule(),
153+
example_inputs=(
154+
ramp_tensor(-1, 1, (1, 3, 8, 8)).to(memory_format=torch.channels_last),
155+
),
156+
),
157+
"max_pool2d": McuTestCase(
158+
model=MaxPool2dModule(),
159+
example_inputs=(
160+
ramp_tensor(-1, 1, (1, 4, 8, 8)).to(memory_format=torch.channels_last),
161+
),
162+
),
163+
"avg_pool2d": McuTestCase(
164+
model=AvgPool2dModule(),
165+
example_inputs=(
166+
ramp_tensor(-1, 1, (1, 4, 8, 8)).to(memory_format=torch.channels_last),
167+
),
168+
),
169+
"softmax": McuTestCase(
170+
model=SoftmaxModule(),
171+
example_inputs=(ramp_tensor(-1, 1, (1, 8)),),
172+
),
173+
"hardswish": McuTestCase(
174+
model=HardswishModule(),
175+
example_inputs=(
176+
ramp_tensor(-3, 3, (1, 2, 4, 4)).to(memory_format=torch.channels_last),
177+
),
178+
),
179+
"hardsigmoid": McuTestCase(
180+
model=HardsigmoidModule(),
181+
example_inputs=(ramp_tensor(-4, 4, (1, 8)),),
182+
),
183+
"conv_add_relu": McuTestCase(
184+
model=ConvAddReLU(),
185+
example_inputs=(
186+
ramp_tensor(-1, 1, (1, 3, 8, 8)).to(memory_format=torch.channels_last),
187+
),
188+
),
189+
}
190+
191+
xfails = {
192+
"conv_add_relu": "Activation fusion does not support relu after add",
193+
}
194+
195+
196+
@parametrize("test_case", test_cases, xfails=xfails, strict=False)
197+
def test_dialect_nn_modules(test_case):
198+
tester = CortexMTester(test_case.model, test_case.example_inputs)
199+
tester.test_dialect({}, {}, qtol=2)

0 commit comments

Comments
 (0)