Skip to content

Commit 89f10ba

Browse files
ggml-hexagon: flash-attention and reduce-sum optimizations (ggml-org#19141)
* wip * ggml-hexagon: add vectorized dot product function for FP32 and FP16 accumulation * ggml-hexagon: optimize dot product functions for FP16 and FP32 with new vectorized implementations * wip * ggml-hexagon: optimize hvx_vec_dump_f32_n and hvx_vec_reduce_sum_qf32x2 functions for improved performance * ggml-hexagon: refactor dot product functions to use a common loading function for improved readability * optimize vector dot product functions to use unified reduction for improved performance * wip * ggml-hexagon: add vectorized dot product function for FP32 and FP16 accumulation * ggml-hexagon: optimize dot product functions for FP16 and FP32 with new vectorized implementations * wip * ggml-hexagon: optimize hvx_vec_dump_f32_n and hvx_vec_reduce_sum_qf32x2 functions for improved performance * ggml-hexagon: refactor dot product functions to use a common loading function for improved readability * optimize vector dot product functions to use unified reduction for improved performance * hexagon: optimize reduce-sum for v75+ * hexagon: always keep row_sums in sf/fp32 * ggml-hexagon: enhance directory checks for HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT * fix compiling error after rebase --------- Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
1 parent 3dd9591 commit 89f10ba

7 files changed

Lines changed: 248 additions & 92 deletions

File tree

ggml/src/ggml-hexagon/CMakeLists.txt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
22
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
33

4-
if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}" OR NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
5-
message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
4+
if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}")
5+
message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT point to the correct Hexagon SDK installation.")
6+
endif()
7+
8+
if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
9+
message("Try to read HEXAGON_TOOLS_ROOT from hexagon_sdk.json")
10+
file(READ "${HEXAGON_SDK_ROOT}/hexagon_sdk.json" HEXAGON_SDK_CONFIG_PATH)
11+
string(JSON HEXAGON_TOOLS_PATH GET ${HEXAGON_SDK_CONFIG_PATH} "root" "tools" "info" 0 "path")
12+
message("Found HEXAGON_TOOLS_PATH: ${HEXAGON_TOOLS_PATH}")
13+
set(HEXAGON_TOOLS_ROOT "${HEXAGON_SDK_ROOT}/${HEXAGON_TOOLS_PATH}")
14+
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
15+
if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
16+
message(FATAL_ERROR "Make sure HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
17+
endif()
618
endif()
719

820
message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels")

ggml/src/ggml-hexagon/htp/flash-attn-ops.c

Lines changed: 136 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717
#include "htp-msg.h"
1818
#include "htp-ops.h"
1919

20+
static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
21+
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
22+
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
23+
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
24+
}
25+
2026
// Dot product of FP32 and FP16 vectors, accumulating to float
2127
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
2228
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
@@ -33,23 +39,19 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict
3339
#pragma unroll(4)
3440
for (i = 0; i < nvec; i++) {
3541
// Load y (fp32) and convert into fp16
36-
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
37-
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
38-
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
42+
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
3943

4044
// Load x (fp16)
4145
HVX_Vector x_hf = vx[i];
4246

4347
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
4448

45-
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
49+
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
4650
}
4751

4852
if (nloe) {
4953
// Load y (fp32) and convert into fp16
50-
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
51-
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
52-
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
54+
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
5355

5456
// Load x (fp16)
5557
HVX_Vector x_hf = vx[i];
@@ -62,13 +64,72 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict
6264

6365
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
6466

65-
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
67+
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
6668
}
6769

68-
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s));
69-
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
70+
rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
71+
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
72+
}
73+
74+
// Dot product of FP32 and FP16 vectors, accumulating to float
75+
static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
76+
const void * restrict y,
77+
const void * restrict x0,
78+
const void * restrict x1,
79+
unsigned int n,
80+
float s) {
81+
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
82+
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
83+
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
84+
85+
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
86+
uint32_t nloe = n % VLEN_FP16; // leftover elements
87+
88+
const HVX_Vector zero = Q6_V_vsplat_R(0);
89+
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
90+
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
91+
92+
uint32_t i = 0;
7093

71-
hvx_vec_store_u(r, 4, rsum);
94+
#pragma unroll(2)
95+
for (i = 0; i < nvec; i++) {
96+
// Load y (fp32) and convert into fp16
97+
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
98+
// Load x (fp16)
99+
HVX_Vector x0_hf = vx0[i];
100+
HVX_Vector x1_hf = vx1[i];
101+
102+
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
103+
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
104+
105+
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
106+
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
107+
}
108+
109+
if (nloe) {
110+
// Load y (fp32) and convert into fp16
111+
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
112+
113+
// Load x (fp16)
114+
HVX_Vector x0_hf = vx0[i];
115+
HVX_Vector x1_hf = vx1[i];
116+
117+
// Zero-out unused elements
118+
// Note that we need to clear both x and y because they may contain NANs
119+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
120+
x0_hf = Q6_V_vand_QV(bmask, x0_hf);
121+
x1_hf = Q6_V_vand_QV(bmask, x1_hf);
122+
y_hf = Q6_V_vand_QV(bmask, y_hf);
123+
124+
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
125+
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
126+
127+
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
128+
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
129+
}
130+
131+
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
132+
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
72133
}
73134

74135
// Dot product of two F16 vectors, accumulating to float
@@ -91,7 +152,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
91152

92153
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
93154

94-
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
155+
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
95156
}
96157

97158
if (nloe) {
@@ -103,12 +164,62 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
103164

104165
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
105166

106-
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
167+
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
168+
}
169+
170+
rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
171+
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
172+
}
173+
174+
static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
175+
const void * restrict y,
176+
const void * restrict x0,
177+
const void * restrict x1,
178+
unsigned int n,
179+
float s) {
180+
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
181+
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
182+
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
183+
184+
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
185+
uint32_t nloe = n % VLEN_FP16; // leftover elements
186+
187+
const HVX_Vector zero = Q6_V_vsplat_R(0);
188+
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
189+
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
190+
191+
uint32_t i = 0;
192+
193+
#pragma unroll(4)
194+
for (i = 0; i < nvec; i++) {
195+
HVX_Vector y_hf = vy[i];
196+
HVX_Vector x0_hf = vx0[i];
197+
HVX_Vector x1_hf = vx1[i];
198+
199+
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
200+
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
201+
202+
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
203+
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
204+
}
205+
206+
if (nloe) {
207+
HVX_Vector y_hf = vy[i];
208+
209+
// Load x (fp16) and zero-out unused elements
210+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
211+
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
212+
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
213+
214+
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
215+
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
216+
217+
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
218+
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
107219
}
108220

109-
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s));
110-
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
111-
hvx_vec_store_u(r, 4, rsum);
221+
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
222+
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
112223
}
113224

114225
// MAD: y (F32) += x (F16) * s (float)
@@ -317,20 +428,22 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
317428
// Inner loop processing the block from VTCM
318429
uint32_t ic = 0;
319430

431+
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
432+
320433
// Process in blocks of 32 (VLEN_FP32)
321-
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
434+
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
322435
HVX_Vector_x4 scores_x4;
323436
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
324437
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
325438
// 1. Compute scores
326-
float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE];
327-
for (int j = 0; j < VLEN_FP32; ++j) {
439+
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
440+
for (int j = 0; j < VLEN_FP32; j += 2) {
328441
const uint32_t cur_ic = ic + j;
329442
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
330-
if (q->type == HTP_TYPE_F32) {
331-
hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
443+
if (is_q_fp32) {
444+
hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
332445
} else {
333-
hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
446+
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
334447
}
335448
}
336449

@@ -403,7 +516,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
403516
float s_val;
404517
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
405518

406-
if (q->type == HTP_TYPE_F32) {
519+
if (is_q_fp32) {
407520
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
408521
} else {
409522
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);

ggml/src/ggml-hexagon/htp/hvx-dump.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,16 @@ static void hvx_vec_dump_f16(char * pref, HVX_Vector v) {
2828
}
2929

3030
static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) {
31-
union {
32-
HVX_Vector v;
33-
float d[32];
34-
} u = { .v = v };
31+
HVX_VectorAlias u = { .v = v };
3532

3633
const uint32_t n0 = n / 16;
3734
const uint32_t n1 = n % 16;
3835
int i = 0;
3936
for (; i < n0; i++) {
40-
hex_dump_f32_line(pref, u.d + (16 * i), 16);
37+
hex_dump_f32_line(pref, u.fp32 + (16 * i), 16);
4138
}
4239
if (n1) {
43-
hex_dump_f32_line(pref, u.d + (16 * i), n1);
40+
hex_dump_f32_line(pref, u.fp32 + (16 * i), n1);
4441
}
4542
}
4643

ggml/src/ggml-hexagon/htp/hvx-reduce.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,45 @@ static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) {
4444
return hvx_vec_reduce_sum_n_qf32(in, 32);
4545
}
4646

47+
#if __HVX_ARCH__ > 75
48+
49+
static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
50+
HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
51+
HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
52+
53+
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));
54+
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));
55+
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));
56+
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16));
57+
return sum_sf;
58+
}
59+
60+
static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
61+
unsigned int total = n * 4; // total vec nbytes
62+
unsigned int width = 4; // fp32 nbytes
63+
64+
HVX_Vector sum = in, sum_t;
65+
while (width < total) {
66+
sum_t = Q6_V_vror_VR(sum, width); // rotate right
67+
sum = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum
68+
width = width << 1;
69+
}
70+
return sum;
71+
}
72+
73+
#else
74+
75+
static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
76+
HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
77+
HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
78+
79+
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));
80+
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));
81+
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));
82+
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16));
83+
return Q6_Vsf_equals_Vqf32(sum_qf);
84+
}
85+
4786
static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
4887
unsigned int total = n * 4; // total vec nbytes
4988
unsigned int width = 4; // fp32 nbytes
@@ -57,6 +96,8 @@ static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n)
5796
return sum;
5897
}
5998

99+
#endif
100+
60101
static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) {
61102
return hvx_vec_reduce_sum_n_f32(in, 32);
62103
}

0 commit comments

Comments
 (0)