Skip to content

Commit 0792840

Browse files
Improve accuracy for models using shuffle, unshuffle, cat ops (#19159)
Summary: Replace the Qualcomm concat observer path with an explicit same-domain-or-requantize model for `aten.cat`. Preserve shared qparams for `pixel_shuffle` and `pixel_unshuffle`, extend `split_with_sizes_copy` coverage, and add regressions for mismatched `cat` branches plus value-preserving ops that must use `SharedQuantizationSpec`. Differential Revision: D102626539
1 parent bf8abb6 commit 0792840

3 files changed

Lines changed: 202 additions & 8 deletions

File tree

backends/qualcomm/quantizer/annotators/htp_rules.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
260260
}
261261
),
262262
)
263+
263264
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
264265
input_qspec_map=input_qspec_map,
265266
output_qspec=output_qspec,
@@ -295,6 +296,7 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
295296
@register_annotator(
296297
[
297298
torch.ops.aten.split_with_sizes.default,
299+
torch.ops.aten.split_with_sizes_copy.default,
298300
torch.ops.aten.split.Tensor,
299301
torch.ops.aten.chunk.default,
300302
],
@@ -1203,14 +1205,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
12031205
[torch.ops.aten.pixel_shuffle.default], QnnConstants.OpDepthToSpace.op_name
12041206
)
12051207
class PixelShuffle(GeneralOpDef):
1206-
pass
1208+
@staticmethod
1209+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
1210+
annotate_in_out_obs_sharing_op(node, quantization_config)
1211+
if not _is_annotated([node]):
1212+
annotate_single_in_share_out(node, quantization_config)
12071213

12081214

12091215
@register_annotator(
12101216
[torch.ops.aten.pixel_unshuffle.default], QnnConstants.OpSpaceToDepth.op_name
12111217
)
12121218
class PixelUnshuffle(GeneralOpDef):
1213-
pass
1219+
@staticmethod
1220+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
1221+
annotate_in_out_obs_sharing_op(node, quantization_config)
1222+
if not _is_annotated([node]):
1223+
annotate_single_in_share_out(node, quantization_config)
12141224

12151225

12161226
@register_annotator(

backends/qualcomm/quantizer/annotators/lpai_rules.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
223223
@register_annotator(
224224
[
225225
torch.ops.aten.split_with_sizes.default,
226+
torch.ops.aten.split_with_sizes_copy.default,
226227
torch.ops.aten.split.Tensor,
227228
torch.ops.aten.chunk.default,
228229
],
@@ -705,14 +706,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
705706
[torch.ops.aten.pixel_shuffle.default], QnnConstants.OpDepthToSpace.op_name
706707
)
707708
class PixelShuffle(GeneralOpDef):
708-
pass
709+
@staticmethod
710+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
711+
annotate_in_out_obs_sharing_op(node, quantization_config)
712+
if not _is_annotated([node]):
713+
annotate_single_in_share_out(node, quantization_config)
709714

710715

711716
@register_annotator(
712717
[torch.ops.aten.pixel_unshuffle.default], QnnConstants.OpSpaceToDepth.op_name
713718
)
714719
class PixelUnshuffle(GeneralOpDef):
715-
pass
720+
@staticmethod
721+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
722+
annotate_in_out_obs_sharing_op(node, quantization_config)
723+
if not _is_annotated([node]):
724+
annotate_single_in_share_out(node, quantization_config)
716725

717726

718727
@register_annotator(

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 179 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import itertools
99
import json
1010
import logging
11+
import operator
1112
import subprocess
1213
import sys
1314
import tempfile
@@ -33,6 +34,7 @@
3334
make_quantizer,
3435
setup_common_args_and_variables,
3536
)
37+
from executorch.backends.qualcomm.quantizer.rules import Q_ANNOTATION_KEY
3638
from executorch.backends.qualcomm.serialization.qc_schema import (
3739
QnnExecuTorchBackendType,
3840
QnnExecuTorchHtpPerformanceMode,
@@ -97,6 +99,7 @@
9799
from executorch.examples.models.wav2letter import Wav2LetterModel
98100
from executorch.exir import to_edge
99101
from executorch.exir.backend.backend_api import disable_validation
102+
from torchao.quantization.pt2e.quantizer import SharedQuantizationSpec
100103

101104

102105
class TestQNNFloatingPointOperator(TestQNN):
@@ -1730,12 +1733,16 @@ def test_qnn_backend_permute(self):
17301733

17311734
def test_qnn_backend_pixel_shuffle(self):
17321735
module = PixelShuffle(2) # noqa: F405
1733-
sample_input = (torch.ones([2, 4, 3, 3]),)
1736+
sample_input = (
1737+
torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3),
1738+
)
17341739
self.lower_module_and_test_output(module, sample_input)
17351740

17361741
def test_qnn_backend_pixel_unshuffle(self):
17371742
module = PixelUnshuffle(2) # noqa: F405
1738-
sample_input = (torch.ones([2, 2, 6, 6]),)
1743+
sample_input = (
1744+
torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6),
1745+
)
17391746
self.lower_module_and_test_output(module, sample_input)
17401747

17411748
def test_qnn_backend_pow_tensor_scalar(self):
@@ -4302,16 +4309,184 @@ def test_qnn_backend_permute(self):
43024309

43034310
def test_qnn_backend_pixel_shuffle(self):
43044311
module = PixelShuffle(2) # noqa: F405
4305-
sample_input = (torch.ones([2, 4, 3, 3]),)
4312+
sample_input = (
4313+
torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3),
4314+
)
43064315
module = self.get_qdq_module(module, sample_input)
43074316
self.lower_module_and_test_output(module, sample_input)
43084317

43094318
def test_qnn_backend_pixel_unshuffle(self):
43104319
module = PixelUnshuffle(2) # noqa: F405
4311-
sample_input = (torch.ones([2, 2, 6, 6]),)
4320+
sample_input = (
4321+
torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6),
4322+
)
43124323
module = self.get_qdq_module(module, sample_input)
43134324
self.lower_module_and_test_output(module, sample_input)
43144325

4326+
def _prepare_module_for_qparam_assertions(self, module, sample_input):
4327+
backend = get_backend_type(self.backend)
4328+
quantizer = make_quantizer(
4329+
quant_dtype=QuantDtype.use_8a8w,
4330+
custom_annotations=(),
4331+
per_channel_conv=True,
4332+
per_channel_linear=False,
4333+
per_channel_embedding=False,
4334+
backend=backend,
4335+
soc_model=self.soc_model,
4336+
)
4337+
return prepare_pt2e(
4338+
torch.export.export(module, sample_input, strict=True).module(),
4339+
quantizer,
4340+
)
4341+
4342+
def _assert_prepared_nodes_share_qparams(
4343+
self, module, sample_input, target_tokens
4344+
) -> list[torch.fx.Node]:
4345+
prepared = self._prepare_module_for_qparam_assertions(module, sample_input)
4346+
matching_nodes = [
4347+
node
4348+
for node in prepared.graph.nodes
4349+
if node.op == "call_function"
4350+
and any(target_token in str(node.target) for target_token in target_tokens)
4351+
]
4352+
4353+
self.assertGreater(
4354+
len(matching_nodes),
4355+
0,
4356+
f"Failed to find node matching any of {target_tokens}",
4357+
)
4358+
for node in matching_nodes:
4359+
self.assertIsInstance(
4360+
node.meta[Q_ANNOTATION_KEY].output_qspec,
4361+
SharedQuantizationSpec,
4362+
)
4363+
4364+
return matching_nodes
4365+
4366+
def test_qnn_backend_pixel_shuffle_unshuffle_share_qparams(self):
4367+
test_cases = [
4368+
(
4369+
"pixel_shuffle",
4370+
PixelShuffle(2), # noqa: F405
4371+
(torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3),),
4372+
torch.ops.aten.pixel_shuffle.default,
4373+
),
4374+
(
4375+
"pixel_unshuffle",
4376+
PixelUnshuffle(2), # noqa: F405
4377+
(torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6),),
4378+
torch.ops.aten.pixel_unshuffle.default,
4379+
),
4380+
]
4381+
4382+
for name, module, sample_input, target in test_cases:
4383+
with self.subTest(name=name):
4384+
prepared = self._prepare_module_for_qparam_assertions(
4385+
module, sample_input
4386+
)
4387+
for node in prepared.graph.nodes:
4388+
if node.op == "call_function" and node.target == target:
4389+
self.assertIsInstance(
4390+
node.meta[Q_ANNOTATION_KEY].output_qspec,
4391+
SharedQuantizationSpec,
4392+
)
4393+
break
4394+
else:
4395+
self.fail(f"Failed to find {target} in prepared graph")
4396+
4397+
def test_qnn_backend_value_preserving_ops_share_qparams(self):
4398+
test_cases = [
4399+
(
4400+
"channel_shuffle",
4401+
ChannelShuffle(2), # noqa: F405
4402+
(torch.randn(1, 4, 3, 3),),
4403+
("aten.channel_shuffle",),
4404+
),
4405+
(
4406+
"permute",
4407+
Permute([0, 2, 3, 1]), # noqa: F405
4408+
(torch.randn(2, 3, 4, 5),),
4409+
("aten.permute",),
4410+
),
4411+
(
4412+
"pixel_shuffle",
4413+
PixelShuffle(2), # noqa: F405
4414+
(torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3),),
4415+
("aten.pixel_shuffle",),
4416+
),
4417+
(
4418+
"pixel_unshuffle",
4419+
PixelUnshuffle(2), # noqa: F405
4420+
(torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6),),
4421+
("aten.pixel_unshuffle",),
4422+
),
4423+
(
4424+
"repeat",
4425+
Repeat(), # noqa: F405
4426+
(torch.randn(2, 2, 2, 2),),
4427+
("aten.repeat",),
4428+
),
4429+
(
4430+
"expand_as",
4431+
ExpandAs(), # noqa: F405
4432+
(torch.randn(3, 4),),
4433+
("aten.expand",),
4434+
),
4435+
(
4436+
"reshape",
4437+
Reshape(), # noqa: F405
4438+
(torch.randn(3, 4),),
4439+
("aten.reshape", "aten.view"),
4440+
),
4441+
]
4442+
4443+
for name, module, sample_input, target_tokens in test_cases:
4444+
with self.subTest(name=name):
4445+
self._assert_prepared_nodes_share_qparams(
4446+
module, sample_input, target_tokens
4447+
)
4448+
4449+
def test_qnn_backend_split_with_sizes_copy_share_qparams(self):
4450+
class SplitWithSizesCopy(torch.nn.Module):
4451+
def forward(self, x):
4452+
out = torch.ops.aten.split_with_sizes_copy.default(x, [2, 2], 1)
4453+
return out[0] + out[1]
4454+
4455+
backend = get_backend_type(self.backend)
4456+
sample_input = (
4457+
torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3),
4458+
)
4459+
quantizer = make_quantizer(
4460+
quant_dtype=QuantDtype.use_8a8w,
4461+
custom_annotations=(),
4462+
per_channel_conv=True,
4463+
per_channel_linear=False,
4464+
per_channel_embedding=False,
4465+
backend=backend,
4466+
soc_model=self.soc_model,
4467+
)
4468+
prepared = prepare_pt2e(
4469+
torch.export.export(
4470+
SplitWithSizesCopy(), sample_input, strict=True
4471+
).module(),
4472+
quantizer,
4473+
)
4474+
4475+
getitem_count = 0
4476+
for node in prepared.graph.nodes:
4477+
if (
4478+
node.op == "call_function"
4479+
and node.target == operator.getitem
4480+
and node.args[0].target == torch.ops.aten.split_with_sizes_copy.default
4481+
):
4482+
self.assertIsInstance(
4483+
node.meta[Q_ANNOTATION_KEY].output_qspec,
4484+
SharedQuantizationSpec,
4485+
)
4486+
getitem_count += 1
4487+
4488+
self.assertGreater(getitem_count, 0)
4489+
43154490
def test_qnn_backend_pow_tensor_scalar(self):
43164491
test_comb = [
43174492
{

0 commit comments

Comments
 (0)