Skip to content

Commit 0aa2989

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Fused quant convolution channels last kernel (#19492)
Summary: Fused quant 2D convolution kernel (NHWC layout, OHWI weights) with optional dequantize/quantize. Supports 4 sets of qparams, stride/padding/dilation/groups, and per-tensor/per-channel quantization. Reviewed By: mvartani-meta Differential Revision: D103754965
1 parent 3504370 commit 0aa2989

5 files changed

Lines changed: 908 additions & 0 deletions

File tree

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/fused_quant/op_convolution_channels_last.h>
10+
#include <executorch/backends/cadence/fused_quant/quant_utils.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace cadence {
14+
namespace fused_quant {
15+
namespace native {
16+
17+
using executorch::aten::IntArrayRef;
18+
using executorch::aten::optional;
19+
using executorch::aten::ScalarType;
20+
using executorch::aten::Tensor;
21+
using executorch::runtime::KernelRuntimeContext;
22+
23+
namespace {
24+
25+
// Convolution kernel for NHWC input and OHWI weight layouts.
26+
// inp: [N, H_in, W_in, C_in] (NHWC)
27+
// weight: [C_out, kH, kW, C_in/groups] (OHWI)
28+
// bias: [C_out]
29+
// out: [N, H_out, W_out, C_out] (NHWC)
30+
void conv2d_nhwc_kernel(
31+
const float* inp,
32+
const float* weight,
33+
const float* bias,
34+
float* out,
35+
int64_t N,
36+
int64_t C_in,
37+
int64_t H_in,
38+
int64_t W_in,
39+
int64_t C_out,
40+
int64_t kH,
41+
int64_t kW,
42+
int64_t stride_h,
43+
int64_t stride_w,
44+
int64_t pad_h,
45+
int64_t pad_w,
46+
int64_t dil_h,
47+
int64_t dil_w,
48+
int64_t groups,
49+
int64_t H_out,
50+
int64_t W_out) {
51+
int64_t C_in_per_group = C_in / groups;
52+
int64_t C_out_per_group = C_out / groups;
53+
54+
for (int64_t n = 0; n < N; ++n) {
55+
for (int64_t oh = 0; oh < H_out; ++oh) {
56+
for (int64_t ow = 0; ow < W_out; ++ow) {
57+
for (int64_t g = 0; g < groups; ++g) {
58+
for (int64_t oc = 0; oc < C_out_per_group; ++oc) {
59+
int64_t oc_global = g * C_out_per_group + oc;
60+
float sum = bias ? bias[oc_global] : 0.0f;
61+
for (int64_t kh = 0; kh < kH; ++kh) {
62+
for (int64_t kw = 0; kw < kW; ++kw) {
63+
int64_t ih = oh * stride_h - pad_h + kh * dil_h;
64+
int64_t iw = ow * stride_w - pad_w + kw * dil_w;
65+
if (ih >= 0 && ih < H_in && iw >= 0 && iw < W_in) {
66+
for (int64_t ic = 0; ic < C_in_per_group; ++ic) {
67+
int64_t ic_global = g * C_in_per_group + ic;
68+
// NHWC: inp[n][ih][iw][ic_global]
69+
float inp_val =
70+
inp[((n * H_in + ih) * W_in + iw) * C_in + ic_global];
71+
// OHWI: weight[oc_global][kh][kw][ic]
72+
float w_val = weight
73+
[((oc_global * kH + kh) * kW + kw) * C_in_per_group +
74+
ic];
75+
sum += inp_val * w_val;
76+
}
77+
}
78+
}
79+
}
80+
// NHWC: out[n][oh][ow][oc_global]
81+
out[((n * H_out + oh) * W_out + ow) * C_out + oc_global] = sum;
82+
}
83+
}
84+
}
85+
}
86+
}
87+
}
88+
89+
} // namespace
90+
91+
Tensor& convolution_channels_last_out(
92+
KernelRuntimeContext& ctx,
93+
const Tensor& inp,
94+
const Tensor& weight,
95+
const optional<Tensor>& bias,
96+
// inp qparams
97+
const optional<Tensor>& inp_scale,
98+
const optional<Tensor>& inp_zero_point,
99+
ScalarType inp_dtype,
100+
int64_t inp_quant_min,
101+
int64_t inp_quant_max,
102+
optional<int64_t> inp_axis,
103+
// weight qparams
104+
const optional<Tensor>& weight_scale,
105+
const optional<Tensor>& weight_zero_point,
106+
ScalarType weight_dtype,
107+
int64_t weight_quant_min,
108+
int64_t weight_quant_max,
109+
optional<int64_t> weight_axis,
110+
// bias qparams
111+
const optional<Tensor>& bias_scale,
112+
const optional<Tensor>& bias_zero_point,
113+
ScalarType bias_dtype,
114+
int64_t bias_quant_min,
115+
int64_t bias_quant_max,
116+
optional<int64_t> bias_axis,
117+
// out qparams
118+
const optional<Tensor>& out_scale,
119+
const optional<Tensor>& out_zero_point,
120+
ScalarType out_dtype,
121+
int64_t out_quant_min,
122+
int64_t out_quant_max,
123+
optional<int64_t> out_axis,
124+
// conv params
125+
IntArrayRef stride,
126+
IntArrayRef padding,
127+
IntArrayRef dilation,
128+
int64_t groups,
129+
Tensor& out) {
130+
// NHWC layout: [N, H_in, W_in, C_in]
131+
int64_t N = inp.size(0);
132+
int64_t H_in = inp.size(1);
133+
int64_t W_in = inp.size(2);
134+
int64_t C_in = inp.size(3);
135+
136+
// OHWI layout: [C_out, kH, kW, C_in/groups]
137+
int64_t C_out = weight.size(0);
138+
int64_t kH = weight.size(1);
139+
int64_t kW = weight.size(2);
140+
141+
int64_t stride_h = stride[0];
142+
int64_t stride_w = stride[1];
143+
int64_t pad_h = padding[0];
144+
int64_t pad_w = padding[1];
145+
int64_t dil_h = dilation[0];
146+
int64_t dil_w = dilation[1];
147+
148+
int64_t H_out = (H_in + 2 * pad_h - dil_h * (kH - 1) - 1) / stride_h + 1;
149+
int64_t W_out = (W_in + 2 * pad_w - dil_w * (kW - 1) - 1) / stride_w + 1;
150+
151+
int64_t inp_numel = inp.numel();
152+
int64_t weight_numel = weight.numel();
153+
int64_t out_numel = N * H_out * W_out * C_out;
154+
155+
bool inp_quantized = inp_scale.has_value();
156+
bool weight_quantized = weight_scale.has_value();
157+
bool bias_quantized = bias_scale.has_value();
158+
bool out_quantized = out_scale.has_value();
159+
160+
// Dequantize input if needed.
161+
std::vector<float> inp_buf;
162+
const float* const inp_float = [&]() -> const float* {
163+
if (!inp_quantized) {
164+
return inp.const_data_ptr<float>();
165+
}
166+
inp_buf.resize(inp_numel);
167+
QParams qp = extract_qparams(
168+
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp);
169+
FUSED_QUANT_DTYPE_SWITCH(
170+
inp.scalar_type(),
171+
scalar_t,
172+
dequantize_buffer(
173+
inp.const_data_ptr<scalar_t>(), inp_buf.data(), inp_numel, qp);)
174+
return inp_buf.data();
175+
}();
176+
177+
// Dequantize weight if needed.
178+
std::vector<float> weight_buf;
179+
const float* const weight_float = [&]() -> const float* {
180+
if (!weight_quantized) {
181+
return weight.const_data_ptr<float>();
182+
}
183+
weight_buf.resize(weight_numel);
184+
QParams qp = extract_qparams(
185+
weight_scale,
186+
weight_zero_point,
187+
weight_quant_min,
188+
weight_quant_max,
189+
weight_axis,
190+
weight);
191+
FUSED_QUANT_DTYPE_SWITCH(weight.scalar_type(),
192+
scalar_t,
193+
dequantize_buffer(
194+
weight.const_data_ptr<scalar_t>(),
195+
weight_buf.data(),
196+
weight_numel,
197+
qp);)
198+
return weight_buf.data();
199+
}();
200+
201+
// Dequantize bias if needed.
202+
bool has_bias = bias.has_value();
203+
std::vector<float> bias_buf;
204+
const float* bias_float = nullptr;
205+
if (has_bias) {
206+
const Tensor& bias_val = bias.value();
207+
if (bias_quantized) {
208+
int64_t bias_numel = bias_val.numel();
209+
bias_buf.resize(bias_numel);
210+
QParams qp = extract_qparams(
211+
bias_scale,
212+
bias_zero_point,
213+
bias_quant_min,
214+
bias_quant_max,
215+
bias_axis,
216+
bias_val);
217+
FUSED_QUANT_DTYPE_SWITCH(bias_val.scalar_type(),
218+
scalar_t,
219+
dequantize_buffer(
220+
bias_val.const_data_ptr<scalar_t>(),
221+
bias_buf.data(),
222+
bias_numel,
223+
qp);)
224+
bias_float = bias_buf.data();
225+
} else {
226+
bias_float = bias_val.const_data_ptr<float>();
227+
}
228+
}
229+
230+
// Run convolution in float.
231+
if (out_quantized) {
232+
std::vector<float> result_float(out_numel);
233+
conv2d_nhwc_kernel(
234+
inp_float,
235+
weight_float,
236+
bias_float,
237+
result_float.data(),
238+
N,
239+
C_in,
240+
H_in,
241+
W_in,
242+
C_out,
243+
kH,
244+
kW,
245+
stride_h,
246+
stride_w,
247+
pad_h,
248+
pad_w,
249+
dil_h,
250+
dil_w,
251+
groups,
252+
H_out,
253+
W_out);
254+
255+
QParams qp = extract_qparams(
256+
out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out);
257+
FUSED_QUANT_DTYPE_SWITCH(out.scalar_type(),
258+
scalar_t,
259+
quantize_buffer(
260+
result_float.data(),
261+
out.mutable_data_ptr<scalar_t>(),
262+
out_numel,
263+
qp);)
264+
} else {
265+
conv2d_nhwc_kernel(
266+
inp_float,
267+
weight_float,
268+
bias_float,
269+
out.mutable_data_ptr<float>(),
270+
N,
271+
C_in,
272+
H_in,
273+
W_in,
274+
C_out,
275+
kH,
276+
kW,
277+
stride_h,
278+
stride_w,
279+
pad_h,
280+
pad_w,
281+
dil_h,
282+
dil_w,
283+
groups,
284+
H_out,
285+
W_out);
286+
}
287+
288+
return out;
289+
}
290+
291+
} // namespace native
292+
} // namespace fused_quant
293+
} // namespace cadence
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
namespace cadence {
15+
namespace fused_quant {
16+
namespace native {
17+
18+
executorch::aten::Tensor& convolution_channels_last_out(
19+
executorch::runtime::KernelRuntimeContext& ctx,
20+
const executorch::aten::Tensor& inp,
21+
const executorch::aten::Tensor& weight,
22+
const executorch::aten::optional<executorch::aten::Tensor>& bias,
23+
// inp qparams (6)
24+
const executorch::aten::optional<executorch::aten::Tensor>& inp_scale,
25+
const executorch::aten::optional<executorch::aten::Tensor>& inp_zero_point,
26+
executorch::aten::ScalarType inp_dtype,
27+
int64_t inp_quant_min,
28+
int64_t inp_quant_max,
29+
executorch::aten::optional<int64_t> inp_axis,
30+
// weight qparams (6)
31+
const executorch::aten::optional<executorch::aten::Tensor>& weight_scale,
32+
const executorch::aten::optional<executorch::aten::Tensor>&
33+
weight_zero_point,
34+
executorch::aten::ScalarType weight_dtype,
35+
int64_t weight_quant_min,
36+
int64_t weight_quant_max,
37+
executorch::aten::optional<int64_t> weight_axis,
38+
// bias qparams (6)
39+
const executorch::aten::optional<executorch::aten::Tensor>& bias_scale,
40+
const executorch::aten::optional<executorch::aten::Tensor>& bias_zero_point,
41+
executorch::aten::ScalarType bias_dtype,
42+
int64_t bias_quant_min,
43+
int64_t bias_quant_max,
44+
executorch::aten::optional<int64_t> bias_axis,
45+
// out qparams (6)
46+
const executorch::aten::optional<executorch::aten::Tensor>& out_scale,
47+
const executorch::aten::optional<executorch::aten::Tensor>& out_zero_point,
48+
executorch::aten::ScalarType out_dtype,
49+
int64_t out_quant_min,
50+
int64_t out_quant_max,
51+
executorch::aten::optional<int64_t> out_axis,
52+
// conv params
53+
executorch::aten::IntArrayRef stride,
54+
executorch::aten::IntArrayRef padding,
55+
executorch::aten::IntArrayRef dilation,
56+
int64_t groups,
57+
executorch::aten::Tensor& out);
58+
59+
} // namespace native
60+
} // namespace fused_quant
61+
} // namespace cadence

backends/cadence/fused_quant/targets.bzl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,15 @@ def define_common_targets():
9494
],
9595
visibility = ["PUBLIC"],
9696
)
97+
98+
runtime.cxx_library(
99+
name = "op_convolution_channels_last",
100+
srcs = ["op_convolution_channels_last.cpp"],
101+
exported_headers = ["op_convolution_channels_last.h"],
102+
platforms = CXX,
103+
deps = [
104+
":quant_utils",
105+
"//executorch/runtime/kernel:kernel_includes",
106+
],
107+
visibility = ["PUBLIC"],
108+
)

backends/cadence/fused_quant/tests/BUCK

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,14 @@ runtime.cxx_test(
7979
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
8080
],
8181
)
82+
83+
runtime.cxx_test(
84+
name = "test_op_convolution_channels_last",
85+
srcs = ["test_op_convolution_channels_last.cpp"],
86+
platforms = CXX,
87+
deps = [
88+
"//executorch/backends/cadence/fused_quant:op_convolution_channels_last",
89+
"//executorch/kernels/test:gtest_utils",
90+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
91+
],
92+
)

0 commit comments

Comments
 (0)