|
| 1 | +// Copyright 2026 Tencent |
| 2 | +// SPDX-License-Identifier: BSD-3-Clause |
| 3 | + |
| 4 | +#include "bnll_loongarch.h" |
| 5 | + |
| 6 | +#if __loongarch_sx |
| 7 | +#include <lsxintrin.h> |
| 8 | +#include "lsx_mathfun.h" |
| 9 | +#if __loongarch_asx |
| 10 | +#include <lasxintrin.h> |
| 11 | +#include "lasx_mathfun.h" |
| 12 | +#endif // __loongarch_asx |
| 13 | +#endif // __loongarch_sx |
| 14 | + |
| 15 | +namespace ncnn { |
| 16 | + |
| 17 | +BNLL_loongarch::BNLL_loongarch() |
| 18 | +{ |
| 19 | +#if __loongarch_sx |
| 20 | + support_packing = true; |
| 21 | +#endif |
| 22 | +#if NCNN_BF16 |
| 23 | + support_bf16_storage = true; |
| 24 | +#endif |
| 25 | +} |
| 26 | + |
| 27 | +int BNLL_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) const |
| 28 | +{ |
| 29 | + int w = bottom_top_blob.w; |
| 30 | + int h = bottom_top_blob.h; |
| 31 | + int d = bottom_top_blob.d; |
| 32 | + int channels = bottom_top_blob.c; |
| 33 | + int elempack = bottom_top_blob.elempack; |
| 34 | + int size = w * h * d * elempack; |
| 35 | + |
| 36 | +#if NCNN_BF16 |
| 37 | + if (opt.use_bf16_storage && bottom_top_blob.elembits() == 16) |
| 38 | + return forward_inplace_bf16s(bottom_top_blob, opt); |
| 39 | +#endif |
| 40 | + |
| 41 | + #pragma omp parallel for num_threads(opt.num_threads) |
| 42 | + for (int q = 0; q < channels; q++) |
| 43 | + { |
| 44 | + float* ptr = bottom_top_blob.channel(q); |
| 45 | + |
| 46 | + int i = 0; |
| 47 | +#if __loongarch_sx |
| 48 | +#if __loongarch_asx |
| 49 | + __m256 _zero8 = (__m256)__lasx_xvreplgr2vr_w(0); |
| 50 | + __m256 _one8 = (__m256)__lasx_xvreplfr2vr_s(1.f); |
| 51 | + __m256i _abs_mask8 = __lasx_xvreplgr2vr_w(0x7fffffff); |
| 52 | + for (; i + 7 < size; i += 8) |
| 53 | + { |
| 54 | + __builtin_prefetch(ptr + 32); |
| 55 | + |
| 56 | + __m256 _p = (__m256)__lasx_xvld(ptr, 0); |
| 57 | + __m256 _abs_p = (__m256)__lasx_xvand_v((__m256i)_p, _abs_mask8); |
| 58 | + __m256 _tmp = log256_ps(__lasx_xvfadd_s(_one8, exp256_ps((__m256)__lasx_xvbitrevi_w((__m256i)_abs_p, 31)))); |
| 59 | + __m256 _outp = __lasx_xvfadd_s(__lasx_xvfmax_s(_p, _zero8), _tmp); |
| 60 | + __lasx_xvst(_outp, ptr, 0); |
| 61 | + |
| 62 | + ptr += 8; |
| 63 | + } |
| 64 | +#endif |
| 65 | + __m128 _zero4 = (__m128)__lsx_vreplgr2vr_w(0); |
| 66 | + __m128 _one4 = (__m128)__lsx_vreplfr2vr_s(1.f); |
| 67 | + __m128i _abs_mask4 = __lsx_vreplgr2vr_w(0x7fffffff); |
| 68 | + for (; i + 3 < size; i += 4) |
| 69 | + { |
| 70 | + __builtin_prefetch(ptr + 16); |
| 71 | + |
| 72 | + __m128 _p = (__m128)__lsx_vld(ptr, 0); |
| 73 | + __m128 _abs_p = (__m128)__lsx_vand_v((__m128i)_p, _abs_mask4); |
| 74 | + __m128 _tmp = log_ps(__lsx_vfadd_s(_one4, exp_ps((__m128)__lsx_vbitrevi_w((__m128i)_abs_p, 31)))); |
| 75 | + __m128 _outp = __lsx_vfadd_s(__lsx_vfmax_s(_p, _zero4), _tmp); |
| 76 | + __lsx_vst(_outp, ptr, 0); |
| 77 | + |
| 78 | + ptr += 4; |
| 79 | + } |
| 80 | +#endif |
| 81 | + for (; i < size; i++) |
| 82 | + { |
| 83 | + if (*ptr > 0.f) |
| 84 | + *ptr = *ptr + logf(1.f + expf(-*ptr)); |
| 85 | + else |
| 86 | + *ptr = logf(1.f + expf(*ptr)); |
| 87 | + |
| 88 | + ptr++; |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + return 0; |
| 93 | +} |
| 94 | + |
| 95 | +#if NCNN_BF16 |
| 96 | +int BNLL_loongarch::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const |
| 97 | +{ |
| 98 | + Option opt_cast = opt; |
| 99 | + opt_cast.blob_allocator = opt.workspace_allocator; |
| 100 | + |
| 101 | + Mat bottom_top_blob_fp32; |
| 102 | + cast_bfloat16_to_float32(bottom_top_blob, bottom_top_blob_fp32, opt_cast); |
| 103 | + if (bottom_top_blob_fp32.empty()) |
| 104 | + return -100; |
| 105 | + |
| 106 | + int ret = forward_inplace(bottom_top_blob_fp32, opt); |
| 107 | + if (ret != 0) |
| 108 | + return ret; |
| 109 | + |
| 110 | + cast_float32_to_bfloat16(bottom_top_blob_fp32, bottom_top_blob, opt); |
| 111 | + if (bottom_top_blob.empty()) |
| 112 | + return -100; |
| 113 | + |
| 114 | + return 0; |
| 115 | +} |
| 116 | +#endif // NCNN_BF16 |
| 117 | + |
| 118 | +} // namespace ncnn |
0 commit comments