|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +"""End-to-end tests that export models with XNNPACK delegation and verify |
| 8 | +them through the Python runtime. |
| 9 | +
|
| 10 | +Covers the export → .pte → Python runtime flow that developers use to |
| 11 | +validate exported models before deploying to device. Complements the |
| 12 | +existing C++ runner tests in .ci/scripts/test_model.sh. |
| 13 | +
|
| 14 | +See https://github.com/pytorch/executorch/issues/11225 |
| 15 | +""" |
| 16 | + |
| 17 | +import tempfile |
| 18 | +import unittest |
| 19 | +from pathlib import Path |
| 20 | + |
| 21 | +import torch |
| 22 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( |
| 23 | + XnnpackPartitioner, |
| 24 | +) |
| 25 | +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower |
| 26 | +from executorch.runtime import Runtime, Verification |
| 27 | +from torch.export import export |
| 28 | + |
| 29 | + |
| 30 | +def _export_and_load(model: torch.nn.Module, example_inputs: tuple): |
| 31 | + """Export *model* with XNNPACK, save to a temp .pte, load via Runtime.""" |
| 32 | + model.eval() |
| 33 | + with torch.no_grad(): |
| 34 | + aten = export(model, example_inputs, strict=True) |
| 35 | + edge = to_edge_transform_and_lower( |
| 36 | + aten, |
| 37 | + compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| 38 | + partitioner=[XnnpackPartitioner()], |
| 39 | + ) |
| 40 | + et = edge.to_executorch() |
| 41 | + |
| 42 | + with tempfile.NamedTemporaryFile(suffix=".pte", delete=False) as f: |
| 43 | + pte_path = f.name |
| 44 | + et.save(pte_path) |
| 45 | + |
| 46 | + rt = Runtime.get() |
| 47 | + program = rt.load_program(Path(pte_path), verification=Verification.Minimal) |
| 48 | + return program.load_method("forward"), pte_path |
| 49 | + |
| 50 | + |
| 51 | +class TestPythonRuntimeXNNPACK(unittest.TestCase): |
| 52 | + """Export → .pte → Python Runtime tests for XNNPACK on Linux.""" |
| 53 | + |
| 54 | + # ------------------------------------------------------------------ |
| 55 | + # Simple arithmetic |
| 56 | + # ------------------------------------------------------------------ |
| 57 | + def test_add(self): |
| 58 | + class Add(torch.nn.Module): |
| 59 | + def forward(self, x, y): |
| 60 | + return x + y |
| 61 | + |
| 62 | + model = Add() |
| 63 | + inputs = (torch.randn(2, 3), torch.randn(2, 3)) |
| 64 | + method, _ = _export_and_load(model, inputs) |
| 65 | + |
| 66 | + expected = model(*inputs) |
| 67 | + actual = method.execute(inputs) |
| 68 | + torch.testing.assert_close(actual[0], expected, atol=1e-4, rtol=1e-4) |
| 69 | + |
| 70 | + # ------------------------------------------------------------------ |
| 71 | + # Linear layer (fp32) |
| 72 | + # ------------------------------------------------------------------ |
| 73 | + def test_linear(self): |
| 74 | + model = torch.nn.Linear(16, 8) |
| 75 | + inputs = (torch.randn(4, 16),) |
| 76 | + method, _ = _export_and_load(model, inputs) |
| 77 | + |
| 78 | + with torch.no_grad(): |
| 79 | + expected = model(*inputs) |
| 80 | + actual = method.execute(inputs) |
| 81 | + torch.testing.assert_close(actual[0], expected, atol=1e-4, rtol=1e-4) |
| 82 | + |
| 83 | + # ------------------------------------------------------------------ |
| 84 | + # Conv2d + ReLU (common vision pattern) |
| 85 | + # ------------------------------------------------------------------ |
| 86 | + def test_conv2d_relu(self): |
| 87 | + model = torch.nn.Sequential( |
| 88 | + torch.nn.Conv2d(3, 16, 3, padding=1), |
| 89 | + torch.nn.ReLU(), |
| 90 | + ) |
| 91 | + inputs = (torch.randn(1, 3, 8, 8),) |
| 92 | + method, _ = _export_and_load(model, inputs) |
| 93 | + |
| 94 | + with torch.no_grad(): |
| 95 | + expected = model(*inputs) |
| 96 | + actual = method.execute(inputs) |
| 97 | + torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3) |
| 98 | + |
| 99 | + # ------------------------------------------------------------------ |
| 100 | + # Small MLP (multiple linear + activation) |
| 101 | + # ------------------------------------------------------------------ |
| 102 | + def test_mlp(self): |
| 103 | + model = torch.nn.Sequential( |
| 104 | + torch.nn.Linear(32, 64), |
| 105 | + torch.nn.ReLU(), |
| 106 | + torch.nn.Linear(64, 10), |
| 107 | + ) |
| 108 | + inputs = (torch.randn(2, 32),) |
| 109 | + method, _ = _export_and_load(model, inputs) |
| 110 | + |
| 111 | + with torch.no_grad(): |
| 112 | + expected = model(*inputs) |
| 113 | + actual = method.execute(inputs) |
| 114 | + torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3) |
| 115 | + |
| 116 | + # ------------------------------------------------------------------ |
| 117 | + # BatchNorm + Conv (common in MobileNet-style models) |
| 118 | + # Skipped: FuseBatchNormPass crashes on Sequential(Conv2d, BN) export. |
| 119 | + # TODO(#11225): re-enable once the XNNPACK pass is fixed. |
| 120 | + # ------------------------------------------------------------------ |
| 121 | + @unittest.skip("FuseBatchNormPass bug in XNNPACK backend") |
| 122 | + def test_conv_bn(self): |
| 123 | + model = torch.nn.Sequential( |
| 124 | + torch.nn.Conv2d(3, 16, 3, padding=1), |
| 125 | + torch.nn.BatchNorm2d(16), |
| 126 | + torch.nn.ReLU(), |
| 127 | + ) |
| 128 | + model.eval() |
| 129 | + inputs = (torch.randn(1, 3, 8, 8),) |
| 130 | + method, _ = _export_and_load(model, inputs) |
| 131 | + |
| 132 | + with torch.no_grad(): |
| 133 | + expected = model(*inputs) |
| 134 | + actual = method.execute(inputs) |
| 135 | + torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3) |
| 136 | + |
| 137 | + # ------------------------------------------------------------------ |
| 138 | + # Depthwise separable conv (MobileNet building block) |
| 139 | + # ------------------------------------------------------------------ |
| 140 | + def test_depthwise_separable_conv(self): |
| 141 | + model = torch.nn.Sequential( |
| 142 | + # Depthwise |
| 143 | + torch.nn.Conv2d(16, 16, 3, padding=1, groups=16), |
| 144 | + torch.nn.ReLU(), |
| 145 | + # Pointwise |
| 146 | + torch.nn.Conv2d(16, 32, 1), |
| 147 | + torch.nn.ReLU(), |
| 148 | + ) |
| 149 | + inputs = (torch.randn(1, 16, 8, 8),) |
| 150 | + method, _ = _export_and_load(model, inputs) |
| 151 | + |
| 152 | + with torch.no_grad(): |
| 153 | + expected = model(*inputs) |
| 154 | + actual = method.execute(inputs) |
| 155 | + torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3) |
| 156 | + |
| 157 | + # ------------------------------------------------------------------ |
| 158 | + # Avgpool + Flatten + Linear (classifier head) |
| 159 | + # ------------------------------------------------------------------ |
| 160 | + def test_classifier_head(self): |
| 161 | + class ClassifierHead(torch.nn.Module): |
| 162 | + def __init__(self): |
| 163 | + super().__init__() |
| 164 | + self.pool = torch.nn.AdaptiveAvgPool2d(1) |
| 165 | + self.fc = torch.nn.Linear(32, 10) |
| 166 | + |
| 167 | + def forward(self, x): |
| 168 | + x = self.pool(x) |
| 169 | + x = x.flatten(1) |
| 170 | + return self.fc(x) |
| 171 | + |
| 172 | + model = ClassifierHead() |
| 173 | + inputs = (torch.randn(1, 32, 8, 8),) |
| 174 | + method, _ = _export_and_load(model, inputs) |
| 175 | + |
| 176 | + with torch.no_grad(): |
| 177 | + expected = model(*inputs) |
| 178 | + actual = method.execute(inputs) |
| 179 | + torch.testing.assert_close(actual[0], expected, atol=1e-3, rtol=1e-3) |
| 180 | + |
| 181 | + |
| 182 | +if __name__ == "__main__": |
| 183 | + unittest.main() |
0 commit comments