Skip to content

Commit 901cdc1

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Fused quant hardswish kernel (#19488)
Summary: Fused quant hardswish kernel with optional dequantize/quantize. Unary op that applies x * min(max(x+3, 0), 6) / 6. Supports per-tensor and per-channel quantization. Reviewed By: mvartani-meta Differential Revision: D103754780
1 parent ced6e17 commit 901cdc1

5 files changed

Lines changed: 557 additions & 0 deletions

File tree

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <algorithm>
10+
11+
#include <executorch/backends/cadence/fused_quant/op_hardswish.h>
12+
#include <executorch/backends/cadence/fused_quant/quant_utils.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace cadence {
16+
namespace fused_quant {
17+
namespace native {
18+
19+
using executorch::aten::optional;
20+
using executorch::aten::ScalarType;
21+
using executorch::aten::Tensor;
22+
using executorch::runtime::KernelRuntimeContext;
23+
24+
namespace {
25+
26+
void hardswish_kernel(const float* inp, float* out, int64_t numel) {
27+
for (int64_t i = 0; i < numel; ++i) {
28+
float x = inp[i];
29+
out[i] = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
30+
}
31+
}
32+
33+
} // namespace
34+
35+
Tensor& hardswish_out(
36+
KernelRuntimeContext& ctx,
37+
const Tensor& inp,
38+
const optional<Tensor>& inp_scale,
39+
const optional<Tensor>& inp_zero_point,
40+
ScalarType inp_dtype,
41+
int64_t inp_quant_min,
42+
int64_t inp_quant_max,
43+
optional<int64_t> inp_axis,
44+
const optional<Tensor>& out_scale,
45+
const optional<Tensor>& out_zero_point,
46+
ScalarType out_dtype,
47+
int64_t out_quant_min,
48+
int64_t out_quant_max,
49+
optional<int64_t> out_axis,
50+
Tensor& out) {
51+
int64_t numel = inp.numel();
52+
53+
bool inp_quantized = inp_scale.has_value();
54+
bool out_quantized = out_scale.has_value();
55+
56+
std::vector<float> inp_buf;
57+
const float* const inp_float = [&]() -> const float* {
58+
if (!inp_quantized) {
59+
return inp.const_data_ptr<float>();
60+
}
61+
inp_buf.resize(numel);
62+
QParams qp = extract_qparams(
63+
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp);
64+
FUSED_QUANT_DTYPE_SWITCH(
65+
inp.scalar_type(),
66+
scalar_t,
67+
dequantize_buffer(
68+
inp.const_data_ptr<scalar_t>(), inp_buf.data(), numel, qp);)
69+
return inp_buf.data();
70+
}();
71+
72+
if (out_quantized) {
73+
std::vector<float> result_float(numel);
74+
hardswish_kernel(inp_float, result_float.data(), numel);
75+
76+
QParams qp = extract_qparams(
77+
out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out);
78+
FUSED_QUANT_DTYPE_SWITCH(
79+
out.scalar_type(),
80+
scalar_t,
81+
quantize_buffer(
82+
result_float.data(), out.mutable_data_ptr<scalar_t>(), numel, qp);)
83+
} else {
84+
hardswish_kernel(inp_float, out.mutable_data_ptr<float>(), numel);
85+
}
86+
87+
return out;
88+
}
89+
90+
} // namespace native
91+
} // namespace fused_quant
92+
} // namespace cadence
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
namespace cadence {
15+
namespace fused_quant {
16+
namespace native {
17+
18+
executorch::aten::Tensor& hardswish_out(
19+
executorch::runtime::KernelRuntimeContext& ctx,
20+
const executorch::aten::Tensor& inp,
21+
const executorch::aten::optional<executorch::aten::Tensor>& inp_scale,
22+
const executorch::aten::optional<executorch::aten::Tensor>& inp_zero_point,
23+
executorch::aten::ScalarType inp_dtype,
24+
int64_t inp_quant_min,
25+
int64_t inp_quant_max,
26+
executorch::aten::optional<int64_t> inp_axis,
27+
const executorch::aten::optional<executorch::aten::Tensor>& out_scale,
28+
const executorch::aten::optional<executorch::aten::Tensor>& out_zero_point,
29+
executorch::aten::ScalarType out_dtype,
30+
int64_t out_quant_min,
31+
int64_t out_quant_max,
32+
executorch::aten::optional<int64_t> out_axis,
33+
executorch::aten::Tensor& out);
34+
35+
} // namespace native
36+
} // namespace fused_quant
37+
} // namespace cadence

backends/cadence/fused_quant/targets.bzl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,15 @@ def define_common_targets():
4646
],
4747
visibility = ["PUBLIC"],
4848
)
49+
50+
runtime.cxx_library(
51+
name = "op_hardswish",
52+
srcs = ["op_hardswish.cpp"],
53+
exported_headers = ["op_hardswish.h"],
54+
platforms = CXX,
55+
deps = [
56+
":quant_utils",
57+
"//executorch/runtime/kernel:kernel_includes",
58+
],
59+
visibility = ["PUBLIC"],
60+
)

backends/cadence/fused_quant/tests/BUCK

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,14 @@ runtime.cxx_test(
3535
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
3636
],
3737
)
38+
39+
runtime.cxx_test(
40+
name = "test_op_hardswish",
41+
srcs = ["test_op_hardswish.cpp"],
42+
platforms = CXX,
43+
deps = [
44+
"//executorch/backends/cadence/fused_quant:op_hardswish",
45+
"//executorch/kernels/test:gtest_utils",
46+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
47+
],
48+
)

0 commit comments

Comments
 (0)