Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions backends/arm/scripts/aot_arm_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,8 +847,8 @@ def _to_edge_TOSA_delegate(
)

# Replace quantized_decomposed::{quantize,dequantize}_per_tensor nodes
# with cortex_m:: equivalents for int8 QDQ ops remaining outside the
# delegated subgraph.
# with cortex_m:: equivalents for int8/int16 QDQ ops remaining outside
# the delegated subgraph.
edge = _apply_replace_quant_nodes(edge, target, direct_drive)

return model_quant, edge
Expand Down Expand Up @@ -955,8 +955,8 @@ def _to_edge_no_delegate(
)

# Replace quantized_decomposed::{quantize,dequantize}_per_tensor nodes
# with cortex_m:: equivalents for int8 QDQ ops remaining outside the
# delegated subgraph.
# with cortex_m:: equivalents for int8/int16 QDQ ops remaining outside
# the delegated subgraph.
edge = _apply_replace_quant_nodes(edge, args.target, args.direct_drive)

return model_quant, edge
Expand Down
187 changes: 118 additions & 69 deletions backends/cortex_m/ops/op_dequantize_per_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
namespace {

/**
* Asserts that the parameters are valid for float to int8 quantization.
* Asserts that the parameters are valid for int8/int16 to float dequantization.
*/
void check_dequantize_args(
const Tensor& input,
Expand All @@ -34,11 +34,18 @@ void check_dequantize_args(
int64_t quant_max,
ScalarType dtype,
Tensor& out) {
// Ensure input is char type
// dtype must be Char (int8) or Short (int16)
ET_CHECK_MSG(
input.scalar_type() == ScalarType::Char,
"input.scalar_type() %" PRId8 " is not char type",
static_cast<int8_t>(input.scalar_type()));
dtype == ScalarType::Char || dtype == ScalarType::Short,
"dtype %" PRId8 " is not int8 (Char) or int16 (Short)",
static_cast<int8_t>(dtype));

// Input scalar type must match dtype
ET_CHECK_MSG(
input.scalar_type() == dtype,
"input.scalar_type() %" PRId8 " does not match dtype %" PRId8,
static_cast<int8_t>(input.scalar_type()),
static_cast<int8_t>(dtype));

// Check zp range
ET_CHECK_MSG(
Expand All @@ -58,26 +65,26 @@ void check_dequantize_args(
"out.scalar_type() %" PRId8 " is not float",
static_cast<int8_t>(out.scalar_type()));

// Check dtype is int8 (Char)
ET_CHECK_MSG(
dtype == ScalarType::Char,
"dtype %" PRId8 " is not int8 (Char)",
static_cast<int8_t>(dtype));

// Validate quant_min and quant_max for int8
int32_t quant_min_lower_bound = std::numeric_limits<int8_t>::min();
int32_t quant_max_upper_bound = std::numeric_limits<int8_t>::max();
// Validate quant_min and quant_max bounds per dtype
int32_t quant_min_lower_bound, quant_max_upper_bound;
if (dtype == ScalarType::Char) {
quant_min_lower_bound = std::numeric_limits<int8_t>::min();
quant_max_upper_bound = std::numeric_limits<int8_t>::max();
} else { // Short
quant_min_lower_bound = std::numeric_limits<int16_t>::min();
quant_max_upper_bound = std::numeric_limits<int16_t>::max();
}

ET_CHECK_MSG(
quant_min >= quant_min_lower_bound,
"quant_min out of bound for int8, expected quant_min_lower_bound: %" PRId32
"quant_min out of bound, expected quant_min_lower_bound: %" PRId32
" actual quant_min: %" PRId64,
quant_min_lower_bound,
quant_min);

ET_CHECK_MSG(
quant_max <= quant_max_upper_bound,
"quant_max out of bound for int8, expected quant_max_upper_bound: %" PRId32
"quant_max out of bound, expected quant_max_upper_bound: %" PRId32
" actual quant_max: %" PRId64,
quant_max_upper_bound,
quant_max);
Expand Down Expand Up @@ -115,66 +122,108 @@ Tensor& dequantize_per_tensor_out(

int32_t zp = static_cast<int32_t>(zero_point);

// Get pointers to input and output data
const int8_t* input_data = input.const_data_ptr<int8_t>();
// Get pointer to output data
float* out_data = out.mutable_data_ptr<float>();
const size_t numel = input.numel();

size_t i = 0;

if (dtype == ScalarType::Char) {
const int8_t* input_data = input.const_data_ptr<int8_t>();

#if defined(HAS_HELIUM_SIMD)
// Helium MVE implementation for int8 to float quantization
static uint8x16_t voffset{
0x0,
0x8,
0x4,
0xC,
0x1,
0x9,
0x5,
0xD,
0x2,
0xA,
0x6,
0xE,
0x3,
0xB,
0x7,
0xF};

int16x8_t vzp = vdupq_n_s16(static_cast<int16_t>(zp));
float32x4_t vscale = vdupq_n_f32(static_cast<float>(scale));

for (; i + 15 < numel; i += 16) {
int8x16_t in_084C195D2A6E3B7F =
vldrbq_gather_offset_s8(input_data, voffset);

int16x8_t in_04152637 = vsubq_s16(vmovlbq_s8(in_084C195D2A6E3B7F), vzp);
int16x8_t in_8C9DAEBF = vsubq_s16(vmovltq_s8(in_084C195D2A6E3B7F), vzp);

float32x4_t inf_0123 = vcvtq_f32_s32(vmovlbq_s16(in_04152637));
float32x4_t inf_4567 = vcvtq_f32_s32(vmovltq_s16(in_04152637));
float32x4_t inf_89AB = vcvtq_f32_s32(vmovlbq_s16(in_8C9DAEBF));
float32x4_t inf_CDEF = vcvtq_f32_s32(vmovltq_s16(in_8C9DAEBF));

float32x4_t out_0123 = vmulq_f32(inf_0123, vscale);
float32x4_t out_4567 = vmulq_f32(inf_4567, vscale);
float32x4_t out_89AB = vmulq_f32(inf_89AB, vscale);
float32x4_t out_CDEF = vmulq_f32(inf_CDEF, vscale);

vstrwq_f32(out_data + 0, out_0123);
vstrwq_f32(out_data + 4, out_4567);
vstrwq_f32(out_data + 8, out_89AB);
vstrwq_f32(out_data + 12, out_CDEF);

input_data += 16;
out_data += 16;
}
// Helium MVE implementation for int8 to float quantization
static uint8x16_t voffset{
0x0,
0x8,
0x4,
0xC,
0x1,
0x9,
0x5,
0xD,
0x2,
0xA,
0x6,
0xE,
0x3,
0xB,
0x7,
0xF};

int16x8_t vzp = vdupq_n_s16(static_cast<int16_t>(zp));
float32x4_t vscale = vdupq_n_f32(static_cast<float>(scale));

for (; i + 15 < numel; i += 16) {
int8x16_t in_084C195D2A6E3B7F =
vldrbq_gather_offset_s8(input_data, voffset);

int16x8_t in_04152637 = vsubq_s16(vmovlbq_s8(in_084C195D2A6E3B7F), vzp);
int16x8_t in_8C9DAEBF = vsubq_s16(vmovltq_s8(in_084C195D2A6E3B7F), vzp);

float32x4_t inf_0123 = vcvtq_f32_s32(vmovlbq_s16(in_04152637));
float32x4_t inf_4567 = vcvtq_f32_s32(vmovltq_s16(in_04152637));
float32x4_t inf_89AB = vcvtq_f32_s32(vmovlbq_s16(in_8C9DAEBF));
float32x4_t inf_CDEF = vcvtq_f32_s32(vmovltq_s16(in_8C9DAEBF));

float32x4_t out_0123 = vmulq_f32(inf_0123, vscale);
float32x4_t out_4567 = vmulq_f32(inf_4567, vscale);
float32x4_t out_89AB = vmulq_f32(inf_89AB, vscale);
float32x4_t out_CDEF = vmulq_f32(inf_CDEF, vscale);

vstrwq_f32(out_data + 0, out_0123);
vstrwq_f32(out_data + 4, out_4567);
vstrwq_f32(out_data + 8, out_89AB);
vstrwq_f32(out_data + 12, out_CDEF);

input_data += 16;
out_data += 16;
}
#endif // defined(HAS_HELIUM_SIMD)

for (; i < numel; i++) {
*out_data = dequantize_val<int8_t, float>(scale, zp, *input_data);
input_data++;
out_data++;
}
} else { // ScalarType::Short — int16 input
const int16_t* input_data = input.const_data_ptr<int16_t>();

#if defined(HAS_HELIUM_SIMD)
// Helium MVE implementation for int16 to float dequantization, processing
// 8 elements per iteration. Mirrors the int8 byte-gather trick at halfword
// granularity: the gather pattern {0,4,1,5,2,6,3,7} arranges int16 lanes
// so that vmovlbq/vmovltq_s16 (which read even/odd lanes of int16x8
// widened to int32x4) yield sequential values, allowing a sequential
// vstrwq_f32 store.
static uint16x8_t voffset_h{0, 4, 1, 5, 2, 6, 3, 7};

int32x4_t vzp = vdupq_n_s32(zp);
float32x4_t vscale = vdupq_n_f32(static_cast<float>(scale));

for (; i + 7 < numel; i += 8) {
int16x8_t in_04152637 =
vldrhq_gather_shifted_offset_s16(input_data, voffset_h);

int32x4_t in_0123 = vsubq_s32(vmovlbq_s16(in_04152637), vzp);
int32x4_t in_4567 = vsubq_s32(vmovltq_s16(in_04152637), vzp);

float32x4_t out_0123 = vmulq_f32(vcvtq_f32_s32(in_0123), vscale);
float32x4_t out_4567 = vmulq_f32(vcvtq_f32_s32(in_4567), vscale);

vstrwq_f32(out_data + 0, out_0123);
vstrwq_f32(out_data + 4, out_4567);

input_data += 8;
out_data += 8;
}
#endif // defined(HAS_HELIUM_SIMD)

for (; i < numel; i++) {
*out_data = dequantize_val<int8_t, float>(scale, zp, *input_data);
input_data++;
out_data++;
for (; i < numel; i++) {
*out_data = dequantize_val<int16_t, float>(scale, zp, *input_data);
input_data++;
out_data++;
}
}
return out;
}
Expand Down
Loading
Loading