@@ -742,17 +742,45 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
742742// activations : fp32 -> fp16
743743
744744static void transfer_activation_chunk_fp32_to_fp16 (__fp16 * restrict vtcm_dst , const float * restrict src , int n_rows , int k_block , int k_stride ) {
745- for (int r = 0 ; r < n_rows ; r += 2 ) {
745+ const int n_rows_padded = hex_align_up (n_rows , HMX_FP16_TILE_N_ROWS );
746+ const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS ) * HMX_FP16_TILE_N_ROWS ;
747+
748+ int r = 0 ;
749+
750+ #pragma unroll(2)
751+ for (r = 0 ; r < n_rows_tiled ; r += 2 ) {
746752 int r0 = r / HMX_FP16_TILE_N_ROWS ; // tile row index
747753 int r1 = r % HMX_FP16_TILE_N_ROWS ; // intra-tile row idx
748754
749- const bool next_row_valid = (r + 1 ) < n_rows ;
750-
751755 const HVX_Vector * pv_in0 = (const HVX_Vector * ) (src + (r + 0 ) * k_stride );
752756 const HVX_Vector * pv_in1 = (const HVX_Vector * ) (src + (r + 1 ) * k_stride );
753757 for (int c = 0 ; c < k_block ; c += 32 ) {
754758 HVX_Vector v0 = * pv_in0 ++ ;
755- HVX_Vector v1 = next_row_valid ? * pv_in1 ++ : Q6_V_vzero ();
759+ HVX_Vector v1 = * pv_in1 ++ ;
760+
761+ HVX_Vector v_out = hvx_vec_f32_to_f16_shuff (v0 , v1 );
762+
763+ // compute output position
764+ int c0 = c / HMX_FP16_TILE_N_COLS ; // tile column index
765+ int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS ) + c0 ;
766+
767+ HVX_Vector * tile = (HVX_Vector * ) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS );
768+ tile [r1 / 2 ] = v_out ;
769+ }
770+ }
771+
772+ for (; r < n_rows_padded ; r += 2 ) {
773+ int r0 = r / HMX_FP16_TILE_N_ROWS ; // tile row index
774+ int r1 = r % HMX_FP16_TILE_N_ROWS ; // intra-tile row idx
775+
776+ const bool row0_valid = r < n_rows ;
777+ const bool row1_valid = (r + 1 ) < n_rows ;
778+
779+ const HVX_Vector * pv_in0 = row0_valid ? (const HVX_Vector * ) (src + (r + 0 ) * k_stride ) : NULL ;
780+ const HVX_Vector * pv_in1 = row1_valid ? (const HVX_Vector * ) (src + (r + 1 ) * k_stride ) : NULL ;
781+ for (int c = 0 ; c < k_block ; c += 32 ) {
782+ HVX_Vector v0 = row0_valid ? * pv_in0 ++ : Q6_V_vzero ();
783+ HVX_Vector v1 = row1_valid ? * pv_in1 ++ : Q6_V_vzero ();
756784
757785 HVX_Vector v_out = hvx_vec_f32_to_f16_shuff (v0 , v1 );
758786
@@ -889,7 +917,9 @@ static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct h
889917 // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper).
890918 const size_t m_block_cost = (size_t ) n * 3 ;
891919 const size_t n_block_cost = (size_t ) m * 2 ;
892- if (hmx_compute_chunks (vtcm_budget , overhead , per_n , per_m , per_mn , m , n , m_block_cost , n_block_cost , & M_BLOCK_SIZE ,
920+ if (hmx_compute_chunks (vtcm_budget , overhead , per_n , per_m , per_mn ,
921+ hex_align_up (m , HMX_FP16_TILE_N_ROWS ), n ,
922+ m_block_cost , n_block_cost , & M_BLOCK_SIZE ,
893923 & N_BLOCK_SIZE , & vtcm_used ) != 0 ) {
894924 FARF (HIGH , "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)" , __func__ , m , k , n , vtcm_budget );
895925 return -1 ;
@@ -1084,7 +1114,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
10841114
10851115 if (m >= 128 ) {
10861116 size_t mc = 0 , nc = 0 , used = 0 ;
1087- if (hmx_compute_chunks (vtcm_budget , /*overhead=*/ 256 , pipe_per_n , /*per_m=*/ vec_dot_size , pipe_per_mn , m , n ,
1117+ if (hmx_compute_chunks (vtcm_budget , /*overhead=*/ 256 , pipe_per_n , /*per_m=*/ vec_dot_size , pipe_per_mn ,
1118+ hex_align_up (m , HMX_FP16_TILE_N_ROWS ), n ,
10881119 /*m_block_cost=*/ (size_t ) n * 3 ,
10891120 /*n_block_cost=*/ (size_t ) m * 2 , & mc , & nc , & used ) == 0 &&
10901121 hmx_ceil_div ((size_t ) n , nc ) >= 2 ) {
@@ -1096,7 +1127,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
10961127 }
10971128
10981129 if (!use_pipeline ) {
1099- if (hmx_compute_chunks (vtcm_budget , /*overhead=*/ 256 , seq_per_n , /*per_m=*/ vec_dot_size , seq_per_mn , m , n ,
1130+ if (hmx_compute_chunks (vtcm_budget , /*overhead=*/ 256 , seq_per_n , /*per_m=*/ vec_dot_size , seq_per_mn ,
1131+ hex_align_up (m , HMX_FP16_TILE_N_ROWS ), n ,
11001132 /*m_block_cost=*/ (size_t ) n * 3 ,
11011133 /*n_block_cost=*/ (size_t ) m * 2 , & m_chunk_n_rows , & n_chunk_n_cols , & vtcm_used ) != 0 ) {
11021134 FARF (HIGH , "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)" , __func__ , m , k , n , vtcm_budget );
@@ -1432,7 +1464,8 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
14321464 if (hmx_compute_chunks (vtcm_budget , /*overhead=*/ 256 ,
14331465 /*per_n=*/ 3 * vec_dot_size ,
14341466 /*per_m=*/ group_size * vec_dot_size + f32_scratch_per_m ,
1435- /*per_mn=*/ sizeof (__fp16 ), params -> m , params -> n ,
1467+ /*per_mn=*/ sizeof (__fp16 ),
1468+ hex_align_up (params -> m , HMX_FP16_TILE_N_ROWS ), params -> n ,
14361469 /*m_block_cost=*/ (size_t ) params -> n ,
14371470 /*n_block_cost=*/ (size_t ) params -> m , & m_chunk_n_rows , & n_chunk_n_cols , & vtcm_used ) != 0 ) {
14381471 FARF (HIGH , "%s: grouped path does not fit VTCM, falling back to legacy batched loop" , __func__ );
@@ -1612,7 +1645,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
16121645 /*per_n=*/ 3 * vec_dot_size , // W + S0 + S1
16131646 /*per_m=*/ vec_dot_size + f32_scratch_per_m , // A + optional F32 scratch
16141647 /*per_mn=*/ sizeof (__fp16 ), // O
1615- m , n ,
1648+ hex_align_up ( m , HMX_FP16_TILE_N_ROWS ) , n ,
16161649 /*m_block_cost=*/ (size_t ) n ,
16171650 /*n_block_cost=*/ (size_t ) m , & m_chunk_n_rows , & n_chunk_n_cols , & vtcm_used ) != 0 ) {
16181651 FARF (HIGH , "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)" , __func__ , m , k , n , vtcm_budget );
0 commit comments