Skip to content

Commit c4e3db0

Browse files
authored
Qualcomm AI Engine Direct - Support 2-bits quantization 16a2w (pytorch#19632)
Qualcomm AI Engine Direct - Support 2-bits quantization 16a2w Summary: 1.Add 2-bits quantization basis 16a2w quantizer with standard symmetric 2.Support per channel and linear layers 3.Currently support soc model SM8850 Test plan: python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_16a2w_conv2d -b build-android -H ${HOST} -s ${SN} -m SM8850 python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_16a2w_linear -b build-android -H ${HOST} -s ${SN} -m SM8850 cc @cccclai @cbilgin @abhinaykukkadapu
1 parent 90cd48f commit c4e3db0

5 files changed

Lines changed: 146 additions & 17 deletions

File tree

backends/qualcomm/builders/node_visitor.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -248,16 +248,19 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
248248
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]
249249

250250
quant_config[QCOM_SCALE_OFFSET] = scale_offset_arr
251-
# special case for 4 bits
252-
if (
253-
quant_config[QCOM_DTYPE] == torch.int8
254-
and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
255-
):
256-
quant_config[QCOM_BITWIDTH] = 4
257-
return (
258-
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
259-
quant_config,
260-
)
251+
if quant_config[QCOM_DTYPE] == torch.int8:
252+
if quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 3:
253+
quant_config[QCOM_BITWIDTH] = 2
254+
return (
255+
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
256+
quant_config,
257+
)
258+
elif quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15:
259+
quant_config[QCOM_BITWIDTH] = 4
260+
return (
261+
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
262+
quant_config,
263+
)
261264
return (
262265
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET,
263266
quant_config,
@@ -272,6 +275,11 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict):
272275
}
273276
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
274277
quant_config[QCOM_OFFSET] = -quant_attrs[QCOM_ZERO_POINT]
278+
range_ = quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN]
279+
assert range_ > 3, (
280+
f"2-bit quantization (range={range_}) does not support per-tensor encoding. "
281+
"Use per-channel quantization instead."
282+
)
275283
# special case for 4 bits
276284
if (
277285
quant_config[QCOM_DTYPE] == torch.int8
@@ -338,6 +346,9 @@ def get_quant_tensor_value(
338346
if quant_configs.get(QCOM_BITWIDTH) == 4:
339347
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
340348
tensor = torch.bitwise_and(mask, tensor)
349+
elif quant_configs.get(QCOM_BITWIDTH) == 2:
350+
mask = torch.full(tensor.size(), 0x03, dtype=torch.int8)
351+
tensor = torch.bitwise_and(mask, tensor)
341352
return tensor
342353

343354
def get_tensor_type(

backends/qualcomm/quantizer/qconfig.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,51 @@ def get_8a4w_qnn_ptq_config(
357357
return quantization_config
358358

359359

360+
# 2 bits weight quantization only supports per channel and symmetric.
361+
def get_16a2w_qnn_ptq_config(
362+
act_symmetric: bool = False,
363+
act_observer=MovingAverageMinMaxObserver,
364+
eps: float = None,
365+
) -> QuantizationConfig:
366+
# the smallest defaults to DEFAULT_EPS_16BIT
367+
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}
368+
369+
act_quantization_spec = QuantizationSpec(
370+
dtype=torch.int32,
371+
quant_min=torch.iinfo(torch.uint16).min,
372+
quant_max=torch.iinfo(torch.uint16).max,
373+
qscheme=(
374+
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
375+
),
376+
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
377+
)
378+
379+
weight_quantization_spec = QuantizationSpec(
380+
dtype=torch.int8,
381+
quant_min=-2,
382+
quant_max=1,
383+
qscheme=torch.per_tensor_symmetric,
384+
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
385+
)
386+
387+
bias_quantization_spec = QuantizationSpec(
388+
dtype=torch.int32,
389+
quant_min=torch.iinfo(torch.int32).min,
390+
quant_max=torch.iinfo(torch.int32).max,
391+
qscheme=torch.per_tensor_symmetric,
392+
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
393+
)
394+
395+
quantization_config = QuantizationConfig(
396+
input_activation=act_quantization_spec,
397+
output_activation=act_quantization_spec,
398+
weight=weight_quantization_spec,
399+
bias=bias_quantization_spec,
400+
)
401+
402+
return quantization_config
403+
404+
360405
# 4 bits quantization only supports specific ops.
361406
def get_16a4w_qnn_ptq_config(
362407
act_symmetric: bool = False,
@@ -573,7 +618,7 @@ def get_ptq_per_channel_quant_config(
573618
torch.int8,
574619
torch.int16,
575620
}
576-
supported_weight_dtypes = {torch.int4, torch.int8, torch.int16}
621+
supported_weight_dtypes = {torch.int2, torch.int4, torch.int8, torch.int16}
577622
assert (
578623
act_dtype in supported_act_types
579624
), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}"
@@ -606,12 +651,23 @@ def get_ptq_per_channel_quant_config(
606651
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
607652
)
608653

654+
q_dtype = weight_dtype
655+
if weight_dtype == torch.int4:
656+
q_dtype = torch.int8
657+
q_min = -7
658+
q_max = 7
659+
elif weight_dtype == torch.int2:
660+
q_dtype = torch.int8
661+
q_min = -2
662+
q_max = 1
663+
else:
664+
q_min = torch.iinfo(weight_dtype).min + 1
665+
q_max = torch.iinfo(weight_dtype).max
666+
609667
weight_quantization_spec = QuantizationSpec(
610-
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
611-
quant_min=(
612-
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
613-
),
614-
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
668+
dtype=q_dtype,
669+
quant_min=q_min,
670+
quant_max=q_max,
615671
qscheme=torch.per_channel_symmetric,
616672
ch_axis=ch_axis,
617673
observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(**extra_args),

backends/qualcomm/quantizer/quantizer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
from .qconfig import (
4646
get_16a16w_qnn_ptq_config,
47+
get_16a2w_qnn_ptq_config,
4748
get_16a4w_qnn_ptq_config,
4849
get_16a4w_qnn_qat_config,
4950
get_16a8w_qnn_ptq_config,
@@ -69,6 +70,7 @@
6970
__all__ = [
7071
"QnnQuantizer",
7172
"QuantDtype",
73+
"get_16a2w_qnn_ptq_config",
7274
"get_16a4w_qnn_ptq_config",
7375
"get_16a8w_qnn_ptq_config",
7476
"get_16a8w_qnn_qat_config",
@@ -94,6 +96,7 @@ class QuantDtype(IntEnum):
9496
use_8a8w = 4
9597
use_8a4w = 5
9698
use_fp16a8w = 6
99+
use_16a2w = 7
97100

98101

99102
QUANT_CONFIG_DICT = {
@@ -125,6 +128,15 @@ class QuantDtype(IntEnum):
125128
),
126129
None,
127130
),
131+
(QuantDtype.use_16a2w, False): (
132+
get_16a2w_qnn_ptq_config,
133+
partial(
134+
get_ptq_per_channel_quant_config,
135+
act_dtype=torch.uint16,
136+
weight_dtype=torch.int2,
137+
),
138+
None,
139+
),
128140
(QuantDtype.use_16a4w_block, False): (
129141
get_16a4w_qnn_ptq_config,
130142
partial(

backends/qualcomm/quantizer/validators.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,12 @@ def _qspec_port_encoding_type(node: Node, qspec: QuantizationSpecBase):
283283
qscheme = qspec.qscheme
284284

285285
if qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
286-
if qspec.dtype == torch.int8 and qspec.quant_max - qspec.quant_min <= 15:
286+
range_ = qspec.quant_max - qspec.quant_min
287+
assert range_ > 3, (
288+
f"2-bit quantization (range={range_}) does not support per-tensor encoding. "
289+
"Use per-channel quantization instead."
290+
)
291+
if qspec.dtype == torch.int8 and range_ <= 15:
287292
encoding_type = (
288293
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET
289294
)
@@ -298,6 +303,10 @@ def _qspec_port_encoding_type(node: Node, qspec: QuantizationSpecBase):
298303
encoding_type = (
299304
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION
300305
)
306+
elif qspec.dtype == torch.int8 and qspec.quant_max - qspec.quant_min <= 3:
307+
encoding_type = (
308+
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET
309+
)
301310
elif qspec.dtype == torch.int8 and qspec.quant_max - qspec.quant_min <= 15:
302311
encoding_type = (
303312
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2632,6 +2632,47 @@ def setUp(self):
26322632
shared_buffer=TestQNN.shared_buffer,
26332633
)
26342634

2635+
@unittest.skipIf(
2636+
is_qnn_sdk_version_less_than("2.41"),
2637+
"UT pass after QNN 2.41.",
2638+
)
2639+
def test_qnn_backend_16a2w_conv2d(self):
2640+
modules = [Conv2dSingle(), Conv2dSingle(bias=False)] # noqa: F405
2641+
torch.manual_seed(8)
2642+
sample_input = (torch.randn([1, 1, 3, 3]),)
2643+
for i, module in enumerate(modules):
2644+
with self.subTest(i=i):
2645+
qdq_module = self.get_qdq_module(
2646+
module,
2647+
sample_input,
2648+
is_linear_per_channel=True,
2649+
quant_dtype=QuantDtype.use_16a2w,
2650+
)
2651+
self.lower_module_and_test_output(qdq_module, sample_input)
2652+
2653+
@unittest.skipIf(
2654+
is_qnn_sdk_version_less_than("2.41"),
2655+
"UT pass after QNN 2.41.",
2656+
)
2657+
def test_qnn_backend_16a2w_linear(self):
2658+
torch.manual_seed(8)
2659+
sample_input = (torch.randn([3, 512]),)
2660+
for i, (per_channel, use_bias) in enumerate(
2661+
[
2662+
(True, False),
2663+
(True, True),
2664+
]
2665+
):
2666+
with self.subTest(i=i):
2667+
module = Linear(use_bias=use_bias) # noqa: F405
2668+
qdq_module = self.get_qdq_module(
2669+
module,
2670+
sample_input,
2671+
is_linear_per_channel=per_channel,
2672+
quant_dtype=QuantDtype.use_16a2w,
2673+
)
2674+
self.lower_module_and_test_output(qdq_module, sample_input)
2675+
26352676
def test_qnn_backend_16a4w_conv2d(self):
26362677
modules = [Conv2dSingle(), Conv2dSingle(bias=False)] # noqa: F405
26372678
sample_input = (torch.randn([1, 1, 3, 3]),)

0 commit comments

Comments
 (0)