@@ -648,9 +648,9 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
648648 assert (n_cols % HMX_FP16_TILE_N_COLS == 0 );
649649 assert (k_block % HMX_FP16_TILE_N_COLS == 0 );
650650
651- int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS ;
652- int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS ;
653- int n_tot_tiles = n_col_tiles * n_k_tiles ;
651+ size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS ;
652+ size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS ;
653+ size_t n_tot_tiles = n_col_tiles * n_k_tiles ;
654654
655655 size_t n_tiles_per_task = hmx_ceil_div (n_tot_tiles , ctx -> n_threads );
656656
@@ -678,9 +678,8 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict
678678 __builtin_assume (n_dot_tiles > 0 );
679679
680680 Q6_bias_mxmem2_A ((void * )scales );
681-
682681 for (int r = 0 ; r < n_row_tiles ; ++ r ) {
683- for (int c = 0 ; c < n_col_tiles ; ++ c ) {
682+ for (size_t c = 0 ; c < n_col_tiles ; ++ c ) {
684683 Q6_mxclracc_hf ();
685684
686685 const __fp16 * row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS ;
@@ -738,25 +737,25 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job,
738737
739738static void transfer_output_chunk_fp16_to_fp32 (float * restrict dst , const __fp16 * restrict vtcm_src , int n_rows , int n_cols , int n ) {
740739 assert (n_cols % HMX_FP16_TILE_N_COLS == 0 );
741- const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS ;
740+ const size_t tile_row_stride = ( n_cols / HMX_FP16_TILE_N_COLS ) * HMX_FP16_TILE_N_ELMS ;
742741
743742 const HVX_Vector one = hvx_vec_splat_f16 (1.0 );
744743
745- for (int r = 0 ; r < n_rows ; r += 2 ) {
746- int r0 = r / HMX_FP16_TILE_N_ROWS ;
747- int r1 = r % HMX_FP16_TILE_N_ROWS ;
744+ for (size_t r = 0 ; r < n_rows ; r += 2 ) {
745+ const size_t r0 = r / HMX_FP16_TILE_N_ROWS ;
746+ const size_t r1 = (r % HMX_FP16_TILE_N_ROWS ) / 2 ; // index of the row pair within the tile
747+ const __fp16 * row_base = vtcm_src + r0 * tile_row_stride ;
748+ float * output_row_base = dst + r * n ; // global memory row base for row r (and r+1)
748749
749750 #pragma unroll(4)
750- for (int c = 0 ; c < n_cols ; c += HMX_FP16_TILE_N_COLS ) {
751- int c0 = c / HMX_FP16_TILE_N_COLS ;
752-
753- const __fp16 * tile = vtcm_src + (r0 * n_col_tiles + c0 ) * HMX_FP16_TILE_N_ELMS ;
754-
755- HVX_Vector v = ((const HVX_Vector * ) tile )[r1 / 2 ];
751+ for (size_t c = 0 ; c < n_cols ; c += HMX_FP16_TILE_N_COLS ) {
752+ const size_t c0 = c / HMX_FP16_TILE_N_COLS ;
753+ const __fp16 * tile = row_base + c0 * HMX_FP16_TILE_N_ELMS ;
754+ HVX_Vector v = ((const HVX_Vector * ) tile )[r1 ];
756755 HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf (v , one );
757756
758- volatile HVX_Vector * pv_out0 = (volatile HVX_Vector * ) (dst + ( r * n + c + 0 ) );
759- volatile HVX_Vector * pv_out1 = (volatile HVX_Vector * ) (dst + ( r * n + c + n ) ); // next row in global memory
757+ volatile HVX_Vector * pv_out0 = (volatile HVX_Vector * ) (output_row_base + c + 0 );
758+ volatile HVX_Vector * pv_out1 = (volatile HVX_Vector * ) (output_row_base + c + n ); // next row in global memory
760759
761760 * pv_out0 = Q6_Vsf_equals_Vqf32 (Q6_V_lo_W (vp ));
762761 if (r + 1 < n_rows ) {
@@ -794,7 +793,7 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
794793 assert (n_cols % HMX_FP16_TILE_N_COLS == 0 );
795794
796795 size_t n_tot_chunks = n_rows ;
797- size_t n_chunks_per_task = 32 ; // must be multiple of HMX_FP16_TILE_N_ROWS (32)
796+ size_t n_chunks_per_task = HMX_FP16_TILE_N_ROWS ; // must be multiple of HMX_FP16_TILE_N_ROWS (32)
798797
799798 output_transfer_task_state_t state ;
800799 state .n_tasks = (n_tot_chunks + n_chunks_per_task - 1 ) / n_chunks_per_task ;
@@ -926,7 +925,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
926925 return hmx_mat_mul_permuted_w16a32_batched_legacy (ctx , params );
927926 }
928927
929- hmx_init_column_scales (vtcm_scales , Q6_V_vsplat_R (0x3c00 )); // fp16 : 1.0
928+ hmx_init_column_scales (vtcm_scales , Q6_V_vsplat_R (0x3c00 )); // scale : 1.0, bias: 0.0 in FP16
930929
931930 FARF (MEDIUM , "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu" ,
932931 __func__ , params -> m , params -> k , params -> n , group_size , params -> ne13 ,
@@ -944,12 +943,15 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
944943 const size_t fp16_row_bytes = (size_t ) params -> k * sizeof (__fp16 );
945944 const size_t weight_row_bytes = (size_t ) params -> weight_stride * sizeof (__fp16 );
946945
946+ HAP_compute_res_hmx_lock (ctx -> vtcm_rctx );
947+
947948 for (int b3 = 0 ; b3 < params -> ne13 ; ++ b3 ) {
948949 for (int b2_base = 0 ; b2_base < params -> ne12 ; b2_base += group_size ) {
949950 const __fp16 * weight_group = hmx_matmul_weight_batch_ptr (params , b2_base , b3 );
950951
951952 for (size_t mr = 0 ; mr < (size_t ) params -> m ; mr += m_chunk_n_rows ) {
952953 const size_t n_rows = hex_smin ((size_t ) params -> m - mr , m_chunk_n_rows );
954+ const size_t n_row_tiles = hmx_ceil_div ((int ) n_rows , HMX_FP16_TILE_N_ROWS );
953955
954956 // Pre-load activations for all heads in the group (once per m_chunk).
955957 // When the source is strided (permuted Q), use 2D DMA to gather
@@ -987,10 +989,9 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
987989 fp16_row_bytes , weight_row_bytes , fp16_row_bytes , n_cols_first );
988990 }
989991
990- HAP_compute_res_hmx_lock (ctx -> vtcm_rctx );
991-
992992 for (size_t nc = 0 ; nc < (size_t ) params -> n ; nc += n_chunk_n_cols ) {
993993 const size_t n_cols = hex_smin ((size_t ) params -> n - nc , n_chunk_n_cols );
994+ const size_t n_col_tiles = hmx_ceil_div ((int ) n_cols , HMX_FP16_TILE_N_COLS );
994995
995996 TIMER_START (weight_load );
996997 {
@@ -1014,11 +1015,9 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
10141015 for (int g = 0 ; g < group_size ; ++ g ) {
10151016 TIMER_START (hmx_core );
10161017 {
1017- const __fp16 * vtcm_act_g = vtcm_activation + (size_t ) g * act_head_stride ;
1018- const int n_row_tiles = hmx_ceil_div ((int ) n_rows , HMX_FP16_TILE_N_ROWS );
1019- const int n_col_tiles = hmx_ceil_div ((int ) n_cols , HMX_FP16_TILE_N_COLS );
1020- core_dot_chunk_fp16 (vtcm_output , vtcm_act_g , vtcm_weight , vtcm_scales ,
1021- n_row_tiles , n_col_tiles , params -> k / 32 );
1018+ const __fp16 * vtcm_act_g = vtcm_activation + (size_t ) g * act_head_stride ;
1019+ core_dot_chunk_fp16 (vtcm_output , vtcm_act_g , vtcm_weight , vtcm_scales , n_row_tiles , n_col_tiles ,
1020+ params -> k / 32 );
10221021 }
10231022 TIMER_STOP (hmx_core );
10241023
@@ -1030,12 +1029,12 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
10301029 TIMER_STOP (output_store );
10311030 }
10321031 }
1033-
1034- HAP_compute_res_hmx_unlock (ctx -> vtcm_rctx );
10351032 }
10361033 }
10371034 }
10381035
1036+ HAP_compute_res_hmx_unlock (ctx -> vtcm_rctx );
1037+
10391038 TIMER_STOP (total );
10401039
10411040#if defined(ENABLE_PROFILE_TIMERS )
@@ -1103,7 +1102,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
11031102 return -1 ;
11041103 }
11051104
1106- hmx_init_column_scales (vtcm_scales , Q6_V_vsplat_R (0x3c00 )); // fp16 : 1.0
1105+ hmx_init_column_scales (vtcm_scales , Q6_V_vsplat_R (0x3c00 )); // scale : 1.0, bias: 0.0 in FP16
11071106
11081107 FARF (MEDIUM , "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu" ,
11091108 __func__ , m , k , n , m_chunk_n_rows , n_chunk_n_cols ,
@@ -1121,7 +1120,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
11211120
11221121 for (size_t mr = 0 ; mr < m ; mr += m_chunk_n_rows ) {
11231122 // transfer activation matrix chunk into VTCM
1124- size_t n_rows = hex_smin (m - mr , m_chunk_n_rows );
1123+ const size_t n_rows = hex_smin (m - mr , m_chunk_n_rows );
1124+ const size_t n_row_tiles = hmx_ceil_div (n_rows , HMX_FP16_TILE_N_ROWS );
11251125
11261126 TIMER_START (activation_load );
11271127 {
@@ -1159,7 +1159,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
11591159 }
11601160
11611161 for (size_t nc = 0 ; nc < n ; nc += n_chunk_n_cols ) {
1162- size_t n_cols = hex_smin (n - nc , n_chunk_n_cols );
1162+ const size_t n_cols = hex_smin (n - nc , n_chunk_n_cols );
1163+ const size_t n_col_tiles = hmx_ceil_div (n_cols , HMX_FP16_TILE_N_COLS );
11631164
11641165 TIMER_START (weight_load );
11651166 {
@@ -1184,8 +1185,6 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
11841185
11851186 TIMER_START (hmx_core );
11861187 {
1187- const int n_row_tiles = hmx_ceil_div (n_rows , HMX_FP16_TILE_N_ROWS );
1188- const int n_col_tiles = hmx_ceil_div (n_cols , HMX_FP16_TILE_N_COLS );
11891188 core_dot_chunk_fp16 (vtcm_output , vtcm_activation , vtcm_weight , vtcm_scales , n_row_tiles , n_col_tiles , k / 32 );
11901189 }
11911190 TIMER_STOP (hmx_core );
@@ -1307,7 +1306,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
13071306 return -1 ;
13081307 }
13091308
1310- hmx_init_column_scales (vtcm_scales , Q6_V_vsplat_R (0x3c00 )); // fp16 : 1.0
1309+ hmx_init_column_scales (vtcm_scales , Q6_V_vsplat_R (0x3c00 )); // scale : 1.0, bias: 0.0 in FP16
13111310
13121311 FARF (MEDIUM , "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu" ,
13131312 __func__ , m , k , n , weight_type , use_pipeline ,
@@ -1330,7 +1329,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
13301329 HAP_compute_res_hmx_lock (ctx -> vtcm_rctx );
13311330 for (size_t mr = 0 ; mr < m ; mr += m_chunk_n_rows ) {
13321331 // transfer activation matrix chunk into VTCM
1333- size_t n_rows = hex_smin (m - mr , m_chunk_n_rows );
1332+ const size_t n_rows = hex_smin (m - mr , m_chunk_n_rows );
1333+ const size_t n_row_tiles = hmx_ceil_div (n_rows , HMX_FP16_TILE_N_ROWS );
13341334
13351335 TIMER_START (activation_load );
13361336 {
@@ -1348,7 +1348,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
13481348 }
13491349
13501350 for (size_t nc = 0 ; nc < n ; nc += n_chunk_n_cols ) {
1351- size_t n_cols = hex_smin (n - nc , n_chunk_n_cols );
1351+ const size_t n_cols = hex_smin (n - nc , n_chunk_n_cols );
1352+ const size_t n_col_tiles = hmx_ceil_div (n_cols , HMX_FP16_TILE_N_COLS );
13521353
13531354 TIMER_START (weight_load );
13541355 {
@@ -1373,8 +1374,6 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
13731374
13741375 TIMER_START (hmx_core );
13751376 {
1376- const int n_row_tiles = hmx_ceil_div (n_rows , HMX_FP16_TILE_N_ROWS );
1377- const int n_col_tiles = hmx_ceil_div (n_cols , HMX_FP16_TILE_N_COLS );
13781377 core_dot_chunk_fp16 (vtcm_output , vtcm_activation , vtcm_weight , vtcm_scales , n_row_tiles , n_col_tiles , k / 32 );
13791378 }
13801379 TIMER_STOP (hmx_core );
@@ -1521,14 +1520,16 @@ void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __f
15211520
15221521 Q6_bias_mxmem2_A ((void * )col_scales );
15231522
1524- for (int i = 0 ; i < n_row_tiles ; ++ i ) {
1525- for (int j = 0 ; j < n_col_tiles ; ++ j ) {
1523+ const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS ;
1524+ for (size_t i = 0 ; i < n_row_tiles ; ++ i ) {
1525+ const __fp16 * row_base = a + i * dot_tile_stride ;
1526+ __fp16 * res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS ;
1527+ for (size_t j = 0 ; j < n_col_tiles ; ++ j ) {
15261528 Q6_mxclracc_hf ();
15271529
1528- const __fp16 * row_tiles = a + i * n_dot_tiles * HMX_FP16_TILE_N_ELMS ;
1529- const __fp16 * col_tiles = b + j * n_dot_tiles * HMX_FP16_TILE_N_ELMS ;
1530-
1531- __fp16 * accum_tile = c + (i * n_col_tiles + j ) * HMX_FP16_TILE_N_ELMS ;
1530+ const __fp16 * col_tiles = b + j * dot_tile_stride ;
1531+ const __fp16 * row_tiles = row_base ;
1532+ __fp16 * accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS ;
15321533 if (!zero_init ) {
15331534 Q6_activation_hf_mxmem_RR ((unsigned int )accum_tile , 2047 );
15341535 Q6_weight_hf_mxmem_RR ((unsigned int )eye_tile , 2047 );
@@ -1697,7 +1698,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
16971698 v = Q6_V_vror_VR (v , VLEN - 8 );
16981699 }
16991700 }
1700- hmx_init_column_scales (vtcm_scales , Q6_V_vsplat_R (0x3c00 )); // fp16 : 1.0
1701+ hmx_init_column_scales (vtcm_scales , Q6_V_vsplat_R (0x3c00 )); // scale : 1.0, bias: 0.0 in FP16
17011702
17021703 TIMER_DEFINE (fetch );
17031704 TIMER_DEFINE (act_load );
@@ -1715,7 +1716,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
17151716 const int n_col_tiles = hmx_ceil_div (n_blk_sz , HMX_FP16_TILE_N_COLS );
17161717
17171718 for (size_t kk = 0 ; kk < k ; kk += K_BLOCK_SIZE ) {
1718- size_t k_blk_sz = hex_smin (k - kk , K_BLOCK_SIZE );
1719+ const size_t k_blk_sz = hex_smin (k - kk , K_BLOCK_SIZE );
17191720
17201721 TIMER_START (fetch );
17211722 // fetch activation block into VTCM
@@ -1731,13 +1732,13 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
17311732 }
17321733
17331734 // fetch weight block into VTCM (x4x2 sub-block: quants + scales)
1735+ const size_t sub_row_stride = get_x4x2_row_stride (weight_type , k_blk_sz );
17341736 {
17351737 qweight_fetch_task_state_t s ;
17361738
17371739 const int blk_start = kk / QK_Q4_0x4x2 ;
17381740 const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1 ) / QK_Q4_0x4x2 ;
17391741 const int full_qrow = (weight_type == HTP_TYPE_Q8_0 ) ? k : (k / 2 );
1740- const size_t sub_row_stride = get_x4x2_row_stride (weight_type , k_blk_sz );
17411742 const int scale_blk_size =
17421743 (weight_type == HTP_TYPE_MXFP4 ) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE ;
17431744
@@ -1777,7 +1778,6 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
17771778 dma_queue_pop (ctx -> dma [0 ]);
17781779 // vtcm_scratch0 is used to store the qweight chunk
17791780 // worker_pool_run_func already returned, so fetch is done
1780- const size_t sub_row_stride = get_x4x2_row_stride (weight_type , k_blk_sz );
17811781 dequantize_x4x2_weight_chunk_to_fp16_tiles (ctx , vtcm_weight , vtcm_scratch0 ,
17821782 n_blk_sz , k_blk_sz , sub_row_stride , weight_type );
17831783 }
0 commit comments