|
7 | 7 | from typing import Dict |
8 | 8 |
|
9 | 9 | import torch |
| 10 | + |
10 | 11 | from executorch.backends.arm.quantizer import ( |
11 | 12 | get_symmetric_a16w8_quantization_config, |
12 | 13 | get_symmetric_quantization_config, |
|
16 | 17 | from executorch.backends.arm.test import common |
17 | 18 | from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline |
18 | 19 | from executorch.backends.arm.tosa import TosaSpecification |
| 20 | +from executorch.backends.test.harness.stages import StageType |
| 21 | +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY |
19 | 22 | from torchvision import models, transforms # type: ignore[import-untyped] |
20 | 23 | from torchvision.ops.misc import Conv2dNormActivation # type: ignore[import-untyped] |
21 | 24 |
|
22 | 25 |
|
23 | | -def get_quantizer(): |
| 26 | +def get_quantizer(use_composable_quantizer: bool = False): |
24 | 27 | tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") |
25 | | - quantizer = TOSAQuantizer(tosa_spec) |
| 28 | + quantizer = TOSAQuantizer( |
| 29 | + tosa_spec, use_composable_quantizer=use_composable_quantizer |
| 30 | + ) |
26 | 31 | quantizer.set_global(get_symmetric_quantization_config()) |
27 | 32 | return quantizer |
28 | 33 |
|
@@ -53,6 +58,25 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
53 | 58 | return x + y |
54 | 59 |
|
55 | 60 |
|
| 61 | +class Cat(torch.nn.Module): |
| 62 | + |
| 63 | + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 64 | + return torch.cat((x, y), dim=1) |
| 65 | + |
| 66 | + |
| 67 | +class LinearGraphTail(torch.nn.Module): |
| 68 | + |
| 69 | + def __init__(self): |
| 70 | + super().__init__() |
| 71 | + self.linear = torch.nn.Linear(10, 10) |
| 72 | + |
| 73 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 74 | + x = self.linear(x) |
| 75 | + x = torch.relu(x) |
| 76 | + x = torch.sigmoid(x) |
| 77 | + return torch.neg(x) |
| 78 | + |
| 79 | + |
56 | 80 | class AddSoftmaxAdd(torch.nn.Module): |
57 | 81 | module_names = {"add_0": None, "add_1": None} |
58 | 82 | module_types = { |
@@ -131,6 +155,75 @@ def test_selective_quant_module_type_tosa_INT(model): |
131 | 155 | pipeline.run() |
132 | 156 |
|
133 | 157 |
|
| 158 | +def test_selective_quant_cat_node_target_none_tosa_INT(): |
| 159 | + model = Cat() |
| 160 | + inputs = (torch.randn(1, 2, 4), torch.randn(1, 3, 4)) |
| 161 | + |
| 162 | + quantizer = get_quantizer(use_composable_quantizer=True) |
| 163 | + quantizer.set_node_target(torch.ops.aten.cat.default, None) |
| 164 | + |
| 165 | + pipeline = QuantizationPipeline[tuple[torch.Tensor, torch.Tensor]]( |
| 166 | + model, |
| 167 | + inputs, |
| 168 | + quantizer=quantizer, |
| 169 | + qspecs={ |
| 170 | + "aten.cat.default": { |
| 171 | + None: 1, |
| 172 | + }, |
| 173 | + }, |
| 174 | + ) |
| 175 | + |
| 176 | + pipeline.run() |
| 177 | + |
| 178 | + |
| 179 | +def test_composable_io_none_skips_global_tosa_INT(): |
| 180 | + model = Add() |
| 181 | + inputs = (torch.randn(1, 10), torch.randn(1, 10)) |
| 182 | + |
| 183 | + quantizer = get_quantizer(use_composable_quantizer=True) |
| 184 | + quantizer.set_io(None) |
| 185 | + |
| 186 | + pipeline = QuantizationPipeline[tuple[torch.Tensor, torch.Tensor]]( |
| 187 | + model, |
| 188 | + inputs, |
| 189 | + quantizer=quantizer, |
| 190 | + input_qspecs={None: 2}, |
| 191 | + output_qspecs={None: 1}, |
| 192 | + ) |
| 193 | + |
| 194 | + pipeline.run() |
| 195 | + |
| 196 | + |
| 197 | +def test_composable_global_none_linear_graph_tail_tosa_INT(): |
| 198 | + model = LinearGraphTail() |
| 199 | + inputs = (torch.randn(1, 10),) |
| 200 | + |
| 201 | + quantizer = get_quantizer(use_composable_quantizer=True) |
| 202 | + quantizer.set_global(None) |
| 203 | + |
| 204 | + pipeline = QuantizationPipeline[tuple[torch.Tensor]]( |
| 205 | + model, |
| 206 | + inputs, |
| 207 | + quantizer=quantizer, |
| 208 | + qspecs={ |
| 209 | + "aten.linear.default": {None: 1}, |
| 210 | + "aten.relu.default": {None: 1}, |
| 211 | + "aten.sigmoid.default": {None: 1}, |
| 212 | + "aten.neg.default": {None: 1}, |
| 213 | + }, |
| 214 | + ) |
| 215 | + |
| 216 | + pipeline.run() |
| 217 | + |
| 218 | + graph = pipeline.tester.get_graph(StageType.QUANTIZE) |
| 219 | + unannotated_nodes = [ |
| 220 | + node.name |
| 221 | + for node in graph.nodes |
| 222 | + if node.op == "call_function" and Q_ANNOTATION_KEY not in node.meta |
| 223 | + ] |
| 224 | + assert not unannotated_nodes |
| 225 | + |
| 226 | + |
134 | 227 | mv3 = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights) |
135 | 228 | mv3.eval() |
136 | 229 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
0 commit comments