|
8 | 8 | import itertools |
9 | 9 | import json |
10 | 10 | import logging |
| 11 | +import operator |
11 | 12 | import subprocess |
12 | 13 | import sys |
13 | 14 | import tempfile |
|
33 | 34 | make_quantizer, |
34 | 35 | setup_common_args_and_variables, |
35 | 36 | ) |
| 37 | +from executorch.backends.qualcomm.quantizer.rules import Q_ANNOTATION_KEY |
36 | 38 | from executorch.backends.qualcomm.serialization.qc_schema import ( |
37 | 39 | QnnExecuTorchBackendType, |
38 | 40 | QnnExecuTorchHtpPerformanceMode, |
|
97 | 99 | from executorch.examples.models.wav2letter import Wav2LetterModel |
98 | 100 | from executorch.exir import to_edge |
99 | 101 | from executorch.exir.backend.backend_api import disable_validation |
| 102 | +from torchao.quantization.pt2e.quantizer import SharedQuantizationSpec |
100 | 103 |
|
101 | 104 |
|
102 | 105 | class TestQNNFloatingPointOperator(TestQNN): |
@@ -1730,12 +1733,16 @@ def test_qnn_backend_permute(self): |
1730 | 1733 |
|
1731 | 1734 | def test_qnn_backend_pixel_shuffle(self): |
1732 | 1735 | module = PixelShuffle(2) # noqa: F405 |
1733 | | - sample_input = (torch.ones([2, 4, 3, 3]),) |
| 1736 | + sample_input = ( |
| 1737 | + torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3), |
| 1738 | + ) |
1734 | 1739 | self.lower_module_and_test_output(module, sample_input) |
1735 | 1740 |
|
1736 | 1741 | def test_qnn_backend_pixel_unshuffle(self): |
1737 | 1742 | module = PixelUnshuffle(2) # noqa: F405 |
1738 | | - sample_input = (torch.ones([2, 2, 6, 6]),) |
| 1743 | + sample_input = ( |
| 1744 | + torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6), |
| 1745 | + ) |
1739 | 1746 | self.lower_module_and_test_output(module, sample_input) |
1740 | 1747 |
|
1741 | 1748 | def test_qnn_backend_pow_tensor_scalar(self): |
@@ -4302,16 +4309,184 @@ def test_qnn_backend_permute(self): |
4302 | 4309 |
|
4303 | 4310 | def test_qnn_backend_pixel_shuffle(self): |
4304 | 4311 | module = PixelShuffle(2) # noqa: F405 |
4305 | | - sample_input = (torch.ones([2, 4, 3, 3]),) |
| 4312 | + sample_input = ( |
| 4313 | + torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3), |
| 4314 | + ) |
4306 | 4315 | module = self.get_qdq_module(module, sample_input) |
4307 | 4316 | self.lower_module_and_test_output(module, sample_input) |
4308 | 4317 |
|
4309 | 4318 | def test_qnn_backend_pixel_unshuffle(self): |
4310 | 4319 | module = PixelUnshuffle(2) # noqa: F405 |
4311 | | - sample_input = (torch.ones([2, 2, 6, 6]),) |
| 4320 | + sample_input = ( |
| 4321 | + torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6), |
| 4322 | + ) |
4312 | 4323 | module = self.get_qdq_module(module, sample_input) |
4313 | 4324 | self.lower_module_and_test_output(module, sample_input) |
4314 | 4325 |
|
| 4326 | + def _prepare_module_for_qparam_assertions(self, module, sample_input): |
| 4327 | + backend = get_backend_type(self.backend) |
| 4328 | + quantizer = make_quantizer( |
| 4329 | + quant_dtype=QuantDtype.use_8a8w, |
| 4330 | + custom_annotations=(), |
| 4331 | + per_channel_conv=True, |
| 4332 | + per_channel_linear=False, |
| 4333 | + per_channel_embedding=False, |
| 4334 | + backend=backend, |
| 4335 | + soc_model=self.soc_model, |
| 4336 | + ) |
| 4337 | + return prepare_pt2e( |
| 4338 | + torch.export.export(module, sample_input, strict=True).module(), |
| 4339 | + quantizer, |
| 4340 | + ) |
| 4341 | + |
| 4342 | + def _assert_prepared_nodes_share_qparams( |
| 4343 | + self, module, sample_input, target_tokens |
| 4344 | + ) -> list[torch.fx.Node]: |
| 4345 | + prepared = self._prepare_module_for_qparam_assertions(module, sample_input) |
| 4346 | + matching_nodes = [ |
| 4347 | + node |
| 4348 | + for node in prepared.graph.nodes |
| 4349 | + if node.op == "call_function" |
| 4350 | + and any(target_token in str(node.target) for target_token in target_tokens) |
| 4351 | + ] |
| 4352 | + |
| 4353 | + self.assertGreater( |
| 4354 | + len(matching_nodes), |
| 4355 | + 0, |
| 4356 | + f"Failed to find node matching any of {target_tokens}", |
| 4357 | + ) |
| 4358 | + for node in matching_nodes: |
| 4359 | + self.assertIsInstance( |
| 4360 | + node.meta[Q_ANNOTATION_KEY].output_qspec, |
| 4361 | + SharedQuantizationSpec, |
| 4362 | + ) |
| 4363 | + |
| 4364 | + return matching_nodes |
| 4365 | + |
| 4366 | + def test_qnn_backend_pixel_shuffle_unshuffle_share_qparams(self): |
| 4367 | + test_cases = [ |
| 4368 | + ( |
| 4369 | + "pixel_shuffle", |
| 4370 | + PixelShuffle(2), # noqa: F405 |
| 4371 | + (torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3),), |
| 4372 | + torch.ops.aten.pixel_shuffle.default, |
| 4373 | + ), |
| 4374 | + ( |
| 4375 | + "pixel_unshuffle", |
| 4376 | + PixelUnshuffle(2), # noqa: F405 |
| 4377 | + (torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6),), |
| 4378 | + torch.ops.aten.pixel_unshuffle.default, |
| 4379 | + ), |
| 4380 | + ] |
| 4381 | + |
| 4382 | + for name, module, sample_input, target in test_cases: |
| 4383 | + with self.subTest(name=name): |
| 4384 | + prepared = self._prepare_module_for_qparam_assertions( |
| 4385 | + module, sample_input |
| 4386 | + ) |
| 4387 | + for node in prepared.graph.nodes: |
| 4388 | + if node.op == "call_function" and node.target == target: |
| 4389 | + self.assertIsInstance( |
| 4390 | + node.meta[Q_ANNOTATION_KEY].output_qspec, |
| 4391 | + SharedQuantizationSpec, |
| 4392 | + ) |
| 4393 | + break |
| 4394 | + else: |
| 4395 | + self.fail(f"Failed to find {target} in prepared graph") |
| 4396 | + |
| 4397 | + def test_qnn_backend_value_preserving_ops_share_qparams(self): |
| 4398 | + test_cases = [ |
| 4399 | + ( |
| 4400 | + "channel_shuffle", |
| 4401 | + ChannelShuffle(2), # noqa: F405 |
| 4402 | + (torch.randn(1, 4, 3, 3),), |
| 4403 | + ("aten.channel_shuffle",), |
| 4404 | + ), |
| 4405 | + ( |
| 4406 | + "permute", |
| 4407 | + Permute([0, 2, 3, 1]), # noqa: F405 |
| 4408 | + (torch.randn(2, 3, 4, 5),), |
| 4409 | + ("aten.permute",), |
| 4410 | + ), |
| 4411 | + ( |
| 4412 | + "pixel_shuffle", |
| 4413 | + PixelShuffle(2), # noqa: F405 |
| 4414 | + (torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3),), |
| 4415 | + ("aten.pixel_shuffle",), |
| 4416 | + ), |
| 4417 | + ( |
| 4418 | + "pixel_unshuffle", |
| 4419 | + PixelUnshuffle(2), # noqa: F405 |
| 4420 | + (torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6),), |
| 4421 | + ("aten.pixel_unshuffle",), |
| 4422 | + ), |
| 4423 | + ( |
| 4424 | + "repeat", |
| 4425 | + Repeat(), # noqa: F405 |
| 4426 | + (torch.randn(2, 2, 2, 2),), |
| 4427 | + ("aten.repeat",), |
| 4428 | + ), |
| 4429 | + ( |
| 4430 | + "expand_as", |
| 4431 | + ExpandAs(), # noqa: F405 |
| 4432 | + (torch.randn(3, 4),), |
| 4433 | + ("aten.expand",), |
| 4434 | + ), |
| 4435 | + ( |
| 4436 | + "reshape", |
| 4437 | + Reshape(), # noqa: F405 |
| 4438 | + (torch.randn(3, 4),), |
| 4439 | + ("aten.reshape", "aten.view"), |
| 4440 | + ), |
| 4441 | + ] |
| 4442 | + |
| 4443 | + for name, module, sample_input, target_tokens in test_cases: |
| 4444 | + with self.subTest(name=name): |
| 4445 | + self._assert_prepared_nodes_share_qparams( |
| 4446 | + module, sample_input, target_tokens |
| 4447 | + ) |
| 4448 | + |
| 4449 | + def test_qnn_backend_split_with_sizes_copy_share_qparams(self): |
| 4450 | + class SplitWithSizesCopy(torch.nn.Module): |
| 4451 | + def forward(self, x): |
| 4452 | + out = torch.ops.aten.split_with_sizes_copy.default(x, [2, 2], 1) |
| 4453 | + return out[0] + out[1] |
| 4454 | + |
| 4455 | + backend = get_backend_type(self.backend) |
| 4456 | + sample_input = ( |
| 4457 | + torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3), |
| 4458 | + ) |
| 4459 | + quantizer = make_quantizer( |
| 4460 | + quant_dtype=QuantDtype.use_8a8w, |
| 4461 | + custom_annotations=(), |
| 4462 | + per_channel_conv=True, |
| 4463 | + per_channel_linear=False, |
| 4464 | + per_channel_embedding=False, |
| 4465 | + backend=backend, |
| 4466 | + soc_model=self.soc_model, |
| 4467 | + ) |
| 4468 | + prepared = prepare_pt2e( |
| 4469 | + torch.export.export( |
| 4470 | + SplitWithSizesCopy(), sample_input, strict=True |
| 4471 | + ).module(), |
| 4472 | + quantizer, |
| 4473 | + ) |
| 4474 | + |
| 4475 | + getitem_count = 0 |
| 4476 | + for node in prepared.graph.nodes: |
| 4477 | + if ( |
| 4478 | + node.op == "call_function" |
| 4479 | + and node.target == operator.getitem |
| 4480 | + and node.args[0].target == torch.ops.aten.split_with_sizes_copy.default |
| 4481 | + ): |
| 4482 | + self.assertIsInstance( |
| 4483 | + node.meta[Q_ANNOTATION_KEY].output_qspec, |
| 4484 | + SharedQuantizationSpec, |
| 4485 | + ) |
| 4486 | + getitem_count += 1 |
| 4487 | + |
| 4488 | + self.assertGreater(getitem_count, 0) |
| 4489 | + |
4315 | 4490 | def test_qnn_backend_pow_tensor_scalar(self): |
4316 | 4491 | test_comb = [ |
4317 | 4492 | { |
|
0 commit comments