Skip to content

Commit 6217c75

Browse files
authored
Disable 4-bit activation quant/dequant support
Differential Revision: D100258703 Pull Request resolved: #18814
1 parent 36e8ed9 commit 6217c75

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

backends/cadence/fusion_g3/operators/op_dequantize.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ using optional = std::optional<T>;
3131
* operator need to be updated accordingly
3232
*/
3333

34+
#ifdef G3_ENABLE_4BIT_QUANTIZATION
3435
enum datatype { Bits4u = 21, Bits4 = 22 };
36+
#endif
3537

3638
/**
3739
* For an input tensor, use the scale and zero_point arguments to quantize it.
@@ -57,8 +59,10 @@ void check_dequantize_per_tensor_args(
5759
input.scalar_type() == ScalarType::Char ||
5860
input.scalar_type() == ScalarType::UInt16 ||
5961
input.scalar_type() == ScalarType::Short ||
62+
#ifdef G3_ENABLE_4BIT_QUANTIZATION
6063
input.scalar_type() == (ScalarType)Bits4 ||
6164
input.scalar_type() == (ScalarType)Bits4u ||
65+
#endif
6266
input.scalar_type() == ScalarType::Int,
6367

6468
"input.scalar_type() %" PRId8 " is not supported:",
@@ -183,6 +187,7 @@ Tensor& dequantize_impl(
183187
axis,
184188
zero_point_data,
185189
scale_data);
190+
#ifdef G3_ENABLE_4BIT_QUANTIZATION
186191
} else if ((input.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
187192
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
188193
XT_KERNEL_CHECK(
@@ -209,6 +214,7 @@ Tensor& dequantize_impl(
209214
axis,
210215
zero_point_data,
211216
scale_data);
217+
#endif
212218
} else {
213219
if (axis == NULL) {
214220
// calculate the dequantized output, cast scale to float to match fbgemm
@@ -391,6 +397,7 @@ Tensor& dequantize_impl(
391397
input.dim(),
392398
axis,
393399
scale_data);
400+
#ifdef G3_ENABLE_4BIT_QUANTIZATION
394401
} else if ((input.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
395402
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
396403
XT_KERNEL_CHECK(
@@ -415,6 +422,7 @@ Tensor& dequantize_impl(
415422
input.dim(),
416423
axis,
417424
scale_data);
425+
#endif
418426
} else {
419427
if (axis == NULL) {
420428
// calculate the dequantized output, cast scale to float to match fbgemm

backends/cadence/fusion_g3/operators/op_quantize.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ using ::executorch::runtime::KernelRuntimeContext;
2828
* updated to have support for below data types, these can be removed and
2929
* operator need to be updated accordingly
3030
*/
31+
#ifdef G3_ENABLE_4BIT_QUANTIZATION
3132
enum datatype { Bits4u = 21, Bits4 = 22 };
33+
#endif
3234

3335
/**
3436
* For an input tensor, use the scale and zero_point arguments to quantize it.
@@ -78,6 +80,7 @@ void check_quantize_per_tensor_args(
7880
} else if (dtype == ScalarType::Short) {
7981
quant_min_lower_bound = std::numeric_limits<int16_t>::min();
8082
quant_max_upper_bound = std::numeric_limits<int16_t>::max();
83+
#ifdef G3_ENABLE_4BIT_QUANTIZATION
8184
} else if (dtype == (ScalarType)Bits4u) {
8285
quant_min_lower_bound = std::numeric_limits<uint8_t>::min();
8386
quant_max_upper_bound = std::numeric_limits<uint8_t>::max();
@@ -90,6 +93,7 @@ void check_quantize_per_tensor_args(
9093
/* Minimum and maximum values fo signed 4-bit data type */
9194
quant_min_lower_bound = quant_min_lower_bound >> 4;
9295
quant_max_upper_bound = quant_max_upper_bound >> 4;
96+
#endif
9397
} else if (dtype == ScalarType::Int) {
9498
quant_min_lower_bound = std::numeric_limits<int32_t>::min();
9599
quant_max_upper_bound = std::numeric_limits<int32_t>::max();
@@ -243,6 +247,7 @@ Tensor& quantize_impl(
243247
zero_point_data,
244248
quant_min,
245249
quant_max);
250+
#ifdef G3_ENABLE_4BIT_QUANTIZATION
246251
} else if ((out.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
247252
uint8_t* out_data = out.mutable_data_ptr<uint8_t>();
248253
XT_KERNEL_CHECK(
@@ -273,6 +278,7 @@ Tensor& quantize_impl(
273278
zero_point_data,
274279
quant_min,
275280
quant_max);
281+
#endif
276282
} else {
277283
if (axis == NULL) {
278284
// Vector quantization
@@ -452,6 +458,7 @@ Tensor& quantize_impl(
452458
scale_data,
453459
quant_min,
454460
quant_max);
461+
#ifdef G3_ENABLE_4BIT_QUANTIZATION
455462
} else if ((out.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
456463
uint8_t* out_data = out.mutable_data_ptr<uint8_t>();
457464
XT_KERNEL_CHECK(
@@ -480,6 +487,7 @@ Tensor& quantize_impl(
480487
scale_data,
481488
quant_min,
482489
quant_max);
490+
#endif
483491
} else {
484492
if (axis == NULL) {
485493
// calculate the quantized input

0 commit comments

Comments
 (0)