@@ -158,6 +158,79 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
158158 }
159159}
160160
161+ static void hvx_fast_norm_f32 (const uint8_t * restrict src ,
162+ uint8_t * restrict dst ,
163+ uint8_t * restrict pad ,
164+ const int num_elems ,
165+ float epsilon ) {
166+ (void )pad ;
167+
168+ const HVX_Vector * restrict v_src = (HVX_Vector * ) src ;
169+ HVX_Vector * restrict v_dst = (HVX_Vector * ) dst ;
170+
171+ const int nvec = num_elems / VLEN_FP32 ; // number of full vectors
172+ const int nloe = num_elems % VLEN_FP32 ; // leftover elements
173+
174+ // Compute sum of squares and sum of values for full vectors
175+ HVX_Vector sum_sq_v = Q6_V_vsplat_R (0x00000000 );
176+ HVX_Vector sum_x_v = Q6_V_vsplat_R (0x00000000 );
177+ HVX_Vector epsilon_v = hvx_vec_splat_f32 (epsilon );
178+
179+ #pragma unroll(4)
180+ for (int i = 0 ; i < nvec ; i ++ ) {
181+ HVX_Vector v1 = v_src [i ];
182+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf (v1 , v1 );
183+ sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32 (sum_sq_v , v2 );
184+ sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32 (sum_x_v , Q6_Vqf32_vadd_VsfVsf (v1 , Q6_V_vzero ()));
185+ }
186+
187+ // Handle tail elements using vectorized ops with masking
188+ if (nloe > 0 ) {
189+ HVX_VectorPred bmask = Q6_Q_vsetq_R (nloe * 4 );
190+ HVX_Vector v1 = Q6_V_vand_QV (bmask , v_src [nvec ]);
191+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf (v1 , v1 );
192+ sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32 (sum_sq_v , v2 );
193+ sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32 (sum_x_v , Q6_Vqf32_vadd_VsfVsf (v1 , Q6_V_vzero ()));
194+ }
195+
196+ // Reduce HVX sums
197+ sum_sq_v = hvx_vec_reduce_sum_f32 (Q6_Vsf_equals_Vqf32 (sum_sq_v ));
198+ sum_x_v = hvx_vec_reduce_sum_f32 (Q6_Vsf_equals_Vqf32 (sum_x_v ));
199+
200+ HVX_Vector t_v = hvx_vec_splat_f32 ((float ) num_elems );
201+ HVX_Vector denom_v = hvx_vec_inverse_f32 (t_v );
202+ HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf (sum_sq_v , denom_v );
203+ HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf (sum_x_v , denom_v );
204+ HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf (Q6_Vsf_equals_Vqf32 (mean_x_v ), Q6_Vsf_equals_Vqf32 (mean_x_v ));
205+ HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32 (mean_sq_v , mean_x_sq_v );
206+ HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf (var_v , epsilon_v );
207+
208+ // scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction
209+ HVX_Vector scale_v = hvx_vec_rsqrt_f32 (Q6_Vsf_equals_Vqf32 (var_epsilon_v ));
210+ HVX_Vector mean_x_b = hvx_vec_splat_f32 (hvx_vec_get_f32 (Q6_Vsf_equals_Vqf32 (mean_x_v )));
211+
212+ #pragma unroll(4)
213+ for (int i = 0 ; i < nvec ; i ++ ) {
214+ HVX_Vector v1 = v_src [i ];
215+ HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf (v1 , mean_x_b );
216+ HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf (Q6_Vsf_equals_Vqf32 (v2 ), scale_v );
217+ v_dst [i ] = Q6_Vsf_equals_Vqf32 (v3 );
218+ }
219+
220+ // Handle tail elements using vectorized ops with masking
221+ if (nloe > 0 ) {
222+
223+ HVX_VectorPred bmask = Q6_Q_vsetq_R (nloe * 4 );
224+ HVX_Vector v1 = Q6_V_vand_QV (bmask , v_src [nvec ]);
225+ HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf (v1 , mean_x_b );
226+ HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf (Q6_Vsf_equals_Vqf32 (v2 ), scale_v );
227+ HVX_Vector result = Q6_Vsf_equals_Vqf32 (v3 );
228+
229+ // Store with masking to avoid overwriting memory beyond the tensor
230+ hvx_vec_store_a (& v_dst [nvec ], nloe * 4 , result );
231+ }
232+ }
233+
161234static void scale_f32 (const float * restrict src ,
162235 float * restrict dst ,
163236 uint8_t * restrict spad ,
@@ -196,6 +269,24 @@ static void rms_norm_f32(const float * restrict src,
196269 }
197270}
198271
272+ static void norm_f32 (const float * restrict src ,
273+ float * restrict dst ,
274+ uint8_t * restrict spad ,
275+ const uint32_t num_rows ,
276+ const uint32_t row_elems ,
277+ const size_t row_size ,
278+ int32_t * op_params ) {
279+ float epsilon = 0.f ;
280+ memcpy (& epsilon , op_params , sizeof (float ));
281+
282+ for (uint32_t ir = 0 ; ir < num_rows ; ir ++ ) {
283+ const uint8_t * restrict src_local = (const uint8_t * )src + (ir * row_size );
284+ uint8_t * restrict dst_local = (uint8_t * )dst + (ir * row_size );
285+
286+ hvx_fast_norm_f32 ((const uint8_t * ) src_local , (uint8_t * ) dst_local , spad , row_elems , epsilon );
287+ }
288+ }
289+
199290static void sqr_f32 (const float * restrict src ,
200291 float * restrict dst ,
201292 uint8_t * restrict spad ,
@@ -556,6 +647,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
556647
557648 // Process block in VTCM
558649 switch (htp_op ) {
650+ case HTP_OP_NORM :
651+ norm_f32 (src0_spad , dst_spad , NULL , block_size , ne0 , src0_row_size_aligned , op_params );
652+ break ;
559653 case HTP_OP_RMS_NORM :
560654 rms_norm_f32 (src0_spad , dst_spad , NULL , block_size , ne0 , src0_row_size_aligned , op_params );
561655 break ;
@@ -632,6 +726,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
632726 const char * op_type = NULL ;
633727
634728 switch (octx -> op ) {
729+ case HTP_OP_NORM :
730+ op_type = "norm-f32" ;
731+ break ;
635732 case HTP_OP_RMS_NORM :
636733 op_type = "rmsnorm-f32" ;
637734 break ;
0 commit comments