Skip to content

Commit c9af27e

Browse files
authored
Qualcomm AI Engine Direct - Add fp16a8w quantization config (pytorch#19537)
### Summary: - Add fp16a8w quantization config - Note that fp16a8w is only supported with Conv2d (kernel size = 1) and Linear by QNN HTP - Add a pass `insert_cast_for_fp_act_quantized_weight.py` to cast fp32 -> fp16 due to constraint in QNN HTP - Add a test case to run conv2d and linear with fp16a8w ### Test plan ``` python3 backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_fp16a8w_simple_model -b build-android -H ${HOST} -s ${DEVICE} -m SM8750 -r /path/to/executorch -a /path/to/artifacts ```
1 parent fbc952c commit c9af27e

10 files changed

Lines changed: 639 additions & 197 deletions

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .fuse_consecutive_cast import FuseConsecutiveCast
4545
from .fuse_consecutive_transpose import FuseConsecutiveTranspose
4646
from .i64_to_i32 import I64toI32
47+
from .insert_cast_for_fp_act_quantized_weight import InsertCastForFpActQuantizedWeight
4748
from .insert_io_qdq import InsertIOQDQ
4849
from .insert_requantize import InsertRequantize
4950
from .insert_reshape_for_reduce_ops import InsertReshapeForReduceOps
@@ -102,6 +103,7 @@
102103
FuseConsecutiveCast,
103104
FuseConsecutiveTranspose,
104105
I64toI32,
106+
InsertCastForFpActQuantizedWeight,
105107
InsertIOQDQ,
106108
InsertReshapeForReduceOps,
107109
InsertRequantize,
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
9+
from executorch.backends.qualcomm.builders.utils import is_parameter
10+
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from executorch.exir.passes import dead_code_elimination_pass
14+
15+
from .utils import copy_meta
16+
17+
TARGET_OPS = {
18+
exir_ops.edge.aten.convolution.default,
19+
exir_ops.edge.aten.linear.default,
20+
}
21+
22+
23+
class InsertCastForFpActQuantizedWeight(ExportPass):
24+
"""
25+
Insert fp32↔fp16 casts around conv/linear nodes that have a quantized
26+
weight but a floating-point activation.
27+
28+
Background — QNN vs PyTorch dtype contract:
29+
In PyTorch, a conv/linear with fp32 activation and int8 weight (e.g.
30+
produced by fp16a8w quantization) is valid: the weight is stored as int8
31+
but dequantized to fp32 before the multiply-accumulate. QNN HTP, however,
32+
requires that when the weight is quantized (int8/int4) the activation must
33+
also be fp16, not fp32. Passing an fp32 activation to such an op causes a
34+
QNN compilation error.
35+
36+
Fix:
37+
Wrap the offending node with an fp32→fp16 cast on the input activation and
38+
an fp16→fp32 cast on the output, so the node itself operates in fp16 while
39+
the surrounding graph continues to see fp32 tensors.
40+
41+
Before: [fp32 act] → conv/linear(w=int8) → [fp32 out]
42+
After: [fp32 act] → cast(fp16) → conv/linear(w=int8) → cast(fp32) → [fp32 out]
43+
44+
Pattern matched:
45+
- Node target is in TARGET_OPS (convolution, linear)
46+
- Node has no QCOM_QUANT_ATTRS (activation is not quantized, i.e. fp32)
47+
- Weight arg (args[1]) is a parameter with QCOM_QUANT_ATTRS,
48+
optionally wrapped in a dequantize op
49+
- Input activation dtype is fp32
50+
51+
The bias meta["val"] is also updated to fp16 to stay consistent with the
52+
fp16 compute domain of the node.
53+
"""
54+
55+
def __init__(self, edge_program: torch.export.ExportedProgram):
56+
super().__init__()
57+
self.edge_program = edge_program
58+
59+
def _get_weight_param_node(self, weight: torch.fx.Node):
60+
"""Return the underlying parameter node for a weight, unwrapping a DQ op if present."""
61+
if is_parameter(weight, self.edge_program):
62+
return weight
63+
if weight.target in dq_ops:
64+
param_node = weight.args[0]
65+
if isinstance(param_node, torch.fx.Node) and is_parameter(
66+
param_node, self.edge_program
67+
):
68+
return param_node
69+
return None
70+
71+
def _has_quantized_weight(self, node: torch.fx.Node) -> bool:
72+
if node.target not in TARGET_OPS or len(node.args) < 2:
73+
return False
74+
weight = node.args[1]
75+
if not isinstance(weight, torch.fx.Node):
76+
return False
77+
param_node = self._get_weight_param_node(weight)
78+
return param_node is not None and bool(param_node.meta.get(QCOM_QUANT_ATTRS))
79+
80+
def _insert_fp32_fp16_casts(
81+
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
82+
):
83+
"""Wrap node with cast(fp32→fp16) on input and cast(fp16→fp32) on output."""
84+
input_act = node.args[0]
85+
86+
with graph_module.graph.inserting_before(node):
87+
cast_in = graph_module.graph.create_node(
88+
"call_function",
89+
exir_ops.edge.aten._to_copy.default,
90+
(input_act,),
91+
{"dtype": torch.float16},
92+
)
93+
cast_in.meta = copy_meta(
94+
node.meta,
95+
lambda m: {**m, "val": input_act.meta["val"].to(torch.float16)},
96+
)
97+
node.replace_input_with(input_act, cast_in)
98+
99+
# Update bias meta["val"] to fp16 if present.
100+
if len(node.args) > 2 and node.args[2] is not None:
101+
bias_node = node.args[2]
102+
if isinstance(bias_node, torch.fx.Node) and "val" in bias_node.meta:
103+
if bias_node.meta["val"].dtype == torch.float32:
104+
bias_node.meta["val"] = bias_node.meta["val"].to(torch.float16)
105+
106+
users = list(node.users.keys())
107+
orig_output_val = node.meta["val"]
108+
node.meta["val"] = orig_output_val.to(torch.float16)
109+
110+
with graph_module.graph.inserting_after(node):
111+
cast_out = graph_module.graph.create_node(
112+
"call_function",
113+
exir_ops.edge.aten._to_copy.default,
114+
(node,),
115+
{"dtype": torch.float32},
116+
)
117+
cast_out.meta = copy_meta(
118+
node.meta,
119+
lambda m: {**m, "val": orig_output_val.to(torch.float32)},
120+
)
121+
122+
for user in users:
123+
user.replace_input_with(node, cast_out)
124+
125+
def call(self, graph_module: torch.fx.GraphModule):
126+
for node in list(graph_module.graph.nodes):
127+
if node.meta.get(QCOM_QUANT_ATTRS):
128+
continue
129+
if not self._has_quantized_weight(node):
130+
continue
131+
input_act = node.args[0]
132+
if not isinstance(input_act, torch.fx.Node):
133+
continue
134+
input_val = input_act.meta.get("val")
135+
if input_val is not None and input_val.dtype == torch.float32:
136+
self._insert_fp32_fp16_casts(graph_module, node)
137+
138+
graph_module.graph.eliminate_dead_code()
139+
graph_module.recompile()
140+
dead_code_elimination_pass(graph_module)
141+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
FuseConsecutiveCast,
5050
FuseConsecutiveTranspose,
5151
I64toI32,
52+
InsertCastForFpActQuantizedWeight,
5253
InsertIOQDQ,
5354
InsertRequantize,
5455
InsertReshapeForReduceOps,
@@ -120,6 +121,7 @@ def get_capture_program_passes():
120121
(FixedLinearKeepDim, True),
121122
(FoldQDQ, True),
122123
(I64toI32, True),
124+
(InsertCastForFpActQuantizedWeight, True),
123125
(LayoutTransform, True),
124126
(RecomposePadMaxPool2d, True),
125127
(RecomposePixelUnshuffle, True),

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def get_passes_dependency_for_capture_program():
8080
FixedLinearKeepDim,
8181
FoldQDQ,
8282
I64toI32,
83+
InsertCastForFpActQuantizedWeight,
8384
LayoutTransform,
8485
RecomposePadMaxPool2d,
8586
RecomposePixelUnshuffle,
@@ -114,6 +115,7 @@ def get_passes_dependency_for_capture_program():
114115
FixedLinearKeepDim: [FoldQDQ],
115116
FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind],
116117
I64toI32: [RemoveRedundancy],
118+
InsertCastForFpActQuantizedWeight: [FoldQDQ, LayoutTransform],
117119
LayoutTransform: [
118120
AnnotateQuantAttrs,
119121
ExpandBroadcastTensorShape,

0 commit comments

Comments
 (0)