@@ -180,44 +180,49 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
180180}
181181#endif
182182
183+ #if defined(__riscv_v_intrinsic)
184+ template <> inline vfloat32m1_t madd (vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
185+ return __riscv_vfmacc_vv_f32m1 (c, a, b, __riscv_vsetvlmax_e32m1 ());
186+ }
187+ template <> inline vfloat32m2_t madd (vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
188+ return __riscv_vfmacc_vv_f32m2 (c, a, b, __riscv_vsetvlmax_e32m2 ());
189+ }
190+ template <> inline vfloat32m4_t madd (vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
191+ return __riscv_vfmacc_vv_f32m4 (c, a, b, __riscv_vsetvlmax_e32m4 ());
192+ }
193+ template <> inline vfloat32m8_t madd (vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
194+ return __riscv_vfmacc_vv_f32m8 (c, a, b, __riscv_vsetvlmax_e32m8 ());
195+ }
196+ #endif
197+
183198#if defined(__riscv_zvfh)
184- template <>
185- inline vfloat32m1_t madd (vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
199+ template <> inline vfloat32m1_t madd (vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
186200 return __riscv_vfwmacc_vv_f32m1 (c, a, b, __riscv_vsetvlmax_e32m1 ());
187201}
188- inline vfloat32m2_t madd (vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
202+ template <> inline vfloat32m2_t madd (vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
189203 return __riscv_vfwmacc_vv_f32m2 (c, a, b, __riscv_vsetvlmax_e32m2 ());
190204}
191- inline vfloat32m4_t madd (vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
205+ template <> inline vfloat32m4_t madd (vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
192206 return __riscv_vfwmacc_vv_f32m4 (c, a, b, __riscv_vsetvlmax_e32m4 ());
193207}
194- inline vfloat32m8_t madd (vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
208+ template <> inline vfloat32m8_t madd (vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
195209 return __riscv_vfwmacc_vv_f32m8 (c, a, b, __riscv_vsetvlmax_e32m8 ());
196210}
197- inline vfloat32m1_t madd (vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
198- return __riscv_vfmacc_vv_f32m1 (c, a, b, __riscv_vsetvlmax_e32m1 ());
199- }
200- inline vfloat32m2_t madd (vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
201- return __riscv_vfmacc_vv_f32m2 (c, a, b, __riscv_vsetvlmax_e32m2 ());
202- }
203- inline vfloat32m4_t madd (vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
204- return __riscv_vfmacc_vv_f32m4 (c, a, b, __riscv_vsetvlmax_e32m4 ());
205- }
206- inline vfloat32m8_t madd (vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
207- return __riscv_vfmacc_vv_f32m8 (c, a, b, __riscv_vsetvlmax_e32m8 ());
208- }
209211#endif
210212
211213#if defined(__riscv_zvfbfwma)
212- inline vfloat32m1_t madd (vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
214+ template <> inline vfloat32m1_t madd (vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
213215 return __riscv_vfwmaccbf16_vv_f32m1 (c, a, b, __riscv_vsetvlmax_e32m1 ());
214216}
215- inline vfloat32m2_t madd (vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
217+ template <> inline vfloat32m2_t madd (vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
216218 return __riscv_vfwmaccbf16_vv_f32m2 (c, a, b, __riscv_vsetvlmax_e32m2 ());
217219}
218- inline vfloat32m4_t madd (vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
220+ template <> inline vfloat32m4_t madd (vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
219221 return __riscv_vfwmaccbf16_vv_f32m4 (c, a, b, __riscv_vsetvlmax_e32m4 ());
220222}
223+ template <> inline vfloat32m8_t madd (vbfloat16m4_t a, vbfloat16m4_t b, vfloat32m8_t c) {
224+ return __riscv_vfwmaccbf16_vv_f32m8 (c, a, b, __riscv_vsetvlmax_e32m8 ());
225+ }
221226#endif
222227
223228// //////////////////////////////////////////////////////////////////////////////////////////////////
@@ -272,7 +277,7 @@ inline float hsum(__m512 x) {
272277}
273278#endif // __AVX512F__
274279
275- #if defined(__riscv_zvfh )
280+ #if defined(__riscv_v_intrinsic )
276281inline float hsum (vfloat32m1_t x) {
277282 return __riscv_vfmv_f_s_f32m1_f32 (
278283 __riscv_vfredusum_vs_f32m1_f32m1 (x, __riscv_vfmv_v_f_f32m1 (0 , 1 ), __riscv_vsetvlmax_e32m1 ()));
@@ -379,6 +384,21 @@ template <> inline __m256bh load(const float *p) {
379384}
380385#endif
381386
387+ #if defined(__riscv_v_intrinsic)
388+ template <> inline vfloat32m1_t load (const float *p) {
389+ return __riscv_vle32_v_f32m1 (p, __riscv_vsetvlmax_e32m1 ());
390+ }
391+ template <> inline vfloat32m2_t load (const float *p) {
392+ return __riscv_vle32_v_f32m2 (p, __riscv_vsetvlmax_e32m2 ());
393+ }
394+ template <> inline vfloat32m4_t load (const float *p) {
395+ return __riscv_vle32_v_f32m4 (p, __riscv_vsetvlmax_e32m4 ());
396+ }
397+ template <> inline vfloat32m8_t load (const float *p) {
398+ return __riscv_vle32_v_f32m8 (p, __riscv_vsetvlmax_e32m8 ());
399+ }
400+ #endif
401+
382402#if defined(__riscv_zvfh)
383403template <> inline vfloat16mf2_t load (const ggml_fp16_t *p) {
384404 return __riscv_vle16_v_f16mf2 (reinterpret_cast <const _Float16 *>(p), __riscv_vsetvlmax_e16mf2 ());
@@ -392,18 +412,6 @@ template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
392412template <> inline vfloat16m4_t load (const ggml_fp16_t *p) {
393413 return __riscv_vle16_v_f16m4 (reinterpret_cast <const _Float16 *>(p), __riscv_vsetvlmax_e16m4 ());
394414}
395- template <> inline vfloat32m1_t load (const float *p) {
396- return __riscv_vle32_v_f32m1 (p, __riscv_vsetvlmax_e32m1 ());
397- }
398- template <> inline vfloat32m2_t load (const float *p) {
399- return __riscv_vle32_v_f32m2 (p, __riscv_vsetvlmax_e32m2 ());
400- }
401- template <> inline vfloat32m4_t load (const float *p) {
402- return __riscv_vle32_v_f32m4 (p, __riscv_vsetvlmax_e32m4 ());
403- }
404- template <> inline vfloat32m8_t load (const float *p) {
405- return __riscv_vle32_v_f32m8 (p, __riscv_vsetvlmax_e32m8 ());
406- }
407415#endif
408416
409417#if defined(__riscv_zvfbfwma)
@@ -416,23 +424,14 @@ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
416424template <> inline vbfloat16m2_t load (const ggml_bf16_t *p) {
417425 return __riscv_vle16_v_bf16m2 (reinterpret_cast <const __bf16*>(p), __riscv_vsetvlmax_e16m2 ());
418426}
427+ template <> inline vbfloat16m4_t load (const ggml_bf16_t *p) {
428+ return __riscv_vle16_v_bf16m4 (reinterpret_cast <const __bf16*>(p), __riscv_vsetvlmax_e16m4 ());
429+ }
419430#endif
420431
421- #if defined(__riscv_zvfh )
432+ #if defined(__riscv_v_intrinsic )
422433template <typename T> T set_zero ();
423434
424- template <> inline vfloat16mf2_t set_zero () {
425- return __riscv_vfmv_v_f_f16mf2 (0 , __riscv_vsetvlmax_e16mf2 ());
426- }
427- template <> inline vfloat16m1_t set_zero () {
428- return __riscv_vfmv_v_f_f16m1 (0 , __riscv_vsetvlmax_e16m1 ());
429- }
430- template <> inline vfloat16m2_t set_zero () {
431- return __riscv_vfmv_v_f_f16m2 (0 , __riscv_vsetvlmax_e16m2 ());
432- }
433- template <> inline vfloat16m4_t set_zero () {
434- return __riscv_vfmv_v_f_f16m4 (0 , __riscv_vsetvlmax_e16m4 ());
435- }
436435template <> inline vfloat32m1_t set_zero () {
437436 return __riscv_vfmv_v_f_f32m1 (0 .0f , __riscv_vsetvlmax_e32m1 ());
438437}
@@ -449,14 +448,22 @@ template <> inline vfloat32m8_t set_zero() {
449448
450449#if defined(__riscv_v_intrinsic)
451450template <typename T> size_t vlmax () {
452- if constexpr (std::is_same_v<T, vfloat16mf2_t >) { return __riscv_vsetvlmax_e16mf2 (); }
453- else if constexpr (std::is_same_v<T, vfloat16m1_t >) { return __riscv_vsetvlmax_e16m1 (); }
454- else if constexpr (std::is_same_v<T, vfloat16m2_t >) { return __riscv_vsetvlmax_e16m2 (); }
455- else if constexpr (std::is_same_v<T, vfloat16m4_t >) { return __riscv_vsetvlmax_e16m4 (); }
456- else if constexpr (std::is_same_v<T, vfloat32m1_t >) { return __riscv_vsetvlmax_e32m1 (); }
451+ if constexpr (std::is_same_v<T, vfloat32m1_t >) { return __riscv_vsetvlmax_e32m1 (); }
457452 else if constexpr (std::is_same_v<T, vfloat32m2_t >) { return __riscv_vsetvlmax_e32m2 (); }
458453 else if constexpr (std::is_same_v<T, vfloat32m4_t >) { return __riscv_vsetvlmax_e32m4 (); }
459454 else if constexpr (std::is_same_v<T, vfloat32m8_t >) { return __riscv_vsetvlmax_e32m8 (); }
455+ #if defined (__riscv_zvfh)
456+ else if constexpr (std::is_same_v<T, vfloat16mf2_t >) { return __riscv_vsetvlmax_e16mf2 (); }
457+ else if constexpr (std::is_same_v<T, vfloat16m1_t >) { return __riscv_vsetvlmax_e16m1 (); }
458+ else if constexpr (std::is_same_v<T, vfloat16m2_t >) { return __riscv_vsetvlmax_e16m2 (); }
459+ else if constexpr (std::is_same_v<T, vfloat16m4_t >) { return __riscv_vsetvlmax_e16m4 (); }
460+ #endif
461+ #if defined (__riscv_zvfbfwma)
462+ else if constexpr (std::is_same_v<T, vbfloat16mf2_t >) { return __riscv_vsetvlmax_e16mf2 (); }
463+ else if constexpr (std::is_same_v<T, vbfloat16m1_t >) { return __riscv_vsetvlmax_e16m1 (); }
464+ else if constexpr (std::is_same_v<T, vbfloat16m2_t >) { return __riscv_vsetvlmax_e16m2 (); }
465+ else if constexpr (std::is_same_v<T, vbfloat16m4_t >) { return __riscv_vsetvlmax_e16m4 (); }
466+ #endif
460467 return 0 ;
461468}
462469#endif
@@ -3740,7 +3747,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
37403747 params->ith , params->nth };
37413748 tb.matmul (m, n);
37423749 return true ;
3743- #elif defined(__riscv_zvfh )
3750+ #elif defined(__riscv_v_intrinsic )
37443751 #if LMUL == 1
37453752 tinyBLAS_RVV<vfloat32m1_t , vfloat32m1_t , float , float , float > tb{ params,
37463753 k, (const float *)A, lda,
@@ -3804,23 +3811,25 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
38043811 return true ;
38053812 }
38063813#elif defined(__riscv_zvfbfwma)
3807- #if LMUL == 1
3808- tinyBLAS_RVV<vfloat32m1_t , vbfloat16mf2_t , ggml_bf16_t , ggml_bf16_t , float > tb{ params,
3809- k, (const ggml_bf16_t *)A, lda,
3810- (const ggml_bf16_t *)B, ldb,
3811- (float *)C, ldc};
3812- #elif LMUL == 2
3813- tinyBLAS_RVV<vfloat32m2_t , vbfloat16m1_t , ggml_bf16_t , ggml_bf16_t , float > tb{ params,
3814- k, (const ggml_bf16_t *)A, lda,
3815- (const ggml_bf16_t *)B, ldb,
3816- (float *)C, ldc};
3817- #else // LMUL = 4
3818- tinyBLAS_RVV<vfloat32m4_t , vbfloat16m2_t , ggml_bf16_t , ggml_bf16_t , float > tb{ params,
3819- k, (const ggml_bf16_t *)A, lda,
3820- (const ggml_bf16_t *)B, ldb,
3821- (float *)C, ldc};
3822- #endif
3823- return tb.matmul (m, n);
3814+ if (Btype == GGML_TYPE_BF16) {
3815+ #if LMUL == 1
3816+ tinyBLAS_RVV<vfloat32m1_t , vbfloat16mf2_t , ggml_bf16_t , ggml_bf16_t , float > tb{ params,
3817+ k, (const ggml_bf16_t *)A, lda,
3818+ (const ggml_bf16_t *)B, ldb,
3819+ (float *)C, ldc};
3820+ #elif LMUL == 2
3821+ tinyBLAS_RVV<vfloat32m2_t , vbfloat16m1_t , ggml_bf16_t , ggml_bf16_t , float > tb{ params,
3822+ k, (const ggml_bf16_t *)A, lda,
3823+ (const ggml_bf16_t *)B, ldb,
3824+ (float *)C, ldc};
3825+ #else // LMUL = 4
3826+ tinyBLAS_RVV<vfloat32m4_t , vbfloat16m2_t , ggml_bf16_t , ggml_bf16_t , float > tb{ params,
3827+ k, (const ggml_bf16_t *)A, lda,
3828+ (const ggml_bf16_t *)B, ldb,
3829+ (float *)C, ldc};
3830+ #endif
3831+ return tb.matmul (m, n);
3832+ }
38243833#endif
38253834 return false ;
38263835 }
0 commit comments