Skip to content

Commit d75e665

Browse files
authored
Qualcomm AI Engine Direct - Enable per-channel quantization for embedding op (#18433)
1 parent 573f930 commit d75e665

8 files changed

Lines changed: 147 additions & 8 deletions

File tree

backends/qualcomm/builders/node_visitor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
227227
# skip dequantize op, e.g. frozen_param -> dq -> conv2d
228228
user_0 = self.get_first_user(node)
229229
# Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO
230-
if user_0.target == exir_ops.edge.aten.convolution.default:
230+
if user_0.target in {
231+
exir_ops.edge.aten.convolution.default,
232+
}:
231233
quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1
232234
else:
233235
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]

backends/qualcomm/builders/op_embedding.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,22 @@
99

1010
import numpy as np
1111
import torch
12-
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
12+
from executorch.backends.qualcomm.utils.constants import (
13+
QCOM_DATA,
14+
QCOM_DTYPE,
15+
QCOM_ENCODING,
16+
QCOM_QUANT_ATTRS,
17+
QCOM_QUANT_MAX,
18+
QCOM_QUANT_MIN,
19+
QCOM_SCALE,
20+
QCOM_SCALES,
21+
QCOM_ZERO_POINT,
22+
QCOM_ZERO_POINTS,
23+
)
1324

14-
from .node_visitor import NodeVisitor
25+
from .node_visitor import NodeVisitor, PER_CHANNEL_ENCODING, QNN_QUANT_TYPE_MAP
1526
from .node_visitor_manager import register_node_visitor
16-
from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW
27+
from .qnn_constants import OpConvert, OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW
1728
from .utils import get_parameter
1829

1930

@@ -30,6 +41,9 @@ def define_node(
3041
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
3142
) -> PyQnnManager.PyQnnOpWrapper:
3243
weight_node = self.get_node(node.args[0])
44+
is_pcq_embedding = QCOM_QUANT_ATTRS in weight_node.meta and weight_node.meta[
45+
QCOM_QUANT_ATTRS
46+
][QCOM_ENCODING] in (PER_CHANNEL_ENCODING)
3347
weight_tensor = get_parameter(weight_node, self.edge_program)
3448
weight_tensor_wrapper = self.define_tensor(
3549
weight_node,
@@ -52,17 +66,41 @@ def define_node(
5266
gather_input_tensors = [weight_tensor_wrapper, indices_tensor_wrapper]
5367

5468
output_tensor = self.get_tensor(node, node)
69+
node_name = node.name
70+
if is_pcq_embedding:
71+
node_quant_attrs = node.meta[QCOM_QUANT_ATTRS].copy()
72+
intermediate_quant_attrs = node.meta[QCOM_QUANT_ATTRS].copy()
73+
# Based on QNN HTP quantization constraints,
74+
# we should set the scale to max of scales and per-tensor quantization for embedding op
75+
intermediate_quant_attrs[QCOM_SCALE] = (
76+
weight_node.meta[QCOM_QUANT_ATTRS][QCOM_SCALES].max().item()
77+
)
78+
intermediate_quant_attrs[QCOM_ZERO_POINT] = (
79+
weight_node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINTS].max().item()
80+
)
81+
intermediate_quant_attrs[QCOM_DTYPE] = weight_node.meta[QCOM_QUANT_ATTRS][
82+
QCOM_DTYPE
83+
]
84+
intermediate_quant_attrs[QCOM_QUANT_MAX] = weight_node.meta[
85+
QCOM_QUANT_ATTRS
86+
][QCOM_QUANT_MAX]
87+
intermediate_quant_attrs[QCOM_QUANT_MIN] = weight_node.meta[
88+
QCOM_QUANT_ATTRS
89+
][QCOM_QUANT_MIN]
90+
node.meta[QCOM_QUANT_ATTRS] = intermediate_quant_attrs
91+
node_name += "_intermediate"
5592
output_tensor_wrapper = self.define_tensor(
5693
node,
5794
node,
5895
output_tensor,
5996
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
6097
nodes_to_wrappers,
98+
node_name=node_name,
6199
)
62100
gather_output_tensors = [output_tensor_wrapper]
63101

64102
gather_op = PyQnnManager.PyQnnOpWrapper(
65-
node.name,
103+
node_name,
66104
QNN_OP_PACKAGE_NAME_QTI_AISW,
67105
OpGather.op_name,
68106
)
@@ -76,4 +114,36 @@ def define_node(
76114
{QCOM_DATA: np.int32(0)},
77115
)
78116

79-
return gather_op
117+
op_wrapper_list = [gather_op]
118+
119+
if is_pcq_embedding:
120+
node.meta[QCOM_QUANT_ATTRS] = node_quant_attrs
121+
act_quant_encoding, act_quant_configs = self.get_quant_encoding_conf(
122+
node, node
123+
)
124+
act_dtype = (
125+
torch.uint16
126+
if act_quant_configs[QCOM_DTYPE] == torch.int32
127+
else act_quant_configs[QCOM_DTYPE]
128+
)
129+
convert_tensor_wrapper = self.define_custom_tensor_wrapper(
130+
node_name=node.name,
131+
tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
132+
dtype=QNN_QUANT_TYPE_MAP[act_dtype],
133+
quant_encoding=act_quant_encoding,
134+
quant_configs=act_quant_configs,
135+
dims=output_tensor.size(),
136+
tensor=output_tensor,
137+
is_fake_tensor=True,
138+
nodes_to_wrappers=nodes_to_wrappers,
139+
)
140+
convert_op = PyQnnManager.PyQnnOpWrapper(
141+
node.name + "_convert",
142+
QNN_OP_PACKAGE_NAME_QTI_AISW,
143+
OpConvert.op_name,
144+
)
145+
convert_op.AddInputTensors(gather_output_tensors)
146+
convert_op.AddOutputTensors([convert_tensor_wrapper])
147+
op_wrapper_list.append(convert_op)
148+
149+
return op_wrapper_list

backends/qualcomm/quantizer/annotators/htp_rules.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,13 @@ class Elu(GeneralOpDef):
535535
# TODO: Embedding op cannot directly map to OpGather because the index input in torch is not a tensor.
536536
@register_annotator(
537537
[
538-
torch.ops.aten.embedding.default,
539538
torch.ops.aten.gather.default,
540539
torch.ops.aten.index.Tensor,
541540
torch.ops.aten.index_select.default,
542541
],
543542
qnn_op=None,
544543
)
545-
class Embedding(GeneralOpDef):
544+
class Gather(GeneralOpDef):
546545
@staticmethod
547546
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
548547
# args[2] = indices, which should be int
@@ -551,6 +550,40 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
551550
annotate_single_in_share_out(node, quantization_config)
552551

553552

553+
@register_annotator(
554+
[
555+
torch.ops.aten.embedding.default,
556+
],
557+
qnn_op=None,
558+
)
559+
class Embedding(GeneralOpDef):
560+
@staticmethod
561+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
562+
weight = node.args[0]
563+
564+
# Only quantize if input is a float tensor
565+
if _is_annotated([node]) or not _is_float_tensor(weight):
566+
return
567+
568+
is_pcq_embedding = quantization_config.per_channel_embedding
569+
input_qspec_map = {}
570+
input_qspec_map[weight] = (
571+
quantization_config.weight
572+
if is_pcq_embedding
573+
else quantization_config.input_activation
574+
)
575+
output_qspec = (
576+
quantization_config.input_activation
577+
if is_pcq_embedding
578+
else SharedQuantizationSpec((weight, node))
579+
)
580+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
581+
input_qspec_map=input_qspec_map,
582+
output_qspec=output_qspec,
583+
_annotated=True,
584+
)
585+
586+
554587
@register_annotator([torch.ops.aten.eq.Tensor], QnnConstants.OpElementWiseEqual.op_name)
555588
class Equal(GeneralOpDef):
556589
@staticmethod

backends/qualcomm/quantizer/qconfig.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class QuantizationConfig:
3939
weight: Optional[QuantizationSpec]
4040
bias: Optional[QuantizationSpec | Callable]
4141
block_size: Optional[Tuple] = None
42+
per_channel_embedding: bool = False
4243

4344

4445
def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec:

backends/qualcomm/quantizer/quantizer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ class ModuleQConfig:
174174
is_qat: bool = False
175175
is_conv_per_channel: bool = False
176176
is_linear_per_channel: bool = False
177+
is_embedding_per_channel: bool = False
177178
act_observer: Optional[UniformQuantizationObserverBase] = None
178179
act_symmetric: bool = False
179180
eps: Optional[float] = None
@@ -226,6 +227,7 @@ def __post_init__(self):
226227
torch.ops.aten.conv_transpose2d.input: 1,
227228
torch.ops.aten.conv_transpose3d.input: 1,
228229
torch.ops.aten.linear.default: 0,
230+
torch.ops.aten.embedding.default: 0,
229231
}
230232

231233
self.use_per_channel_weight_quant_ops = {}
@@ -245,6 +247,17 @@ def __post_init__(self):
245247
self.use_per_channel_weight_quant_ops.update(
246248
{k: self.op_axis_dict[k] for k in linear_ops if k in self.op_axis_dict}
247249
)
250+
if self.is_embedding_per_channel:
251+
embedding_ops = [torch.ops.aten.embedding.default]
252+
self.use_per_channel_weight_quant_ops.update(
253+
{
254+
k: self.op_axis_dict[k]
255+
for k in embedding_ops
256+
if k in self.op_axis_dict
257+
}
258+
)
259+
for pcq_config in self.per_channel_quant_config_list:
260+
pcq_config.per_channel_embedding = True
248261

249262
if per_block_quant_config_func:
250263
self.per_block_quant_config_list = []
@@ -533,6 +546,7 @@ def set_default_quant_config(
533546
is_qat=False,
534547
is_conv_per_channel=False,
535548
is_linear_per_channel=False,
549+
is_embedding_per_channel=False,
536550
act_observer=None,
537551
act_symmetric=False,
538552
eps=None,
@@ -545,6 +559,7 @@ def set_default_quant_config(
545559
is_qat (bool, optional): Enables Quantization-Aware Training (QAT) mode. Defaults to Post-Training Quantization (PTQ) mode.
546560
is_conv_per_channel (bool, optional): Enables per-channel quantization for convolution operations.
547561
is_linear_per_channel (bool, optional): Enables per-channel quantization for linear (fully connected) operations.
562+
is_embedding_per_channel (bool, optional): Enables per-channel quantization for embedding operations.
548563
act_observer (Optional[UniformQuantizationObserverBase], optional): Custom observer for activation quantization. If not specified, the default observer is determined by `QUANT_CONFIG_DICT`.
549564
550565
"""
@@ -553,6 +568,7 @@ def set_default_quant_config(
553568
is_qat=is_qat,
554569
is_conv_per_channel=is_conv_per_channel,
555570
is_linear_per_channel=is_linear_per_channel,
571+
is_embedding_per_channel=is_embedding_per_channel,
556572
act_observer=act_observer,
557573
act_symmetric=act_symmetric,
558574
eps=eps,

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3089,6 +3089,19 @@ def test_qnn_backend_embedding(self):
30893089
)
30903090
self.lower_module_and_test_output(modules[i], sample_input)
30913091

3092+
# TODO: Once the accuracy issue is fixed, enable this test.
3093+
@unittest.skip("Bad accuracy for HTP")
3094+
def test_qnn_backend_embedding_per_channel(self):
3095+
module = Embedding() # noqa: F405
3096+
sample_input = (torch.Tensor([1, 2, 4, 5]).to(torch.int32),)
3097+
qdq_module = self.get_qdq_module(
3098+
module,
3099+
sample_input,
3100+
quant_dtype=QuantDtype.use_16a8w,
3101+
is_embedding_per_channel=True,
3102+
)
3103+
self.lower_module_and_test_output(qdq_module, sample_input)
3104+
30923105
def test_qnn_backend_equal(self):
30933106
test_comb = [
30943107
{

backends/qualcomm/tests/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ def get_qdq_module(
643643
inputs: Tuple[torch.Tensor],
644644
is_conv_per_channel: Optional[bool] = True,
645645
is_linear_per_channel: Optional[bool] = False,
646+
is_embedding_per_channel: Optional[bool] = False,
646647
custom_quant_annotations: Tuple[Callable] = (),
647648
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
648649
dynamic_shapes: Dict = None,
@@ -659,6 +660,7 @@ def get_qdq_module(
659660
custom_annotations=custom_quant_annotations,
660661
per_channel_conv=is_conv_per_channel,
661662
per_channel_linear=is_linear_per_channel,
663+
per_channel_embedding=is_embedding_per_channel,
662664
submodule_qconfig_list=submodule_qconfig_list,
663665
backend=get_backend_type(self.backend),
664666
soc_model=self.model,

examples/qualcomm/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def make_quantizer(
366366
custom_annotations=(),
367367
per_channel_conv=True,
368368
per_channel_linear=False,
369+
per_channel_embedding=False,
369370
act_observer=MovingAverageMinMaxObserver,
370371
act_symmetric=False,
371372
is_qat=False,
@@ -381,6 +382,7 @@ def make_quantizer(
381382
is_qat=is_qat,
382383
is_conv_per_channel=per_channel_conv,
383384
is_linear_per_channel=per_channel_linear,
385+
is_embedding_per_channel=per_channel_embedding,
384386
act_observer=act_observer,
385387
act_symmetric=act_symmetric,
386388
eps=eps,

0 commit comments

Comments
 (0)