Skip to content

Commit 5ec362a

Browse files
authored
Qualcomm AI Engine Direct - Fix MobilenetV3 and Stack Layout Transform (pytorch#16686)
### Summary - Resolve Mainline Issue: pytorch#16616 (comment) - Fix MobileNetV3 Accuracy Issue - Support drawing pydot graph so we can draw LLM models. SVG will stuck when drawing LLM models. ### Test plan UT added.
1 parent 879f4a7 commit 5ec362a

9 files changed

Lines changed: 146 additions & 65 deletions

File tree

backends/qualcomm/_passes/layout_transform.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,8 @@ class LayoutTransform(ExportPass):
124124
exir_ops.edge.aten.sqrt.default,
125125
exir_ops.edge.aten.sub.Tensor,
126126
exir_ops.edge.aten.sum.dim_IntList,
127-
exir_ops.edge.aten.stack.default,
128127
exir_ops.edge.aten.topk.default,
129128
exir_ops.edge.aten._to_copy.default,
130-
exir_ops.edge.aten.unbind.int,
131129
exir_ops.edge.aten.where.self,
132130
_operator.getitem,
133131
torch.ops.aten.scalar_tensor.default,

backends/qualcomm/_passes/seq_mse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def _make_operator(self, aten_op):
5656
groups = 1 if len(aten_op.args) < 7 else aten_op.args[6]
5757
has_bias = self.nominal_bias is not None
5858
module = torch.nn.Conv2d(
59-
in_channels=self.nominal_weight.shape[1],
59+
in_channels=self.nominal_weight.shape[1]
60+
* groups, # equivalent to input_tensor.shape[1]
6061
out_channels=self.nominal_weight.shape[0],
6162
kernel_size=self.nominal_weight.shape[-2:],
6263
stride=stride,

backends/qualcomm/builders/op_stack.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import numpy as np
1111
import torch
12-
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
1313

1414
from .node_visitor import NodeVisitor
1515
from .node_visitor_manager import register_node_visitor
@@ -50,11 +50,10 @@ def define_node(
5050
)
5151
stack_output_tensors = [output_tensor_wrapper]
5252

53+
# Don't need to check axis_order since stack is a pytorch layout op according to layout transform.
5354
dim = 0 if len(node.args) == 1 else cast(int, node.args[1])
5455
if dim < 0:
5556
dim = dim % len(output_tensor.shape)
56-
if QCOM_AXIS_ORDER in node.meta:
57-
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
5857
stack_op = PyQnnManager.PyQnnOpWrapper(
5958
node.name,
6059
QNN_OP_PACKAGE_NAME_QTI_AISW,

backends/qualcomm/builders/op_unbind.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import numpy as np
1111
import torch
12-
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
1313

1414
from .node_visitor import NodeVisitor
1515
from .node_visitor_manager import register_node_visitor
@@ -52,11 +52,10 @@ def define_node(
5252
)
5353
unbind_output_tensors.append(output_tensor_wrapper)
5454

55+
# Don't need to check axis_order since unbind is a pytorch layout op according to layout transform.
5556
dim = 0 if len(node.args) == 1 else cast(int, node.args[1])
5657
if dim < 0:
5758
dim = dim % len(input_tensor.shape)
58-
if QCOM_AXIS_ORDER in node.meta:
59-
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
6059
unbind_op = PyQnnManager.PyQnnOpWrapper(
6160
node.name,
6261
QNN_OP_PACKAGE_NAME_QTI_AISW,

backends/qualcomm/quantizer/annotators.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,9 @@ def annotate_abs(node: Node, quantization_config: QuantizationConfig) -> None:
400400

401401
@register_annotator(
402402
[
403-
torch.torch.ops.aten.arange.default,
404-
torch.torch.ops.aten.arange.start,
405-
torch.torch.ops.aten.arange.start_step,
403+
torch.ops.aten.arange.default,
404+
torch.ops.aten.arange.start,
405+
torch.ops.aten.arange.start_step,
406406
]
407407
)
408408
def annotate_arange(node: Node, quantization_config: QuantizationConfig) -> None:
@@ -586,13 +586,6 @@ def annotate_hardswish(node: Node, quantization_config: QuantizationConfig) -> N
586586
annotate_single_in_single_out(node, quantization_config)
587587

588588

589-
@register_annotator(
590-
[torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardsigmoid_.default]
591-
)
592-
def annotate_hardsigmoid(node: Node, quantization_config: QuantizationConfig) -> None:
593-
annotate_single_in_single_out(node, quantization_config)
594-
595-
596589
@register_annotator([torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default])
597590
def annotate_hardtanh(node: Node, quantization_config: QuantizationConfig) -> None:
598591
annotate_single_in_single_out(node, quantization_config)
@@ -871,7 +864,14 @@ def annotate_rsqrt(node: Node, quantization_config: QuantizationConfig) -> None:
871864
annotate_single_in_single_out(node, quantization_config)
872865

873866

874-
@register_annotator([torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default])
867+
@register_annotator(
868+
[
869+
torch.ops.aten.hardsigmoid.default,
870+
torch.ops.aten.hardsigmoid_.default,
871+
torch.ops.aten.sigmoid,
872+
torch.ops.aten.sigmoid.default,
873+
]
874+
)
875875
def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> None:
876876
if _is_annotated([node]):
877877
return
@@ -896,7 +896,7 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non
896896

897897
scale = 1 / (q_max - q_min + 1)
898898

899-
bias_obs_ctr = observer = FixedQParamsObserver.with_args(
899+
output_obs_ctr = observer = FixedQParamsObserver.with_args(
900900
scale=scale,
901901
zero_point=0,
902902
dtype=quantization_config.output_activation.dtype,
@@ -908,7 +908,7 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non
908908
get_8a8w_qnn_qat_config(),
909909
get_16a4w_qnn_qat_config(),
910910
):
911-
bias_obs_ctr = FixedQParamsFakeQuantize.with_args(
911+
output_obs_ctr = FixedQParamsFakeQuantize.with_args(
912912
observer=observer,
913913
scale=scale,
914914
zero_point=0,
@@ -923,7 +923,7 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non
923923
dtype=quantization_config.output_activation.dtype,
924924
quant_max=q_max,
925925
quant_min=q_min,
926-
observer_or_fake_quant_ctr=bias_obs_ctr,
926+
observer_or_fake_quant_ctr=output_obs_ctr,
927927
qscheme=torch.torch.per_tensor_affine,
928928
)
929929

backends/qualcomm/tests/models.py

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,48 @@ def forward(self, x, y):
598598
return z
599599

600600

601+
class Conv2dDownUpSample(torch.nn.Module):
602+
def __init__(self, bias=True):
603+
super().__init__()
604+
self.conv = torch.nn.Conv2d(
605+
in_channels=16,
606+
out_channels=16,
607+
kernel_size=3,
608+
stride=2,
609+
padding=1,
610+
bias=bias,
611+
)
612+
self.conv_transpose = torch.nn.ConvTranspose2d(
613+
in_channels=16,
614+
out_channels=16,
615+
kernel_size=3,
616+
stride=2,
617+
padding=1,
618+
bias=bias,
619+
)
620+
621+
def forward(self, x):
622+
return self.conv_transpose(self.conv(x))
623+
624+
625+
class Conv2dFlip(torch.nn.Module):
626+
def __init__(self):
627+
super().__init__()
628+
self.conv = torch.nn.Conv2d(
629+
in_channels=16,
630+
out_channels=16,
631+
kernel_size=3,
632+
stride=2,
633+
padding=1,
634+
bias=False,
635+
)
636+
self.dims = [1, 3]
637+
638+
def forward(self, x):
639+
x = self.conv(x)
640+
return torch.flip(x, self.dims)
641+
642+
601643
class Conv2dMaxPool2d(torch.nn.Module):
602644
def __init__(self):
603645
super().__init__()
@@ -660,46 +702,14 @@ def forward(self, x):
660702
return self.conv(x)
661703

662704

663-
class Conv2dDownUpSample(torch.nn.Module):
664-
def __init__(self, bias=True):
665-
super().__init__()
666-
self.conv = torch.nn.Conv2d(
667-
in_channels=16,
668-
out_channels=16,
669-
kernel_size=3,
670-
stride=2,
671-
padding=1,
672-
bias=bias,
673-
)
674-
self.conv_transpose = torch.nn.ConvTranspose2d(
675-
in_channels=16,
676-
out_channels=16,
677-
kernel_size=3,
678-
stride=2,
679-
padding=1,
680-
bias=bias,
681-
)
682-
683-
def forward(self, x):
684-
return self.conv_transpose(self.conv(x))
685-
686-
687-
class Conv2dFlip(torch.nn.Module):
705+
class Conv2dStack(torch.nn.Module):
688706
def __init__(self):
689707
super().__init__()
690-
self.conv = torch.nn.Conv2d(
691-
in_channels=16,
692-
out_channels=16,
693-
kernel_size=3,
694-
stride=2,
695-
padding=1,
696-
bias=False,
697-
)
698-
self.dims = [1, 3]
708+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
699709

700-
def forward(self, x):
701-
x = self.conv(x)
702-
return torch.flip(x, self.dims)
710+
def forward(self, x, y, z):
711+
x1 = self.conv1(x)
712+
return torch.stack((x1, y, z))
703713

704714

705715
class Conv2dSliceCopy(torch.nn.Module):
@@ -744,6 +754,16 @@ def forward(self, x):
744754
return topk_values
745755

746756

757+
class Conv2dUnbind(torch.nn.Module):
758+
def __init__(self):
759+
super().__init__()
760+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
761+
762+
def forward(self, x):
763+
x1 = self.conv1(x)
764+
return torch.unbind(x1, dim=1)
765+
766+
747767
class Conv3dSequential(torch.nn.Module):
748768
def __init__(self, bias=True):
749769
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1982,6 +1982,15 @@ def test_qnn_backend_conv2d_slice_copy(self):
19821982
sample_input = (torch.randn([2, 1, 3, 3]),)
19831983
self.lower_module_and_test_output(module, sample_input)
19841984

1985+
def test_qnn_backend_conv2d_stack(self):
1986+
module = Conv2dStack() # noqa: F405
1987+
sample_input = (
1988+
torch.randn(1, 3, 5, 5),
1989+
torch.randn(1, 3, 3, 3),
1990+
torch.randn(1, 3, 3, 3),
1991+
)
1992+
self.lower_module_and_test_output(module, sample_input)
1993+
19851994
def test_qnn_backend_conv2d_sum_reduce_dim(self):
19861995
module = Conv2dSumReduceDim() # noqa: F405
19871996
sample_input = (torch.randn([1, 1, 3, 3]),)
@@ -1992,6 +2001,14 @@ def test_qnn_backend_conv2d_topk(self):
19922001
sample_input = (torch.randn(1, 3, 32, 32),)
19932002
self.lower_module_and_test_output(module, sample_input)
19942003

2004+
# This test is to ensure unbind should be pytorch layout.
2005+
# However, unbind will be forced decomposed by executorch framework.
2006+
# Keep it here in case unbind doesn't get forced decomposed in future.
2007+
def test_qnn_backend_conv2d_unbind(self):
2008+
module = Conv2dUnbind() # noqa: F405
2009+
sample_input = (torch.randn(1, 3, 5, 5),)
2010+
self.lower_module_and_test_output(module, sample_input)
2011+
19952012
def test_qnn_backend_copy(self):
19962013
sample_inputs = [
19972014
(torch.randn(3, 4, 5),),
@@ -4365,6 +4382,16 @@ def test_qnn_backend_conv2d_slice_copy(self):
43654382
module = self.get_qdq_module(module, sample_input)
43664383
self.lower_module_and_test_output(module, sample_input)
43674384

4385+
def test_qnn_backend_conv2d_stack(self):
4386+
module = Conv2dStack() # noqa: F405
4387+
sample_input = (
4388+
torch.randn(1, 3, 5, 5),
4389+
torch.randn(1, 3, 3, 3),
4390+
torch.randn(1, 3, 3, 3),
4391+
)
4392+
module = self.get_qdq_module(module, sample_input)
4393+
self.lower_module_and_test_output(module, sample_input)
4394+
43684395
def test_qnn_backend_conv2d_sum_reduce_dim(self):
43694396
module = Conv2dSumReduceDim() # noqa: F405
43704397
sample_input = (torch.randn([1, 1, 3, 3]),)
@@ -4377,6 +4404,15 @@ def test_qnn_backend_conv2d_topk(self):
43774404
module = self.get_qdq_module(module, sample_input)
43784405
self.lower_module_and_test_output(module, sample_input)
43794406

4407+
# This test is to ensure unbind should be pytorch layout.
4408+
# However, unbind will be forced decomposed by executorch framework.
4409+
# Keep it here in case unbind doesn't get forced decomposed in future.
4410+
def test_qnn_backend_conv2d_unbind(self):
4411+
module = Conv2dUnbind() # noqa: F405
4412+
sample_input = (torch.randn(1, 3, 5, 5),)
4413+
module = self.get_qdq_module(module, sample_input)
4414+
self.lower_module_and_test_output(module, sample_input)
4415+
43804416
def test_qnn_backend_copy(self):
43814417
sample_inputs = [
43824418
(torch.randn(3, 4, 5),),
@@ -7757,7 +7793,7 @@ def test_mobilenet_v3(self):
77577793
metric = {
77587794
# GPU has accuracy issue now
77597795
QnnExecuTorchBackendType.kGpuBackend: {"top_1": 0, "top_5": 0},
7760-
QnnExecuTorchBackendType.kHtpBackend: {"top_1": 55, "top_5": 81},
7796+
QnnExecuTorchBackendType.kHtpBackend: {"top_1": 51, "top_5": 76},
77617797
}
77627798
self.assertGreaterEqual(
77637799
msg["top_1"], metric[get_backend_type(self.backend)]["top_1"]

backends/qualcomm/utils/utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
import warnings
1010
from collections import defaultdict, OrderedDict
11+
from enum import Enum
1112
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1213

1314
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
@@ -924,10 +925,24 @@ def preprocess_binary(ctx_bin, compiler_specs):
924925
return bundle_prog
925926

926927

927-
def draw_graph(title, path, graph_module: torch.fx.GraphModule):
928+
class DrawFormat(Enum):
929+
SVG = 1
930+
PYDOT = 2
931+
932+
933+
def draw_graph(title, path, graph_module: torch.fx.GraphModule, format=DrawFormat.SVG):
928934
graph = passes.graph_drawer.FxGraphDrawer(graph_module, title)
929-
with open(f"{path}/{title}.svg", "wb") as f:
930-
f.write(graph.get_dot_graph().create_svg())
935+
warnings.warn(
936+
"For large models such as LLM, it is strongly recommended to use PYDOT format.",
937+
stacklevel=1,
938+
)
939+
if format == DrawFormat.SVG:
940+
with open(f"{path}/{title}.svg", "wb") as f:
941+
f.write(graph.get_dot_graph().create_svg())
942+
elif format == DrawFormat.PYDOT:
943+
graph.get_dot_graph().write_raw(f"{path}/{title}.dot")
944+
else:
945+
raise RuntimeError(f"Unknown format {format}.")
931946

932947

933948
def generate_gpu_compiler_spec(

0 commit comments

Comments
 (0)