Skip to content

Commit 225b5eb

Browse files
authored
Add Python runtime end-to-end tests for XNNPACK on Linux (#18703)
## Summary Add Python runtime end-to-end tests for the XNNPACK backend on Linux — the first batch for the testing gap described in #11225. This update keeps the Linux-only coverage under `runtime/test/test_runtime_xnnpack.py` so the standard unittest jobs collect it cleanly. Each test exports a small model with XNNPACK delegation, saves it as `.pte`, loads it via `executorch.runtime.Runtime`, and compares the output against PyTorch eager mode. **Models tested:** - `test_add` — element-wise addition - `test_linear` — single `Linear` layer (fp32) - `test_conv2d_relu` — `Conv2d + ReLU` - `test_mlp` — multi-layer perceptron (`Linear -> ReLU -> Linear`) - `test_depthwise_separable_conv` — depthwise + pointwise conv (MobileNet building block) - `test_classifier_head` — `AdaptiveAvgPool2d + Flatten + Linear` - `test_conv_bn` — `Conv2d + BatchNorm2d + ReLU` (**skipped**: `FuseBatchNormPass` crashes; TODO to re-enable) Part of #11225 (Linux XNNPACK) <details> <summary>Before</summary> ```text $ pull / unittest-editable / linux / linux-job ==================================== ERRORS ==================================== _________ ERROR collecting test/end2end/test_python_runtime_xnnpack.py _________ ImportError while importing test module '/pytorch/executorch/test/end2end/test_python_runtime_xnnpack.py'. Traceback: E ModuleNotFoundError: No module named 'test.end2end' $ lintrunner --skip MYPY test/end2end/test_python_runtime_xnnpack.py Warning (UFMT) format - from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( - XnnpackPartitioner, - ) + from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner ``` </details> <details> <summary>After</summary> ```text $ lintrunner --skip MYPY runtime/test/test_runtime_xnnpack.py ok No lint issues. $ python -m py_compile runtime/test/test_runtime_xnnpack.py ``` </details> ## Test plan - [x] Moved the Linux-only XNNPACK runtime tests to `runtime/test/test_runtime_xnnpack.py` - [x] Added `runtime/test/TARGETS` coverage for the new Python runtime test - [x] `lintrunner --skip MYPY runtime/test/test_runtime_xnnpack.py` - [x] `python -m py_compile runtime/test/test_runtime_xnnpack.py` - [x] `test_conv_bn` remains skipped with TODO — `FuseBatchNormPass` still has a known bug on `Sequential(Conv2d, BN)` export - [ ] Full `python -m pytest runtime/test/test_runtime_xnnpack.py -v` Local blocker: this workstation does not currently have a fully provisioned ExecuTorch runtime test environment; upstream CI remains the source of truth for this suite. cc @GregoryComer @digantdesai @cbilgin @jathu @larryliu0820 > This PR was authored with the help of Claude. --------- Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
1 parent 2eff4f4 commit 225b5eb

2 files changed

Lines changed: 192 additions & 0 deletions

File tree

runtime/test/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,14 @@ runtime.python_test(
2121
"//executorch/devtools/etdump:serialize",
2222
],
2323
)
24+
25+
runtime.python_test(
26+
name = "test_runtime_xnnpack",
27+
srcs = ["test_runtime_xnnpack.py"],
28+
deps = [
29+
"//caffe2:torch",
30+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
31+
"//executorch/exir:lib",
32+
"//executorch/runtime:runtime",
33+
],
34+
)
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Python runtime end-to-end tests for the XNNPACK backend on Linux.
8+
9+
Covers the export -> .pte -> Python runtime flow that developers use to
10+
validate exported models before deploying to device. Placing these tests
11+
under ``runtime/test`` lets the standard unittest jobs collect them cleanly.
12+
13+
See https://github.com/pytorch/executorch/issues/11225
14+
"""
15+
16+
import sys
17+
import tempfile
18+
import unittest
19+
from pathlib import Path
20+
21+
import torch
22+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
23+
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
24+
from executorch.runtime import Runtime, Verification
25+
from torch.export import export
26+
27+
28+
def _export_and_execute(
29+
model: torch.nn.Module, example_inputs: tuple[torch.Tensor, ...]
30+
):
31+
"""Export *model* with XNNPACK, save to a temp .pte, and run it via Runtime."""
32+
with tempfile.TemporaryDirectory() as temp_dir, torch.no_grad():
33+
model.eval()
34+
aten = export(model, example_inputs, strict=True)
35+
edge = to_edge_transform_and_lower(
36+
aten,
37+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
38+
partitioner=[XnnpackPartitioner()],
39+
)
40+
et = edge.to_executorch()
41+
42+
pte_path = Path(temp_dir) / "xnnpack_runtime_test.pte"
43+
et.save(str(pte_path))
44+
45+
runtime = Runtime.get()
46+
program = runtime.load_program(pte_path, verification=Verification.Minimal)
47+
method = program.load_method("forward")
48+
assert method is not None, "forward method should exist in exported program"
49+
return method.execute(example_inputs)
50+
51+
52+
@unittest.skipUnless(
53+
sys.platform == "linux",
54+
"XNNPACK Python runtime end-to-end coverage in this batch targets Linux only",
55+
)
56+
class RuntimeXNNPACKTest(unittest.TestCase):
57+
"""Export → .pte → Python Runtime tests for XNNPACK on Linux."""
58+
59+
# ------------------------------------------------------------------
60+
# Simple arithmetic
61+
# ------------------------------------------------------------------
62+
def test_add(self):
63+
class Add(torch.nn.Module):
64+
def forward(self, x, y):
65+
return x + y
66+
67+
model = Add()
68+
inputs = (torch.randn(2, 3), torch.randn(2, 3))
69+
70+
expected = model(*inputs)
71+
actual = _export_and_execute(model, inputs)
72+
torch.testing.assert_close(actual[0], expected, atol=1e-4, rtol=1e-4)
73+
74+
# ------------------------------------------------------------------
75+
# Linear layer (fp32)
76+
# ------------------------------------------------------------------
77+
def test_linear(self):
78+
model = torch.nn.Linear(16, 8)
79+
inputs = (torch.randn(4, 16),)
80+
81+
with torch.no_grad():
82+
expected = model(*inputs)
83+
actual = _export_and_execute(model, inputs)
84+
torch.testing.assert_close(actual[0], expected, atol=1e-4, rtol=1e-4)
85+
86+
# ------------------------------------------------------------------
87+
# Conv2d + ReLU (common vision pattern)
88+
# ------------------------------------------------------------------
89+
def test_conv2d_relu(self):
90+
model = torch.nn.Sequential(
91+
torch.nn.Conv2d(3, 16, 3, padding=1),
92+
torch.nn.ReLU(),
93+
)
94+
inputs = (torch.randn(1, 3, 8, 8),)
95+
96+
with torch.no_grad():
97+
expected = model(*inputs)
98+
actual = _export_and_execute(model, inputs)
99+
torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3)
100+
101+
# ------------------------------------------------------------------
102+
# Small MLP (multiple linear + activation)
103+
# ------------------------------------------------------------------
104+
def test_mlp(self):
105+
model = torch.nn.Sequential(
106+
torch.nn.Linear(32, 64),
107+
torch.nn.ReLU(),
108+
torch.nn.Linear(64, 10),
109+
)
110+
inputs = (torch.randn(2, 32),)
111+
112+
with torch.no_grad():
113+
expected = model(*inputs)
114+
actual = _export_and_execute(model, inputs)
115+
torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3)
116+
117+
# ------------------------------------------------------------------
118+
# BatchNorm + Conv (common in MobileNet-style models)
119+
# Skipped: FuseBatchNormPass crashes on Sequential(Conv2d, BN) export.
120+
# TODO(#11225): re-enable once the XNNPACK pass is fixed.
121+
# ------------------------------------------------------------------
122+
@unittest.skip("FuseBatchNormPass bug in XNNPACK backend")
123+
def test_conv_bn(self):
124+
model = torch.nn.Sequential(
125+
torch.nn.Conv2d(3, 16, 3, padding=1),
126+
torch.nn.BatchNorm2d(16),
127+
torch.nn.ReLU(),
128+
)
129+
model.eval()
130+
inputs = (torch.randn(1, 3, 8, 8),)
131+
132+
with torch.no_grad():
133+
expected = model(*inputs)
134+
actual = _export_and_execute(model, inputs)
135+
torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3)
136+
137+
# ------------------------------------------------------------------
138+
# Depthwise separable conv (MobileNet building block)
139+
# ------------------------------------------------------------------
140+
def test_depthwise_separable_conv(self):
141+
model = torch.nn.Sequential(
142+
# Depthwise
143+
torch.nn.Conv2d(16, 16, 3, padding=1, groups=16),
144+
torch.nn.ReLU(),
145+
# Pointwise
146+
torch.nn.Conv2d(16, 32, 1),
147+
torch.nn.ReLU(),
148+
)
149+
inputs = (torch.randn(1, 16, 8, 8),)
150+
151+
with torch.no_grad():
152+
expected = model(*inputs)
153+
actual = _export_and_execute(model, inputs)
154+
torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3)
155+
156+
# ------------------------------------------------------------------
157+
# Avgpool + Flatten + Linear (classifier head)
158+
# ------------------------------------------------------------------
159+
def test_classifier_head(self):
160+
class ClassifierHead(torch.nn.Module):
161+
def __init__(self):
162+
super().__init__()
163+
self.pool = torch.nn.AdaptiveAvgPool2d(1)
164+
self.fc = torch.nn.Linear(32, 10)
165+
166+
def forward(self, x):
167+
x = self.pool(x)
168+
x = x.flatten(1)
169+
return self.fc(x)
170+
171+
model = ClassifierHead()
172+
inputs = (torch.randn(1, 32, 8, 8),)
173+
174+
with torch.no_grad():
175+
expected = model(*inputs)
176+
actual = _export_and_execute(model, inputs)
177+
torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3)
178+
179+
180+
if __name__ == "__main__":
181+
unittest.main()

0 commit comments

Comments
 (0)