Skip to content

Commit bad96b5

Browse files
Nicoshevmeta-codesync[bot]
authored andcommitted
Implement PackDepthwiseConvMatrix in NEON + deprecate aarch64 compat layers (#5779)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2709 Pull Request resolved: #5779 Add a NEON-based aarch64 implementation of the `PackedDepthWiseConvMatrix` constructor in `PackDepthwiseConvMatrix.cc`, alongside the existing AVX2 x86 implementation. The constructor packs depthwise convolution weight matrices into a SIMD-friendly interleaved layout. Rename depthwise-convolution related files, as NEON and AVX2 implementations already co-exist Remove compilation of avx2 source files for aarch64 targets and remove usage of aarch64 compat layers Reviewed By: q10, YifanYuan3 Differential Revision: D106137964
1 parent 6c71acd commit bad96b5

12 files changed

Lines changed: 133 additions & 16 deletions

bench/Depthwise3DBenchmark.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
#include "./AlignedVec.h"
2121
#include "./BenchUtils.h"
22-
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
22+
#include "fbgemm/FbgemmI8Depthwise.h"
2323
#include "fbgemm/Utils.h"
2424
#include "src/RefImplementations.h" // @manual
2525

bench/DepthwiseBenchmark.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
#include "./AlignedVec.h"
2222
#include "./BenchUtils.h"
23-
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
23+
#include "fbgemm/FbgemmI8Depthwise.h"
2424
#include "fbgemm/Utils.h"
2525
#include "src/RefImplementations.h" // @manual
2626

defs.bzl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def get_fbgemm_public_headers():
106106
"include/fbgemm/FbgemmFP32.h",
107107
"include/fbgemm/FbgemmFPCommon.h",
108108
"include/fbgemm/FbgemmI64.h",
109-
"include/fbgemm/FbgemmI8DepthwiseAvx2.h",
109+
"include/fbgemm/FbgemmI8Depthwise.h",
110110
"include/fbgemm/FbgemmI8DirectconvAvx2.h",
111111
"include/fbgemm/FbgemmI8Spmdm.h",
112112
"include/fbgemm/FbgemmPackMatrixB.h",
@@ -132,9 +132,9 @@ def get_fbgemm_avx2_srcs(buck = False):
132132
# downstream targets pull in both fbgemm_avx2 and fbgemm_sve (the
133133
# latter is selected by the main fbgemm target on arm64).
134134
depthwise_srcs = [
135-
"src/FbgemmI8Depthwise3DAvx2.cc",
136-
"src/FbgemmI8DepthwiseAvx2.cc",
137-
"src/PackDepthwiseConvMatrixAvx2.cc",
135+
"src/FbgemmI8Depthwise3D.cc",
136+
"src/FbgemmI8Depthwise.cc",
137+
"src/PackDepthwiseConvMatrix.cc",
138138
]
139139

140140
common_srcs = [
@@ -211,6 +211,9 @@ def get_fbgemm_inline_sve_srcs(msvc = False, buck = False):
211211
"src/FbgemmFP16UKernelsSve128.cc",
212212
"src/UtilsSve.cc",
213213
"src/FbgemmFloat16ConvertSVE.cc",
214+
"src/PackDepthwiseConvMatrix.cc",
215+
"src/FbgemmI8Depthwise3D.cc",
216+
"src/FbgemmI8Depthwise.cc",
214217
]
215218

216219
if buck:

include/fbgemm/Fbgemm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include "./ConvUtils.h" // @manual
1818
#include "./FbgemmBuild.h" // @manual
1919
#include "./FbgemmEmbedding.h" // @manual
20-
#include "./FbgemmI8DepthwiseAvx2.h" // @manual
20+
#include "./FbgemmI8Depthwise.h" // @manual
2121
#include "./FbgemmI8DirectconvAvx2.h" // @manual
2222
#include "./FbgemmI8Spmdm.h" // @manual
2323
#include "./FloatConversion.h" // @manual
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
*/
88

99
#define FBGEMM_EXPORTS
10-
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
10+
#include "fbgemm/FbgemmI8Depthwise.h"
1111

1212
#include <stdexcept> // for logic_error
1313
#include <string>
1414

15-
#include "./FbgemmI8Depthwise2DAvx2-inl.h" // @manual
15+
#include "./FbgemmI8Depthwise2D-inl.h" // @manual
1616

1717
using namespace std;
1818

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
*/
88

99
#define FBGEMM_EXPORTS
10-
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
10+
#include "fbgemm/FbgemmI8Depthwise.h"
1111

1212
#include <stdexcept> // for logic_error
1313
#include <string>

src/FbgemmI8DepthwisePerChannelQuantAvx2.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
*/
88

99
#define FBGEMM_EXPORTS
10-
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
10+
#include "fbgemm/FbgemmI8Depthwise.h"
1111

12-
#include "./FbgemmI8Depthwise2DAvx2-inl.h" // @manual
12+
#include "./FbgemmI8Depthwise2D-inl.h" // @manual
1313

1414
namespace fbgemm {
1515

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,24 @@
77
*/
88

99
#define FBGEMM_EXPORTS
10-
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
10+
#include "fbgemm/FbgemmI8Depthwise.h"
1111

1212
#if defined(__x86_64__) || defined(__i386__) || \
1313
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
1414
#include <immintrin.h>
15+
#include "./MaskAvx2.h" // @manual
16+
#elif defined(__aarch64__)
17+
#include <arm_neon.h>
18+
#include <cstring>
1519
#endif
1620

17-
#include "./MaskAvx2.h" // @manual
1821
#include "fbgemm/UtilsAvx2.h"
1922

2023
namespace fbgemm {
2124

25+
#if defined(__x86_64__) || defined(__i386__) || \
26+
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
27+
2228
PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix(
2329
int OC,
2430
int kernel_prod,
@@ -159,6 +165,114 @@ PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix(
159165
}
160166
}
161167

168+
#elif defined(__aarch64__)
169+
170+
namespace {
171+
struct neon_256i {
172+
int8x16_t lo, hi;
173+
};
174+
} // namespace
175+
176+
PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix(
177+
int OC,
178+
int kernel_prod,
179+
const int8_t* smat)
180+
: OC_(OC), kernel_prod_(kernel_prod) {
181+
auto smat_transposed_owner =
182+
makeAlignedUniquePtr<int8_t>(64, OC * kernel_prod);
183+
int8_t* smat_transposed = smat_transposed_owner.get();
184+
for (int i = 0; i < kernel_prod; ++i) {
185+
for (int j = 0; j < OC; ++j) {
186+
smat_transposed[i * OC + j] = smat[i + j * kernel_prod];
187+
}
188+
}
189+
190+
int kernel_prod_aligned = (kernel_prod + 1) / 2 * 2;
191+
pmat_ = static_cast<int8_t*>(fbgemmAlignedAlloc(
192+
64, ((OC + 31) / 32) * kernel_prod_aligned * 32 * sizeof(int8_t)));
193+
194+
auto b_v_owner = makeAlignedUniquePtr<neon_256i>(64, kernel_prod);
195+
auto b_v = b_v_owner.get();
196+
auto b_interleaved_epi16_owner =
197+
makeAlignedUniquePtr<neon_256i>(64, kernel_prod_aligned);
198+
auto b_interleaved_epi16 = b_interleaved_epi16_owner.get();
199+
auto b_interleaved_epi32_owner =
200+
makeAlignedUniquePtr<neon_256i>(64, kernel_prod_aligned);
201+
auto b_interleaved_epi32 = b_interleaved_epi32_owner.get();
202+
203+
for (int k1 = 0; k1 < OC; k1 += 32) {
204+
int remainder = OC - k1;
205+
if (remainder < 32) {
206+
for (int i = 0; i < kernel_prod; ++i) {
207+
alignas(16) int8_t tmp[32] = {};
208+
int valid_bytes = (remainder / 4) * 4;
209+
memcpy(tmp, smat_transposed + i * OC + k1, valid_bytes);
210+
b_v[i].lo = vld1q_s8(tmp);
211+
b_v[i].hi = vld1q_s8(tmp + 16);
212+
}
213+
} else {
214+
for (int i = 0; i < kernel_prod; ++i) {
215+
const int8_t* src = smat_transposed + i * OC + k1;
216+
b_v[i].lo = vld1q_s8(src);
217+
b_v[i].hi = vld1q_s8(src + 16);
218+
}
219+
}
220+
221+
neon_256i zero_v;
222+
zero_v.lo = vdupq_n_s8(0);
223+
zero_v.hi = vdupq_n_s8(0);
224+
for (int i = 0; i < kernel_prod_aligned / 2; ++i) {
225+
neon_256i a = b_v[2 * i];
226+
neon_256i b_val = (2 * i + 1 >= kernel_prod) ? zero_v : b_v[2 * i + 1];
227+
b_interleaved_epi16[2 * i].lo = vzip1q_s8(a.lo, b_val.lo);
228+
b_interleaved_epi16[2 * i].hi = vzip1q_s8(a.hi, b_val.hi);
229+
b_interleaved_epi16[2 * i + 1].lo = vzip2q_s8(a.lo, b_val.lo);
230+
b_interleaved_epi16[2 * i + 1].hi = vzip2q_s8(a.hi, b_val.hi);
231+
}
232+
233+
for (int i = 0; i < kernel_prod_aligned / 4; ++i) {
234+
int16x8_t a_lo = vreinterpretq_s16_s8(b_interleaved_epi16[4 * i].lo);
235+
int16x8_t a_hi = vreinterpretq_s16_s8(b_interleaved_epi16[4 * i].hi);
236+
int16x8_t c_lo = vreinterpretq_s16_s8(b_interleaved_epi16[4 * i + 2].lo);
237+
int16x8_t c_hi = vreinterpretq_s16_s8(b_interleaved_epi16[4 * i + 2].hi);
238+
239+
b_interleaved_epi32[4 * i].lo =
240+
vreinterpretq_s8_s16(vzip1q_s16(a_lo, c_lo));
241+
b_interleaved_epi32[4 * i].hi =
242+
vreinterpretq_s8_s16(vzip1q_s16(a_hi, c_hi));
243+
b_interleaved_epi32[4 * i + 1].lo =
244+
vreinterpretq_s8_s16(vzip2q_s16(a_lo, c_lo));
245+
b_interleaved_epi32[4 * i + 1].hi =
246+
vreinterpretq_s8_s16(vzip2q_s16(a_hi, c_hi));
247+
248+
int16x8_t b_lo = vreinterpretq_s16_s8(b_interleaved_epi16[4 * i + 1].lo);
249+
int16x8_t b_hi = vreinterpretq_s16_s8(b_interleaved_epi16[4 * i + 1].hi);
250+
int16x8_t d_lo = vreinterpretq_s16_s8(b_interleaved_epi16[4 * i + 3].lo);
251+
int16x8_t d_hi = vreinterpretq_s16_s8(b_interleaved_epi16[4 * i + 3].hi);
252+
253+
b_interleaved_epi32[4 * i + 2].lo =
254+
vreinterpretq_s8_s16(vzip1q_s16(b_lo, d_lo));
255+
b_interleaved_epi32[4 * i + 2].hi =
256+
vreinterpretq_s8_s16(vzip1q_s16(b_hi, d_hi));
257+
b_interleaved_epi32[4 * i + 3].lo =
258+
vreinterpretq_s8_s16(vzip2q_s16(b_lo, d_lo));
259+
b_interleaved_epi32[4 * i + 3].hi =
260+
vreinterpretq_s8_s16(vzip2q_s16(b_hi, d_hi));
261+
}
262+
for (int i = kernel_prod_aligned / 4 * 4; i < kernel_prod_aligned; ++i) {
263+
b_interleaved_epi32[i] = b_interleaved_epi16[i];
264+
}
265+
266+
for (int i = 0; i < kernel_prod_aligned; ++i) {
267+
int8_t* dst = &pmat_[((k1 / 32) * kernel_prod_aligned + i) * 32];
268+
vst1q_s8(dst, b_interleaved_epi32[i].lo);
269+
vst1q_s8(dst + 16, b_interleaved_epi32[i].hi);
270+
}
271+
}
272+
}
273+
274+
#endif
275+
162276
int PackedDepthWiseConvMatrix::addr(int r, int c) {
163277
int kernel_prod_aligned = (kernel_prod_ + 1) / 2 * 2;
164278
if (c >= kernel_prod_ / 4 * 4 &&

0 commit comments

Comments
 (0)