Skip to content

Commit 762a9cb

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Fused quant bmm kernel
Summary: Fused quant batch matrix multiply kernel with optional dequantize/quantize. Binary op on 3D tensors [B,M,K] x [B,K,N] -> [B,M,N]. Supports per-tensor and per-channel quantization. Reviewed By: mvartani-meta Differential Revision: D103754815
1 parent 901cdc1 commit 762a9cb

5 files changed

Lines changed: 575 additions & 0 deletions

File tree

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/fused_quant/op_bmm.h>
10+
#include <executorch/backends/cadence/fused_quant/quant_utils.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace cadence {
14+
namespace fused_quant {
15+
namespace native {
16+
17+
using executorch::aten::optional;
18+
using executorch::aten::ScalarType;
19+
using executorch::aten::Tensor;
20+
using executorch::runtime::KernelRuntimeContext;
21+
22+
namespace {
23+
24+
void bmm_kernel(
25+
const float* inp,
26+
const float* other,
27+
float* out,
28+
int64_t batch,
29+
int64_t M,
30+
int64_t K,
31+
int64_t N) {
32+
for (int64_t b = 0; b < batch; ++b) {
33+
for (int64_t m = 0; m < M; ++m) {
34+
for (int64_t n = 0; n < N; ++n) {
35+
float sum = 0.0f;
36+
for (int64_t k = 0; k < K; ++k) {
37+
sum += inp[b * M * K + m * K + k] * other[b * K * N + k * N + n];
38+
}
39+
out[b * M * N + m * N + n] = sum;
40+
}
41+
}
42+
}
43+
}
44+
45+
} // namespace
46+
47+
Tensor& bmm_out(
48+
KernelRuntimeContext& ctx,
49+
const Tensor& inp,
50+
const Tensor& other,
51+
const optional<Tensor>& inp_scale,
52+
const optional<Tensor>& inp_zero_point,
53+
ScalarType inp_dtype,
54+
int64_t inp_quant_min,
55+
int64_t inp_quant_max,
56+
optional<int64_t> inp_axis,
57+
const optional<Tensor>& other_scale,
58+
const optional<Tensor>& other_zero_point,
59+
ScalarType other_dtype,
60+
int64_t other_quant_min,
61+
int64_t other_quant_max,
62+
optional<int64_t> other_axis,
63+
const optional<Tensor>& out_scale,
64+
const optional<Tensor>& out_zero_point,
65+
ScalarType out_dtype,
66+
int64_t out_quant_min,
67+
int64_t out_quant_max,
68+
optional<int64_t> out_axis,
69+
Tensor& out) {
70+
int64_t batch = inp.size(0);
71+
int64_t M = inp.size(1);
72+
int64_t K = inp.size(2);
73+
int64_t N = other.size(2);
74+
int64_t inp_numel = inp.numel();
75+
int64_t other_numel = other.numel();
76+
int64_t out_numel = batch * M * N;
77+
78+
bool inp_quantized = inp_scale.has_value();
79+
bool other_quantized = other_scale.has_value();
80+
bool out_quantized = out_scale.has_value();
81+
82+
// Dequantize inp
83+
std::vector<float> inp_buf;
84+
const float* const inp_float = [&]() -> const float* {
85+
if (!inp_quantized) {
86+
return inp.const_data_ptr<float>();
87+
}
88+
inp_buf.resize(inp_numel);
89+
QParams qp = extract_qparams(
90+
inp_scale, inp_zero_point, inp_quant_min, inp_quant_max, inp_axis, inp);
91+
FUSED_QUANT_DTYPE_SWITCH(
92+
inp.scalar_type(),
93+
scalar_t,
94+
dequantize_buffer(
95+
inp.const_data_ptr<scalar_t>(), inp_buf.data(), inp_numel, qp);)
96+
return inp_buf.data();
97+
}();
98+
99+
// Dequantize other
100+
std::vector<float> other_buf;
101+
const float* const other_float = [&]() -> const float* {
102+
if (!other_quantized) {
103+
return other.const_data_ptr<float>();
104+
}
105+
other_buf.resize(other_numel);
106+
QParams qp = extract_qparams(
107+
other_scale,
108+
other_zero_point,
109+
other_quant_min,
110+
other_quant_max,
111+
other_axis,
112+
other);
113+
FUSED_QUANT_DTYPE_SWITCH(other.scalar_type(),
114+
scalar_t,
115+
dequantize_buffer(
116+
other.const_data_ptr<scalar_t>(),
117+
other_buf.data(),
118+
other_numel,
119+
qp);)
120+
return other_buf.data();
121+
}();
122+
123+
// BMM in float, then optionally quantize output
124+
if (out_quantized) {
125+
std::vector<float> result_float(out_numel);
126+
bmm_kernel(inp_float, other_float, result_float.data(), batch, M, K, N);
127+
128+
QParams qp = extract_qparams(
129+
out_scale, out_zero_point, out_quant_min, out_quant_max, out_axis, out);
130+
FUSED_QUANT_DTYPE_SWITCH(out.scalar_type(),
131+
scalar_t,
132+
quantize_buffer(
133+
result_float.data(),
134+
out.mutable_data_ptr<scalar_t>(),
135+
out_numel,
136+
qp);)
137+
} else {
138+
bmm_kernel(
139+
inp_float, other_float, out.mutable_data_ptr<float>(), batch, M, K, N);
140+
}
141+
142+
return out;
143+
}
144+
145+
} // namespace native
146+
} // namespace fused_quant
147+
} // namespace cadence
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
namespace cadence {
15+
namespace fused_quant {
16+
namespace native {
17+
18+
executorch::aten::Tensor& bmm_out(
19+
executorch::runtime::KernelRuntimeContext& ctx,
20+
const executorch::aten::Tensor& inp,
21+
const executorch::aten::Tensor& other,
22+
const executorch::aten::optional<executorch::aten::Tensor>& inp_scale,
23+
const executorch::aten::optional<executorch::aten::Tensor>& inp_zero_point,
24+
executorch::aten::ScalarType inp_dtype,
25+
int64_t inp_quant_min,
26+
int64_t inp_quant_max,
27+
executorch::aten::optional<int64_t> inp_axis,
28+
const executorch::aten::optional<executorch::aten::Tensor>& other_scale,
29+
const executorch::aten::optional<executorch::aten::Tensor>&
30+
other_zero_point,
31+
executorch::aten::ScalarType other_dtype,
32+
int64_t other_quant_min,
33+
int64_t other_quant_max,
34+
executorch::aten::optional<int64_t> other_axis,
35+
const executorch::aten::optional<executorch::aten::Tensor>& out_scale,
36+
const executorch::aten::optional<executorch::aten::Tensor>& out_zero_point,
37+
executorch::aten::ScalarType out_dtype,
38+
int64_t out_quant_min,
39+
int64_t out_quant_max,
40+
executorch::aten::optional<int64_t> out_axis,
41+
executorch::aten::Tensor& out);
42+
43+
} // namespace native
44+
} // namespace fused_quant
45+
} // namespace cadence

backends/cadence/fused_quant/targets.bzl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,15 @@ def define_common_targets():
5858
],
5959
visibility = ["PUBLIC"],
6060
)
61+
62+
runtime.cxx_library(
63+
name = "op_bmm",
64+
srcs = ["op_bmm.cpp"],
65+
exported_headers = ["op_bmm.h"],
66+
platforms = CXX,
67+
deps = [
68+
":quant_utils",
69+
"//executorch/runtime/kernel:kernel_includes",
70+
],
71+
visibility = ["PUBLIC"],
72+
)

backends/cadence/fused_quant/tests/BUCK

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,14 @@ runtime.cxx_test(
4646
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
4747
],
4848
)
49+
50+
runtime.cxx_test(
51+
name = "test_op_bmm",
52+
srcs = ["test_op_bmm.cpp"],
53+
platforms = CXX,
54+
deps = [
55+
"//executorch/backends/cadence/fused_quant:op_bmm",
56+
"//executorch/kernels/test:gtest_utils",
57+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
58+
],
59+
)

0 commit comments

Comments
 (0)