1616
1717#if __HVX_ARCH__ < 79
1818#define HVX_OP_MUL_F32 (a , b ) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
19+ #define HVX_OP_MUL_F16 (a , b ) Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b))
1920#else
2021#define HVX_OP_MUL_F32 (a , b ) Q6_Vsf_vmpy_VsfVsf(a, b)
22+ #define HVX_OP_MUL_F16 (a , b ) Q6_Vhf_vmpy_VhfVhf(a, b)
2123#endif
2224
2325// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32.
@@ -43,46 +45,67 @@ static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX
4345 return res ;
4446}
4547
46- #define hvx_div_scaler_f16_loop_body (dst_type , src_type , vec_store ) \
47- do { \
48- dst_type * restrict vdst = (dst_type *) dst; \
49- src_type * restrict vsrc = (src_type *) src; \
50- HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
51- \
52- const uint32_t nvec = n / VLEN_FP16; \
53- const uint32_t nloe = n % VLEN_FP16; \
54- \
55- uint32_t i = 0; \
56- \
57- _Pragma("unroll(4)") \
58- for (; i < nvec; i++) { \
59- HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
60- vdst[i] = res; \
61- } \
62- if (nloe) { \
63- HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
64- vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
65- } \
48+ // Variant for <v79: Use pre-computed f16 reciprocal constant
49+ static inline HVX_Vector hvx_div_mul_f16_const_using_f16 (HVX_Vector vec1_hf , HVX_Vector const_inv_hf ) {
50+ // Multiply by pre-computed f16 reciprocal constant
51+ return HVX_OP_MUL_F16 (vec1_hf , const_inv_hf );
52+ }
53+
54+ #define hvx_div_scaler_f16_loop_body (dst_type , src_type , vec_store ) \
55+ do { \
56+ dst_type * restrict vdst = (dst_type *) dst; \
57+ src_type * restrict vsrc = (src_type *) src; \
58+ \
59+ HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
60+ \
61+ const uint32_t nvec = n / VLEN_FP16; \
62+ const uint32_t nloe = n % VLEN_FP16; \
63+ \
64+ uint32_t i = 0; \
65+ \
66+ _Pragma("unroll(4)") \
67+ for (; i < nvec; i++) { \
68+ HVX_Vector res; \
69+ if (__HVX_ARCH__ < 79) { \
70+ res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16); \
71+ } else { \
72+ res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
73+ } \
74+ vdst[i] = res; \
75+ } \
76+ if (nloe) { \
77+ HVX_Vector res; \
78+ if (__HVX_ARCH__ < 79) { \
79+ res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16); \
80+ } else { \
81+ res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
82+ } \
83+ vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
84+ } \
6685 } while(0)
6786
6887static inline void hvx_div_scalar_f16_aa (uint8_t * restrict dst , const uint8_t * restrict src , const _Float16 val , uint32_t n ) {
6988 const HVX_Vector val_vec_f32 = hvx_vec_splat_f32 (1.0f /((float )val ));
89+ const HVX_Vector val_vec_f16 = hvx_vec_splat_f16 (1.0f / val );
7090 assert ((uintptr_t ) dst % 128 == 0 );
7191 assert ((uintptr_t ) src % 128 == 0 );
7292 hvx_div_scaler_f16_loop_body (HVX_Vector , HVX_Vector , hvx_vec_store_a );
7393}
7494static inline void hvx_div_scalar_f16_au (uint8_t * restrict dst , const uint8_t * restrict src , const _Float16 val , uint32_t n ) {
7595 const HVX_Vector val_vec_f32 = hvx_vec_splat_f32 (1.0f /((float )val ));
96+ const HVX_Vector val_vec_f16 = hvx_vec_splat_f16 (1.0f / val );
7697 assert ((uintptr_t ) dst % 128 == 0 );
7798 hvx_div_scaler_f16_loop_body (HVX_Vector , HVX_UVector , hvx_vec_store_a );
7899}
79100static inline void hvx_div_scalar_f16_ua (uint8_t * restrict dst , const uint8_t * restrict src , const _Float16 val , uint32_t n ) {
80101 const HVX_Vector val_vec_f32 = hvx_vec_splat_f32 (1.0f /((float )val ));
102+ const HVX_Vector val_vec_f16 = hvx_vec_splat_f16 (1.0f / val );
81103 assert ((uintptr_t ) src % 128 == 0 );
82104 hvx_div_scaler_f16_loop_body (HVX_UVector , HVX_Vector , hvx_vec_store_u );
83105}
84106static inline void hvx_div_scalar_f16_uu (uint8_t * restrict dst , const uint8_t * restrict src , const _Float16 val , uint32_t n ) {
85107 const HVX_Vector val_vec_f32 = hvx_vec_splat_f32 (1.0f /((float )val ));
108+ const HVX_Vector val_vec_f16 = hvx_vec_splat_f16 (1.0f / val );
86109 hvx_div_scaler_f16_loop_body (HVX_UVector , HVX_UVector , hvx_vec_store_u );
87110}
88111
@@ -128,13 +151,25 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v
128151 return recip ;
129152}
130153
154+ // Hybrid approach: f16 reciprocal for <v79, f32 precision for >=v79
155+ static inline HVX_Vector hvx_vec_hybrid_div_f16 (HVX_Vector vec1 , HVX_Vector vec2 , HVX_Vector f32_nan_inf_mask , HVX_Vector f16_nan_inf_mask , HVX_Vector vec_hf_one_1_0 ) {
156+ #if __HVX_ARCH__ < 79
157+ // For older architectures, use f16 reciprocal to avoid NaN/-inf issues
158+ HVX_Vector vec2_inv = hvx_vec_inverse_f16_guard (vec2 , f16_nan_inf_mask );
159+ return HVX_OP_MUL_F16 (vec1 , vec2_inv );
160+ #else
161+ return hvx_vec_div_f16_using_f32 (vec1 , vec2 , f32_nan_inf_mask , vec_hf_one_1_0 );
162+ #endif
163+ }
164+
131165#define hvx_div_f16_loop_body (dst_type , src0_type , src1_type , vec_store ) \
132166 do { \
133167 dst_type * restrict vdst = (dst_type *) dst; \
134168 src0_type * restrict vsrc0 = (src0_type *) src0; \
135169 src1_type * restrict vsrc1 = (src1_type *) src1; \
136170 \
137- const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
171+ const HVX_Vector f32_nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
172+ const HVX_Vector f16_nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \
138173 const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
139174 \
140175 const uint32_t nvec = n / VLEN_FP16; \
@@ -144,11 +179,15 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v
144179 \
145180 _Pragma("unroll(4)") \
146181 for (; i < nvec; i++) { \
147- HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
182+ HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \
183+ f32_nan_inf_mask, f16_nan_inf_mask, \
184+ hf_one); \
148185 vdst[i] = res; \
149186 } \
150187 if (nloe) { \
151- HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
188+ HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \
189+ f32_nan_inf_mask, f16_nan_inf_mask, \
190+ hf_one); \
152191 vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
153192 } \
154193 } while(0)
@@ -247,5 +286,6 @@ HVX_DIV_DISPATCHER(hvx_div_f32)
247286HVX_DIV_DISPATCHER (hvx_div_f16 )
248287
249288#undef HVX_OP_MUL_F32
289+ #undef HVX_OP_MUL_F16
250290
251291#endif // HVX_DIV_H
0 commit comments