1717#include "htp-msg.h"
1818#include "htp-ops.h"
1919
20+ static inline HVX_Vector hvx_load_f32_to_f16 (const HVX_Vector * restrict src , const HVX_Vector zero ) {
21+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf (src [0 ], zero ); // 32 elements
22+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf (src [1 ], zero ); // 32 elements
23+ return Q6_Vh_vdeal_Vh (Q6_Vhf_equals_Wqf32 (Q6_W_vcombine_VV (y1_qf , y0_qf )));
24+ }
25+
2026// Dot product of FP32 and FP16 vectors, accumulating to float
2127static inline void hvx_dot_f32_f16_aa (float * restrict r , const void * restrict y , const void * restrict x , unsigned int n , float s ) {
2228 const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y ; // fp32
@@ -33,23 +39,19 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict
3339 #pragma unroll(4)
3440 for (i = 0 ; i < nvec ; i ++ ) {
3541 // Load y (fp32) and convert into fp16
36- HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf (vy [i * 2 + 0 ], zero ); // 32 elements
37- HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf (vy [i * 2 + 1 ], zero ); // 32 elements
38- HVX_Vector y_hf = Q6_Vh_vdeal_Vh (Q6_Vhf_equals_Wqf32 (Q6_W_vcombine_VV (y1_qf , y0_qf )));
42+ HVX_Vector y_hf = hvx_load_f32_to_f16 (& vy [i * 2 ], zero );
3943
4044 // Load x (fp16)
4145 HVX_Vector x_hf = vx [i ];
4246
4347 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf (x_hf , y_hf );
4448
45- rsum = Q6_Vqf32_vadd_Vqf32Vqf32 ( rsum , Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy_qf ), Q6_V_hi_W (xy_qf )));
49+ rsum = Q6_Vsf_equals_Vqf32 ( Q6_Vqf32_vadd_Vqf32Vsf ( Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy_qf ), Q6_V_hi_W (xy_qf )), rsum ));
4650 }
4751
4852 if (nloe ) {
4953 // Load y (fp32) and convert into fp16
50- HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf (vy [i * 2 + 0 ], zero ); // 32 elements
51- HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf (vy [i * 2 + 1 ], zero ); // 32 elements
52- HVX_Vector y_hf = Q6_Vh_vdeal_Vh (Q6_Vhf_equals_Wqf32 (Q6_W_vcombine_VV (y1_qf , y0_qf )));
54+ HVX_Vector y_hf = hvx_load_f32_to_f16 (& vy [i * 2 ], zero );
5355
5456 // Load x (fp16)
5557 HVX_Vector x_hf = vx [i ];
@@ -62,13 +64,72 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict
6264
6365 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf (x_hf , y_hf );
6466
65- rsum = Q6_Vqf32_vadd_Vqf32Vqf32 ( rsum , Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy_qf ), Q6_V_hi_W (xy_qf )));
67+ rsum = Q6_Vsf_equals_Vqf32 ( Q6_Vqf32_vadd_Vqf32Vsf ( Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy_qf ), Q6_V_hi_W (xy_qf )), rsum ));
6668 }
6769
68- rsum = Q6_Vqf32_vmpy_VsfVsf (Q6_Vsf_equals_Vqf32 (rsum ), hvx_vec_splat_f32 (s ));
69- rsum = Q6_Vsf_equals_Vqf32 (hvx_vec_reduce_sum_qf32 (rsum ));
70+ rsum = Q6_Vqf32_vmpy_VsfVsf (hvx_vec_splat_f32 (s ), hvx_vec_reduce_sum_f32 (rsum ));
71+ hvx_vec_store_u (r , 4 , Q6_Vsf_equals_Vqf32 (rsum ));
72+ }
73+
74+ // Dot product of FP32 and FP16 vectors, accumulating to float
75+ static inline void hvx_dot_f32_f16_aa_rx2 (float * restrict r ,
76+ const void * restrict y ,
77+ const void * restrict x0 ,
78+ const void * restrict x1 ,
79+ unsigned int n ,
80+ float s ) {
81+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y ; // fp32
82+ const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0 ; // fp16
83+ const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1 ; // fp16
84+
85+ uint32_t nvec = n / VLEN_FP16 ; // num full fp16 hvx vectors
86+ uint32_t nloe = n % VLEN_FP16 ; // leftover elements
87+
88+ const HVX_Vector zero = Q6_V_vsplat_R (0 );
89+ HVX_Vector rsum0 = Q6_V_vsplat_R (0 );
90+ HVX_Vector rsum1 = Q6_V_vsplat_R (0 );
91+
92+ uint32_t i = 0 ;
7093
71- hvx_vec_store_u (r , 4 , rsum );
94+ #pragma unroll(2)
95+ for (i = 0 ; i < nvec ; i ++ ) {
96+ // Load y (fp32) and convert into fp16
97+ HVX_Vector y_hf = hvx_load_f32_to_f16 (& vy [i * 2 ], zero );
98+ // Load x (fp16)
99+ HVX_Vector x0_hf = vx0 [i ];
100+ HVX_Vector x1_hf = vx1 [i ];
101+
102+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf (x0_hf , y_hf );
103+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf (x1_hf , y_hf );
104+
105+ rsum0 = Q6_Vsf_equals_Vqf32 (Q6_Vqf32_vadd_Vqf32Vsf (Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy0_qf ), Q6_V_hi_W (xy0_qf )), rsum0 ));
106+ rsum1 = Q6_Vsf_equals_Vqf32 (Q6_Vqf32_vadd_Vqf32Vsf (Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy1_qf ), Q6_V_hi_W (xy1_qf )), rsum1 ));
107+ }
108+
109+ if (nloe ) {
110+ // Load y (fp32) and convert into fp16
111+ HVX_Vector y_hf = hvx_load_f32_to_f16 (& vy [i * 2 ], zero );
112+
113+ // Load x (fp16)
114+ HVX_Vector x0_hf = vx0 [i ];
115+ HVX_Vector x1_hf = vx1 [i ];
116+
117+ // Zero-out unused elements
118+ // Note that we need to clear both x and y because they may contain NANs
119+ HVX_VectorPred bmask = Q6_Q_vsetq_R (nloe * 2 );
120+ x0_hf = Q6_V_vand_QV (bmask , x0_hf );
121+ x1_hf = Q6_V_vand_QV (bmask , x1_hf );
122+ y_hf = Q6_V_vand_QV (bmask , y_hf );
123+
124+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf (x0_hf , y_hf );
125+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf (x1_hf , y_hf );
126+
127+ rsum0 = Q6_Vsf_equals_Vqf32 (Q6_Vqf32_vadd_Vqf32Vsf (Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy0_qf ), Q6_V_hi_W (xy0_qf )), rsum0 ));
128+ rsum1 = Q6_Vsf_equals_Vqf32 (Q6_Vqf32_vadd_Vqf32Vsf (Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy1_qf ), Q6_V_hi_W (xy1_qf )), rsum1 ));
129+ }
130+
131+ HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf (hvx_vec_splat_f32 (s ), hvx_vec_reduce_sum_f32x2 (rsum0 , rsum1 ));
132+ hvx_vec_store_u (r , 8 , Q6_Vsf_equals_Vqf32 (rsum ));
72133}
73134
74135// Dot product of two F16 vectors, accumulating to float
@@ -91,7 +152,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
91152
92153 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf (x_hf , y_hf );
93154
94- rsum = Q6_Vqf32_vadd_Vqf32Vqf32 ( rsum , Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy_qf ), Q6_V_hi_W (xy_qf )));
155+ rsum = Q6_Vsf_equals_Vqf32 ( Q6_Vqf32_vadd_Vqf32Vsf ( Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy_qf ), Q6_V_hi_W (xy_qf )), rsum ));
95156 }
96157
97158 if (nloe ) {
@@ -103,12 +164,62 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
103164
104165 HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf (x_hf , y_hf );
105166
106- rsum = Q6_Vqf32_vadd_Vqf32Vqf32 (rsum , Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy_qf ), Q6_V_hi_W (xy_qf )));
167+ rsum = Q6_Vsf_equals_Vqf32 (Q6_Vqf32_vadd_Vqf32Vsf (Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy_qf ), Q6_V_hi_W (xy_qf )), rsum ));
168+ }
169+
170+ rsum = Q6_Vqf32_vmpy_VsfVsf (hvx_vec_splat_f32 (s ), hvx_vec_reduce_sum_f32 (rsum ));
171+ hvx_vec_store_u (r , 4 , Q6_Vsf_equals_Vqf32 (rsum ));
172+ }
173+
174+ static inline void hvx_dot_f16_f16_aa_rx2 (float * restrict r ,
175+ const void * restrict y ,
176+ const void * restrict x0 ,
177+ const void * restrict x1 ,
178+ unsigned int n ,
179+ float s ) {
180+ const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0 ; // fp16
181+ const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1 ; // fp16
182+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y ; // fp16
183+
184+ uint32_t nvec = n / VLEN_FP16 ; // num full fp16 hvx vectors
185+ uint32_t nloe = n % VLEN_FP16 ; // leftover elements
186+
187+ const HVX_Vector zero = Q6_V_vsplat_R (0 );
188+ HVX_Vector rsum0 = Q6_V_vsplat_R (0 );
189+ HVX_Vector rsum1 = Q6_V_vsplat_R (0 );
190+
191+ uint32_t i = 0 ;
192+
193+ #pragma unroll(4)
194+ for (i = 0 ; i < nvec ; i ++ ) {
195+ HVX_Vector y_hf = vy [i ];
196+ HVX_Vector x0_hf = vx0 [i ];
197+ HVX_Vector x1_hf = vx1 [i ];
198+
199+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf (x0_hf , y_hf );
200+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf (x1_hf , y_hf );
201+
202+ rsum0 = Q6_Vsf_equals_Vqf32 (Q6_Vqf32_vadd_Vqf32Vsf (Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy0_qf ), Q6_V_hi_W (xy0_qf )), rsum0 ));
203+ rsum1 = Q6_Vsf_equals_Vqf32 (Q6_Vqf32_vadd_Vqf32Vsf (Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy1_qf ), Q6_V_hi_W (xy1_qf )), rsum1 ));
204+ }
205+
206+ if (nloe ) {
207+ HVX_Vector y_hf = vy [i ];
208+
209+ // Load x (fp16) and zero-out unused elements
210+ HVX_VectorPred bmask = Q6_Q_vsetq_R (nloe * 2 );
211+ HVX_Vector x0_hf = Q6_V_vand_QV (bmask , vx0 [i ]);
212+ HVX_Vector x1_hf = Q6_V_vand_QV (bmask , vx1 [i ]);
213+
214+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf (x0_hf , y_hf );
215+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf (x1_hf , y_hf );
216+
217+ rsum0 = Q6_Vsf_equals_Vqf32 (Q6_Vqf32_vadd_Vqf32Vsf (Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy0_qf ), Q6_V_hi_W (xy0_qf )), rsum0 ));
218+ rsum1 = Q6_Vsf_equals_Vqf32 (Q6_Vqf32_vadd_Vqf32Vsf (Q6_Vqf32_vadd_Vqf32Vqf32 (Q6_V_lo_W (xy1_qf ), Q6_V_hi_W (xy1_qf )), rsum1 ));
107219 }
108220
109- rsum = Q6_Vqf32_vmpy_VsfVsf (Q6_Vsf_equals_Vqf32 (rsum ), hvx_vec_splat_f32 (s ));
110- rsum = Q6_Vsf_equals_Vqf32 (hvx_vec_reduce_sum_qf32 (rsum ));
111- hvx_vec_store_u (r , 4 , rsum );
221+ HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf (hvx_vec_splat_f32 (s ), hvx_vec_reduce_sum_f32x2 (rsum0 , rsum1 ));
222+ hvx_vec_store_u (r , 8 , Q6_Vsf_equals_Vqf32 (rsum ));
112223}
113224
114225// MAD: y (F32) += x (F16) * s (float)
@@ -317,20 +428,22 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
317428 // Inner loop processing the block from VTCM
318429 uint32_t ic = 0 ;
319430
431+ const bool is_q_fp32 = (q -> type == HTP_TYPE_F32 );
432+
320433 // Process in blocks of 32 (VLEN_FP32)
321- static_assert (FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 = = 4 , "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage" );
434+ static_assert (FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 < = 4 , "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage" );
322435 HVX_Vector_x4 scores_x4 ;
323436 HVX_Vector v_max = hvx_vec_splat_f32 (- INFINITY );
324437 for (uint32_t iv = 0 ; ic + VLEN_FP32 <= current_block_size ; ic += VLEN_FP32 , ++ iv ) {
325438 // 1. Compute scores
326- float __attribute__((aligned (VLEN ))) scores_arr [FLASH_ATTN_BLOCK_SIZE ];
327- for (int j = 0 ; j < VLEN_FP32 ; ++ j ) {
439+ float __attribute__((aligned (VLEN ))) scores_arr [VLEN_FP32 ];
440+ for (int j = 0 ; j < VLEN_FP32 ; j += 2 ) {
328441 const uint32_t cur_ic = ic + j ;
329442 const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded ;
330- if (q -> type == HTP_TYPE_F32 ) {
331- hvx_dot_f32_f16_aa (& scores_arr [j ], q_ptr_vtcm , k_ptr , DK , scale );
443+ if (is_q_fp32 ) {
444+ hvx_dot_f32_f16_aa_rx2 (& scores_arr [j ], q_ptr_vtcm , k_ptr , k_ptr + size_k_row_padded , DK , scale );
332445 } else {
333- hvx_dot_f16_f16_aa (& scores_arr [j ], q_ptr_vtcm , k_ptr , DK , scale );
446+ hvx_dot_f16_f16_aa_rx2 (& scores_arr [j ], q_ptr_vtcm , k_ptr , k_ptr + size_k_row_padded , DK , scale );
334447 }
335448 }
336449
@@ -403,7 +516,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
403516 float s_val ;
404517 const uint8_t * k_ptr = k_base + ic * size_k_row_padded ;
405518
406- if (q -> type == HTP_TYPE_F32 ) {
519+ if (is_q_fp32 ) {
407520 hvx_dot_f32_f16_aa (& s_val , q_ptr_vtcm , k_ptr , DK , scale );
408521 } else {
409522 hvx_dot_f16_f16_aa (& s_val , q_ptr_vtcm , k_ptr , DK , scale );
0 commit comments