Skip to content

Commit 2098fd6

Browse files
authored
hexagon: enable non-contiguous row tensor support for unary ops (ggml-org#22574)
1 parent ab6120c commit 2098fd6

3 files changed

Lines changed: 85 additions & 33 deletions

File tree

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,8 +2421,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
24212421
return false;
24222422
}
24232423

2424-
// TODO: add support for non-contigiuos tensors
2425-
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
2424+
// TODO: add support for non-contiguous elements within a row
2425+
if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) {
24262426
return false;
24272427
}
24282428

ggml/src/ggml-hexagon/htp/hvx-exp.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
1818
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
1919
#define EXP_ONE (0x3f800000) // 1.0
20-
#define EXP_RANGE_R (0x42B16666) // 88.7
20+
#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228
2121
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
2222

2323
static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
@@ -163,7 +163,7 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict
163163
HVX_Vector vec_out = Q6_V_vzero();
164164

165165
static const float kInf = INFINITY;
166-
static const float kMaxExp = 88.7f;
166+
static const float kMaxExp = 88.7228f;
167167

168168
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
169169
const HVX_Vector inf = hvx_vec_splat_f32(kInf);

ggml/src/ggml-hexagon/htp/unary-ops.c

Lines changed: 81 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ struct htp_unary_context {
2626
const uint8_t * data_src0;
2727
uint8_t * data_dst;
2828

29-
size_t src0_row_size;
30-
size_t dst_row_size;
29+
size_t src0_data_row_size; // actual data bytes per row
30+
size_t dst_data_row_size; // actual data bytes per row
3131

3232
size_t src0_row_size_aligned;
3333
size_t dst_row_size_aligned;
@@ -41,6 +41,40 @@ struct htp_unary_context {
4141
uint32_t nc;
4242
};
4343

44+
// Convert flat row index to DDR byte offset using the tensor's actual strides.
45+
// ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3
46+
static inline size_t unary_row_offset(uint32_t ir,
47+
uint32_t ne1, uint32_t ne2,
48+
size_t nb1, size_t nb2, size_t nb3) {
49+
const uint32_t i1 = ir % ne1;
50+
const uint32_t i2 = (ir / ne1) % ne2;
51+
const uint32_t i3 = ir / (ne1 * ne2);
52+
return i1 * nb1 + i2 * nb2 + i3 * nb3;
53+
}
54+
// Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice
55+
// boundary of src and dst so the nb1 stride stays valid for all rows.
56+
static inline uint32_t unary_block_size(uint32_t ir,
57+
uint32_t end_row,
58+
uint32_t block,
59+
bool src_contig,
60+
bool dst_contig,
61+
uint32_t src_ne1,
62+
uint32_t dst_ne1) {
63+
uint32_t limit = MIN(block, end_row - ir);
64+
65+
if (!src_contig) {
66+
const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1;
67+
limit = MIN(limit, src_slice_end - ir);
68+
}
69+
70+
if (!dst_contig) {
71+
const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1;
72+
limit = MIN(limit, dst_slice_end - ir);
73+
}
74+
75+
return limit;
76+
}
77+
4478
#define htp_unary_preamble \
4579
const uint32_t ne00 = src->ne[0]; \
4680
const uint32_t ne01 = src->ne[1]; \
@@ -276,8 +310,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
276310
int32_t * op_params = octx->op_params;
277311
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
278312

279-
const size_t src0_row_size = uctx->src0_row_size;
280-
const size_t dst_row_size = uctx->dst_row_size;
313+
const size_t src0_data_row_size = uctx->src0_data_row_size;
314+
const size_t dst_data_row_size = uctx->dst_data_row_size;
281315

282316
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
283317
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
@@ -303,7 +337,16 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
303337
size_t src0_spad_half_size = uctx->src0_spad_half_size;
304338
size_t dst_spad_half_size = uctx->dst_spad_half_size;
305339

306-
const int BLOCK = uctx->block;
340+
// Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride
341+
// 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every
342+
// transfer stays within a nb1-uniform region. Skipped for contiguous tensors.
343+
const bool src0_contig = (nb02 == (size_t)ne01 * nb01) &&
344+
(nb03 == (size_t)ne02 * nb02);
345+
const bool dst_contig = (nb2 == (size_t)ne1 * nb1) &&
346+
(nb3 == (size_t)ne2 * nb2);
347+
const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01);
348+
const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1);
349+
const uint32_t BLOCK = MIN(src0_max_block, dst_max_block);
307350
if (BLOCK == 0) {
308351
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
309352
octx->src0_spad.size_per_thread, src0_row_size_aligned);
@@ -312,21 +355,23 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
312355

313356
dma_queue * dma_queue = octx->ctx->dma[ith];
314357

315-
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
316-
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
358+
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) {
359+
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
317360

318361
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
319-
dma_queue_push_vtcm_to_ddr(dma_queue,
362+
dma_queue_push(dma_queue,
320363
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
321-
dst_row_size, dst_row_size_aligned, 0);
364+
nb1, dst_row_size_aligned, dst_data_row_size, 0);
322365

323-
dma_queue_push_ddr_to_vtcm(dma_queue,
324-
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
325-
src0_row_size_aligned, src0_row_size, block_size);
366+
const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03);
367+
dma_queue_push(dma_queue,
368+
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off),
369+
src0_row_size_aligned, nb01, src0_data_row_size, block_size);
370+
ir += block_size;
326371
}
327372

328-
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
329-
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
373+
for (uint32_t ir = src0_start_row; ir < src0_end_row; ) {
374+
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
330375

331376
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
332377
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
@@ -361,18 +406,25 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
361406
break;
362407
}
363408

364-
dma_queue_push_vtcm_to_ddr(dma_queue,
365-
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
366-
dst_row_size, dst_row_size_aligned, block_size);
409+
const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3);
410+
dma_queue_push(dma_queue,
411+
dma_make_ptr(data_dst + dst_off, dst_spad),
412+
nb1, dst_row_size_aligned, dst_data_row_size, block_size);
367413

368414
// prefetch N+2 loop iteration if any
369-
const uint32_t pref_block = (ir + BLOCK * 2);
370-
if (pref_block < src0_end_row) {
371-
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
372-
dma_queue_push_ddr_to_vtcm(dma_queue,
373-
dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
374-
src0_row_size_aligned, src0_row_size, pref_block_size);
415+
const uint32_t next_ir = ir + block_size;
416+
if (next_ir < src0_end_row) {
417+
const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
418+
const uint32_t pref_ir = next_ir + next_block_size;
419+
if (pref_ir < src0_end_row) {
420+
const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
421+
const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03);
422+
dma_queue_push(dma_queue,
423+
dma_make_ptr(src0_spad, data_src + src0_pref_off),
424+
src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size);
425+
}
375426
}
427+
ir += block_size;
376428
}
377429

378430
dma_queue_flush(dma_queue);
@@ -426,11 +478,11 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
426478
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
427479
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
428480

429-
const size_t src0_row_size = src0->nb[1];
430-
const size_t dst_row_size = dst->nb[1];
481+
const size_t src0_data_row_size = src0->ne[0] * sizeof(float);
482+
const size_t dst_data_row_size = dst->ne[0] * sizeof(float);
431483

432-
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
433-
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
484+
const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN);
485+
const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN);
434486

435487
// VTCM scratchpads for all tensors
436488
// N rows per thread, padded to HVX vector size
@@ -468,8 +520,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
468520
.data_src0 = (const uint8_t *)src0->data,
469521
.data_dst = (uint8_t *)dst->data,
470522

471-
.src0_row_size = src0_row_size,
472-
.dst_row_size = dst_row_size,
523+
.src0_data_row_size = src0_data_row_size,
524+
.dst_data_row_size = dst_data_row_size,
473525

474526
.src0_row_size_aligned = src0_row_size_aligned,
475527
.dst_row_size_aligned = dst_row_size_aligned,

0 commit comments

Comments
 (0)