Skip to content

Commit 94d2881

Browse files
authored
Add cortex_m MVE/Helium int16 quantize/dequantize support (#19218)
Differential Revision: D103129855 Pull Request resolved: #19218
1 parent 2108986 commit 94d2881

6 files changed

Lines changed: 395 additions & 185 deletions

File tree

backends/arm/scripts/aot_arm_compiler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -847,8 +847,8 @@ def _to_edge_TOSA_delegate(
847847
)
848848

849849
# Replace quantized_decomposed::{quantize,dequantize}_per_tensor nodes
850-
# with cortex_m:: equivalents for int8 QDQ ops remaining outside the
851-
# delegated subgraph.
850+
# with cortex_m:: equivalents for int8/int16 QDQ ops remaining outside
851+
# the delegated subgraph.
852852
edge = _apply_replace_quant_nodes(edge, target, direct_drive)
853853

854854
return model_quant, edge
@@ -955,8 +955,8 @@ def _to_edge_no_delegate(
955955
)
956956

957957
# Replace quantized_decomposed::{quantize,dequantize}_per_tensor nodes
958-
# with cortex_m:: equivalents for int8 QDQ ops remaining outside the
959-
# delegated subgraph.
958+
# with cortex_m:: equivalents for int8/int16 QDQ ops remaining outside
959+
# the delegated subgraph.
960960
edge = _apply_replace_quant_nodes(edge, args.target, args.direct_drive)
961961

962962
return model_quant, edge

backends/cortex_m/ops/op_dequantize_per_tensor.cpp

Lines changed: 118 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
2525
namespace {
2626

2727
/**
28-
* Asserts that the parameters are valid for float to int8 quantization.
28+
* Asserts that the parameters are valid for int8/int16 to float dequantization.
2929
*/
3030
void check_dequantize_args(
3131
const Tensor& input,
@@ -34,11 +34,18 @@ void check_dequantize_args(
3434
int64_t quant_max,
3535
ScalarType dtype,
3636
Tensor& out) {
37-
// Ensure input is char type
37+
// dtype must be Char (int8) or Short (int16)
3838
ET_CHECK_MSG(
39-
input.scalar_type() == ScalarType::Char,
40-
"input.scalar_type() %" PRId8 " is not char type",
41-
static_cast<int8_t>(input.scalar_type()));
39+
dtype == ScalarType::Char || dtype == ScalarType::Short,
40+
"dtype %" PRId8 " is not int8 (Char) or int16 (Short)",
41+
static_cast<int8_t>(dtype));
42+
43+
// Input scalar type must match dtype
44+
ET_CHECK_MSG(
45+
input.scalar_type() == dtype,
46+
"input.scalar_type() %" PRId8 " does not match dtype %" PRId8,
47+
static_cast<int8_t>(input.scalar_type()),
48+
static_cast<int8_t>(dtype));
4249

4350
// Check zp range
4451
ET_CHECK_MSG(
@@ -58,26 +65,26 @@ void check_dequantize_args(
5865
"out.scalar_type() %" PRId8 " is not float",
5966
static_cast<int8_t>(out.scalar_type()));
6067

61-
// Check dtype is int8 (Char)
62-
ET_CHECK_MSG(
63-
dtype == ScalarType::Char,
64-
"dtype %" PRId8 " is not int8 (Char)",
65-
static_cast<int8_t>(dtype));
66-
67-
// Validate quant_min and quant_max for int8
68-
int32_t quant_min_lower_bound = std::numeric_limits<int8_t>::min();
69-
int32_t quant_max_upper_bound = std::numeric_limits<int8_t>::max();
68+
// Validate quant_min and quant_max bounds per dtype
69+
int32_t quant_min_lower_bound, quant_max_upper_bound;
70+
if (dtype == ScalarType::Char) {
71+
quant_min_lower_bound = std::numeric_limits<int8_t>::min();
72+
quant_max_upper_bound = std::numeric_limits<int8_t>::max();
73+
} else { // Short
74+
quant_min_lower_bound = std::numeric_limits<int16_t>::min();
75+
quant_max_upper_bound = std::numeric_limits<int16_t>::max();
76+
}
7077

7178
ET_CHECK_MSG(
7279
quant_min >= quant_min_lower_bound,
73-
"quant_min out of bound for int8, expected quant_min_lower_bound: %" PRId32
80+
"quant_min out of bound, expected quant_min_lower_bound: %" PRId32
7481
" actual quant_min: %" PRId64,
7582
quant_min_lower_bound,
7683
quant_min);
7784

7885
ET_CHECK_MSG(
7986
quant_max <= quant_max_upper_bound,
80-
"quant_max out of bound for int8, expected quant_max_upper_bound: %" PRId32
87+
"quant_max out of bound, expected quant_max_upper_bound: %" PRId32
8188
" actual quant_max: %" PRId64,
8289
quant_max_upper_bound,
8390
quant_max);
@@ -115,66 +122,108 @@ Tensor& dequantize_per_tensor_out(
115122

116123
int32_t zp = static_cast<int32_t>(zero_point);
117124

118-
// Get pointers to input and output data
119-
const int8_t* input_data = input.const_data_ptr<int8_t>();
125+
// Get pointer to output data
120126
float* out_data = out.mutable_data_ptr<float>();
121127
const size_t numel = input.numel();
122128

123129
size_t i = 0;
130+
131+
if (dtype == ScalarType::Char) {
132+
const int8_t* input_data = input.const_data_ptr<int8_t>();
133+
124134
#if defined(HAS_HELIUM_SIMD)
125-
// Helium MVE implementation for int8 to float quantization
126-
static uint8x16_t voffset{
127-
0x0,
128-
0x8,
129-
0x4,
130-
0xC,
131-
0x1,
132-
0x9,
133-
0x5,
134-
0xD,
135-
0x2,
136-
0xA,
137-
0x6,
138-
0xE,
139-
0x3,
140-
0xB,
141-
0x7,
142-
0xF};
143-
144-
int16x8_t vzp = vdupq_n_s16(static_cast<int16_t>(zp));
145-
float32x4_t vscale = vdupq_n_f32(static_cast<float>(scale));
146-
147-
for (; i + 15 < numel; i += 16) {
148-
int8x16_t in_084C195D2A6E3B7F =
149-
vldrbq_gather_offset_s8(input_data, voffset);
150-
151-
int16x8_t in_04152637 = vsubq_s16(vmovlbq_s8(in_084C195D2A6E3B7F), vzp);
152-
int16x8_t in_8C9DAEBF = vsubq_s16(vmovltq_s8(in_084C195D2A6E3B7F), vzp);
153-
154-
float32x4_t inf_0123 = vcvtq_f32_s32(vmovlbq_s16(in_04152637));
155-
float32x4_t inf_4567 = vcvtq_f32_s32(vmovltq_s16(in_04152637));
156-
float32x4_t inf_89AB = vcvtq_f32_s32(vmovlbq_s16(in_8C9DAEBF));
157-
float32x4_t inf_CDEF = vcvtq_f32_s32(vmovltq_s16(in_8C9DAEBF));
158-
159-
float32x4_t out_0123 = vmulq_f32(inf_0123, vscale);
160-
float32x4_t out_4567 = vmulq_f32(inf_4567, vscale);
161-
float32x4_t out_89AB = vmulq_f32(inf_89AB, vscale);
162-
float32x4_t out_CDEF = vmulq_f32(inf_CDEF, vscale);
163-
164-
vstrwq_f32(out_data + 0, out_0123);
165-
vstrwq_f32(out_data + 4, out_4567);
166-
vstrwq_f32(out_data + 8, out_89AB);
167-
vstrwq_f32(out_data + 12, out_CDEF);
168-
169-
input_data += 16;
170-
out_data += 16;
171-
}
135+
// Helium MVE implementation for int8 to float quantization
136+
static uint8x16_t voffset{
137+
0x0,
138+
0x8,
139+
0x4,
140+
0xC,
141+
0x1,
142+
0x9,
143+
0x5,
144+
0xD,
145+
0x2,
146+
0xA,
147+
0x6,
148+
0xE,
149+
0x3,
150+
0xB,
151+
0x7,
152+
0xF};
153+
154+
int16x8_t vzp = vdupq_n_s16(static_cast<int16_t>(zp));
155+
float32x4_t vscale = vdupq_n_f32(static_cast<float>(scale));
156+
157+
for (; i + 15 < numel; i += 16) {
158+
int8x16_t in_084C195D2A6E3B7F =
159+
vldrbq_gather_offset_s8(input_data, voffset);
160+
161+
int16x8_t in_04152637 = vsubq_s16(vmovlbq_s8(in_084C195D2A6E3B7F), vzp);
162+
int16x8_t in_8C9DAEBF = vsubq_s16(vmovltq_s8(in_084C195D2A6E3B7F), vzp);
163+
164+
float32x4_t inf_0123 = vcvtq_f32_s32(vmovlbq_s16(in_04152637));
165+
float32x4_t inf_4567 = vcvtq_f32_s32(vmovltq_s16(in_04152637));
166+
float32x4_t inf_89AB = vcvtq_f32_s32(vmovlbq_s16(in_8C9DAEBF));
167+
float32x4_t inf_CDEF = vcvtq_f32_s32(vmovltq_s16(in_8C9DAEBF));
168+
169+
float32x4_t out_0123 = vmulq_f32(inf_0123, vscale);
170+
float32x4_t out_4567 = vmulq_f32(inf_4567, vscale);
171+
float32x4_t out_89AB = vmulq_f32(inf_89AB, vscale);
172+
float32x4_t out_CDEF = vmulq_f32(inf_CDEF, vscale);
173+
174+
vstrwq_f32(out_data + 0, out_0123);
175+
vstrwq_f32(out_data + 4, out_4567);
176+
vstrwq_f32(out_data + 8, out_89AB);
177+
vstrwq_f32(out_data + 12, out_CDEF);
178+
179+
input_data += 16;
180+
out_data += 16;
181+
}
182+
#endif // defined(HAS_HELIUM_SIMD)
183+
184+
for (; i < numel; i++) {
185+
*out_data = dequantize_val<int8_t, float>(scale, zp, *input_data);
186+
input_data++;
187+
out_data++;
188+
}
189+
} else { // ScalarType::Short — int16 input
190+
const int16_t* input_data = input.const_data_ptr<int16_t>();
191+
192+
#if defined(HAS_HELIUM_SIMD)
193+
// Helium MVE implementation for int16 to float dequantization, processing
194+
// 8 elements per iteration. Mirrors the int8 byte-gather trick at halfword
195+
// granularity: the gather pattern {0,4,1,5,2,6,3,7} arranges int16 lanes
196+
// so that vmovlbq/vmovltq_s16 (which read even/odd lanes of int16x8
197+
// widened to int32x4) yield sequential values, allowing a sequential
198+
// vstrwq_f32 store.
199+
static uint16x8_t voffset_h{0, 4, 1, 5, 2, 6, 3, 7};
200+
201+
int32x4_t vzp = vdupq_n_s32(zp);
202+
float32x4_t vscale = vdupq_n_f32(static_cast<float>(scale));
203+
204+
for (; i + 7 < numel; i += 8) {
205+
int16x8_t in_04152637 =
206+
vldrhq_gather_shifted_offset_s16(input_data, voffset_h);
207+
208+
int32x4_t in_0123 = vsubq_s32(vmovlbq_s16(in_04152637), vzp);
209+
int32x4_t in_4567 = vsubq_s32(vmovltq_s16(in_04152637), vzp);
210+
211+
float32x4_t out_0123 = vmulq_f32(vcvtq_f32_s32(in_0123), vscale);
212+
float32x4_t out_4567 = vmulq_f32(vcvtq_f32_s32(in_4567), vscale);
213+
214+
vstrwq_f32(out_data + 0, out_0123);
215+
vstrwq_f32(out_data + 4, out_4567);
216+
217+
input_data += 8;
218+
out_data += 8;
219+
}
172220
#endif // defined(HAS_HELIUM_SIMD)
173221

174-
for (; i < numel; i++) {
175-
*out_data = dequantize_val<int8_t, float>(scale, zp, *input_data);
176-
input_data++;
177-
out_data++;
222+
for (; i < numel; i++) {
223+
*out_data = dequantize_val<int16_t, float>(scale, zp, *input_data);
224+
input_data++;
225+
out_data++;
226+
}
178227
}
179228
return out;
180229
}

0 commit comments

Comments
 (0)