Skip to content

Commit aa50b2c

Browse files
hexagon: add support for Q4_1 in MUL_MAT and MUL_MAT_ID (ggml-org#23647)
* hex-mm: add support for Q4_1 matmul/matvec, hvx-only for now * hmx-mm: add support for Q4_1 * hex-mm: use Q8_1 dynamic quantization to avoid having to compute sums in the vec_dot * hexagon: fix repack scratch buffer overflow * hex-mm: fix Q4_1 repack buffer sizing * hexagon: flip the build order for mm and fa (seems to help LTO) * hex-mm: add vec_dot 4x1s and minor HMX cleanup after adding Q4_1 * hex-mm: fix fp16 vec_dot fallback to 2x1 and another issue that could cause incorrect output * hexagon: resurrect early-wake and add support for polling for op-batch completions With Q4_1 ggml-hexagon now claims pretty much the entire graphs which gives the CPU more time to chilax. This is a good thing! But it does add extra latency for the pure benchmark runs. Early wakeup helps recover the latency a bit in the normals runs and op-batch polling is just for benchmarking. --------- Co-authored-by: Todor Boinovski <todorb@qti.qualcomm.com>
1 parent c40006a commit aa50b2c

8 files changed

Lines changed: 2018 additions & 202 deletions

File tree

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 260 additions & 7 deletions
Large diffs are not rendered by default.

ggml/src/ggml-hexagon/htp/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
5959
if (_hmx_idx GREATER_EQUAL 0)
6060
target_sources(${HTP_LIB} PRIVATE
6161
hmx-queue.c
62-
hmx-matmul-ops.c
6362
hmx-flash-attn-ops.c
63+
hmx-matmul-ops.c
6464
)
6565

6666
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
6767
set_source_files_properties(
68-
hmx-matmul-ops.c
6968
hmx-flash-attn-ops.c
69+
hmx-matmul-ops.c
7070
PROPERTIES COMPILE_OPTIONS "-mhmx"
7171
)
7272

ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c

Lines changed: 125 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
3434
-8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0,
3535
};
3636

37+
static const __fp16 q4_1_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
38+
0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0,
39+
};
40+
3741
// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value
3842
// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6
3943
static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
@@ -62,6 +66,8 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) {
6266
case HTP_TYPE_Q4_0:
6367
case HTP_TYPE_IQ4_NL:
6468
return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
69+
case HTP_TYPE_Q4_1:
70+
return (size_t) nb * (QK_Q4_0x4x2 / 2 + 32); // 160 * nb
6571
case HTP_TYPE_Q8_0:
6672
return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
6773
case HTP_TYPE_MXFP4:
@@ -233,6 +239,54 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx(
233239
return r;
234240
}
235241

242+
static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) {
243+
HVX_Vector vq = hvx_vmemu(packed_32);
244+
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
245+
HVX_Vector v_dm = hvx_vmemu(scale_offset);
246+
HVX_Vector v_scales = hvx_vec_repl_f16(v_dm);
247+
HVX_Vector v_offsets = hvx_vec_repl_f16(Q6_V_vror_VR(v_dm, 2));
248+
249+
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
250+
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
251+
v_quants = Q6_Vb_vshuff_Vb(v_quants);
252+
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
253+
HVX_Vector v_hf = Q6_V_lo_W(vp);
254+
255+
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets));
256+
}
257+
258+
static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx(
259+
const uint8_t *packed_128, bool upper_nibbles,
260+
const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) {
261+
HVX_Vector vq = hvx_vmemu(packed_128);
262+
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
263+
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
264+
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
265+
266+
v_quants = Q6_Vb_vshuff_Vb(v_quants);
267+
268+
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
269+
HVX_Vector v_lo = Q6_V_lo_W(vp);
270+
HVX_Vector v_hi = Q6_V_hi_W(vp);
271+
272+
HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4);
273+
HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2);
274+
HVX_Vector vd = Q6_V_lo_W(dm_deal);
275+
HVX_Vector vm = Q6_V_hi_W(dm_deal);
276+
277+
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vd);
278+
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vd, 4));
279+
280+
HVX_Vector v_os01 = hvx_vec_repl_2x_f16(vm);
281+
HVX_Vector v_os23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vm, 4));
282+
283+
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01), v_os01));
284+
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23), v_os23));
285+
286+
HVX_Vector_x2 r = { v_lo, v_hi };
287+
return r;
288+
}
289+
236290
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
237291
static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) {
238292
HVX_Vector vq = hvx_vmemu(quants_32);
@@ -331,11 +385,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
331385
int start_tile, int end_tile) {
332386

333387
const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS;
334-
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
388+
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL);
389+
const bool is_q4_1 = (weight_type == HTP_TYPE_Q4_1);
335390
const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block;
336391

337392
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
338393
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
394+
(weight_type == HTP_TYPE_Q4_1) ? hvx_vmem(q4_1_to_fp16_lut) :
339395
hvx_vmem(q4_0_to_fp16_lut);
340396

341397
// vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
@@ -356,8 +412,10 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
356412
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
357413
bool upper = (sub_blk_base >= 4);
358414
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
359-
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE
360-
+ sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales
415+
unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE;
416+
unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16);
417+
unsigned scale_off = qrow_size + blk_idx * dblk_size
418+
+ sub_blk_base * scale_step;
361419

362420
__fp16 *tile_bases[4];
363421
for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
@@ -367,20 +425,38 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
367425
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
368426
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
369427

370-
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
371-
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
372-
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
428+
if (is_q4_1) {
429+
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
430+
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
431+
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
373432

374-
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
375-
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
433+
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
434+
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_1_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
376435

377-
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
378-
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
379-
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
436+
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
437+
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
438+
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
380439

381-
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
382-
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
383-
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
440+
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
441+
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
442+
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
443+
}
444+
} else {
445+
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
446+
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
447+
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
448+
449+
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
450+
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
451+
452+
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
453+
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
454+
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
455+
456+
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
457+
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
458+
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
459+
}
384460
}
385461

386462
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); }
@@ -446,26 +522,43 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
446522
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
447523
bool upper = (sub_blk >= 4);
448524
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
449-
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
525+
unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE;
526+
unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16);
527+
unsigned scale_off = qrow_size + blk_idx * dblk_size + sub_blk * scale_step;
450528

451529
HVX_Vector v_off = v_scat_base; // reset to column 0
452530
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
453531
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
454-
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
455-
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
456-
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
457-
458-
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(
459-
r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
460-
HVX_Vector v1 = (row1 < n_cols)
461-
? dequantize_x4x2_q4_0_group_hvx(
462-
r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
463-
: Q6_V_vzero();
464-
465-
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
466-
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
467-
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
468-
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
532+
if (is_q4_1) {
533+
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
534+
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
535+
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
536+
537+
HVX_Vector v0 = dequantize_x4x2_q4_1_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
538+
HVX_Vector v1 = (row1 < n_cols)
539+
? dequantize_x4x2_q4_1_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
540+
: Q6_V_vzero();
541+
542+
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
543+
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
544+
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
545+
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
546+
}
547+
} else {
548+
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
549+
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
550+
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
551+
552+
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
553+
HVX_Vector v1 = (row1 < n_cols)
554+
? dequantize_x4x2_q4_0_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
555+
: Q6_V_vzero();
556+
557+
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
558+
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
559+
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
560+
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
561+
}
469562
}
470563
(void) *(volatile HVX_Vector *)(tile_base);
471564
} else if (weight_type == HTP_TYPE_MXFP4) {
@@ -593,6 +686,8 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
593686

594687
// --- End x4x2 dequantizers ---
595688

689+
#pragma clang diagnostic ignored "-Wbackend-plugin" // spurios warning for hmx intrinsics
690+
596691
// requires external HMX lock
597692
static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales,
598693
int n_row_tiles, int n_col_tiles, int n_dot_tiles) {

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ enum htp_data_type {
2020
HTP_TYPE_F32 = 0,
2121
HTP_TYPE_F16 = 1,
2222
HTP_TYPE_Q4_0 = 2,
23+
HTP_TYPE_Q4_1 = 3,
2324
HTP_TYPE_Q8_0 = 8,
2425
HTP_TYPE_IQ4_NL = 20,
2526
HTP_TYPE_I32 = 26,
@@ -28,6 +29,7 @@ enum htp_data_type {
2829

2930
// types used internally for repack, dyn.quant, etc
3031
HTP_TYPE_Q4_0x4x2 = 200,
32+
HTP_TYPE_Q4_1x4x2,
3133
HTP_TYPE_Q8_0x4x2,
3234
HTP_TYPE_MXFP4x4x2,
3335

ggml/src/ggml-hexagon/htp/main.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,11 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
853853
for (uint32_t i=0; i < n_ops; i++) {
854854
struct profile_data prof;
855855

856+
if (i == (n_ops-1)) {
857+
// wake up the host before starting the last op
858+
dspqueue_write_early_wakeup_noblock(queue, 0, 0);
859+
}
860+
856861
profile_start(ctx->profiler, &prof);
857862

858863
proc_op_req(octx, tens, i, &ops[i]);
@@ -869,8 +874,6 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
869874
}
870875
}
871876

872-
// dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0);
873-
874877
struct htp_opbatch_rsp rsp;
875878
rsp.id = req.id;
876879
rsp.status = HTP_STATUS_OK;

0 commit comments

Comments
 (0)