Skip to content

Commit 799ebb9

Browse files
chraacarthw
authored andcommitted
hexagon: optimize HMX matmul operations (ggml-org#21071)
* optimize hmx_mat_mul functions by calculating row and column tiles upfront * refactor core_dot_chunk_fp16 to use size_t for tile counts and improve readability * wip * set scale outside of loop * wip * refactor core_mma_chunk_fp16 and mat_mul_qk_0_d16a32 to use size_t for tile counts * wip * wip * refactor transfer_output_chunk_fp16_to_fp32 to use size_t for dimensions * refactor core_dot_chunk_fp16 to use size_t for tile row stride calculation * wip * refactor hmx_mat_mul functions to use hvx_vec_splat_f16 for column scales initialization * refactor hmx_mat_mul_permuted_w16a32_batched to streamline scale setting and locking * refactor core_dot_chunk_fp16 to improve tile stride calculations for output * refactor hmx_mat_mul functions to use Q6_V_vsplat_R for column scales initialization * fix compiling error * wip * optimize row and column tile indexing in core_mma_chunk_fp16 function * wip * Revert "wip" This reverts commit cde679e. * Add size limit check for HAP_mmap in htp_iface_mmap and drop_mmap functions * wip
1 parent 5d34078 commit 799ebb9

3 files changed

Lines changed: 80 additions & 49 deletions

File tree

ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -648,9 +648,9 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
648648
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
649649
assert(k_block % HMX_FP16_TILE_N_COLS == 0);
650650

651-
int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
652-
int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
653-
int n_tot_tiles = n_col_tiles * n_k_tiles;
651+
size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
652+
size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
653+
size_t n_tot_tiles = n_col_tiles * n_k_tiles;
654654

655655
size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads);
656656

@@ -678,9 +678,8 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict
678678
__builtin_assume(n_dot_tiles > 0);
679679

680680
Q6_bias_mxmem2_A((void *)scales);
681-
682681
for (int r = 0; r < n_row_tiles; ++r) {
683-
for (int c = 0; c < n_col_tiles; ++c) {
682+
for (size_t c = 0; c < n_col_tiles; ++c) {
684683
Q6_mxclracc_hf();
685684

686685
const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
@@ -738,25 +737,25 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job,
738737

739738
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) {
740739
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
741-
const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
740+
const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS;
742741

743742
const HVX_Vector one = hvx_vec_splat_f16(1.0);
744743

745-
for (int r = 0; r < n_rows; r += 2) {
746-
int r0 = r / HMX_FP16_TILE_N_ROWS;
747-
int r1 = r % HMX_FP16_TILE_N_ROWS;
744+
for (size_t r = 0; r < n_rows; r += 2) {
745+
const size_t r0 = r / HMX_FP16_TILE_N_ROWS;
746+
const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile
747+
const __fp16 *row_base = vtcm_src + r0 * tile_row_stride;
748+
float *output_row_base = dst + r * n; // global memory row base for row r (and r+1)
748749

749750
#pragma unroll(4)
750-
for (int c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) {
751-
int c0 = c / HMX_FP16_TILE_N_COLS;
752-
753-
const __fp16 *tile = vtcm_src + (r0 * n_col_tiles + c0) * HMX_FP16_TILE_N_ELMS;
754-
755-
HVX_Vector v = ((const HVX_Vector *) tile)[r1 / 2];
751+
for (size_t c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) {
752+
const size_t c0 = c / HMX_FP16_TILE_N_COLS;
753+
const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS;
754+
HVX_Vector v = ((const HVX_Vector *) tile)[r1];
756755
HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one);
757756

758-
volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (dst + (r * n + c + 0));
759-
volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (dst + (r * n + c + n)); // next row in global memory
757+
volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row_base + c + 0);
758+
volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (output_row_base + c + n); // next row in global memory
760759

761760
*pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
762761
if (r + 1 < n_rows) {
@@ -794,7 +793,7 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
794793
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
795794

796795
size_t n_tot_chunks = n_rows;
797-
size_t n_chunks_per_task = 32; // must be multiple of HMX_FP16_TILE_N_ROWS (32)
796+
size_t n_chunks_per_task = HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32)
798797

799798
output_transfer_task_state_t state;
800799
state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task;
@@ -926,7 +925,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
926925
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
927926
}
928927

929-
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0
928+
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
930929

931930
FARF(MEDIUM, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu",
932931
__func__, params->m, params->k, params->n, group_size, params->ne13,
@@ -944,12 +943,15 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
944943
const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16);
945944
const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16);
946945

946+
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
947+
947948
for (int b3 = 0; b3 < params->ne13; ++b3) {
948949
for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) {
949950
const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3);
950951

951952
for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) {
952953
const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows);
954+
const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS);
953955

954956
// Pre-load activations for all heads in the group (once per m_chunk).
955957
// When the source is strided (permuted Q), use 2D DMA to gather
@@ -987,10 +989,9 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
987989
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first);
988990
}
989991

990-
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
991-
992992
for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) {
993993
const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols);
994+
const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS);
994995

995996
TIMER_START(weight_load);
996997
{
@@ -1014,11 +1015,9 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
10141015
for (int g = 0; g < group_size; ++g) {
10151016
TIMER_START(hmx_core);
10161017
{
1017-
const __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
1018-
const int n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS);
1019-
const int n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS);
1020-
core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales,
1021-
n_row_tiles, n_col_tiles, params->k / 32);
1018+
const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
1019+
core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles,
1020+
params->k / 32);
10221021
}
10231022
TIMER_STOP(hmx_core);
10241023

@@ -1030,12 +1029,12 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
10301029
TIMER_STOP(output_store);
10311030
}
10321031
}
1033-
1034-
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
10351032
}
10361033
}
10371034
}
10381035

1036+
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
1037+
10391038
TIMER_STOP(total);
10401039

10411040
#if defined(ENABLE_PROFILE_TIMERS)
@@ -1103,7 +1102,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
11031102
return -1;
11041103
}
11051104

1106-
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0
1105+
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
11071106

11081107
FARF(MEDIUM, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu",
11091108
__func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols,
@@ -1121,7 +1120,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
11211120

11221121
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
11231122
// transfer activation matrix chunk into VTCM
1124-
size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
1123+
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
1124+
const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
11251125

11261126
TIMER_START(activation_load);
11271127
{
@@ -1159,7 +1159,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
11591159
}
11601160

11611161
for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
1162-
size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
1162+
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
1163+
const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
11631164

11641165
TIMER_START(weight_load);
11651166
{
@@ -1184,8 +1185,6 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
11841185

11851186
TIMER_START(hmx_core);
11861187
{
1187-
const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
1188-
const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
11891188
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32);
11901189
}
11911190
TIMER_STOP(hmx_core);
@@ -1307,7 +1306,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
13071306
return -1;
13081307
}
13091308

1310-
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0
1309+
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
13111310

13121311
FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu",
13131312
__func__, m, k, n, weight_type, use_pipeline,
@@ -1330,7 +1329,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
13301329
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
13311330
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
13321331
// transfer activation matrix chunk into VTCM
1333-
size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
1332+
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
1333+
const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
13341334

13351335
TIMER_START(activation_load);
13361336
{
@@ -1348,7 +1348,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
13481348
}
13491349

13501350
for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
1351-
size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
1351+
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
1352+
const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
13521353

13531354
TIMER_START(weight_load);
13541355
{
@@ -1373,8 +1374,6 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
13731374

13741375
TIMER_START(hmx_core);
13751376
{
1376-
const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
1377-
const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
13781377
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32);
13791378
}
13801379
TIMER_STOP(hmx_core);
@@ -1521,14 +1520,16 @@ void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __f
15211520

15221521
Q6_bias_mxmem2_A((void *)col_scales);
15231522

1524-
for (int i = 0; i < n_row_tiles; ++i) {
1525-
for (int j = 0; j < n_col_tiles; ++j) {
1523+
const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS;
1524+
for (size_t i = 0; i < n_row_tiles; ++i) {
1525+
const __fp16 *row_base = a + i * dot_tile_stride;
1526+
__fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS;
1527+
for (size_t j = 0; j < n_col_tiles; ++j) {
15261528
Q6_mxclracc_hf();
15271529

1528-
const __fp16 *row_tiles = a + i * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
1529-
const __fp16 *col_tiles = b + j * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
1530-
1531-
__fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS;
1530+
const __fp16 *col_tiles = b + j * dot_tile_stride;
1531+
const __fp16 *row_tiles = row_base;
1532+
__fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS;
15321533
if (!zero_init) {
15331534
Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047);
15341535
Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047);
@@ -1697,7 +1698,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
16971698
v = Q6_V_vror_VR(v, VLEN - 8);
16981699
}
16991700
}
1700-
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0
1701+
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16
17011702

17021703
TIMER_DEFINE(fetch);
17031704
TIMER_DEFINE(act_load);
@@ -1715,7 +1716,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
17151716
const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS);
17161717

17171718
for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) {
1718-
size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE);
1719+
const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE);
17191720

17201721
TIMER_START(fetch);
17211722
// fetch activation block into VTCM
@@ -1731,13 +1732,13 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
17311732
}
17321733

17331734
// fetch weight block into VTCM (x4x2 sub-block: quants + scales)
1735+
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
17341736
{
17351737
qweight_fetch_task_state_t s;
17361738

17371739
const int blk_start = kk / QK_Q4_0x4x2;
17381740
const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
17391741
const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2);
1740-
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
17411742
const int scale_blk_size =
17421743
(weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE;
17431744

@@ -1777,7 +1778,6 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
17771778
dma_queue_pop(ctx->dma[0]);
17781779
// vtcm_scratch0 is used to store the qweight chunk
17791780
// worker_pool_run_func already returned, so fetch is done
1780-
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
17811781
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0,
17821782
n_blk_sz, k_blk_sz, sub_row_stride, weight_type);
17831783
}

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ enum htp_op_code {
9898
#define HTP_OP_MAX_VMEM (3221225472u)
9999
#endif
100100

101+
#define HTP_MMAP_MAX_VMEM (2147483648u)
102+
101103
enum htp_tensor_flags {
102104
HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights)
103105
HTP_TENSOR_FLUSHED = (1U << 1) // Tensor buffer has been flushed (set by the NPU)

ggml/src/ggml-hexagon/htp/main.c

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,11 @@ AEEResult htp_iface_close(remote_handle64 handle) {
118118
// release the mmaps (if any)
119119
for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) {
120120
if (ctx->mmap[i].size) {
121+
#if __HVX_ARCH__ > 73
121122
HAP_munmap2((void *) ctx->mmap[i].base, ctx->mmap[i].size);
123+
#else
124+
HAP_munmap((void *) ctx->mmap[i].base, ctx->mmap[i].size);
125+
#endif
122126
ctx->mmap[i].size = 0;
123127
ctx->mmap[i].base = NULL;
124128
ctx->mmap[i].fd = -1;
@@ -173,8 +177,16 @@ AEEResult htp_iface_mmap(remote_handle64 handle, int fd, uint32_t size, uint32_t
173177
struct htp_mmap *m = &ctx->mmap[i];
174178
if (!m->size) {
175179
FARF(HIGH, "mmap : fd %u size %u pinned %u", fd, size, pinned);
176-
180+
#if __HVX_ARCH__ > 73
177181
void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0);
182+
#else
183+
if (size > HTP_MMAP_MAX_VMEM) { // HAP_mmap has a size limit of 2GB
184+
FARF(ERROR, "mmap failed : size %u exceeds 2GB limit for HAP_mmap", (uint32_t) size);
185+
abort(); // can't do much else at this point
186+
}
187+
188+
void *va = HAP_mmap(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0);
189+
#endif
178190
if (va == (void*)-1) {
179191
FARF(ERROR, "mmap failed : va %p fd %u size %u", va, fd, (uint32_t) size);
180192
return AEE_EFAILED;
@@ -202,7 +214,11 @@ AEEResult htp_iface_munmap(remote_handle64 handle, int fd) {
202214
struct htp_mmap *m = &ctx->mmap[i];
203215
if (fd < 0 || m->fd == fd) {
204216
FARF(HIGH, "unmmap : base %p fd %u size %u", (void*) m->base, m->fd, (uint32_t) m->size);
217+
#if __HVX_ARCH__ > 73
205218
HAP_munmap2((void *) m->base, m->size);
219+
#else
220+
HAP_munmap((void *) m->base, m->size);
221+
#endif
206222
m->size = 0;
207223
m->base = NULL;
208224
m->fd = -1;
@@ -526,7 +542,11 @@ static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct
526542
static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) {
527543
if (m->size && !m->pinned) {
528544
FARF(HIGH, "unmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned);
545+
#if __HVX_ARCH__ > 73
529546
HAP_munmap2((void *) m->base, m->size);
547+
#else
548+
HAP_munmap((void *) m->base, m->size);
549+
#endif
530550
m->size = 0;
531551
m->base = 0;
532552
m->fd = -1;
@@ -540,7 +560,16 @@ static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) {
540560
for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) {
541561
struct htp_mmap *m = &ctx->mmap[i];
542562
if (!m->size) {
563+
#if __HVX_ARCH__ > 73
543564
void *va = HAP_mmap2(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0);
565+
#else
566+
if (b->size > HTP_MMAP_MAX_VMEM) { // HAP_mmap has a size limit of 2GB
567+
FARF(ERROR, "mmap failed : size %u exceeds 2GB limit for HAP_mmap", (uint32_t) b->size);
568+
abort(); // can't do much else at this point
569+
}
570+
571+
void *va = HAP_mmap(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0);
572+
#endif
544573
if (va == (void*)-1) {
545574
FARF(ERROR, "mmap failed : va %p fd %u size %u", va, b->fd, (uint32_t) b->size);
546575
abort(); // can't do much else at this point

0 commit comments

Comments
 (0)