Skip to content

Commit ad96bb8

Browse files
hexagon: add unary tanh op (#22999)
1 parent e75cd5e commit ad96bb8

4 files changed

Lines changed: 25 additions & 1 deletion

File tree

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2865,6 +2865,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
28652865
case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG;
28662866
case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP;
28672867
case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS;
2868+
case GGML_UNARY_OP_TANH: return HTP_OP_UNARY_TANH;
28682869
default:
28692870
break;
28702871
}
@@ -3335,6 +3336,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
33353336
case GGML_UNARY_OP_EXP:
33363337
case GGML_UNARY_OP_SIGMOID:
33373338
case GGML_UNARY_OP_SOFTPLUS:
3339+
case GGML_UNARY_OP_TANH:
33383340
supp = ggml_hexagon_supported_unary(sess, op);
33393341
break;
33403342
case GGML_UNARY_OP_SILU:

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ enum htp_op_code {
6262
HTP_OP_UNARY_EXP,
6363
HTP_OP_UNARY_NEG,
6464
HTP_OP_UNARY_SOFTPLUS,
65+
HTP_OP_UNARY_TANH,
6566
HTP_OP_GLU_SWIGLU,
6667
HTP_OP_GLU_SWIGLU_OAI,
6768
HTP_OP_GLU_GEGLU,

ggml/src/ggml-hexagon/htp/main.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ static int execute_op(struct htp_ops_context * octx) {
542542
case HTP_OP_UNARY_SIGMOID:
543543
case HTP_OP_UNARY_NEG:
544544
case HTP_OP_UNARY_EXP:
545+
case HTP_OP_UNARY_TANH:
545546
case HTP_OP_L2_NORM:
546547
return op_unary(octx);
547548

ggml/src/ggml-hexagon/htp/unary-ops.c

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,21 @@ static void l2_norm_f32(const float * restrict src,
373373
}
374374
}
375375

376+
static void tanh_f32(const float * restrict src,
377+
float * restrict dst,
378+
uint8_t * restrict spad,
379+
const uint32_t num_rows,
380+
const uint32_t row_elems,
381+
const size_t row_size,
382+
int32_t * op_params) {
383+
for (uint32_t ir = 0; ir < num_rows; ir++) {
384+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
385+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
386+
387+
hvx_tanh_f32_aa(dst_local, src_local, row_elems);
388+
}
389+
}
390+
376391
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
377392
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
378393
struct htp_ops_context * octx = uctx->octx;
@@ -477,6 +492,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
477492
case HTP_OP_UNARY_SOFTPLUS:
478493
softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
479494
break;
495+
case HTP_OP_UNARY_TANH:
496+
tanh_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
497+
break;
480498
case HTP_OP_L2_NORM:
481499
l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
482500
break;
@@ -547,10 +565,12 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
547565
case HTP_OP_UNARY_SOFTPLUS:
548566
op_type = "softplus-f32";
549567
break;
568+
case HTP_OP_UNARY_TANH:
569+
op_type = "tanh-f32";
570+
break;
550571
case HTP_OP_L2_NORM:
551572
op_type = "l2norm-f32";
552573
break;
553-
554574
default:
555575
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
556576
return HTP_STATUS_NO_SUPPORT;

0 commit comments

Comments
 (0)