@@ -373,6 +373,21 @@ static void l2_norm_f32(const float * restrict src,
373373 }
374374}
375375
376+ static void tanh_f32 (const float * restrict src ,
377+ float * restrict dst ,
378+ uint8_t * restrict spad ,
379+ const uint32_t num_rows ,
380+ const uint32_t row_elems ,
381+ const size_t row_size ,
382+ int32_t * op_params ) {
383+ for (uint32_t ir = 0 ; ir < num_rows ; ir ++ ) {
384+ const uint8_t * restrict src_local = (const uint8_t * )src + (ir * row_size );
385+ uint8_t * restrict dst_local = (uint8_t * )dst + (ir * row_size );
386+
387+ hvx_tanh_f32_aa (dst_local , src_local , row_elems );
388+ }
389+ }
390+
376391static void unary_job_f32_per_thread (unsigned int nth , unsigned int ith , void * data ) {
377392 const struct htp_unary_context * uctx = (const struct htp_unary_context * ) data ;
378393 struct htp_ops_context * octx = uctx -> octx ;
@@ -477,6 +492,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
477492 case HTP_OP_UNARY_SOFTPLUS :
478493 softplus_f32 (src0_spad , dst_spad , NULL , block_size , ne0 , src0_row_size_aligned , op_params );
479494 break ;
495+ case HTP_OP_UNARY_TANH :
496+ tanh_f32 (src0_spad , dst_spad , NULL , block_size , ne0 , src0_row_size_aligned , op_params );
497+ break ;
480498 case HTP_OP_L2_NORM :
481499 l2_norm_f32 (src0_spad , dst_spad , NULL , block_size , ne0 , src0_row_size_aligned , op_params );
482500 break ;
@@ -547,10 +565,12 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
547565 case HTP_OP_UNARY_SOFTPLUS :
548566 op_type = "softplus-f32" ;
549567 break ;
568+ case HTP_OP_UNARY_TANH :
569+ op_type = "tanh-f32" ;
570+ break ;
550571 case HTP_OP_L2_NORM :
551572 op_type = "l2norm-f32" ;
552573 break ;
553-
554574 default :
555575 FARF (ERROR , "Unsupported unary Op %u\n" , octx -> op );
556576 return HTP_STATUS_NO_SUPPORT ;
0 commit comments