@@ -34,6 +34,10 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
3434 -8 , 0 , -7 , 0 , -6 , 0 , -5 , 0 , -4 , 0 , -3 , 0 , -2 , 0 , -1 , 0 , 0 , 0 , 1 , 0 , 2 , 0 , 3 , 0 , 4 , 0 , 5 , 0 , 6 , 0 , 7 , 0 ,
3535};
3636
37+ static const __fp16 q4_1_to_fp16_lut [64 ] __attribute__((aligned (VLEN ))) = {
38+ 0 , 0 , 1 , 0 , 2 , 0 , 3 , 0 , 4 , 0 , 5 , 0 , 6 , 0 , 7 , 0 , 8 , 0 , 9 , 0 , 10 , 0 , 11 , 0 , 12 , 0 , 13 , 0 , 14 , 0 , 15 , 0 ,
39+ };
40+
3741// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value
3842// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6
3943static const __fp16 mxfp4_to_fp16_lut [64 ] __attribute__((aligned (VLEN ))) = {
@@ -62,6 +66,8 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) {
6266 case HTP_TYPE_Q4_0 :
6367 case HTP_TYPE_IQ4_NL :
6468 return (size_t ) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE ); // 144 * nb
69+ case HTP_TYPE_Q4_1 :
70+ return (size_t ) nb * (QK_Q4_0x4x2 / 2 + 32 ); // 160 * nb
6571 case HTP_TYPE_Q8_0 :
6672 return (size_t ) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE ); // 272 * nb
6773 case HTP_TYPE_MXFP4 :
@@ -233,6 +239,54 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx(
233239 return r ;
234240}
235241
242+ static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx (const uint8_t * packed_32 , bool upper_nibbles , const __fp16 * scale_offset , const HVX_Vector vlut_cvt ) {
243+ HVX_Vector vq = hvx_vmemu (packed_32 );
244+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R (0x0F );
245+ HVX_Vector v_dm = hvx_vmemu (scale_offset );
246+ HVX_Vector v_scales = hvx_vec_repl_f16 (v_dm );
247+ HVX_Vector v_offsets = hvx_vec_repl_f16 (Q6_V_vror_VR (v_dm , 2 ));
248+
249+ HVX_Vector v_quants = Q6_Vub_vlsr_VubR (vq , 4 * upper_nibbles );
250+ v_quants = Q6_V_vand_VV (v_quants , mask_h4 );
251+ v_quants = Q6_Vb_vshuff_Vb (v_quants );
252+ HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR (v_quants , vlut_cvt , 0 );
253+ HVX_Vector v_hf = Q6_V_lo_W (vp );
254+
255+ return Q6_Vhf_equals_Vqf16 (Q6_Vqf16_vadd_Vqf16Vhf (Q6_Vqf16_vmpy_VhfVhf (v_hf , v_scales ), v_offsets ));
256+ }
257+
258+ static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx (
259+ const uint8_t * packed_128 , bool upper_nibbles ,
260+ const __fp16 * scales_offsets_4 , const HVX_Vector vlut_cvt ) {
261+ HVX_Vector vq = hvx_vmemu (packed_128 );
262+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R (0x0F );
263+ HVX_Vector v_quants = Q6_Vub_vlsr_VubR (vq , 4 * upper_nibbles );
264+ v_quants = Q6_V_vand_VV (v_quants , mask_h4 );
265+
266+ v_quants = Q6_Vb_vshuff_Vb (v_quants );
267+
268+ HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR (v_quants , vlut_cvt , 0 );
269+ HVX_Vector v_lo = Q6_V_lo_W (vp );
270+ HVX_Vector v_hi = Q6_V_hi_W (vp );
271+
272+ HVX_Vector vscale_offset = hvx_vmemu (scales_offsets_4 );
273+ HVX_VectorPair dm_deal = Q6_W_vdeal_VVR (vscale_offset , vscale_offset , -2 );
274+ HVX_Vector vd = Q6_V_lo_W (dm_deal );
275+ HVX_Vector vm = Q6_V_hi_W (dm_deal );
276+
277+ HVX_Vector v_sc01 = hvx_vec_repl_2x_f16 (vd );
278+ HVX_Vector v_sc23 = hvx_vec_repl_2x_f16 (Q6_V_vror_VR (vd , 4 ));
279+
280+ HVX_Vector v_os01 = hvx_vec_repl_2x_f16 (vm );
281+ HVX_Vector v_os23 = hvx_vec_repl_2x_f16 (Q6_V_vror_VR (vm , 4 ));
282+
283+ v_lo = Q6_Vhf_equals_Vqf16 (Q6_Vqf16_vadd_Vqf16Vhf (Q6_Vqf16_vmpy_VhfVhf (v_lo , v_sc01 ), v_os01 ));
284+ v_hi = Q6_Vhf_equals_Vqf16 (Q6_Vqf16_vadd_Vqf16Vhf (Q6_Vqf16_vmpy_VhfVhf (v_hi , v_sc23 ), v_os23 ));
285+
286+ HVX_Vector_x2 r = { v_lo , v_hi };
287+ return r ;
288+ }
289+
236290// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
237291static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx (const int8_t * quants_32 , const __fp16 * scale ) {
238292 HVX_Vector vq = hvx_vmemu (quants_32 );
@@ -331,11 +385,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
331385 int start_tile , int end_tile ) {
332386
333387 const int n_k_tiles = (unsigned )k_block / HMX_FP16_TILE_N_COLS ;
334- const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL );
388+ const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL );
389+ const bool is_q4_1 = (weight_type == HTP_TYPE_Q4_1 );
335390 const int qrow_size = is_q4 ? ((unsigned )k_block / 2 ) : k_block ;
336391
337392 const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL ) ? hvx_vmem (iq4_nl_to_fp16_lut ) :
338393 (weight_type == HTP_TYPE_MXFP4 ) ? hvx_vmem (mxfp4_to_fp16_lut ) :
394+ (weight_type == HTP_TYPE_Q4_1 ) ? hvx_vmem (q4_1_to_fp16_lut ) :
339395 hvx_vmem (q4_0_to_fp16_lut );
340396
341397 // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
@@ -356,8 +412,10 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
356412 unsigned sub_blk_base = ((kt * 32 ) % QK_Q4_0x4x2 ) / 32 ; // 0 or 4
357413 bool upper = (sub_blk_base >= 4 );
358414 unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2 ); // 128 contiguous packed bytes
359- unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE
360- + sub_blk_base * (int )sizeof (__fp16 ); // 4 consecutive scales
415+ unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE ;
416+ unsigned scale_step = is_q4_1 ? 4 : (int )sizeof (__fp16 );
417+ unsigned scale_off = qrow_size + blk_idx * dblk_size
418+ + sub_blk_base * scale_step ;
361419
362420 __fp16 * tile_bases [4 ];
363421 for (unsigned g = 0 ; g < 4 ; g ++ ) { tile_bases [g ] = vtcm_dst + (t + g ) * HMX_FP16_TILE_N_ELMS ; }
@@ -367,20 +425,38 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
367425 unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride ;
368426 unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1 ;
369427
370- for (int r = 0 ; r < HMX_FP16_TILE_N_ROWS ; r += 2 , row1 += 2 ) {
371- const uint8_t * r0 = vtcm_src + row_offset ; row_offset += row_stride ;
372- const uint8_t * r1 = vtcm_src + row_offset ; row_offset += row_stride ;
428+ if (is_q4_1 ) {
429+ for (int r = 0 ; r < HMX_FP16_TILE_N_ROWS ; r += 2 , row1 += 2 ) {
430+ const uint8_t * r0 = vtcm_src + row_offset ; row_offset += row_stride ;
431+ const uint8_t * r1 = vtcm_src + row_offset ; row_offset += row_stride ;
373432
374- HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx (r0 + packed_off , upper , (const __fp16 * )(r0 + scale_off ), vlut_cvt );
375- HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx (r1 + packed_off , upper , (const __fp16 * )(r1 + scale_off ), vlut_cvt );
433+ HVX_Vector_x2 dv0 = dequantize_x4x2_q4_1_x4groups_hvx (r0 + packed_off , upper , (const __fp16 * )(r0 + scale_off ), vlut_cvt );
434+ HVX_Vector_x2 dv1 = dequantize_x4x2_q4_1_x4groups_hvx (r1 + packed_off , upper , (const __fp16 * )(r1 + scale_off ), vlut_cvt );
376435
377- Q6_vscatter_RMVwV ((size_t )tile_bases [0 ], 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv0 .v [0 ]);
378- Q6_vscatter_RMVwV ((size_t )tile_bases [2 ], 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv0 .v [1 ]);
379- v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
436+ Q6_vscatter_RMVwV ((size_t )tile_bases [0 ], 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv0 .v [0 ]);
437+ Q6_vscatter_RMVwV ((size_t )tile_bases [2 ], 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv0 .v [1 ]);
438+ v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
380439
381- Q6_vscatter_RMVwV ((size_t )tile_bases [0 ], 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv1 .v [0 ]);
382- Q6_vscatter_RMVwV ((size_t )tile_bases [2 ], 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv1 .v [1 ]);
383- v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
440+ Q6_vscatter_RMVwV ((size_t )tile_bases [0 ], 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv1 .v [0 ]);
441+ Q6_vscatter_RMVwV ((size_t )tile_bases [2 ], 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv1 .v [1 ]);
442+ v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
443+ }
444+ } else {
445+ for (int r = 0 ; r < HMX_FP16_TILE_N_ROWS ; r += 2 , row1 += 2 ) {
446+ const uint8_t * r0 = vtcm_src + row_offset ; row_offset += row_stride ;
447+ const uint8_t * r1 = vtcm_src + row_offset ; row_offset += row_stride ;
448+
449+ HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx (r0 + packed_off , upper , (const __fp16 * )(r0 + scale_off ), vlut_cvt );
450+ HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx (r1 + packed_off , upper , (const __fp16 * )(r1 + scale_off ), vlut_cvt );
451+
452+ Q6_vscatter_RMVwV ((size_t )tile_bases [0 ], 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv0 .v [0 ]);
453+ Q6_vscatter_RMVwV ((size_t )tile_bases [2 ], 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv0 .v [1 ]);
454+ v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
455+
456+ Q6_vscatter_RMVwV ((size_t )tile_bases [0 ], 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv1 .v [0 ]);
457+ Q6_vscatter_RMVwV ((size_t )tile_bases [2 ], 2 * HMX_FP16_TILE_SIZE - 1 , v_off , dv1 .v [1 ]);
458+ v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
459+ }
384460 }
385461
386462 for (int g = 0 ; g < 4 ; g ++ ) { (void ) * (volatile HVX_Vector * )(tile_bases [g ]); }
@@ -446,26 +522,43 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
446522 unsigned sub_blk = ((kt * 32 ) % QK_Q4_0x4x2 ) / 32 ;
447523 bool upper = (sub_blk >= 4 );
448524 unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2 ) + (upper ? (sub_blk - 4 ) : sub_blk ) * 32 ;
449- unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int )sizeof (__fp16 );
525+ unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE ;
526+ unsigned scale_step = is_q4_1 ? 4 : (int )sizeof (__fp16 );
527+ unsigned scale_off = qrow_size + blk_idx * dblk_size + sub_blk * scale_step ;
450528
451529 HVX_Vector v_off = v_scat_base ; // reset to column 0
452530 unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride ;
453531 unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1 ;
454- for (int r = 0 ; r < HMX_FP16_TILE_N_ROWS ; r += 2 , row1 += 2 ) {
455- const uint8_t * r0 = vtcm_src + row_offset ; row_offset += row_stride ;
456- const uint8_t * r1 = vtcm_src + row_offset ; row_offset += row_stride ;
457-
458- HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx (
459- r0 + byte_off , upper , (const __fp16 * )(r0 + scale_off ), vlut_cvt );
460- HVX_Vector v1 = (row1 < n_cols )
461- ? dequantize_x4x2_q4_0_group_hvx (
462- r1 + byte_off , upper , (const __fp16 * )(r1 + scale_off ), vlut_cvt )
463- : Q6_V_vzero ();
464-
465- Q6_vscatter_QRMVwV (q_mask64 , (size_t )tile_base , HMX_FP16_TILE_SIZE - 1 , v_off , v0 );
466- v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
467- Q6_vscatter_QRMVwV (q_mask64 , (size_t )tile_base , HMX_FP16_TILE_SIZE - 1 , v_off , v1 );
468- v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
532+ if (is_q4_1 ) {
533+ for (int r = 0 ; r < HMX_FP16_TILE_N_ROWS ; r += 2 , row1 += 2 ) {
534+ const uint8_t * r0 = vtcm_src + row_offset ; row_offset += row_stride ;
535+ const uint8_t * r1 = vtcm_src + row_offset ; row_offset += row_stride ;
536+
537+ HVX_Vector v0 = dequantize_x4x2_q4_1_group_hvx (r0 + byte_off , upper , (const __fp16 * )(r0 + scale_off ), vlut_cvt );
538+ HVX_Vector v1 = (row1 < n_cols )
539+ ? dequantize_x4x2_q4_1_group_hvx (r1 + byte_off , upper , (const __fp16 * )(r1 + scale_off ), vlut_cvt )
540+ : Q6_V_vzero ();
541+
542+ Q6_vscatter_QRMVwV (q_mask64 , (size_t )tile_base , HMX_FP16_TILE_SIZE - 1 , v_off , v0 );
543+ v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
544+ Q6_vscatter_QRMVwV (q_mask64 , (size_t )tile_base , HMX_FP16_TILE_SIZE - 1 , v_off , v1 );
545+ v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
546+ }
547+ } else {
548+ for (int r = 0 ; r < HMX_FP16_TILE_N_ROWS ; r += 2 , row1 += 2 ) {
549+ const uint8_t * r0 = vtcm_src + row_offset ; row_offset += row_stride ;
550+ const uint8_t * r1 = vtcm_src + row_offset ; row_offset += row_stride ;
551+
552+ HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx (r0 + byte_off , upper , (const __fp16 * )(r0 + scale_off ), vlut_cvt );
553+ HVX_Vector v1 = (row1 < n_cols )
554+ ? dequantize_x4x2_q4_0_group_hvx (r1 + byte_off , upper , (const __fp16 * )(r1 + scale_off ), vlut_cvt )
555+ : Q6_V_vzero ();
556+
557+ Q6_vscatter_QRMVwV (q_mask64 , (size_t )tile_base , HMX_FP16_TILE_SIZE - 1 , v_off , v0 );
558+ v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
559+ Q6_vscatter_QRMVwV (q_mask64 , (size_t )tile_base , HMX_FP16_TILE_SIZE - 1 , v_off , v1 );
560+ v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
561+ }
469562 }
470563 (void ) * (volatile HVX_Vector * )(tile_base );
471564 } else if (weight_type == HTP_TYPE_MXFP4 ) {
@@ -593,6 +686,8 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
593686
594687// --- End x4x2 dequantizers ---
595688
689+ #pragma clang diagnostic ignored "-Wbackend-plugin" // spurios warning for hmx intrinsics
690+
596691// requires external HMX lock
597692static void core_dot_chunk_fp16 (__fp16 * restrict output , const __fp16 * restrict activation , const __fp16 * restrict weight , const __fp16 * restrict scales ,
598693 int n_row_tiles , int n_col_tiles , int n_dot_tiles ) {
0 commit comments