11#include "common.h"
22#include <riscv_vector.h>
33
4+ #define BF16_WIDEN_ONE
5+
46int CNAME (BLASLONG M , BLASLONG N , BLASLONG K , FLOAT alpha , IFLOAT * A , IFLOAT * B , FLOAT * C , BLASLONG ldc )
57{
68 BLASLONG gvl = 0 ;
@@ -28,6 +30,30 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
2830 vfloat32m2_t result7 = __riscv_vfmv_v_f_f32m2 (0.0f , gvl );
2931
3032 for (BLASLONG k = 0 ; k < K ; k ++ ) {
33+ #ifdef BF16_WIDEN_ONE
34+ float B0 = (float )(BB [bi + 0 ]);
35+ float B1 = (float )(BB [bi + 1 ]);
36+ float B2 = (float )(BB [bi + 2 ]);
37+ float B3 = (float )(BB [bi + 3 ]);
38+ float B4 = (float )(BB [bi + 4 ]);
39+ float B5 = (float )(BB [bi + 5 ]);
40+ float B6 = (float )(BB [bi + 6 ]);
41+ float B7 = (float )(BB [bi + 7 ]);
42+ bi += 8 ;
43+
44+ vbfloat16m1_t A00 = __riscv_vle16_v_bf16m1 ( & AA [ai + 0 * gvl ], gvl );
45+ vfloat32m2_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m2 (A00 , gvl );
46+ ai += 16 ;
47+
48+ result0 = __riscv_vfmacc_vf_f32m2 (result0 , B0 , A0 , gvl );
49+ result1 = __riscv_vfmacc_vf_f32m2 (result1 , B1 , A0 , gvl );
50+ result2 = __riscv_vfmacc_vf_f32m2 (result2 , B2 , A0 , gvl );
51+ result3 = __riscv_vfmacc_vf_f32m2 (result3 , B3 , A0 , gvl );
52+ result4 = __riscv_vfmacc_vf_f32m2 (result4 , B4 , A0 , gvl );
53+ result5 = __riscv_vfmacc_vf_f32m2 (result5 , B5 , A0 , gvl );
54+ result6 = __riscv_vfmacc_vf_f32m2 (result6 , B6 , A0 , gvl );
55+ result7 = __riscv_vfmacc_vf_f32m2 (result7 , B7 , A0 , gvl );
56+ #else
3157 __bf16 B0 = BB [bi + 0 ];
3258 __bf16 B1 = BB [bi + 1 ];
3359 __bf16 B2 = BB [bi + 2 ];
@@ -49,6 +75,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
4975 result5 = __riscv_vfwmaccbf16_vf_f32m2 (result5 , B5 , A0 , gvl );
5076 result6 = __riscv_vfwmaccbf16_vf_f32m2 (result6 , B6 , A0 , gvl );
5177 result7 = __riscv_vfwmaccbf16_vf_f32m2 (result7 , B7 , A0 , gvl );
78+ #endif
5279 }
5380
5481 BLASLONG ci = n_top * ldc + m_top ;
@@ -102,6 +129,30 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
102129 vfloat32m1_t result7 = __riscv_vfmv_v_f_f32m1 (0.0f , gvl );
103130
104131 for (BLASLONG k = 0 ; k < K ; k ++ ) {
132+ #ifdef BF16_WIDEN_ONE
133+ float B0 = (float )(BB [bi + 0 ]);
134+ float B1 = (float )(BB [bi + 1 ]);
135+ float B2 = (float )(BB [bi + 2 ]);
136+ float B3 = (float )(BB [bi + 3 ]);
137+ float B4 = (float )(BB [bi + 4 ]);
138+ float B5 = (float )(BB [bi + 5 ]);
139+ float B6 = (float )(BB [bi + 6 ]);
140+ float B7 = (float )(BB [bi + 7 ]);
141+ bi += 8 ;
142+
143+ vbfloat16mf2_t A00 = __riscv_vle16_v_bf16mf2 ( & AA [ai + 0 * gvl ], gvl );
144+ vfloat32m1_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m1 (A00 , gvl );
145+ ai += 8 ;
146+
147+ result0 = __riscv_vfmacc_vf_f32m1 (result0 , B0 , A0 , gvl );
148+ result1 = __riscv_vfmacc_vf_f32m1 (result1 , B1 , A0 , gvl );
149+ result2 = __riscv_vfmacc_vf_f32m1 (result2 , B2 , A0 , gvl );
150+ result3 = __riscv_vfmacc_vf_f32m1 (result3 , B3 , A0 , gvl );
151+ result4 = __riscv_vfmacc_vf_f32m1 (result4 , B4 , A0 , gvl );
152+ result5 = __riscv_vfmacc_vf_f32m1 (result5 , B5 , A0 , gvl );
153+ result6 = __riscv_vfmacc_vf_f32m1 (result6 , B6 , A0 , gvl );
154+ result7 = __riscv_vfmacc_vf_f32m1 (result7 , B7 , A0 , gvl );
155+ #else
105156 __bf16 B0 = BB [bi + 0 ];
106157 __bf16 B1 = BB [bi + 1 ];
107158 __bf16 B2 = BB [bi + 2 ];
@@ -123,6 +174,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
123174 result5 = __riscv_vfwmaccbf16_vf_f32m1 (result5 , B5 , A0 , gvl );
124175 result6 = __riscv_vfwmaccbf16_vf_f32m1 (result6 , B6 , A0 , gvl );
125176 result7 = __riscv_vfwmaccbf16_vf_f32m1 (result7 , B7 , A0 , gvl );
177+ #endif
126178 }
127179
128180 BLASLONG ci = n_top * ldc + m_top ;
@@ -174,6 +226,30 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
174226 vfloat32m1_t result7 = __riscv_vfmv_v_f_f32m1 (0.0f , gvl );
175227
176228 for (BLASLONG k = 0 ; k < K ; ++ k ) {
229+ #ifdef BF16_WIDEN_ONE
230+ float B0 = (float )(BB [bi + 0 ]);
231+ float B1 = (float )(BB [bi + 1 ]);
232+ float B2 = (float )(BB [bi + 2 ]);
233+ float B3 = (float )(BB [bi + 3 ]);
234+ float B4 = (float )(BB [bi + 4 ]);
235+ float B5 = (float )(BB [bi + 5 ]);
236+ float B6 = (float )(BB [bi + 6 ]);
237+ float B7 = (float )(BB [bi + 7 ]);
238+ bi += 8 ;
239+
240+ vbfloat16mf4_t A00 = __riscv_vle16_v_bf16mf4 ( & AA [ai + 0 * gvl ], gvl );
241+ vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1 (__riscv_vfwcvtbf16_f_f_v_f32mf2 (A00 , gvl ));
242+ ai += 4 ;
243+
244+ result0 = __riscv_vfmacc_vf_f32m1 (result0 , B0 , A0 , gvl );
245+ result1 = __riscv_vfmacc_vf_f32m1 (result1 , B1 , A0 , gvl );
246+ result2 = __riscv_vfmacc_vf_f32m1 (result2 , B2 , A0 , gvl );
247+ result3 = __riscv_vfmacc_vf_f32m1 (result3 , B3 , A0 , gvl );
248+ result4 = __riscv_vfmacc_vf_f32m1 (result4 , B4 , A0 , gvl );
249+ result5 = __riscv_vfmacc_vf_f32m1 (result5 , B5 , A0 , gvl );
250+ result6 = __riscv_vfmacc_vf_f32m1 (result6 , B6 , A0 , gvl );
251+ result7 = __riscv_vfmacc_vf_f32m1 (result7 , B7 , A0 , gvl );
252+ #else
177253 __bf16 B0 = BB [bi + 0 ];
178254 __bf16 B1 = BB [bi + 1 ];
179255 __bf16 B2 = BB [bi + 2 ];
@@ -195,6 +271,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
195271 result5 = __riscv_vfwmaccbf16_vf_f32m1 (result5 , B5 , A0 , gvl );
196272 result6 = __riscv_vfwmaccbf16_vf_f32m1 (result6 , B6 , A0 , gvl );
197273 result7 = __riscv_vfwmaccbf16_vf_f32m1 (result7 , B7 , A0 , gvl );
274+ #endif
198275 }
199276
200277 BLASLONG ci = n_top * ldc + m_top ;
@@ -356,6 +433,22 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
356433 vfloat32m2_t result3 = __riscv_vfmv_v_f_f32m2 (0.0f , gvl );
357434
358435 for (BLASLONG k = 0 ; k < K ; k ++ ) {
436+ #ifdef BF16_WIDEN_ONE
437+ float B0 = (float )(BB [bi + 0 ]);
438+ float B1 = (float )(BB [bi + 1 ]);
439+ float B2 = (float )(BB [bi + 2 ]);
440+ float B3 = (float )(BB [bi + 3 ]);
441+ bi += 4 ;
442+
443+ vbfloat16m1_t A00 = __riscv_vle16_v_bf16m1 ( & AA [ai + 0 * gvl ], gvl );
444+ vfloat32m2_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m2 (A00 , gvl );
445+ ai += 16 ;
446+
447+ result0 = __riscv_vfmacc_vf_f32m2 (result0 , B0 , A0 , gvl );
448+ result1 = __riscv_vfmacc_vf_f32m2 (result1 , B1 , A0 , gvl );
449+ result2 = __riscv_vfmacc_vf_f32m2 (result2 , B2 , A0 , gvl );
450+ result3 = __riscv_vfmacc_vf_f32m2 (result3 , B3 , A0 , gvl );
451+ #else
359452 __bf16 B0 = BB [bi + 0 ];
360453 __bf16 B1 = BB [bi + 1 ];
361454 __bf16 B2 = BB [bi + 2 ];
@@ -369,6 +462,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
369462 result1 = __riscv_vfwmaccbf16_vf_f32m2 (result1 , B1 , A0 , gvl );
370463 result2 = __riscv_vfwmaccbf16_vf_f32m2 (result2 , B2 , A0 , gvl );
371464 result3 = __riscv_vfwmaccbf16_vf_f32m2 (result3 , B3 , A0 , gvl );
465+ #endif
372466 }
373467
374468 BLASLONG ci = n_top * ldc + m_top ;
@@ -403,6 +497,22 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
403497 vfloat32m1_t result3 = __riscv_vfmv_v_f_f32m1 (0.0f , gvl );
404498
405499 for (BLASLONG k = 0 ; k < K ; k ++ ) {
500+ #ifdef BF16_WIDEN_ONE
501+ float B0 = (float )(BB [bi + 0 ]);
502+ float B1 = (float )(BB [bi + 1 ]);
503+ float B2 = (float )(BB [bi + 2 ]);
504+ float B3 = (float )(BB [bi + 3 ]);
505+ bi += 4 ;
506+
507+ vbfloat16mf2_t A00 = __riscv_vle16_v_bf16mf2 ( & AA [ai + 0 * gvl ], gvl );
508+ vfloat32m1_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m1 (A00 , gvl );
509+ ai += 8 ;
510+
511+ result0 = __riscv_vfmacc_vf_f32m1 (result0 , B0 , A0 , gvl );
512+ result1 = __riscv_vfmacc_vf_f32m1 (result1 , B1 , A0 , gvl );
513+ result2 = __riscv_vfmacc_vf_f32m1 (result2 , B2 , A0 , gvl );
514+ result3 = __riscv_vfmacc_vf_f32m1 (result3 , B3 , A0 , gvl );
515+ #else
406516 __bf16 B0 = BB [bi + 0 ];
407517 __bf16 B1 = BB [bi + 1 ];
408518 __bf16 B2 = BB [bi + 2 ];
@@ -416,6 +526,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
416526 result1 = __riscv_vfwmaccbf16_vf_f32m1 (result1 , B1 , A0 , gvl );
417527 result2 = __riscv_vfwmaccbf16_vf_f32m1 (result2 , B2 , A0 , gvl );
418528 result3 = __riscv_vfwmaccbf16_vf_f32m1 (result3 , B3 , A0 , gvl );
529+ #endif
419530 }
420531
421532 BLASLONG ci = n_top * ldc + m_top ;
@@ -451,6 +562,22 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
451562 vfloat32m1_t result3 = __riscv_vfmv_v_f_f32m1 (0.0f , gvl );
452563
453564 for (BLASLONG k = 0 ; k < K ; ++ k ) {
565+ #ifdef BF16_WIDEN_ONE
566+ float B0 = (float )(BB [bi + 0 ]);
567+ float B1 = (float )(BB [bi + 1 ]);
568+ float B2 = (float )(BB [bi + 2 ]);
569+ float B3 = (float )(BB [bi + 3 ]);
570+ bi += 4 ;
571+
572+ vbfloat16mf4_t A00 = __riscv_vle16_v_bf16mf4 ( & AA [ai + 0 * gvl ], gvl );
573+ vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1 (__riscv_vfwcvtbf16_f_f_v_f32mf2 (A00 , gvl ));
574+ ai += 4 ;
575+
576+ result0 = __riscv_vfmacc_vf_f32m1 (result0 , B0 , A0 , gvl );
577+ result1 = __riscv_vfmacc_vf_f32m1 (result1 , B1 , A0 , gvl );
578+ result2 = __riscv_vfmacc_vf_f32m1 (result2 , B2 , A0 , gvl );
579+ result3 = __riscv_vfmacc_vf_f32m1 (result3 , B3 , A0 , gvl );
580+ #else
454581 __bf16 B0 = BB [bi + 0 ];
455582 __bf16 B1 = BB [bi + 1 ];
456583 __bf16 B2 = BB [bi + 2 ];
@@ -464,6 +591,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
464591 result1 = __riscv_vfwmaccbf16_vf_f32m1 (result1 , B1 , A0 , gvl );
465592 result2 = __riscv_vfwmaccbf16_vf_f32m1 (result2 , B2 , A0 , gvl );
466593 result3 = __riscv_vfwmaccbf16_vf_f32m1 (result3 , B3 , A0 , gvl );
594+ #endif
467595 }
468596
469597 BLASLONG ci = n_top * ldc + m_top ;
0 commit comments