Skip to content

Commit 70ab8e0

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Fused quant convolution kernel (#19491)
Summary: Fused quant 2D convolution kernel (NCHW layout) with optional dequantize/quantize. Supports 4 sets of qparams, stride/padding/dilation/groups, and per-tensor/per-channel quantization. Reviewed By: mvartani-meta Differential Revision: D103754924
1 parent 2ea7fec commit 70ab8e0

5 files changed

Lines changed: 846 additions & 0 deletions

File tree

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/fused_quant/op_convolution.h>
10+
#include <executorch/backends/cadence/fused_quant/quant_utils.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace cadence {
14+
namespace fused_quant {
15+
namespace native {
16+
17+
using executorch::aten::IntArrayRef;
18+
using executorch::aten::optional;
19+
using executorch::aten::ScalarType;
20+
using executorch::aten::Tensor;
21+
using executorch::runtime::KernelRuntimeContext;
22+
23+
namespace {
24+
25+
void conv2d_kernel(
26+
const float* inp,
27+
const float* weight,
28+
const float* bias,
29+
float* out,
30+
int64_t N,
31+
int64_t C_in,
32+
int64_t H_in,
33+
int64_t W_in,
34+
int64_t C_out,
35+
int64_t kH,
36+
int64_t kW,
37+
int64_t stride_h,
38+
int64_t stride_w,
39+
int64_t pad_h,
40+
int64_t pad_w,
41+
int64_t dil_h,
42+
int64_t dil_w,
43+
int64_t groups,
44+
int64_t H_out,
45+
int64_t W_out) {
46+
int64_t C_in_per_group = C_in / groups;
47+
int64_t C_out_per_group = C_out / groups;
48+
49+
for (int64_t n = 0; n < N; ++n) {
50+
for (int64_t g = 0; g < groups; ++g) {
51+
for (int64_t oc = 0; oc < C_out_per_group; ++oc) {
52+
int64_t oc_global = g * C_out_per_group + oc;
53+
for (int64_t oh = 0; oh < H_out; ++oh) {
54+
for (int64_t ow = 0; ow < W_out; ++ow) {
55+
float sum = bias ? bias[oc_global] : 0.0f;
56+
for (int64_t ic = 0; ic < C_in_per_group; ++ic) {
57+
int64_t ic_global = g * C_in_per_group + ic;
58+
for (int64_t kh = 0; kh < kH; ++kh) {
59+
for (int64_t kw = 0; kw < kW; ++kw) {
60+
int64_t ih = oh * stride_h - pad_h + kh * dil_h;
61+
int64_t iw = ow * stride_w - pad_w + kw * dil_w;
62+
if (ih >= 0 && ih < H_in && iw >= 0 && iw < W_in) {
63+
float inp_val =
64+
inp[((n * C_in + ic_global) * H_in + ih) * W_in + iw];
65+
float w_val = weight
66+
[((oc_global * C_in_per_group + ic) * kH + kh) * kW +
67+
kw];
68+
sum += inp_val * w_val;
69+
}
70+
}
71+
}
72+
}
73+
out[((n * C_out + oc_global) * H_out + oh) * W_out + ow] = sum;
74+
}
75+
}
76+
}
77+
}
78+
}
79+
}
80+
81+
} // namespace
82+
83+
Tensor& convolution_out(
84+
KernelRuntimeContext& ctx,
85+
const Tensor& inp,
86+
const Tensor& weight,
87+
const optional<Tensor>& bias,
88+
// inp qparams
89+
const optional<Tensor>& inp_scale,
90+
const optional<Tensor>& inp_zero_point,
91+
ScalarType inp_dtype,
92+
int64_t inp_quant_min,
93+
int64_t inp_quant_max,
94+
optional<int64_t> inp_axis,
95+
// weight qparams
96+
const optional<Tensor>& weight_scale,
97+
const optional<Tensor>& weight_zero_point,
98+
ScalarType weight_dtype,
99+
int64_t weight_quant_min,
100+
int64_t weight_quant_max,
101+
optional<int64_t> weight_axis,
102+
// bias qparams
103+
const optional<Tensor>& bias_scale,
104+
const optional<Tensor>& bias_zero_point,
105+
ScalarType bias_dtype,
106+
int64_t bias_quant_min,
107+
int64_t bias_quant_max,
108+
optional<int64_t> bias_axis,
109+
// out qparams
110+
const optional<Tensor>& out_scale,
111+
const optional<Tensor>& out_zero_point,
112+
ScalarType out_dtype,
113+
int64_t out_quant_min,
114+
int64_t out_quant_max,
115+
optional<int64_t> out_axis,
116+
// conv params
117+
IntArrayRef stride,
118+
IntArrayRef padding,
119+
IntArrayRef dilation,
120+
int64_t groups,
121+
Tensor& out) {
122+
// Extract dimensions from input tensor [N, C_in, H_in, W_in]
123+
int64_t N = inp.size(0);
124+
int64_t C_in = inp.size(1);
125+
int64_t H_in = inp.size(2);
126+
int64_t W_in = inp.size(3);
127+
128+
// Extract dimensions from weight tensor [C_out, C_in/groups, kH, kW]
129+
int64_t C_out = weight.size(0);
130+
int64_t kH = weight.size(2);
131+
int64_t kW = weight.size(3);
132+
133+
int64_t stride_h = stride[0];
134+
int64_t stride_w = stride[1];
135+
int64_t pad_h = padding[0];
136+
int64_t pad_w = padding[1];
137+
int64_t dil_h = dilation[0];
138+
int64_t dil_w = dilation[1];
139+
140+
int64_t H_out = (H_in + 2 * pad_h - dil_h * (kH - 1) - 1) / stride_h + 1;
141+
int64_t W_out = (W_in + 2 * pad_w - dil_w * (kW - 1) - 1) / stride_w + 1;
142+
143+
int64_t inp_numel = inp.numel();
144+
int64_t weight_numel = weight.numel();
145+
int64_t out_numel = N * C_out * H_out * W_out;
146+
147+
bool inp_quantized = inp_scale.has_value();
148+
bool weight_quantized = weight_scale.has_value();
149+
bool bias_quantized = bias_scale.has_value();
150+
bool out_quantized = out_scale.has_value();
151+
152+
// Dequantize input if quantized
153+
std::vector<float> inp_buf;
154+
const float* const inp_float = [&]() -> const float* {
155+
if (!inp_quantized) {
156+
return inp.const_data_ptr<float>();
157+
}
158+
inp_buf.resize(inp_numel);
159+
QParams qp = extract_qparams(
160+
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp);
161+
FUSED_QUANT_DTYPE_SWITCH(
162+
inp.scalar_type(),
163+
scalar_t,
164+
dequantize_buffer(
165+
inp.const_data_ptr<scalar_t>(), inp_buf.data(), inp_numel, qp);)
166+
return inp_buf.data();
167+
}();
168+
169+
// Dequantize weight if quantized
170+
std::vector<float> weight_buf;
171+
const float* const weight_float = [&]() -> const float* {
172+
if (!weight_quantized) {
173+
return weight.const_data_ptr<float>();
174+
}
175+
weight_buf.resize(weight_numel);
176+
QParams qp = extract_qparams(
177+
weight_scale,
178+
weight_zero_point,
179+
weight_quant_min,
180+
weight_quant_max,
181+
weight_axis,
182+
weight);
183+
FUSED_QUANT_DTYPE_SWITCH(weight.scalar_type(),
184+
scalar_t,
185+
dequantize_buffer(
186+
weight.const_data_ptr<scalar_t>(),
187+
weight_buf.data(),
188+
weight_numel,
189+
qp);)
190+
return weight_buf.data();
191+
}();
192+
193+
// Dequantize bias if present and quantized
194+
std::vector<float> bias_buf;
195+
const float* bias_float = nullptr;
196+
if (bias.has_value()) {
197+
const Tensor& bias_tensor = bias.value();
198+
if (bias_quantized) {
199+
int64_t bias_numel = bias_tensor.numel();
200+
bias_buf.resize(bias_numel);
201+
QParams qp = extract_qparams(
202+
bias_scale,
203+
bias_zero_point,
204+
bias_quant_min,
205+
bias_quant_max,
206+
bias_axis,
207+
bias_tensor);
208+
FUSED_QUANT_DTYPE_SWITCH(bias_tensor.scalar_type(),
209+
scalar_t,
210+
dequantize_buffer(
211+
bias_tensor.const_data_ptr<scalar_t>(),
212+
bias_buf.data(),
213+
bias_numel,
214+
qp);)
215+
bias_float = bias_buf.data();
216+
} else {
217+
bias_float = bias_tensor.const_data_ptr<float>();
218+
}
219+
}
220+
221+
// Run convolution
222+
if (out_quantized) {
223+
std::vector<float> result_float(out_numel);
224+
conv2d_kernel(
225+
inp_float,
226+
weight_float,
227+
bias_float,
228+
result_float.data(),
229+
N,
230+
C_in,
231+
H_in,
232+
W_in,
233+
C_out,
234+
kH,
235+
kW,
236+
stride_h,
237+
stride_w,
238+
pad_h,
239+
pad_w,
240+
dil_h,
241+
dil_w,
242+
groups,
243+
H_out,
244+
W_out);
245+
246+
QParams qp = extract_qparams(
247+
out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out);
248+
FUSED_QUANT_DTYPE_SWITCH(out.scalar_type(),
249+
scalar_t,
250+
quantize_buffer(
251+
result_float.data(),
252+
out.mutable_data_ptr<scalar_t>(),
253+
out_numel,
254+
qp);)
255+
} else {
256+
conv2d_kernel(
257+
inp_float,
258+
weight_float,
259+
bias_float,
260+
out.mutable_data_ptr<float>(),
261+
N,
262+
C_in,
263+
H_in,
264+
W_in,
265+
C_out,
266+
kH,
267+
kW,
268+
stride_h,
269+
stride_w,
270+
pad_h,
271+
pad_w,
272+
dil_h,
273+
dil_w,
274+
groups,
275+
H_out,
276+
W_out);
277+
}
278+
279+
return out;
280+
}
281+
282+
} // namespace native
283+
} // namespace fused_quant
284+
} // namespace cadence
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
namespace cadence {
15+
namespace fused_quant {
16+
namespace native {
17+
18+
executorch::aten::Tensor& convolution_out(
19+
executorch::runtime::KernelRuntimeContext& ctx,
20+
const executorch::aten::Tensor& inp,
21+
const executorch::aten::Tensor& weight,
22+
const executorch::aten::optional<executorch::aten::Tensor>& bias,
23+
// inp qparams (6)
24+
const executorch::aten::optional<executorch::aten::Tensor>& inp_scale,
25+
const executorch::aten::optional<executorch::aten::Tensor>& inp_zero_point,
26+
executorch::aten::ScalarType inp_dtype,
27+
int64_t inp_quant_min,
28+
int64_t inp_quant_max,
29+
executorch::aten::optional<int64_t> inp_axis,
30+
// weight qparams (6)
31+
const executorch::aten::optional<executorch::aten::Tensor>& weight_scale,
32+
const executorch::aten::optional<executorch::aten::Tensor>&
33+
weight_zero_point,
34+
executorch::aten::ScalarType weight_dtype,
35+
int64_t weight_quant_min,
36+
int64_t weight_quant_max,
37+
executorch::aten::optional<int64_t> weight_axis,
38+
// bias qparams (6)
39+
const executorch::aten::optional<executorch::aten::Tensor>& bias_scale,
40+
const executorch::aten::optional<executorch::aten::Tensor>& bias_zero_point,
41+
executorch::aten::ScalarType bias_dtype,
42+
int64_t bias_quant_min,
43+
int64_t bias_quant_max,
44+
executorch::aten::optional<int64_t> bias_axis,
45+
// out qparams (6)
46+
const executorch::aten::optional<executorch::aten::Tensor>& out_scale,
47+
const executorch::aten::optional<executorch::aten::Tensor>& out_zero_point,
48+
executorch::aten::ScalarType out_dtype,
49+
int64_t out_quant_min,
50+
int64_t out_quant_max,
51+
executorch::aten::optional<int64_t> out_axis,
52+
// conv params
53+
executorch::aten::IntArrayRef stride,
54+
executorch::aten::IntArrayRef padding,
55+
executorch::aten::IntArrayRef dilation,
56+
int64_t groups,
57+
executorch::aten::Tensor& out);
58+
59+
} // namespace native
60+
} // namespace fused_quant
61+
} // namespace cadence

backends/cadence/fused_quant/targets.bzl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,15 @@ def define_common_targets():
8282
],
8383
visibility = ["PUBLIC"],
8484
)
85+
86+
runtime.cxx_library(
87+
name = "op_convolution",
88+
srcs = ["op_convolution.cpp"],
89+
exported_headers = ["op_convolution.h"],
90+
platforms = CXX,
91+
deps = [
92+
":quant_utils",
93+
"//executorch/runtime/kernel:kernel_includes",
94+
],
95+
visibility = ["PUBLIC"],
96+
)

backends/cadence/fused_quant/tests/BUCK

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,14 @@ runtime.cxx_test(
6868
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
6969
],
7070
)
71+
72+
runtime.cxx_test(
73+
name = "test_op_convolution",
74+
srcs = ["test_op_convolution.cpp"],
75+
platforms = CXX,
76+
deps = [
77+
"//executorch/backends/cadence/fused_quant:op_convolution",
78+
"//executorch/kernels/test:gtest_utils",
79+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
80+
],
81+
)

0 commit comments

Comments
 (0)