Skip to content

Commit bbeb89d

Browse files
Hexagon: Process M-tail rows on HMX instead of HVX (ggml-org#22724)
* hex-mm: process m-tail rows on HMX instead of HVX * hmx-mm: unroll and optimize padded activation loop --------- Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
1 parent ff806a1 commit bbeb89d

2 files changed

Lines changed: 48 additions & 39 deletions

File tree

ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -742,17 +742,45 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
742742
// activations : fp32 -> fp16
743743

744744
static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) {
745-
for (int r = 0; r < n_rows; r += 2) {
745+
const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS);
746+
const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS;
747+
748+
int r = 0;
749+
750+
#pragma unroll(2)
751+
for (r = 0; r < n_rows_tiled; r += 2) {
746752
int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index
747753
int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx
748754

749-
const bool next_row_valid = (r + 1) < n_rows;
750-
751755
const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride);
752756
const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride);
753757
for (int c = 0; c < k_block; c += 32) {
754758
HVX_Vector v0 = *pv_in0++;
755-
HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero();
759+
HVX_Vector v1 = *pv_in1++;
760+
761+
HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1);
762+
763+
// compute output position
764+
int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index
765+
int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0;
766+
767+
HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS);
768+
tile[r1 / 2] = v_out;
769+
}
770+
}
771+
772+
for (; r < n_rows_padded; r += 2) {
773+
int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index
774+
int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx
775+
776+
const bool row0_valid = r < n_rows;
777+
const bool row1_valid = (r + 1) < n_rows;
778+
779+
const HVX_Vector *pv_in0 = row0_valid ? (const HVX_Vector *) (src + (r + 0) * k_stride) : NULL;
780+
const HVX_Vector *pv_in1 = row1_valid ? (const HVX_Vector *) (src + (r + 1) * k_stride) : NULL;
781+
for (int c = 0; c < k_block; c += 32) {
782+
HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero();
783+
HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero();
756784

757785
HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1);
758786

@@ -889,7 +917,9 @@ static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct h
889917
// n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper).
890918
const size_t m_block_cost = (size_t) n * 3;
891919
const size_t n_block_cost = (size_t) m * 2;
892-
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE,
920+
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn,
921+
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
922+
m_block_cost, n_block_cost, &M_BLOCK_SIZE,
893923
&N_BLOCK_SIZE, &vtcm_used) != 0) {
894924
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
895925
return -1;
@@ -1084,7 +1114,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
10841114

10851115
if (m >= 128) {
10861116
size_t mc = 0, nc = 0, used = 0;
1087-
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, m, n,
1117+
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn,
1118+
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
10881119
/*m_block_cost=*/(size_t) n * 3,
10891120
/*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 &&
10901121
hmx_ceil_div((size_t) n, nc) >= 2) {
@@ -1096,7 +1127,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
10961127
}
10971128

10981129
if (!use_pipeline) {
1099-
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, m, n,
1130+
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn,
1131+
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
11001132
/*m_block_cost=*/(size_t) n * 3,
11011133
/*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
11021134
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
@@ -1432,7 +1464,8 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
14321464
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256,
14331465
/*per_n=*/3 * vec_dot_size,
14341466
/*per_m=*/group_size * vec_dot_size + f32_scratch_per_m,
1435-
/*per_mn=*/sizeof(__fp16), params->m, params->n,
1467+
/*per_mn=*/sizeof(__fp16),
1468+
hex_align_up(params->m, HMX_FP16_TILE_N_ROWS), params->n,
14361469
/*m_block_cost=*/(size_t) params->n,
14371470
/*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
14381471
FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__);
@@ -1612,7 +1645,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
16121645
/*per_n=*/3 * vec_dot_size, // W + S0 + S1
16131646
/*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch
16141647
/*per_mn=*/sizeof(__fp16), // O
1615-
m, n,
1648+
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
16161649
/*m_block_cost=*/(size_t) n,
16171650
/*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
16181651
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);

ggml/src/ggml-hexagon/htp/matmul-ops.c

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2991,12 +2991,10 @@ int op_matmul(struct htp_ops_context * octx) {
29912991
return op_matmul_hvx(octx);
29922992
}
29932993

2994-
// M alignment: when M > 32 but not 32-aligned, we split into
2995-
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
2996-
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
2994+
// M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows)
2995+
// is handled by HMX itself; when M < 32 fall back to HVX.
29972996
const int m_total = (int) src1->ne[1];
2998-
const int m_tail = m_total % 32;
2999-
const int m_hmx = m_total - m_tail;
2997+
const int m_hmx = m_total & ~31; // 0 when M < 32
30002998

30012999
if (m_hmx == 0) {
30023000
return op_matmul_hvx(octx);
@@ -3009,7 +3007,6 @@ int op_matmul(struct htp_ops_context * octx) {
30093007
int k = (int) src0->ne[0]; // inner dimension
30103008
int n = (int) src0->ne[1]; // weight columns
30113009

3012-
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
30133010
int ret = -1;
30143011

30153012
// Row strides in elements. For compact tensors these equal k; for
@@ -3027,7 +3024,7 @@ int op_matmul(struct htp_ops_context * octx) {
30273024
.dst = (float *) dst->data,
30283025
.activation = (float *) src1->data,
30293026
.permuted_weight = (const __fp16 *) src0->data,
3030-
.m = m_hmx,
3027+
.m = m_total,
30313028
.k = k,
30323029
.n = n,
30333030
.act_stride = act_stride,
@@ -3048,40 +3045,19 @@ int op_matmul(struct htp_ops_context * octx) {
30483045
} else {
30493046
ret = hmx_mat_mul_permuted_w16a32(octx->ctx,
30503047
(float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data,
3051-
m_hmx, k, n, act_stride, wgt_stride);
3048+
m_total, k, n, act_stride, wgt_stride);
30523049
}
30533050
} else {
30543051
ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx,
30553052
(float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
3056-
m_hmx, k, n, (int) src0->type);
3053+
m_total, k, n, (int) src0->type);
30573054
}
30583055

30593056
if (ret != 0) {
30603057
FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret);
30613058
return op_matmul(octx);
30623059
}
30633060

3064-
// --- Phase 2: HVX on the remaining m_tail rows ---
3065-
if (m_tail > 0) {
3066-
// copy of src1 and dst
3067-
struct htp_tensor src1_tail = *src1;
3068-
struct htp_tensor dst_tail = *dst;
3069-
3070-
src1_tail.ne[1] = m_tail; // only tail rows
3071-
dst_tail.ne[1] = m_tail; // only tail rows
3072-
3073-
// Offset activation and dst pointers past the HMX-processed rows.
3074-
// Use nb[1] (row stride in bytes) to compute the byte offset.
3075-
src1_tail.data += (uint32_t) m_hmx * src1->nb[1];
3076-
dst_tail.data += (uint32_t) m_hmx * dst->nb[1];
3077-
3078-
octx->src[1] = &src1_tail;
3079-
octx->dst = &dst_tail;
3080-
3081-
FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data);
3082-
return op_matmul_hvx(octx);
3083-
}
3084-
30853061
return 0;
30863062
#endif // HTP_HAS_HMX
30873063
}

0 commit comments

Comments
 (0)