@@ -298,6 +298,81 @@ static void softplus_f32(const float * restrict src,
298298 }
299299}
300300
301+ // --- L2_NORM HVX kernel ---
302+ // Computes y[i] = x[i] / fmax(sqrt(sum(x[j]^2)), epsilon) for each row.
303+ // scale = 1/fmax(sqrt(sum), epsilon) is computed entirely in HVX registers
304+ // using rsqrt + inverse to avoid scalar extraction.
305+ static void hvx_fast_l2_norm_f32 (const uint8_t * restrict src ,
306+ uint8_t * restrict dst ,
307+ uint8_t * restrict pad ,
308+ const int num_elems ,
309+ float epsilon ) {
310+ (void )pad ;
311+
312+ const HVX_Vector * restrict v_src = (HVX_Vector * ) src ;
313+ HVX_Vector * restrict v_dst = (HVX_Vector * ) dst ;
314+
315+ HVX_Vector sum_v = hvx_vec_splat_f32 (0.0f );
316+
317+ const int nvec = num_elems / VLEN_FP32 ;
318+ const int nloe = num_elems % VLEN_FP32 ;
319+
320+ #pragma unroll(4)
321+ for (int i = 0 ; i < nvec ; i ++ ) {
322+ HVX_Vector v1 = v_src [i ];
323+ HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf (v1 , v1 );
324+ sum_v = Q6_Vqf32_vadd_Vqf32Vqf32 (sum_v , sq );
325+ }
326+
327+ // Include tail elements in the sum-of-squares using a predicate mask
328+ if (nloe > 0 ) {
329+ HVX_VectorPred bmask = Q6_Q_vsetq_R (nloe * 4 );
330+ HVX_Vector v1 = Q6_V_vand_QV (bmask , v_src [nvec ]);
331+ HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf (v1 , v1 );
332+ sum_v = Q6_Vqf32_vadd_Vqf32Vqf32 (sum_v , sq );
333+ }
334+
335+ // Compute scale = 1/fmax(sqrt(sum), epsilon) entirely in HVX registers.
336+ // hvx_vec_rsqrt_f32 + hvx_vec_inverse_f32 avoids scalar extraction.
337+ HVX_Vector sum_sf = hvx_vec_reduce_sum_f32 (Q6_Vsf_equals_Vqf32 (sum_v ));
338+ HVX_Vector rsqrt_v = hvx_vec_rsqrt_f32 (sum_sf ); // 1/sqrt(sum)
339+ HVX_Vector sqrt_v = hvx_vec_inverse_f32 (rsqrt_v ); // sqrt(sum)
340+ HVX_Vector epsilon_v = hvx_vec_splat_f32 (epsilon );
341+ HVX_Vector denom_v = Q6_Vsf_vmax_VsfVsf (sqrt_v , epsilon_v ); // fmax(sqrt(sum), epsilon)
342+ HVX_Vector scale_v = hvx_vec_inverse_f32 (denom_v ); // 1/fmax(sqrt(sum), epsilon)
343+
344+ #pragma unroll(4)
345+ for (int i = 0 ; i < nvec ; i ++ ) {
346+ HVX_Vector v1 = v_src [i ];
347+ v_dst [i ] = Q6_Vsf_equals_Vqf32 (Q6_Vqf32_vmpy_VsfVsf (v1 , scale_v ));
348+ }
349+
350+ if (nloe > 0 ) {
351+ HVX_VectorPred bmask = Q6_Q_vsetq_R (nloe * 4 );
352+ HVX_Vector v1 = Q6_V_vand_QV (bmask , v_src [nvec ]);
353+ HVX_Vector result = Q6_Vsf_equals_Vqf32 (Q6_Vqf32_vmpy_VsfVsf (v1 , scale_v ));
354+ hvx_vec_store_a (& v_dst [nvec ], nloe * 4 , result );
355+ }
356+ }
357+
358+ static void l2_norm_f32 (const float * restrict src ,
359+ float * restrict dst ,
360+ uint8_t * restrict spad ,
361+ const uint32_t num_rows ,
362+ const uint32_t row_elems ,
363+ const size_t row_size ,
364+ int32_t * op_params ) {
365+ float epsilon = 0.f ;
366+ memcpy (& epsilon , op_params , sizeof (float ));
367+
368+ for (uint32_t ir = 0 ; ir < num_rows ; ir ++ ) {
369+ const float * restrict src_f = (const float * )((const uint8_t * )src + (ir * row_size ));
370+ float * restrict dst_f = (float * )((uint8_t * )dst + (ir * row_size ));
371+
372+ hvx_fast_l2_norm_f32 ((const uint8_t * )src_f , (uint8_t * )dst_f , spad , row_elems , epsilon );
373+ }
374+ }
375+
301376static void unary_job_f32_per_thread (unsigned int nth , unsigned int ith , void * data ) {
302377 const struct htp_unary_context * uctx = (const struct htp_unary_context * ) data ;
303378 struct htp_ops_context * octx = uctx -> octx ;
@@ -402,6 +477,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
402477 case HTP_OP_UNARY_SOFTPLUS :
403478 softplus_f32 (src0_spad , dst_spad , NULL , block_size , ne0 , src0_row_size_aligned , op_params );
404479 break ;
480+ case HTP_OP_L2_NORM :
481+ l2_norm_f32 (src0_spad , dst_spad , NULL , block_size , ne0 , src0_row_size_aligned , op_params );
482+ break ;
405483 default :
406484 break ;
407485 }
@@ -469,6 +547,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
469547 case HTP_OP_UNARY_SOFTPLUS :
470548 op_type = "softplus-f32" ;
471549 break ;
550+ case HTP_OP_L2_NORM :
551+ op_type = "l2norm-f32" ;
552+ break ;
472553
473554 default :
474555 FARF (ERROR , "Unsupported unary Op %u\n" , octx -> op );
0 commit comments