Skip to content

Commit 9a532ae

Browse files
pdhinakatboinovski1max-krasnyansky
authored
hexagon: add support for TRI op (#22822)
* Hexagon: TRI HVX Kernel addition to ggml hexagon HTP ops and context * addressed PR review comments for TRI op * hexagon: clang format * hex-unary: remove merge conflict markers * hex-ggml: remove duplicate op cases (merge conflict) * hex-ggml: fix editor config errors --------- Co-authored-by: Todor Boinovski <todorb@qti.qualcomm.com> Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
1 parent b734044 commit 9a532ae

5 files changed

Lines changed: 137 additions & 1 deletion

File tree

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,6 +2828,21 @@ static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session *
28282828
return true;
28292829
}
28302830

2831+
static bool ggml_hexagon_supported_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2832+
2833+
const struct ggml_tensor * src0 = op->src[0];
2834+
const struct ggml_tensor * dst = op;
2835+
2836+
if (src0->type != GGML_TYPE_F32) { return false; }
2837+
if (dst->type != GGML_TYPE_F32) { return false; }
2838+
if (!ggml_are_same_shape(src0, dst)) { return false; }
2839+
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { return false; }
2840+
2841+
return true;
2842+
2843+
GGML_UNUSED(sess);
2844+
}
2845+
28312846
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
28322847
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
28332848
return sess->c_name();
@@ -2869,6 +2884,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
28692884
case GGML_OP_FILL: return HTP_OP_FILL;
28702885
case GGML_OP_DIAG: return HTP_OP_DIAG;
28712886
case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI;
2887+
case GGML_OP_TRI: return HTP_OP_TRI;
28722888
case GGML_OP_PAD: return HTP_OP_PAD;
28732889

28742890
case GGML_OP_UNARY:
@@ -3430,6 +3446,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
34303446
supp = ggml_hexagon_supported_solve_tri(sess, op);
34313447
break;
34323448

3449+
case GGML_OP_TRI:
3450+
supp = ggml_hexagon_supported_tri(sess, op);
3451+
break;
3452+
34333453
case GGML_OP_PAD:
34343454
supp = ggml_hexagon_supported_pad(sess, op);
34353455
break;

ggml/src/ggml-hexagon/htp/htp-ctx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ int op_fill(struct htp_ops_context * octx);
107107
int op_diag(struct htp_ops_context * octx);
108108
int op_solve_tri(struct htp_ops_context * octx);
109109
int op_gated_delta_net(struct htp_ops_context * octx);
110+
int op_tri(struct htp_ops_context * octx);
110111
int op_pad(struct htp_ops_context * octx);
111112

112113
#endif /* HTP_CTX_H */

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ enum htp_op_code {
8686
HTP_OP_SOLVE_TRI,
8787
HTP_OP_L2_NORM,
8888
HTP_OP_GATED_DELTA_NET,
89+
HTP_OP_TRI,
8990
HTP_OP_PAD,
9091

9192
HTP_OP_INVALID

ggml/src/ggml-hexagon/htp/main.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,9 @@ static int execute_op(struct htp_ops_context * octx) {
601601
case HTP_OP_GATED_DELTA_NET:
602602
return op_gated_delta_net(octx);
603603

604+
case HTP_OP_TRI:
605+
return op_tri(octx);
606+
604607
case HTP_OP_INVALID:
605608
break;
606609

ggml/src/ggml-hexagon/htp/unary-ops.c

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "ggml-common.h"
1818
#include "htp-ctx.h"
1919
#include "htp-ops.h"
20-
#include "htp-ops.h"
2120

2221
struct htp_unary_context {
2322
struct htp_ops_context * octx;
@@ -277,6 +276,95 @@ static void sigmoid_f32(const float * restrict src,
277276
}
278277
}
279278

279+
static void tri_f32(const float * restrict src,
280+
float * restrict dst,
281+
uint8_t * restrict spad,
282+
const uint32_t num_rows,
283+
const uint32_t row_elems,
284+
const size_t row_size,
285+
int32_t * op_params,
286+
const uint32_t ir,
287+
const struct htp_unary_context * uctx) {
288+
289+
const int32_t ttype = op_params[0];
290+
const HVX_Vector zero = hvx_vec_splat_f32(0.0f);
291+
const uint32_t nvec = row_elems / VLEN_FP32;
292+
const uint32_t nloe = row_elems % VLEN_FP32;
293+
294+
const uint32_t ne01 = uctx->octx->src[0]->ne[1];
295+
296+
for (uint32_t b = 0; b < num_rows; b++) {
297+
const uint32_t abs_row = ir + b;
298+
const uint32_t i01 = abs_row % ne01;
299+
300+
const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size);
301+
HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size);
302+
303+
uint32_t boundary;
304+
int keep_left;
305+
switch (ttype) {
306+
case 0: boundary = i01; keep_left = 0; break; // keep col >= row
307+
case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row
308+
case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row
309+
case 3: boundary = i01; keep_left = 1; break; // keep col < row
310+
default: boundary = 0; keep_left = 0; break;
311+
}
312+
if (boundary > row_elems) boundary = row_elems;
313+
314+
// Full HVX vectors — each starts at a 128-byte aligned offset
315+
for (uint32_t i = 0; i < nvec; i++) {
316+
const uint32_t vec_start = i * VLEN_FP32;
317+
const uint32_t vec_end = vec_start + VLEN_FP32;
318+
if (keep_left) {
319+
if (vec_end <= boundary) {
320+
v_dst[i] = v_src[i];
321+
} else if (vec_start >= boundary) {
322+
v_dst[i] = zero;
323+
} else {
324+
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
325+
v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero);
326+
}
327+
} else {
328+
if (vec_end <= boundary) {
329+
v_dst[i] = zero;
330+
} else if (vec_start >= boundary) {
331+
v_dst[i] = v_src[i];
332+
} else {
333+
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
334+
v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]);
335+
}
336+
}
337+
}
338+
339+
// Tail elements (row_elems not a multiple of VLEN_FP32)
340+
if (nloe > 0) {
341+
const uint32_t vec_start = nvec * VLEN_FP32;
342+
const uint32_t vec_end = vec_start + nloe;
343+
HVX_Vector tail_val;
344+
if (keep_left) {
345+
if (vec_end <= boundary) {
346+
tail_val = v_src[nvec];
347+
} else if (vec_start >= boundary) {
348+
tail_val = zero;
349+
} else {
350+
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
351+
tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero);
352+
}
353+
} else {
354+
if (vec_end <= boundary) {
355+
tail_val = zero;
356+
} else if (vec_start >= boundary) {
357+
tail_val = v_src[nvec];
358+
} else {
359+
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
360+
tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]);
361+
}
362+
}
363+
hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val);
364+
}
365+
}
366+
}
367+
280368
static void softplus_f32(const float * restrict src,
281369
float * restrict dst,
282370
uint8_t * restrict spad,
@@ -498,6 +586,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
498586
case HTP_OP_L2_NORM:
499587
l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
500588
break;
589+
case HTP_OP_TRI:
590+
tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx);
591+
break;
501592
default:
502593
break;
503594
}
@@ -571,6 +662,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
571662
case HTP_OP_L2_NORM:
572663
op_type = "l2norm-f32";
573664
break;
665+
case HTP_OP_TRI:
666+
op_type = "tri-f32";
667+
break;
668+
574669
default:
575670
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
576671
return HTP_STATUS_NO_SUPPORT;
@@ -640,6 +735,22 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
640735
return err;
641736
}
642737

738+
int op_tri(struct htp_ops_context * octx) {
739+
int err = HTP_STATUS_OK;
740+
741+
switch (octx->src[0]->type) {
742+
case HTP_TYPE_F32:
743+
err = execute_op_unary_f32(octx);
744+
break;
745+
746+
default:
747+
err = HTP_STATUS_NO_SUPPORT;
748+
break;
749+
}
750+
751+
return err;
752+
}
753+
643754
int op_unary(struct htp_ops_context * octx) {
644755
int err = HTP_STATUS_OK;
645756

0 commit comments

Comments
 (0)