Skip to content

Commit c57549a

Browse files
Andrew Grebenisanmeta-codesync[bot]
authored andcommitted
Fused quant linear kernel
Summary: Fused quant linear kernel (out = inp @ weight^T + bias) with optional dequantize/quantize. Supports 4 sets of qparams (inp, weight, bias, out), optional bias, and per-tensor/per-channel quantization. Reviewed By: mvartani-meta Differential Revision: D103754853
1 parent e4381ce commit c57549a

5 files changed

Lines changed: 733 additions & 0 deletions

File tree

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/fused_quant/op_linear.h>
10+
#include <executorch/backends/cadence/fused_quant/quant_utils.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace cadence {
14+
namespace fused_quant {
15+
namespace native {
16+
17+
using executorch::aten::optional;
18+
using executorch::aten::ScalarType;
19+
using executorch::aten::Tensor;
20+
using executorch::runtime::KernelRuntimeContext;
21+
22+
namespace {
23+
24+
void linear_kernel(
25+
const float* inp,
26+
const float* weight,
27+
const float* bias,
28+
float* out,
29+
int64_t num_rows,
30+
int64_t in_features,
31+
int64_t out_features) {
32+
for (int64_t r = 0; r < num_rows; ++r) {
33+
for (int64_t o = 0; o < out_features; ++o) {
34+
float sum = bias ? bias[o] : 0.0f;
35+
for (int64_t i = 0; i < in_features; ++i) {
36+
sum += inp[r * in_features + i] * weight[o * in_features + i];
37+
}
38+
out[r * out_features + o] = sum;
39+
}
40+
}
41+
}
42+
43+
} // namespace
44+
45+
Tensor& linear_out(
46+
KernelRuntimeContext& ctx,
47+
const Tensor& inp,
48+
const Tensor& weight,
49+
const optional<Tensor>& bias,
50+
// inp qparams
51+
const optional<Tensor>& inp_scale,
52+
const optional<Tensor>& inp_zero_point,
53+
ScalarType inp_dtype,
54+
int64_t inp_quant_min,
55+
int64_t inp_quant_max,
56+
optional<int64_t> inp_axis,
57+
// weight qparams
58+
const optional<Tensor>& weight_scale,
59+
const optional<Tensor>& weight_zero_point,
60+
ScalarType weight_dtype,
61+
int64_t weight_quant_min,
62+
int64_t weight_quant_max,
63+
optional<int64_t> weight_axis,
64+
// bias qparams
65+
const optional<Tensor>& bias_scale,
66+
const optional<Tensor>& bias_zero_point,
67+
ScalarType bias_dtype,
68+
int64_t bias_quant_min,
69+
int64_t bias_quant_max,
70+
optional<int64_t> bias_axis,
71+
// out qparams
72+
const optional<Tensor>& out_scale,
73+
const optional<Tensor>& out_zero_point,
74+
ScalarType out_dtype,
75+
int64_t out_quant_min,
76+
int64_t out_quant_max,
77+
optional<int64_t> out_axis,
78+
Tensor& out) {
79+
int64_t in_features = inp.size(inp.dim() - 1);
80+
int64_t out_features = weight.size(0);
81+
int64_t num_rows = inp.numel() / in_features;
82+
int64_t inp_numel = inp.numel();
83+
int64_t weight_numel = weight.numel();
84+
int64_t out_numel = num_rows * out_features;
85+
86+
bool inp_quantized = inp_scale.has_value();
87+
bool weight_quantized = weight_scale.has_value();
88+
bool bias_quantized = bias_scale.has_value();
89+
bool out_quantized = out_scale.has_value();
90+
91+
// Dequantize inp
92+
std::vector<float> inp_buf;
93+
const float* const inp_float = [&]() -> const float* {
94+
if (!inp_quantized) {
95+
return inp.const_data_ptr<float>();
96+
}
97+
inp_buf.resize(inp_numel);
98+
QParams qp = extract_qparams(
99+
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp);
100+
FUSED_QUANT_DTYPE_SWITCH(
101+
inp.scalar_type(),
102+
scalar_t,
103+
dequantize_buffer(
104+
inp.const_data_ptr<scalar_t>(), inp_buf.data(), inp_numel, qp);)
105+
return inp_buf.data();
106+
}();
107+
108+
// Dequantize weight
109+
std::vector<float> weight_buf;
110+
const float* const weight_float = [&]() -> const float* {
111+
if (!weight_quantized) {
112+
return weight.const_data_ptr<float>();
113+
}
114+
weight_buf.resize(weight_numel);
115+
QParams qp = extract_qparams(
116+
weight_scale,
117+
weight_zero_point,
118+
weight_quant_min,
119+
weight_quant_max,
120+
weight_axis,
121+
weight);
122+
FUSED_QUANT_DTYPE_SWITCH(weight.scalar_type(),
123+
scalar_t,
124+
dequantize_buffer(
125+
weight.const_data_ptr<scalar_t>(),
126+
weight_buf.data(),
127+
weight_numel,
128+
qp);)
129+
return weight_buf.data();
130+
}();
131+
132+
// Dequantize bias if present and quantized
133+
std::vector<float> bias_buf;
134+
const float* const bias_float = [&]() -> const float* {
135+
if (!bias.has_value()) {
136+
return nullptr;
137+
}
138+
const Tensor& b = bias.value();
139+
if (!bias_quantized) {
140+
return b.const_data_ptr<float>();
141+
}
142+
int64_t bias_numel = b.numel();
143+
bias_buf.resize(bias_numel);
144+
QParams qp = extract_qparams(
145+
bias_scale,
146+
bias_zero_point,
147+
bias_quant_min,
148+
bias_quant_max,
149+
bias_axis,
150+
b);
151+
FUSED_QUANT_DTYPE_SWITCH(
152+
b.scalar_type(),
153+
scalar_t,
154+
dequantize_buffer(
155+
b.const_data_ptr<scalar_t>(), bias_buf.data(), bias_numel, qp);)
156+
return bias_buf.data();
157+
}();
158+
159+
// Linear + optional quantize
160+
if (out_quantized) {
161+
std::vector<float> result_float(out_numel);
162+
linear_kernel(
163+
inp_float,
164+
weight_float,
165+
bias_float,
166+
result_float.data(),
167+
num_rows,
168+
in_features,
169+
out_features);
170+
QParams qp = extract_qparams(
171+
out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out);
172+
FUSED_QUANT_DTYPE_SWITCH(out.scalar_type(),
173+
scalar_t,
174+
quantize_buffer(
175+
result_float.data(),
176+
out.mutable_data_ptr<scalar_t>(),
177+
out_numel,
178+
qp);)
179+
} else {
180+
linear_kernel(
181+
inp_float,
182+
weight_float,
183+
bias_float,
184+
out.mutable_data_ptr<float>(),
185+
num_rows,
186+
in_features,
187+
out_features);
188+
}
189+
190+
return out;
191+
}
192+
193+
} // namespace native
194+
} // namespace fused_quant
195+
} // namespace cadence
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
namespace cadence {
15+
namespace fused_quant {
16+
namespace native {
17+
18+
executorch::aten::Tensor& linear_out(
19+
executorch::runtime::KernelRuntimeContext& ctx,
20+
const executorch::aten::Tensor& inp,
21+
const executorch::aten::Tensor& weight,
22+
const executorch::aten::optional<executorch::aten::Tensor>& bias,
23+
// inp qparams
24+
const executorch::aten::optional<executorch::aten::Tensor>& inp_scale,
25+
const executorch::aten::optional<executorch::aten::Tensor>& inp_zero_point,
26+
executorch::aten::ScalarType inp_dtype,
27+
int64_t inp_quant_min,
28+
int64_t inp_quant_max,
29+
executorch::aten::optional<int64_t> inp_axis,
30+
// weight qparams
31+
const executorch::aten::optional<executorch::aten::Tensor>& weight_scale,
32+
const executorch::aten::optional<executorch::aten::Tensor>&
33+
weight_zero_point,
34+
executorch::aten::ScalarType weight_dtype,
35+
int64_t weight_quant_min,
36+
int64_t weight_quant_max,
37+
executorch::aten::optional<int64_t> weight_axis,
38+
// bias qparams
39+
const executorch::aten::optional<executorch::aten::Tensor>& bias_scale,
40+
const executorch::aten::optional<executorch::aten::Tensor>& bias_zero_point,
41+
executorch::aten::ScalarType bias_dtype,
42+
int64_t bias_quant_min,
43+
int64_t bias_quant_max,
44+
executorch::aten::optional<int64_t> bias_axis,
45+
// out qparams
46+
const executorch::aten::optional<executorch::aten::Tensor>& out_scale,
47+
const executorch::aten::optional<executorch::aten::Tensor>& out_zero_point,
48+
executorch::aten::ScalarType out_dtype,
49+
int64_t out_quant_min,
50+
int64_t out_quant_max,
51+
executorch::aten::optional<int64_t> out_axis,
52+
executorch::aten::Tensor& out);
53+
54+
} // namespace native
55+
} // namespace fused_quant
56+
} // namespace cadence

backends/cadence/fused_quant/targets.bzl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,15 @@ def define_common_targets():
7070
],
7171
visibility = ["PUBLIC"],
7272
)
73+
74+
runtime.cxx_library(
75+
name = "op_linear",
76+
srcs = ["op_linear.cpp"],
77+
exported_headers = ["op_linear.h"],
78+
platforms = CXX,
79+
deps = [
80+
":quant_utils",
81+
"//executorch/runtime/kernel:kernel_includes",
82+
],
83+
visibility = ["PUBLIC"],
84+
)

backends/cadence/fused_quant/tests/BUCK

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,14 @@ runtime.cxx_test(
5757
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
5858
],
5959
)
60+
61+
runtime.cxx_test(
62+
name = "test_op_linear",
63+
srcs = ["test_op_linear.cpp"],
64+
platforms = CXX,
65+
deps = [
66+
"//executorch/backends/cadence/fused_quant:op_linear",
67+
"//executorch/kernels/test:gtest_utils",
68+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
69+
],
70+
)

0 commit comments

Comments
 (0)