Skip to content

Commit 4a9e1cf

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Fused quant relu kernel (#19486)
Summary: Fused quant ReLU kernel with optional dequantize/quantize. Unary op that applies max(0, x) in float space. Supports per-tensor and per-channel quantization. Reviewed By: mvartani-meta Differential Revision: D103754745
1 parent 9b5e742 commit 4a9e1cf

5 files changed

Lines changed: 505 additions & 0 deletions

File tree

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <algorithm>
10+
11+
#include <executorch/backends/cadence/fused_quant/op_relu.h>
12+
#include <executorch/backends/cadence/fused_quant/quant_utils.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace cadence {
16+
namespace fused_quant {
17+
namespace native {
18+
19+
using executorch::aten::optional;
20+
using executorch::aten::ScalarType;
21+
using executorch::aten::Tensor;
22+
using executorch::runtime::KernelRuntimeContext;
23+
24+
namespace {
25+
26+
void relu_kernel(const float* inp, float* out, int64_t numel) {
27+
for (int64_t i = 0; i < numel; ++i) {
28+
out[i] = std::max(0.0f, inp[i]);
29+
}
30+
}
31+
32+
} // namespace
33+
34+
Tensor& relu_out(
35+
KernelRuntimeContext& ctx,
36+
const Tensor& inp,
37+
const optional<Tensor>& inp_scale,
38+
const optional<Tensor>& inp_zero_point,
39+
ScalarType inp_dtype,
40+
int64_t inp_quant_min,
41+
int64_t inp_quant_max,
42+
optional<int64_t> inp_axis,
43+
const optional<Tensor>& out_scale,
44+
const optional<Tensor>& out_zero_point,
45+
ScalarType out_dtype,
46+
int64_t out_quant_min,
47+
int64_t out_quant_max,
48+
optional<int64_t> out_axis,
49+
Tensor& out) {
50+
int64_t numel = inp.numel();
51+
52+
bool inp_quantized = inp_scale.has_value();
53+
bool out_quantized = out_scale.has_value();
54+
55+
std::vector<float> inp_buf;
56+
const float* const inp_float = [&]() -> const float* {
57+
if (!inp_quantized) {
58+
return inp.const_data_ptr<float>();
59+
}
60+
inp_buf.resize(numel);
61+
QParams qp = extract_qparams(
62+
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp);
63+
FUSED_QUANT_DTYPE_SWITCH(
64+
inp.scalar_type(),
65+
scalar_t,
66+
dequantize_buffer(
67+
inp.const_data_ptr<scalar_t>(), inp_buf.data(), numel, qp);)
68+
return inp_buf.data();
69+
}();
70+
71+
if (out_quantized) {
72+
std::vector<float> result_float(numel);
73+
relu_kernel(inp_float, result_float.data(), numel);
74+
75+
QParams qp = extract_qparams(
76+
out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out);
77+
FUSED_QUANT_DTYPE_SWITCH(
78+
out.scalar_type(),
79+
scalar_t,
80+
quantize_buffer(
81+
result_float.data(), out.mutable_data_ptr<scalar_t>(), numel, qp);)
82+
} else {
83+
relu_kernel(inp_float, out.mutable_data_ptr<float>(), numel);
84+
}
85+
86+
return out;
87+
}
88+
89+
} // namespace native
90+
} // namespace fused_quant
91+
} // namespace cadence
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
namespace cadence {
15+
namespace fused_quant {
16+
namespace native {
17+
18+
executorch::aten::Tensor& relu_out(
19+
executorch::runtime::KernelRuntimeContext& ctx,
20+
const executorch::aten::Tensor& inp,
21+
const executorch::aten::optional<executorch::aten::Tensor>& inp_scale,
22+
const executorch::aten::optional<executorch::aten::Tensor>& inp_zero_point,
23+
executorch::aten::ScalarType inp_dtype,
24+
int64_t inp_quant_min,
25+
int64_t inp_quant_max,
26+
executorch::aten::optional<int64_t> inp_axis,
27+
const executorch::aten::optional<executorch::aten::Tensor>& out_scale,
28+
const executorch::aten::optional<executorch::aten::Tensor>& out_zero_point,
29+
executorch::aten::ScalarType out_dtype,
30+
int64_t out_quant_min,
31+
int64_t out_quant_max,
32+
executorch::aten::optional<int64_t> out_axis,
33+
executorch::aten::Tensor& out);
34+
35+
} // namespace native
36+
} // namespace fused_quant
37+
} // namespace cadence

backends/cadence/fused_quant/targets.bzl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,15 @@ def define_common_targets():
3434
],
3535
visibility = ["PUBLIC"],
3636
)
37+
38+
runtime.cxx_library(
39+
name = "op_relu",
40+
srcs = ["op_relu.cpp"],
41+
exported_headers = ["op_relu.h"],
42+
platforms = CXX,
43+
deps = [
44+
":quant_utils",
45+
"//executorch/runtime/kernel:kernel_includes",
46+
],
47+
visibility = ["PUBLIC"],
48+
)

backends/cadence/fused_quant/tests/BUCK

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,14 @@ runtime.cxx_test(
2424
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
2525
],
2626
)
27+
28+
runtime.cxx_test(
29+
name = "test_op_relu",
30+
srcs = ["test_op_relu.cpp"],
31+
platforms = CXX,
32+
deps = [
33+
"//executorch/backends/cadence/fused_quant:op_relu",
34+
"//executorch/kernels/test:gtest_utils",
35+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
36+
],
37+
)

0 commit comments

Comments
 (0)