Skip to content

Commit bc43dcc

Browse files
committed
cc
1 parent 1d498d6 commit bc43dcc

28 files changed

Lines changed: 2744 additions & 2808 deletions

src/layer/loongarch/batchnorm_bf16s.h

Lines changed: 0 additions & 85 deletions
This file was deleted.

src/layer/loongarch/batchnorm_loongarch.cpp

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,83 @@
1515
namespace ncnn {
1616

1717
#if NCNN_BF16
18-
#include "batchnorm_bf16s.h"
18+
static void batchnorm_bf16s_sse(unsigned short* ptr, const float* a, const float* b, int size, int elempack)
19+
{
20+
#if __loongarch_sx
21+
__m128 _a128 = (elempack == 4) ? (__m128)__lsx_vld(a, 0) : (__m128)__lsx_vreplfr2vr_s(a[0]);
22+
__m128 _b128 = (elempack == 4) ? (__m128)__lsx_vld(b, 0) : (__m128)__lsx_vreplfr2vr_s(b[0]);
23+
#if __loongarch_asx
24+
__m256 _a256 = (elempack == 8) ? (__m256)__lasx_xvld(a, 0) : combine4x2_ps(_a128, _a128);
25+
__m256 _b256 = (elempack == 8) ? (__m256)__lasx_xvld(b, 0) : combine4x2_ps(_b128, _b128);
26+
#endif
27+
#endif
28+
float sa = a[0];
29+
float sb = b[0];
30+
31+
int i = 0;
32+
#if __loongarch_sx
33+
#if __loongarch_asx
34+
for (; i + 7 < size; i += 8)
35+
{
36+
__m256 _p = bfloat2float_lasx((__m128i)__lsx_vld(ptr, 0));
37+
_p = __lasx_xvfmadd_s(_p, _b256, _a256);
38+
__lsx_vst(float2bfloat_lasx(_p), ptr, 0);
39+
ptr += 8;
40+
}
41+
#endif // __loongarch_asx
42+
for (; i + 3 < size; i += 4)
43+
{
44+
__m128 _p = bfloat2float_lsx((__m128i)__lsx_vld(ptr, 0));
45+
_p = __lsx_vfmadd_s(_p, _b128, _a128);
46+
__lsx_vstelm_d(float2bfloat_lsx(_p, _p), ptr, 0, 0);
47+
ptr += 4;
48+
}
49+
#endif // __loongarch_sx
50+
for (; i < size; i++)
51+
{
52+
*ptr = float32_to_bfloat16(sb * bfloat16_to_float32(*ptr) + sa);
53+
ptr++;
54+
}
55+
}
56+
57+
static void batchnorm_bf16s_per_element_sse(unsigned short* ptr, const float* a, const float* b, int size, int num_threads)
58+
{
59+
int nn_size = 0;
60+
int remain_size_start = 0;
61+
#if __loongarch_sx
62+
#if __loongarch_asx
63+
nn_size = (size - remain_size_start) / 8;
64+
#pragma omp parallel for num_threads(num_threads)
65+
for (int ii = 0; ii < nn_size; ii++)
66+
{
67+
int i = remain_size_start + ii * 8;
68+
__m256 _p = bfloat2float_lasx((__m128i)__lsx_vld(ptr + i, 0));
69+
__m256 _a0 = (__m256)__lasx_xvld(a + i, 0);
70+
__m256 _b0 = (__m256)__lasx_xvld(b + i, 0);
71+
_p = __lasx_xvfmadd_s(_p, _b0, _a0);
72+
__lsx_vst(float2bfloat_lasx(_p), ptr + i, 0);
73+
}
74+
remain_size_start += nn_size * 8;
75+
#endif // __loongarch_asx
76+
nn_size = (size - remain_size_start) / 4;
77+
#pragma omp parallel for num_threads(num_threads)
78+
for (int ii = 0; ii < nn_size; ii++)
79+
{
80+
int i = remain_size_start + ii * 4;
81+
__m128 _p = bfloat2float_lsx((__m128i)__lsx_vld(ptr + i, 0));
82+
__m128 _a0 = (__m128)__lsx_vld(a + i, 0);
83+
__m128 _b0 = (__m128)__lsx_vld(b + i, 0);
84+
_p = __lsx_vfmadd_s(_p, _b0, _a0);
85+
__lsx_vstelm_d(float2bfloat_lsx(_p, _p), ptr + i, 0, 0);
86+
}
87+
remain_size_start += nn_size * 4;
88+
#endif // __loongarch_sx
89+
#pragma omp parallel for num_threads(num_threads)
90+
for (int i = remain_size_start; i < size; i++)
91+
{
92+
ptr[i] = float32_to_bfloat16(b[i] * bfloat16_to_float32(ptr[i]) + a[i]);
93+
}
94+
}
1995
#endif
2096

2197
BatchNorm_loongarch::BatchNorm_loongarch()

0 commit comments

Comments
 (0)