Skip to content

Commit 8710e5f

Browse files
hexagon: improve RMS_NORM and DIV accuracy (ggml-org#21251)
* hexagon-rms_norm: fix RMS_NORM for non-aligned tensor sizes Co-authored-by: Krishna Sridhar <srsr@qti.qualcomm.com> * hexagon-div: perform DIV in fp16 domain for lower dsp archs --------- Co-authored-by: Krishna Sridhar <srsr@qti.qualcomm.com>
1 parent 1d6d4cf commit 8710e5f

2 files changed

Lines changed: 97 additions & 30 deletions

File tree

ggml/src/ggml-hexagon/htp/hvx-div.h

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
#if __HVX_ARCH__ < 79
1818
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
19+
#define HVX_OP_MUL_F16(a, b) Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b))
1920
#else
2021
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
22+
#define HVX_OP_MUL_F16(a, b) Q6_Vhf_vmpy_VhfVhf(a, b)
2123
#endif
2224

2325
// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32.
@@ -43,46 +45,67 @@ static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX
4345
return res;
4446
}
4547

46-
#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \
47-
do { \
48-
dst_type * restrict vdst = (dst_type *) dst; \
49-
src_type * restrict vsrc = (src_type *) src; \
50-
HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
51-
\
52-
const uint32_t nvec = n / VLEN_FP16; \
53-
const uint32_t nloe = n % VLEN_FP16; \
54-
\
55-
uint32_t i = 0; \
56-
\
57-
_Pragma("unroll(4)") \
58-
for (; i < nvec; i++) { \
59-
HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
60-
vdst[i] = res; \
61-
} \
62-
if (nloe) { \
63-
HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
64-
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
65-
} \
48+
// Variant for <v79: Use pre-computed f16 reciprocal constant
49+
static inline HVX_Vector hvx_div_mul_f16_const_using_f16(HVX_Vector vec1_hf, HVX_Vector const_inv_hf) {
50+
// Multiply by pre-computed f16 reciprocal constant
51+
return HVX_OP_MUL_F16(vec1_hf, const_inv_hf);
52+
}
53+
54+
#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \
55+
do { \
56+
dst_type * restrict vdst = (dst_type *) dst; \
57+
src_type * restrict vsrc = (src_type *) src; \
58+
\
59+
HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
60+
\
61+
const uint32_t nvec = n / VLEN_FP16; \
62+
const uint32_t nloe = n % VLEN_FP16; \
63+
\
64+
uint32_t i = 0; \
65+
\
66+
_Pragma("unroll(4)") \
67+
for (; i < nvec; i++) { \
68+
HVX_Vector res; \
69+
if (__HVX_ARCH__ < 79) { \
70+
res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16); \
71+
} else { \
72+
res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
73+
} \
74+
vdst[i] = res; \
75+
} \
76+
if (nloe) { \
77+
HVX_Vector res; \
78+
if (__HVX_ARCH__ < 79) { \
79+
res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16); \
80+
} else { \
81+
res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
82+
} \
83+
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
84+
} \
6685
} while(0)
6786

6887
static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
6988
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
89+
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
7090
assert((uintptr_t) dst % 128 == 0);
7191
assert((uintptr_t) src % 128 == 0);
7292
hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
7393
}
7494
static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
7595
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
96+
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
7697
assert((uintptr_t) dst % 128 == 0);
7798
hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
7899
}
79100
static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
80101
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
102+
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
81103
assert((uintptr_t) src % 128 == 0);
82104
hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
83105
}
84106
static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
85107
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
108+
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
86109
hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
87110
}
88111

@@ -128,13 +151,25 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v
128151
return recip;
129152
}
130153

154+
// Hybrid approach: f16 reciprocal for <v79, f32 precision for >=v79
155+
static inline HVX_Vector hvx_vec_hybrid_div_f16(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector f16_nan_inf_mask, HVX_Vector vec_hf_one_1_0) {
156+
#if __HVX_ARCH__ < 79
157+
// For older architectures, use f16 reciprocal to avoid NaN/-inf issues
158+
HVX_Vector vec2_inv = hvx_vec_inverse_f16_guard(vec2, f16_nan_inf_mask);
159+
return HVX_OP_MUL_F16(vec1, vec2_inv);
160+
#else
161+
return hvx_vec_div_f16_using_f32(vec1, vec2, f32_nan_inf_mask, vec_hf_one_1_0);
162+
#endif
163+
}
164+
131165
#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \
132166
do { \
133167
dst_type * restrict vdst = (dst_type *) dst; \
134168
src0_type * restrict vsrc0 = (src0_type *) src0; \
135169
src1_type * restrict vsrc1 = (src1_type *) src1; \
136170
\
137-
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
171+
const HVX_Vector f32_nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
172+
const HVX_Vector f16_nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \
138173
const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
139174
\
140175
const uint32_t nvec = n / VLEN_FP16; \
@@ -144,11 +179,15 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v
144179
\
145180
_Pragma("unroll(4)") \
146181
for (; i < nvec; i++) { \
147-
HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
182+
HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \
183+
f32_nan_inf_mask, f16_nan_inf_mask, \
184+
hf_one); \
148185
vdst[i] = res; \
149186
} \
150187
if (nloe) { \
151-
HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
188+
HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \
189+
f32_nan_inf_mask, f16_nan_inf_mask, \
190+
hf_one); \
152191
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
153192
} \
154193
} while(0)
@@ -247,5 +286,6 @@ HVX_DIV_DISPATCHER(hvx_div_f32)
247286
HVX_DIV_DISPATCHER(hvx_div_f16)
248287

249288
#undef HVX_OP_MUL_F32
289+
#undef HVX_OP_MUL_F16
250290

251291
#endif // HVX_DIV_H

ggml/src/ggml-hexagon/htp/unary-ops.c

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,34 +67,61 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
6767
uint8_t * restrict pad,
6868
const int num_elems,
6969
float epsilon) {
70+
(void)pad;
71+
7072
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
7173
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
7274

73-
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
75+
const int nvec = num_elems / VLEN_FP32; // number of full vectors
76+
const int nloe = num_elems % VLEN_FP32; // leftover elements
77+
78+
// Compute sum of squares for full vectors
79+
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
7480
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
7581

76-
int step_of_1 = num_elems >> 5;
7782
#pragma unroll(4)
78-
for (int i = 0; i < step_of_1; i++) {
83+
for (int i = 0; i < nvec; i++) {
7984
HVX_Vector v1 = v_src[i];
8085
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
81-
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
86+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
87+
}
88+
89+
// Handle tail elements using vectorized ops with masking
90+
if (nloe > 0) {
91+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
92+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
93+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
94+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
8295
}
8396

84-
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
97+
// Reduce HVX sum
98+
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
8599

86100
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
87101
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
88102
HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
89103
HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
90104

105+
// Scale full vectors
91106
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
92107

93108
#pragma unroll(4)
94-
for (int i = 0; i < step_of_1; i++) {
109+
for (int i = 0; i < nvec; i++) {
95110
HVX_Vector v1 = v_src[i];
96111
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
97-
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
112+
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
113+
}
114+
115+
// Handle tail elements using vectorized ops with masking
116+
if (nloe > 0) {
117+
118+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
119+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
120+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
121+
HVX_Vector result = Q6_Vsf_equals_Vqf32(v2);
122+
123+
// Store with masking to avoid overwriting memory beyond the tensor
124+
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
98125
}
99126
}
100127

0 commit comments

Comments
 (0)