Skip to content

Commit aa4e489

Browse files
authored
Qualcomm AI Engine Direct - AOT Lowering Time Optimization (#18516)
1 parent d75e665 commit aa4e489

14 files changed

Lines changed: 172 additions & 82 deletions

File tree

backends/qualcomm/_passes/seq_mse.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import (
1212
PerBlockParamObserver,
1313
)
14+
from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import (
15+
PerChannelParamObserver,
16+
)
1417
from executorch.exir.pass_base import ExportPass, PassResult
15-
from torchao.quantization.pt2e import PerChannelMinMaxObserver
1618

1719

1820
class SeqMseModule(torch.nn.Module):
@@ -97,7 +99,7 @@ def _per_channel_qdq(self, scale, zero_point):
9799

98100
def _fake_quant(self, scale, zero_point):
99101
dispatcher = {
100-
PerChannelMinMaxObserver: self._per_channel_qdq,
102+
PerChannelParamObserver: self._per_channel_qdq,
101103
PerBlockParamObserver: self._per_block_qdq,
102104
}
103105
return dispatcher[type(self.observer)](scale, zero_point)

backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,37 @@ std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
2727
quantize_param_wrapper = std::make_unique<UndefinedQuantizeParamsWrapper>();
2828
} else if (encoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) {
2929
int32_t axis = quant_info["axis"].cast<int32_t>();
30-
std::vector<Qnn_ScaleOffset_t> scale_offset =
31-
quant_info["scale_offset"].cast<std::vector<Qnn_ScaleOffset_t>>();
32-
30+
auto so_arr =
31+
quant_info["scale_offset"].cast<py::array_t<Qnn_ScaleOffset_t>>();
32+
auto so_buf = so_arr.request();
33+
const Qnn_ScaleOffset_t* so_ptr =
34+
static_cast<const Qnn_ScaleOffset_t*>(so_buf.ptr);
35+
std::vector<Qnn_ScaleOffset_t> scale_offset(so_ptr, so_ptr + so_buf.size);
3336
quantize_param_wrapper =
3437
std::make_unique<AxisScaleOffsetQuantizeParamsWrapper>(
35-
axis, scale_offset);
38+
axis, std::move(scale_offset));
3639
} else if (encoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) {
3740
uint32_t bitwidth = quant_info["bitwidth"].cast<uint32_t>();
3841
int32_t axis = quant_info["axis"].cast<int32_t>();
39-
std::vector<Qnn_ScaleOffset_t> scale_offset =
40-
quant_info["scale_offset"].cast<std::vector<Qnn_ScaleOffset_t>>();
41-
uint32_t num_elements = scale_offset.size();
42-
std::vector<float> scales;
43-
std::vector<int32_t> offsets;
44-
for (const auto& scale_offset : scale_offset) {
45-
scales.push_back(scale_offset.scale);
46-
offsets.push_back(scale_offset.offset);
42+
auto so_arr =
43+
quant_info["scale_offset"].cast<py::array_t<Qnn_ScaleOffset_t>>();
44+
auto so_buf = so_arr.request();
45+
const Qnn_ScaleOffset_t* so_ptr =
46+
static_cast<const Qnn_ScaleOffset_t*>(so_buf.ptr);
47+
uint32_t num_elements = static_cast<uint32_t>(so_buf.size);
48+
std::vector<float> scales(num_elements);
49+
std::vector<int32_t> offsets(num_elements);
50+
for (uint32_t i = 0; i < num_elements; ++i) {
51+
scales[i] = so_ptr[i].scale;
52+
offsets[i] = so_ptr[i].offset;
4753
}
4854
quantize_param_wrapper =
4955
std::make_unique<BwAxisScaleOffsetQuantizeParamsWrapper>(
50-
bitwidth, axis, num_elements, scales, offsets);
56+
bitwidth,
57+
axis,
58+
num_elements,
59+
std::move(scales),
60+
std::move(offsets));
5161
} else if (encoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) {
5262
uint32_t bitwidth = quant_info["bitwidth"].cast<uint32_t>();
5363
float scale = quant_info["scale"].cast<float>();
@@ -62,26 +72,32 @@ std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
6272
std::make_unique<ScaleOffsetQuantizeParamsWrapper>(scale, offset);
6373
} else if (encoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION) {
6474
int32_t axis = quant_info["axis"].cast<int32_t>();
65-
std::vector<Qnn_ScaleOffset_t> scale_offset =
66-
quant_info["block_scale_offset"].cast<std::vector<Qnn_ScaleOffset_t>>();
75+
auto so_arr =
76+
quant_info["block_scale_offset"].cast<py::array_t<Qnn_ScaleOffset_t>>();
77+
auto so_buf = so_arr.request();
78+
const Qnn_ScaleOffset_t* so_ptr =
79+
static_cast<const Qnn_ScaleOffset_t*>(so_buf.ptr);
80+
std::vector<Qnn_ScaleOffset_t> scale_offset(so_ptr, so_ptr + so_buf.size);
6781
uint32_t num_blocks_per_axis =
6882
quant_info["num_blocks_per_axis"].cast<uint32_t>();
6983
uint32_t block_scale_bitwidth =
7084
quant_info["block_scale_bitwidth"].cast<uint32_t>();
7185
Qnn_BlockwiseExpansionBlockScaleStorageType_t block_storage_type =
7286
quant_info["block_storage_type"]
7387
.cast<Qnn_BlockwiseExpansionBlockScaleStorageType_t>();
74-
std::vector<uint8_t> buf =
75-
quant_info["block_scales"].cast<std::vector<uint8_t>>();
88+
py::array_t<uint8_t> block_scales_arr =
89+
quant_info["block_scales"].cast<py::array_t<uint8_t>>();
90+
auto buf_info = block_scales_arr.request();
91+
const uint8_t* ptr = static_cast<const uint8_t*>(buf_info.ptr);
92+
std::vector<uint8_t> block_scales_vec(ptr, ptr + buf_info.size);
7693
quantize_param_wrapper =
7794
std::make_unique<BlockwiseExpansionQuantizeParamsWrapper>(
7895
axis,
79-
scale_offset,
96+
std::move(scale_offset),
8097
num_blocks_per_axis,
8198
block_scale_bitwidth,
8299
block_storage_type,
83-
buf.data(),
84-
buf.size());
100+
std::move(block_scales_vec));
85101
} else {
86102
QNN_EXECUTORCH_LOG_ERROR(
87103
"Unknown the encoding of quantization: %d", encoding);
@@ -196,6 +212,7 @@ PYBIND11_MODULE(PyQnnManagerAdaptor, m) {
196212
// TODO: Add related documents for configurations listed below
197213
using namespace qnn_delegate;
198214
PYBIND11_NUMPY_DTYPE(PyQnnTensorWrapper::EncodingData, scale, offset);
215+
PYBIND11_NUMPY_DTYPE(Qnn_ScaleOffset_t, scale, offset);
199216

200217
m.def("GetQNNCtxBinAlignment", &GetQNNCtxBinAlignment);
201218
m.def("GetQnnSdkBuildId", &GetQnnSdkBuildId);

backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,17 @@ std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
7777
quantization.blockwiseExpansion->scaleOffsets,
7878
quantization.blockwiseExpansion->scaleOffsets +
7979
QNN_TENSOR_VER_PTR(tensor)->dimensions[ch_axis]);
80+
std::vector<uint8_t> block_scales(
81+
quantization.blockwiseExpansion->blocksScale8,
82+
quantization.blockwiseExpansion->blocksScale8 + block_scales_sz);
8083
quantize_param_wrapper =
8184
std::make_unique<BlockwiseExpansionQuantizeParamsWrapper>(
8285
quantization.blockwiseExpansion->axis,
8386
scale_offsets,
8487
quantization.blockwiseExpansion->numBlocksPerAxis,
8588
quantization.blockwiseExpansion->blockScaleBitwidth,
8689
quantization.blockwiseExpansion->blockScaleStorageType,
87-
quantization.blockwiseExpansion->blocksScale8,
88-
block_scales_sz);
90+
block_scales);
8991
} else {
9092
QNN_EXECUTORCH_LOG_ERROR(
9193
"Unknown the encoding of quantization: %d",

backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ class BwAxisScaleOffsetQuantizeParamsWrapper final
9292
bitwidth_(bitwidth),
9393
axis_(axis),
9494
num_elements_(num_elements),
95-
scales_(scales),
96-
offsets_(offsets) {}
95+
scales_(std::move(scales)),
96+
offsets_(std::move(offsets)) {}
9797

9898
BwAxisScaleOffsetQuantizeParamsWrapper(
9999
const BwAxisScaleOffsetQuantizeParamsWrapper& rhs)
@@ -235,12 +235,12 @@ class AxisScaleOffsetQuantizeParamsWrapper final
235235
public:
236236
explicit AxisScaleOffsetQuantizeParamsWrapper(
237237
std::int32_t axis,
238-
const std::vector<Qnn_ScaleOffset_t>& scale_offsets)
238+
std::vector<Qnn_ScaleOffset_t> scale_offsets)
239239
: QuantizeParamsWrapper(
240240
QNN_DEFINITION_DEFINED,
241241
QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET),
242242
axis_(axis),
243-
scale_offsets_(scale_offsets) {}
243+
scale_offsets_(std::move(scale_offsets)) {}
244244

245245
AxisScaleOffsetQuantizeParamsWrapper(
246246
const AxisScaleOffsetQuantizeParamsWrapper& rhs)
@@ -249,8 +249,6 @@ class AxisScaleOffsetQuantizeParamsWrapper final
249249
rhs.GetQuantizationEncoding()),
250250
axis_(rhs.axis_),
251251
scale_offsets_(rhs.scale_offsets_) {}
252-
AxisScaleOffsetQuantizeParamsWrapper(
253-
AxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete;
254252
AxisScaleOffsetQuantizeParamsWrapper& operator=(
255253
const AxisScaleOffsetQuantizeParamsWrapper& rhs) = delete;
256254
AxisScaleOffsetQuantizeParamsWrapper& operator=(
@@ -286,21 +284,20 @@ class BlockwiseExpansionQuantizeParamsWrapper final
286284
public:
287285
explicit BlockwiseExpansionQuantizeParamsWrapper(
288286
std::int32_t axis,
289-
const std::vector<Qnn_ScaleOffset_t>& scale_offsets,
287+
std::vector<Qnn_ScaleOffset_t> scale_offsets,
290288
std::uint32_t num_blocks_per_axis,
291289
std::uint32_t block_scale_bitwidth,
292290
Qnn_BlockwiseExpansionBlockScaleStorageType_t storage_type,
293-
const uint8_t* block_scales_ptr,
294-
std::uint32_t block_scales_size)
291+
std::vector<uint8_t> block_scales)
295292
: QuantizeParamsWrapper(
296293
QNN_DEFINITION_DEFINED,
297294
QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION),
298295
axis_(axis),
299-
scale_offsets_(scale_offsets),
296+
scale_offsets_(std::move(scale_offsets)),
300297
num_blocks_per_axis_(num_blocks_per_axis),
301298
block_scale_bitwidth_(block_scale_bitwidth),
302299
block_storage_type_(storage_type),
303-
block_scales_(block_scales_ptr, block_scales_ptr + block_scales_size) {}
300+
block_scales_(std::move(block_scales)) {}
304301

305302
BlockwiseExpansionQuantizeParamsWrapper(
306303
const BlockwiseExpansionQuantizeParamsWrapper& rhs)
@@ -314,8 +311,6 @@ class BlockwiseExpansionQuantizeParamsWrapper final
314311
block_storage_type_(rhs.block_storage_type_),
315312
block_scales_(rhs.block_scales_) {}
316313

317-
BlockwiseExpansionQuantizeParamsWrapper(
318-
BlockwiseExpansionQuantizeParamsWrapper&& rhs) = delete;
319314
BlockwiseExpansionQuantizeParamsWrapper& operator=(
320315
const BlockwiseExpansionQuantizeParamsWrapper& rhs) = delete;
321316
BlockwiseExpansionQuantizeParamsWrapper& operator=(

backends/qualcomm/builders/node_visitor.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import copy
87
from typing import Any, Dict, Tuple
98

109
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
@@ -151,8 +150,12 @@ def _get_tensor(node, index):
151150
def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
152151
import math
153152

154-
quant_config = copy.deepcopy(quant_attrs)
155-
scales, scale_offset, quantized_scales = quant_attrs[QCOM_SCALE], [], []
153+
quant_config = {
154+
QCOM_DTYPE: quant_attrs[QCOM_DTYPE],
155+
QCOM_QUANT_MIN: quant_attrs[QCOM_QUANT_MIN],
156+
QCOM_QUANT_MAX: quant_attrs[QCOM_QUANT_MAX],
157+
}
158+
scales = quant_attrs[QCOM_SCALE]
156159
# channel in observers defaults to zero
157160
num_channels = node.meta["val"].shape[0]
158161
user_0 = self.get_first_user(node)
@@ -170,17 +173,23 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
170173
PyQnnManager.Qnn_BlockwiseExpansionBlockScaleStorageType_t.QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8
171174
)
172175

176+
scale_offset_arr = np.empty(
177+
num_channels, dtype=[("scale", np.float32), ("offset", np.int32)]
178+
)
179+
# move channel axis to dim 0 for transpose_conv case
180+
candidates = scales if ch_axis == 0 else scales.transpose(0, 1)
181+
candidates = candidates.reshape(num_channels, -1)
182+
# find max scale per channel
183+
max_scales = candidates.amax(dim=-1) / num_steps
184+
# quantize scales per channel
185+
q_scales = torch.clamp(
186+
input=torch.round(input=candidates / max_scales.unsqueeze(-1)),
187+
min=1,
188+
max=2**bitwidth_of_scale,
189+
).to(quant_scales_dtype)
190+
# symmetric quantization is required
173191
for ch in range(num_channels):
174-
candidates = scales[ch] if ch_axis == 0 else scales[:, ch, ...]
175-
max_scale = candidates.reshape(1, -1).amax(dim=-1) / num_steps
176-
q_scales = torch.clamp(
177-
input=torch.round(input=candidates / max_scale),
178-
min=1,
179-
max=2**bitwidth_of_scale,
180-
).to(quant_scales_dtype)
181-
quantized_scales.append(q_scales)
182-
# symmetric quantization is required
183-
scale_offset.append(PyQnnManager.Qnn_ScaleOffset_t(max_scale, 0))
192+
scale_offset_arr[ch] = (float(max_scales[ch]), 0)
184193

185194
# skip dequantize op, e.g. frozen_param -> dq -> conv2d
186195
user_0 = self.get_first_user(node)
@@ -195,9 +204,9 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
195204
else:
196205
raise AttributeError("undetermined axis for block quantization")
197206

198-
quant_config[QCOM_NUM_BLOCKS_PER_AXIS] = quantized_scales[0].shape.numel()
199-
quant_config[QCOM_BLOCK_SCALE_OFFSET] = scale_offset
200-
quant_config[QCOM_BLOCK_SCALES] = torch.cat(quantized_scales).detach().numpy()
207+
quant_config[QCOM_NUM_BLOCKS_PER_AXIS] = q_scales.shape[1]
208+
quant_config[QCOM_BLOCK_SCALE_OFFSET] = scale_offset_arr
209+
quant_config[QCOM_BLOCK_SCALES] = q_scales.flatten().detach().numpy()
201210
# e.g. if use 16 bit for quantized scales, we need to expand 16 - 4 = 12 bits
202211
quant_config[QCOM_BLOCK_SCALE_BITWIDTH] = (
203212
int(math.log2(torch.iinfo(quant_scales_dtype).max + 1)) - bitwidth_of_scale
@@ -209,20 +218,23 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
209218
)
210219

211220
def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
212-
quant_config = copy.deepcopy(quant_attrs)
221+
quant_config = {
222+
QCOM_DTYPE: quant_attrs[QCOM_DTYPE],
223+
QCOM_QUANT_MAX: quant_attrs[QCOM_QUANT_MAX],
224+
QCOM_QUANT_MIN: quant_attrs[QCOM_QUANT_MIN],
225+
}
213226

214227
scales = quant_attrs[QCOM_SCALES]
215228
zero_points = quant_attrs[QCOM_ZERO_POINTS]
216229
assert len(scales) == len(
217230
zero_points
218231
), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}"
219232

220-
scale_offset = []
233+
scale_offset_arr = np.empty(
234+
len(scales), dtype=[("scale", np.float32), ("offset", np.int32)]
235+
)
221236
for i in range(len(scales)):
222-
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
223-
scale_offset.append(
224-
PyQnnManager.Qnn_ScaleOffset_t(scales[i], -zero_points[i])
225-
)
237+
scale_offset_arr[i] = (float(scales[i]), int(-zero_points[i]))
226238

227239
# skip dequantize op, e.g. frozen_param -> dq -> conv2d
228240
user_0 = self.get_first_user(node)
@@ -234,7 +246,7 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
234246
else:
235247
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]
236248

237-
quant_config[QCOM_SCALE_OFFSET] = scale_offset
249+
quant_config[QCOM_SCALE_OFFSET] = scale_offset_arr
238250
# special case for 4 bits
239251
if (
240252
quant_config[QCOM_DTYPE] == torch.int8
@@ -251,7 +263,12 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
251263
)
252264

253265
def make_qnn_per_tensor_config(self, quant_attrs: Dict):
254-
quant_config = copy.deepcopy(quant_attrs)
266+
quant_config = {
267+
QCOM_DTYPE: quant_attrs[QCOM_DTYPE],
268+
QCOM_SCALE: quant_attrs[QCOM_SCALE],
269+
QCOM_QUANT_MAX: quant_attrs[QCOM_QUANT_MAX],
270+
QCOM_QUANT_MIN: quant_attrs[QCOM_QUANT_MIN],
271+
}
255272
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
256273
quant_config[QCOM_OFFSET] = -quant_attrs[QCOM_ZERO_POINT]
257274
# special case for 4 bits

backends/qualcomm/quantizer/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def ptq_per_channel_quant_config(
128128
quant_max=torch.iinfo(weight_dtype).max,
129129
qscheme=torch.per_channel_symmetric,
130130
ch_axis=0,
131-
observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args),
131+
observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(**extra_args),
132132
)
133133

134134
bias_quantization_spec = _derived_bias_quant_spec
@@ -142,7 +142,7 @@ def ptq_per_channel_quant_config(
142142

143143
return quantization_config
144144
```
145-
Here we choose `torch.uint8` + `MinMaxObserver` for better coverage of IO activation and apply rules to `weight` w/`PerChannelMinMaxObserver`, `bias` w/`_derived_bias_quant_spec` (a callable method to calculate encoding in desired way) to meet aforementioned constraints. The well-defined `quantizaton_config` will then be shipped to callback for annotation.<br/>
145+
Here we choose `torch.uint8` + `MinMaxObserver` for better coverage of IO activation and apply rules to `weight` w/`PerChannelParamObserver`, `bias` w/`_derived_bias_quant_spec` (a callable method to calculate encoding in desired way) to meet aforementioned constraints. The well-defined `quantizaton_config` will then be shipped to callback for annotation.<br/>
146146

147147
Now, we can start to fill in the function body:
148148
- Register annotator

backends/qualcomm/quantizer/observers/concat_observer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
from executorch.backends.qualcomm.utils.constants import DEFAULT_EPS_FP32
89
from torchao.quantization.pt2e import UniformQuantizationObserverBase
910

1011

@@ -23,7 +24,7 @@ def __init__(
2324
quant_min=None,
2425
quant_max=None,
2526
factory_kwargs=None,
26-
eps=torch.finfo(torch.float32).eps, # noqa: B008
27+
eps=DEFAULT_EPS_FP32,
2728
is_dynamic=False,
2829
**kwargs,
2930
) -> None:
@@ -49,8 +50,9 @@ def __init__(
4950

5051
def forward(self, x_orig):
5152
# calculate the min / max first
52-
self.min_val = min(self.min_val, x_orig.min())
53-
self.max_val = max(self.max_val, x_orig.max())
53+
min_val, max_val = torch.aminmax(x_orig.detach())
54+
self.min_val = min(self.min_val, min_val)
55+
self.max_val = max(self.max_val, max_val)
5456

5557
if len(self.input_observers) == 0:
5658
# collect observers first if they are not cached

0 commit comments

Comments
 (0)