Skip to content

Commit 856c3ad

Browse files
hexagon: eliminate scalar VTCM loads via HVX splat helpers (ggml-org#22993)
* hexagon: add hvx_vec_repl helpers and use those for splat-from-vtcm usecase * hmx-mm: optimize per-group scale handling * hmx-fa: optimize slope load from vtcm * hmx-fa: use aligned access where possible in hmx-utils * hexagon: add hvx_vec_repl_2x_f16 helper and consolidate repl helpers --------- Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
1 parent a9883db commit 856c3ad

6 files changed

Lines changed: 107 additions & 38 deletions

File tree

ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -760,8 +760,9 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
760760
// ALiBi slopes — only needed when has_alibi (scheme A)
761761
HVX_Vector v_slope0, v_slope1;
762762
if (args->has_alibi) {
763-
v_slope0 = hvx_vec_splat_f16(args->slopes[r + 0]);
764-
v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_splat_f16(args->slopes[r + 1]) : Q6_V_vzero();
763+
HVX_Vector v_s = hvx_vmemu(args->slopes + r);
764+
v_slope0 = hvx_vec_repl_f16(v_s);
765+
v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_repl_f16(Q6_V_vror_VR(v_s, 2)) : Q6_V_vzero();
765766
}
766767

767768
const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c)

ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,10 @@ static int hmx_compute_chunks(size_t vtcm_total,
180180
// Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes.
181181
// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles
182182
// of the same 32 packed bytes.
183-
static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(
184-
const uint8_t *packed_32, bool upper_nibbles,
185-
const __fp16 *scale, const HVX_Vector vlut_cvt) {
183+
static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) {
186184
HVX_Vector vq = hvx_vmemu(packed_32);
187185
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
188-
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
186+
HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale));
189187
// q4x4x2 stores two int4 values per byte. Keep only the selected nibble.
190188
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
191189
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
@@ -223,9 +221,10 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
223221
HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16]
224222

225223
// Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b
226-
HVX_VectorPred q64 = Q6_Q_vsetq_R(64);
227-
HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[0]), hvx_vec_splat_f16(scales_4[1]));
228-
HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[2]), hvx_vec_splat_f16(scales_4[3]));
224+
volatile HVX_Vector vscale = hvx_vmemu(scales_4);
225+
226+
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale);
227+
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4));
229228

230229
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
231230
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
@@ -237,10 +236,10 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
237236

238237
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
239238
static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) {
240-
HVX_Vector vq = hvx_vmemu(quants_32);
241-
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
242-
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq));
243-
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
239+
HVX_Vector vq = hvx_vmemu(quants_32);
240+
HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale));
241+
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq));
242+
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
244243
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
245244
}
246245

@@ -521,12 +520,8 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
521520
const uint8_t *r0 = vtcm_src + row0 * row_stride;
522521
const uint8_t *r1 = vtcm_src + row1 * row_stride;
523522

524-
HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx(
525-
(const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off));
526-
HVX_Vector v1 = (row1 < n_cols)
527-
? dequantize_x4x2_q8_0_group_hvx(
528-
(const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off))
529-
: Q6_V_vzero();
523+
HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off));
524+
HVX_Vector v1 = (row1 < n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero();
530525

531526
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
532527
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);

ggml/src/ggml-hexagon/htp/hmx-utils.h

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,25 +77,26 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
7777
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
7878
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
7979

80-
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
81-
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
82-
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
80+
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
81+
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
82+
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
83+
84+
assert(hex_is_aligned(p0, 128));
85+
assert(hex_is_aligned(p1, 128));
86+
assert(c_byte_step % 128 == 0);
8387

8488
if (p1) {
8589
for (int i = 0; i < n_c_iters; ++i) {
86-
HVX_Vector v0 = hvx_vmemu(p0);
87-
p0 += c_byte_step;
88-
HVX_Vector v1 = hvx_vmemu(p1);
89-
p1 += c_byte_step;
90+
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
91+
HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step;
9092
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
9193
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1);
9294
tile_base += dst_step;
9395
}
9496
} else {
9597
const HVX_Vector vzero = Q6_V_vzero();
9698
for (int i = 0; i < n_c_iters; ++i) {
97-
HVX_Vector v0 = hvx_vmemu(p0);
98-
p0 += c_byte_step;
99+
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
99100
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
100101
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
101102
tile_base += dst_step;
@@ -116,25 +117,22 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
116117
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
117118
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
118119

119-
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
120-
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
121-
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
120+
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
121+
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
122+
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
122123

123124
if (p1) {
124125
for (int i = 0; i < n_c_iters; ++i) {
125-
HVX_Vector v0 = hvx_vmemu(p0);
126-
p0 += c_byte_step;
127-
HVX_Vector v1 = hvx_vmemu(p1);
128-
p1 += c_byte_step;
126+
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
127+
HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step;
129128
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
130129
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1);
131130
tile_base += dst_step;
132131
}
133132
} else {
134133
const HVX_Vector vzero = Q6_V_vzero();
135134
for (int i = 0; i < n_c_iters; ++i) {
136-
HVX_Vector v0 = hvx_vmemu(p0);
137-
p0 += c_byte_step;
135+
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
138136
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
139137
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
140138
tile_base += dst_step;
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#ifndef HVX_REPL_H
2+
#define HVX_REPL_H
3+
4+
#include <assert.h>
5+
#include <stddef.h>
6+
#include <stdint.h>
7+
8+
#include "hvx-base.h"
9+
10+
static inline HVX_Vector hvx_vec_repl(HVX_Vector v, const uint8_t * ctrl) {
11+
return Q6_V_vdelta_VV(v, hvx_vmem(ctrl));
12+
}
13+
14+
static inline HVX_Vector hvx_vec_repl_u32(HVX_Vector v) {
15+
// vdelta control to replicate first 4 bytes across all lanes
16+
static const uint8_t __attribute__((aligned(128))) repl[128] = {
17+
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
18+
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
19+
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
20+
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
21+
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
22+
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
23+
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
24+
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
25+
};
26+
return hvx_vec_repl(v, repl);
27+
}
28+
29+
static inline HVX_Vector hvx_vec_repl_f32(HVX_Vector v) {
30+
// vdelta control to replicate first 4 bytes across all lanes
31+
static const uint8_t __attribute__((aligned(128))) repl[128] = {
32+
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
33+
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
34+
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
35+
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
36+
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
37+
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
38+
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
39+
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
40+
};
41+
return hvx_vec_repl(v, repl);
42+
}
43+
44+
static inline HVX_Vector hvx_vec_repl_f16(HVX_Vector v) {
45+
// vdelta control to replicate first two bytes across all lanes
46+
static const uint8_t __attribute__((aligned(128))) repl[128] = {
47+
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
48+
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
49+
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
50+
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
51+
0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
52+
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
53+
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
54+
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
55+
};
56+
return hvx_vec_repl(v, repl);
57+
}
58+
59+
static inline HVX_Vector hvx_vec_repl_2x_f16(HVX_Vector v) {
60+
// vdelta control to splat a pair of f16s: first half = f16[0], second half = f16[1]
61+
static const uint8_t __attribute__((aligned(128))) repl[128] = {
62+
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
63+
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
64+
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
65+
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
66+
0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04,
67+
0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04,
68+
0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04,
69+
0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04,
70+
};
71+
return hvx_vec_repl(v, repl);
72+
}
73+
74+
#endif // HVX_REPL_H

ggml/src/ggml-hexagon/htp/hvx-utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "hvx-types.h"
77
#include "hvx-copy.h"
8+
#include "hvx-repl.h"
89
#include "hvx-scale.h"
910
#include "hvx-exp.h"
1011
#include "hvx-inverse.h"

scripts/snapdragon/adb/run-completion.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,5 @@ adb $adbserial $adbhost shell " \
7070
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
7171
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
7272
--ctx-size 8192 --ubatch-size 256 -fa on \
73-
-ngl 99 -no-cnv --device $device $cli_opts $@ \
73+
-ngl 99 --device $device $cli_opts $@ \
7474
"

0 commit comments

Comments
 (0)