@@ -13,10 +13,8 @@ static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
1313 BLASLONG tail_index_32 = n & (~31 );
1414
1515 __m256 c_256 , s_256 ;
16- if (n >= 8 ) {
17- c_256 = _mm256_set1_ps (c );
18- s_256 = _mm256_set1_ps (s );
19- }
16+ c_256 = _mm256_set1_ps (c );
17+ s_256 = _mm256_set1_ps (s );
2018
2119 __m256 x0 , x1 , x2 , x3 ;
2220 __m256 y0 , y1 , y2 , y3 ;
@@ -77,10 +75,20 @@ static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
7775 _mm256_storeu_ps (& y [i ], t0 );
7876 }
7977
80- for (i = tail_index_8 ; i < n ; ++ i ) {
81- FLOAT temp = c * x [i ] + s * y [i ];
82- y [i ] = c * y [i ] - s * x [i ];
83- x [i ] = temp ;
78+ if ((n & 7 ) > 0 ) {
79+ const int32_t mask_v [16 ] = {-1 ,-1 ,-1 ,-1 , -1 ,-1 ,-1 ,-1 ,0 ,0 ,0 ,0 ,0 ,0 ,0 ,0 };
80+ __m256i tail_mask = _mm256_loadu_si256 ((__m256i * )& mask_v [8 - (n & 7 )]);
81+
82+ x0 = _mm256_maskload_ps (& x [tail_index_8 ], tail_mask );
83+ y0 = _mm256_maskload_ps (& y [tail_index_8 ], tail_mask );
84+
85+ t0 = _mm256_mul_ps (s_256 , y0 );
86+ t0 = _mm256_fmadd_ps (c_256 , x0 , t0 );
87+ _mm256_maskstore_ps (& x [tail_index_8 ], tail_mask , t0 );
88+
89+ t0 = _mm256_mul_ps (s_256 , x0 );
90+ t0 = _mm256_fmsub_ps (c_256 , y0 , t0 );
91+ _mm256_maskstore_ps (& y [tail_index_8 ], tail_mask , t0 );
8492 }
8593}
8694#endif
0 commit comments