@@ -25,7 +25,7 @@ using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
2525namespace {
2626
2727/* *
28- * Asserts that the parameters are valid for float to int8 quantization .
28+ * Asserts that the parameters are valid for int8/int16 to float dequantization .
2929 */
3030void check_dequantize_args (
3131 const Tensor& input,
@@ -34,11 +34,18 @@ void check_dequantize_args(
3434 int64_t quant_max,
3535 ScalarType dtype,
3636 Tensor& out) {
37- // Ensure input is char type
37+ // dtype must be Char (int8) or Short (int16)
3838 ET_CHECK_MSG (
39- input.scalar_type () == ScalarType::Char,
40- " input.scalar_type() %" PRId8 " is not char type" ,
41- static_cast <int8_t >(input.scalar_type ()));
39+ dtype == ScalarType::Char || dtype == ScalarType::Short,
40+ " dtype %" PRId8 " is not int8 (Char) or int16 (Short)" ,
41+ static_cast <int8_t >(dtype));
42+
43+ // Input scalar type must match dtype
44+ ET_CHECK_MSG (
45+ input.scalar_type () == dtype,
46+ " input.scalar_type() %" PRId8 " does not match dtype %" PRId8,
47+ static_cast <int8_t >(input.scalar_type ()),
48+ static_cast <int8_t >(dtype));
4249
4350 // Check zp range
4451 ET_CHECK_MSG (
@@ -58,26 +65,26 @@ void check_dequantize_args(
5865 " out.scalar_type() %" PRId8 " is not float" ,
5966 static_cast <int8_t >(out.scalar_type ()));
6067
61- // Check dtype is int8 (Char)
62- ET_CHECK_MSG (
63- dtype == ScalarType::Char,
64- " dtype % " PRId8 " is not int8 (Char) " ,
65- static_cast <int8_t >(dtype) );
66-
67- // Validate quant_min and quant_max for int8
68- int32_t quant_min_lower_bound = std::numeric_limits<int8_t >::min ();
69- int32_t quant_max_upper_bound = std::numeric_limits< int8_t >:: max ();
68+ // Validate quant_min and quant_max bounds per dtype
69+ int32_t quant_min_lower_bound, quant_max_upper_bound;
70+ if ( dtype == ScalarType::Char) {
71+ quant_min_lower_bound = std::numeric_limits< int8_t >:: min ();
72+ quant_max_upper_bound = std::numeric_limits <int8_t >:: max ( );
73+ } else { // Short
74+ quant_min_lower_bound = std::numeric_limits< int16_t >:: min ();
75+ quant_max_upper_bound = std::numeric_limits<int16_t >::max ();
76+ }
7077
7178 ET_CHECK_MSG (
7279 quant_min >= quant_min_lower_bound,
73- " quant_min out of bound for int8 , expected quant_min_lower_bound: %" PRId32
80+ " quant_min out of bound, expected quant_min_lower_bound: %" PRId32
7481 " actual quant_min: %" PRId64,
7582 quant_min_lower_bound,
7683 quant_min);
7784
7885 ET_CHECK_MSG (
7986 quant_max <= quant_max_upper_bound,
80- " quant_max out of bound for int8 , expected quant_max_upper_bound: %" PRId32
87+ " quant_max out of bound, expected quant_max_upper_bound: %" PRId32
8188 " actual quant_max: %" PRId64,
8289 quant_max_upper_bound,
8390 quant_max);
@@ -115,66 +122,108 @@ Tensor& dequantize_per_tensor_out(
115122
116123 int32_t zp = static_cast <int32_t >(zero_point);
117124
118- // Get pointers to input and output data
119- const int8_t * input_data = input.const_data_ptr <int8_t >();
125+ // Get pointer to output data
120126 float * out_data = out.mutable_data_ptr <float >();
121127 const size_t numel = input.numel ();
122128
123129 size_t i = 0 ;
130+
131+ if (dtype == ScalarType::Char) {
132+ const int8_t * input_data = input.const_data_ptr <int8_t >();
133+
124134#if defined(HAS_HELIUM_SIMD)
125- // Helium MVE implementation for int8 to float quantization
126- static uint8x16_t voffset{
127- 0x0 ,
128- 0x8 ,
129- 0x4 ,
130- 0xC ,
131- 0x1 ,
132- 0x9 ,
133- 0x5 ,
134- 0xD ,
135- 0x2 ,
136- 0xA ,
137- 0x6 ,
138- 0xE ,
139- 0x3 ,
140- 0xB ,
141- 0x7 ,
142- 0xF };
143-
144- int16x8_t vzp = vdupq_n_s16 (static_cast <int16_t >(zp));
145- float32x4_t vscale = vdupq_n_f32 (static_cast <float >(scale));
146-
147- for (; i + 15 < numel; i += 16 ) {
148- int8x16_t in_084C195D2A6E3B7F =
149- vldrbq_gather_offset_s8 (input_data, voffset);
150-
151- int16x8_t in_04152637 = vsubq_s16 (vmovlbq_s8 (in_084C195D2A6E3B7F), vzp);
152- int16x8_t in_8C9DAEBF = vsubq_s16 (vmovltq_s8 (in_084C195D2A6E3B7F), vzp);
153-
154- float32x4_t inf_0123 = vcvtq_f32_s32 (vmovlbq_s16 (in_04152637));
155- float32x4_t inf_4567 = vcvtq_f32_s32 (vmovltq_s16 (in_04152637));
156- float32x4_t inf_89AB = vcvtq_f32_s32 (vmovlbq_s16 (in_8C9DAEBF));
157- float32x4_t inf_CDEF = vcvtq_f32_s32 (vmovltq_s16 (in_8C9DAEBF));
158-
159- float32x4_t out_0123 = vmulq_f32 (inf_0123, vscale);
160- float32x4_t out_4567 = vmulq_f32 (inf_4567, vscale);
161- float32x4_t out_89AB = vmulq_f32 (inf_89AB, vscale);
162- float32x4_t out_CDEF = vmulq_f32 (inf_CDEF, vscale);
163-
164- vstrwq_f32 (out_data + 0 , out_0123);
165- vstrwq_f32 (out_data + 4 , out_4567);
166- vstrwq_f32 (out_data + 8 , out_89AB);
167- vstrwq_f32 (out_data + 12 , out_CDEF);
168-
169- input_data += 16 ;
170- out_data += 16 ;
171- }
135+ // Helium MVE implementation for int8 to float quantization
136+ static uint8x16_t voffset{
137+ 0x0 ,
138+ 0x8 ,
139+ 0x4 ,
140+ 0xC ,
141+ 0x1 ,
142+ 0x9 ,
143+ 0x5 ,
144+ 0xD ,
145+ 0x2 ,
146+ 0xA ,
147+ 0x6 ,
148+ 0xE ,
149+ 0x3 ,
150+ 0xB ,
151+ 0x7 ,
152+ 0xF };
153+
154+ int16x8_t vzp = vdupq_n_s16 (static_cast <int16_t >(zp));
155+ float32x4_t vscale = vdupq_n_f32 (static_cast <float >(scale));
156+
157+ for (; i + 15 < numel; i += 16 ) {
158+ int8x16_t in_084C195D2A6E3B7F =
159+ vldrbq_gather_offset_s8 (input_data, voffset);
160+
161+ int16x8_t in_04152637 = vsubq_s16 (vmovlbq_s8 (in_084C195D2A6E3B7F), vzp);
162+ int16x8_t in_8C9DAEBF = vsubq_s16 (vmovltq_s8 (in_084C195D2A6E3B7F), vzp);
163+
164+ float32x4_t inf_0123 = vcvtq_f32_s32 (vmovlbq_s16 (in_04152637));
165+ float32x4_t inf_4567 = vcvtq_f32_s32 (vmovltq_s16 (in_04152637));
166+ float32x4_t inf_89AB = vcvtq_f32_s32 (vmovlbq_s16 (in_8C9DAEBF));
167+ float32x4_t inf_CDEF = vcvtq_f32_s32 (vmovltq_s16 (in_8C9DAEBF));
168+
169+ float32x4_t out_0123 = vmulq_f32 (inf_0123, vscale);
170+ float32x4_t out_4567 = vmulq_f32 (inf_4567, vscale);
171+ float32x4_t out_89AB = vmulq_f32 (inf_89AB, vscale);
172+ float32x4_t out_CDEF = vmulq_f32 (inf_CDEF, vscale);
173+
174+ vstrwq_f32 (out_data + 0 , out_0123);
175+ vstrwq_f32 (out_data + 4 , out_4567);
176+ vstrwq_f32 (out_data + 8 , out_89AB);
177+ vstrwq_f32 (out_data + 12 , out_CDEF);
178+
179+ input_data += 16 ;
180+ out_data += 16 ;
181+ }
182+ #endif // defined(HAS_HELIUM_SIMD)
183+
184+ for (; i < numel; i++) {
185+ *out_data = dequantize_val<int8_t , float >(scale, zp, *input_data);
186+ input_data++;
187+ out_data++;
188+ }
189+ } else { // ScalarType::Short — int16 input
190+ const int16_t * input_data = input.const_data_ptr <int16_t >();
191+
192+ #if defined(HAS_HELIUM_SIMD)
193+ // Helium MVE implementation for int16 to float dequantization, processing
194+ // 8 elements per iteration. Mirrors the int8 byte-gather trick at halfword
195+ // granularity: the gather pattern {0,4,1,5,2,6,3,7} arranges int16 lanes
196+ // so that vmovlbq/vmovltq_s16 (which read even/odd lanes of int16x8
197+ // widened to int32x4) yield sequential values, allowing a sequential
198+ // vstrwq_f32 store.
199+ static uint16x8_t voffset_h{0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 };
200+
201+ int32x4_t vzp = vdupq_n_s32 (zp);
202+ float32x4_t vscale = vdupq_n_f32 (static_cast <float >(scale));
203+
204+ for (; i + 7 < numel; i += 8 ) {
205+ int16x8_t in_04152637 =
206+ vldrhq_gather_shifted_offset_s16 (input_data, voffset_h);
207+
208+ int32x4_t in_0123 = vsubq_s32 (vmovlbq_s16 (in_04152637), vzp);
209+ int32x4_t in_4567 = vsubq_s32 (vmovltq_s16 (in_04152637), vzp);
210+
211+ float32x4_t out_0123 = vmulq_f32 (vcvtq_f32_s32 (in_0123), vscale);
212+ float32x4_t out_4567 = vmulq_f32 (vcvtq_f32_s32 (in_4567), vscale);
213+
214+ vstrwq_f32 (out_data + 0 , out_0123);
215+ vstrwq_f32 (out_data + 4 , out_4567);
216+
217+ input_data += 8 ;
218+ out_data += 8 ;
219+ }
172220#endif // defined(HAS_HELIUM_SIMD)
173221
174- for (; i < numel; i++) {
175- *out_data = dequantize_val<int8_t , float >(scale, zp, *input_data);
176- input_data++;
177- out_data++;
222+ for (; i < numel; i++) {
223+ *out_data = dequantize_val<int16_t , float >(scale, zp, *input_data);
224+ input_data++;
225+ out_data++;
226+ }
178227 }
179228 return out;
180229}
0 commit comments