|
15 | 15 | namespace ncnn { |
16 | 16 |
|
17 | 17 | #if NCNN_BF16 |
18 | | -#include "batchnorm_bf16s.h" |
| 18 | +static void batchnorm_bf16s_sse(unsigned short* ptr, const float* a, const float* b, int size, int elempack) |
| 19 | +{ |
| 20 | +#if __loongarch_sx |
| 21 | + __m128 _a128 = (elempack == 4) ? (__m128)__lsx_vld(a, 0) : (__m128)__lsx_vreplfr2vr_s(a[0]); |
| 22 | + __m128 _b128 = (elempack == 4) ? (__m128)__lsx_vld(b, 0) : (__m128)__lsx_vreplfr2vr_s(b[0]); |
| 23 | +#if __loongarch_asx |
| 24 | + __m256 _a256 = (elempack == 8) ? (__m256)__lasx_xvld(a, 0) : combine4x2_ps(_a128, _a128); |
| 25 | + __m256 _b256 = (elempack == 8) ? (__m256)__lasx_xvld(b, 0) : combine4x2_ps(_b128, _b128); |
| 26 | +#endif |
| 27 | +#endif |
| 28 | + float sa = a[0]; |
| 29 | + float sb = b[0]; |
| 30 | + |
| 31 | + int i = 0; |
| 32 | +#if __loongarch_sx |
| 33 | +#if __loongarch_asx |
| 34 | + for (; i + 7 < size; i += 8) |
| 35 | + { |
| 36 | + __m256 _p = bfloat2float_lasx((__m128i)__lsx_vld(ptr, 0)); |
| 37 | + _p = __lasx_xvfmadd_s(_p, _b256, _a256); |
| 38 | + __lsx_vst(float2bfloat_lasx(_p), ptr, 0); |
| 39 | + ptr += 8; |
| 40 | + } |
| 41 | +#endif // __loongarch_asx |
| 42 | + for (; i + 3 < size; i += 4) |
| 43 | + { |
| 44 | + __m128 _p = bfloat2float_lsx((__m128i)__lsx_vld(ptr, 0)); |
| 45 | + _p = __lsx_vfmadd_s(_p, _b128, _a128); |
| 46 | + __lsx_vstelm_d(float2bfloat_lsx(_p, _p), ptr, 0, 0); |
| 47 | + ptr += 4; |
| 48 | + } |
| 49 | +#endif // __loongarch_sx |
| 50 | + for (; i < size; i++) |
| 51 | + { |
| 52 | + *ptr = float32_to_bfloat16(sb * bfloat16_to_float32(*ptr) + sa); |
| 53 | + ptr++; |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +static void batchnorm_bf16s_per_element_sse(unsigned short* ptr, const float* a, const float* b, int size, int num_threads) |
| 58 | +{ |
| 59 | + int nn_size = 0; |
| 60 | + int remain_size_start = 0; |
| 61 | +#if __loongarch_sx |
| 62 | +#if __loongarch_asx |
| 63 | + nn_size = (size - remain_size_start) / 8; |
| 64 | + #pragma omp parallel for num_threads(num_threads) |
| 65 | + for (int ii = 0; ii < nn_size; ii++) |
| 66 | + { |
| 67 | + int i = remain_size_start + ii * 8; |
| 68 | + __m256 _p = bfloat2float_lasx((__m128i)__lsx_vld(ptr + i, 0)); |
| 69 | + __m256 _a0 = (__m256)__lasx_xvld(a + i, 0); |
| 70 | + __m256 _b0 = (__m256)__lasx_xvld(b + i, 0); |
| 71 | + _p = __lasx_xvfmadd_s(_p, _b0, _a0); |
| 72 | + __lsx_vst(float2bfloat_lasx(_p), ptr + i, 0); |
| 73 | + } |
| 74 | + remain_size_start += nn_size * 8; |
| 75 | +#endif // __loongarch_asx |
| 76 | + nn_size = (size - remain_size_start) / 4; |
| 77 | + #pragma omp parallel for num_threads(num_threads) |
| 78 | + for (int ii = 0; ii < nn_size; ii++) |
| 79 | + { |
| 80 | + int i = remain_size_start + ii * 4; |
| 81 | + __m128 _p = bfloat2float_lsx((__m128i)__lsx_vld(ptr + i, 0)); |
| 82 | + __m128 _a0 = (__m128)__lsx_vld(a + i, 0); |
| 83 | + __m128 _b0 = (__m128)__lsx_vld(b + i, 0); |
| 84 | + _p = __lsx_vfmadd_s(_p, _b0, _a0); |
| 85 | + __lsx_vstelm_d(float2bfloat_lsx(_p, _p), ptr + i, 0, 0); |
| 86 | + } |
| 87 | + remain_size_start += nn_size * 4; |
| 88 | +#endif // __loongarch_sx |
| 89 | + #pragma omp parallel for num_threads(num_threads) |
| 90 | + for (int i = remain_size_start; i < size; i++) |
| 91 | + { |
| 92 | + ptr[i] = float32_to_bfloat16(b[i] * bfloat16_to_float32(ptr[i]) + a[i]); |
| 93 | + } |
| 94 | +} |
19 | 95 | #endif |
20 | 96 |
|
21 | 97 | BatchNorm_loongarch::BatchNorm_loongarch() |
|
0 commit comments