Skip to content

Commit 931add3

Browse files
committed
Add Python runtime end-to-end tests for XNNPACK on Linux
Export small models (add, linear, conv2d+relu, MLP, depthwise separable conv, classifier head) with XNNPACK delegation and verify correctness through the Python runtime (executorch.runtime.Runtime). This is the Linux XNNPACK portion of the Python runtime test gap described in the issue. Each test exports a model to .pte, loads it via Runtime, and asserts that the output matches PyTorch eager mode. Part of #11225 Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
1 parent ac68932 commit 931add3

1 file changed

Lines changed: 183 additions & 0 deletions

File tree

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""End-to-end tests that export models with XNNPACK delegation and verify
8+
them through the Python runtime.
9+
10+
Covers the export → .pte → Python runtime flow that developers use to
11+
validate exported models before deploying to device. Complements the
12+
existing C++ runner tests in .ci/scripts/test_model.sh.
13+
14+
See https://github.com/pytorch/executorch/issues/11225
15+
"""
16+
17+
import tempfile
18+
import unittest
19+
from pathlib import Path
20+
21+
import torch
22+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
23+
XnnpackPartitioner,
24+
)
25+
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
26+
from executorch.runtime import Runtime, Verification
27+
from torch.export import export
28+
29+
30+
def _export_and_load(model: torch.nn.Module, example_inputs: tuple):
31+
"""Export *model* with XNNPACK, save to a temp .pte, load via Runtime."""
32+
model.eval()
33+
with torch.no_grad():
34+
aten = export(model, example_inputs, strict=True)
35+
edge = to_edge_transform_and_lower(
36+
aten,
37+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
38+
partitioner=[XnnpackPartitioner()],
39+
)
40+
et = edge.to_executorch()
41+
42+
with tempfile.NamedTemporaryFile(suffix=".pte", delete=False) as f:
43+
pte_path = f.name
44+
et.save(pte_path)
45+
46+
rt = Runtime.get()
47+
program = rt.load_program(Path(pte_path), verification=Verification.Minimal)
48+
return program.load_method("forward"), pte_path
49+
50+
51+
class TestPythonRuntimeXNNPACK(unittest.TestCase):
52+
"""Export → .pte → Python Runtime tests for XNNPACK on Linux."""
53+
54+
# ------------------------------------------------------------------
55+
# Simple arithmetic
56+
# ------------------------------------------------------------------
57+
def test_add(self):
58+
class Add(torch.nn.Module):
59+
def forward(self, x, y):
60+
return x + y
61+
62+
model = Add()
63+
inputs = (torch.randn(2, 3), torch.randn(2, 3))
64+
method, _ = _export_and_load(model, inputs)
65+
66+
expected = model(*inputs)
67+
actual = method.execute(inputs)
68+
torch.testing.assert_close(actual[0], expected, atol=1e-4, rtol=1e-4)
69+
70+
# ------------------------------------------------------------------
71+
# Linear layer (fp32)
72+
# ------------------------------------------------------------------
73+
def test_linear(self):
74+
model = torch.nn.Linear(16, 8)
75+
inputs = (torch.randn(4, 16),)
76+
method, _ = _export_and_load(model, inputs)
77+
78+
with torch.no_grad():
79+
expected = model(*inputs)
80+
actual = method.execute(inputs)
81+
torch.testing.assert_close(actual[0], expected, atol=1e-4, rtol=1e-4)
82+
83+
# ------------------------------------------------------------------
84+
# Conv2d + ReLU (common vision pattern)
85+
# ------------------------------------------------------------------
86+
def test_conv2d_relu(self):
87+
model = torch.nn.Sequential(
88+
torch.nn.Conv2d(3, 16, 3, padding=1),
89+
torch.nn.ReLU(),
90+
)
91+
inputs = (torch.randn(1, 3, 8, 8),)
92+
method, _ = _export_and_load(model, inputs)
93+
94+
with torch.no_grad():
95+
expected = model(*inputs)
96+
actual = method.execute(inputs)
97+
torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3)
98+
99+
# ------------------------------------------------------------------
100+
# Small MLP (multiple linear + activation)
101+
# ------------------------------------------------------------------
102+
def test_mlp(self):
103+
model = torch.nn.Sequential(
104+
torch.nn.Linear(32, 64),
105+
torch.nn.ReLU(),
106+
torch.nn.Linear(64, 10),
107+
)
108+
inputs = (torch.randn(2, 32),)
109+
method, _ = _export_and_load(model, inputs)
110+
111+
with torch.no_grad():
112+
expected = model(*inputs)
113+
actual = method.execute(inputs)
114+
torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3)
115+
116+
# ------------------------------------------------------------------
117+
# BatchNorm + Conv (common in MobileNet-style models)
118+
# Skipped: FuseBatchNormPass crashes on Sequential(Conv2d, BN) export.
119+
# TODO(#11225): re-enable once the XNNPACK pass is fixed.
120+
# ------------------------------------------------------------------
121+
@unittest.skip("FuseBatchNormPass bug in XNNPACK backend")
122+
def test_conv_bn(self):
123+
model = torch.nn.Sequential(
124+
torch.nn.Conv2d(3, 16, 3, padding=1),
125+
torch.nn.BatchNorm2d(16),
126+
torch.nn.ReLU(),
127+
)
128+
model.eval()
129+
inputs = (torch.randn(1, 3, 8, 8),)
130+
method, _ = _export_and_load(model, inputs)
131+
132+
with torch.no_grad():
133+
expected = model(*inputs)
134+
actual = method.execute(inputs)
135+
torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3)
136+
137+
# ------------------------------------------------------------------
138+
# Depthwise separable conv (MobileNet building block)
139+
# ------------------------------------------------------------------
140+
def test_depthwise_separable_conv(self):
141+
model = torch.nn.Sequential(
142+
# Depthwise
143+
torch.nn.Conv2d(16, 16, 3, padding=1, groups=16),
144+
torch.nn.ReLU(),
145+
# Pointwise
146+
torch.nn.Conv2d(16, 32, 1),
147+
torch.nn.ReLU(),
148+
)
149+
inputs = (torch.randn(1, 16, 8, 8),)
150+
method, _ = _export_and_load(model, inputs)
151+
152+
with torch.no_grad():
153+
expected = model(*inputs)
154+
actual = method.execute(inputs)
155+
torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3)
156+
157+
# ------------------------------------------------------------------
158+
# Avgpool + Flatten + Linear (classifier head)
159+
# ------------------------------------------------------------------
160+
def test_classifier_head(self):
161+
class ClassifierHead(torch.nn.Module):
162+
def __init__(self):
163+
super().__init__()
164+
self.pool = torch.nn.AdaptiveAvgPool2d(1)
165+
self.fc = torch.nn.Linear(32, 10)
166+
167+
def forward(self, x):
168+
x = self.pool(x)
169+
x = x.flatten(1)
170+
return self.fc(x)
171+
172+
model = ClassifierHead()
173+
inputs = (torch.randn(1, 32, 8, 8),)
174+
method, _ = _export_and_load(model, inputs)
175+
176+
with torch.no_grad():
177+
expected = model(*inputs)
178+
actual = method.execute(inputs)
179+
torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3)
180+
181+
182+
if __name__ == "__main__":
183+
unittest.main()

0 commit comments

Comments
 (0)