Skip to content

Commit ac76808

Browse files
authored
hexagon: enable support for NORM op (ggml-org#23319)
1 parent baf3cc6 commit ac76808

4 files changed

Lines changed: 101 additions & 3 deletions

File tree

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2870,6 +2870,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
28702870
case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS;
28712871
case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS;
28722872
case GGML_OP_ARGSORT: return HTP_OP_ARGSORT;
2873+
case GGML_OP_NORM: return HTP_OP_NORM;
28732874
case GGML_OP_L2_NORM: return HTP_OP_L2_NORM;
28742875
case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM;
28752876
case GGML_OP_SCALE: return HTP_OP_SCALE;
@@ -3338,10 +3339,8 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
33383339
supp = ggml_hexagon_supported_add_id(sess, op);
33393340
break;
33403341

3342+
case GGML_OP_NORM:
33413343
case GGML_OP_L2_NORM:
3342-
supp = ggml_hexagon_supported_unary(sess, op);
3343-
break;
3344-
33453344
case GGML_OP_RMS_NORM:
33463345
case GGML_OP_SCALE:
33473346
supp = ggml_hexagon_supported_unary(sess, op);

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ enum htp_op_code {
8888
HTP_OP_GATED_DELTA_NET,
8989
HTP_OP_TRI,
9090
HTP_OP_PAD,
91+
HTP_OP_NORM,
9192

9293
HTP_OP_INVALID
9394
};

ggml/src/ggml-hexagon/htp/main.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ static int execute_op(struct htp_ops_context * octx) {
534534
case HTP_OP_ADD_ID:
535535
return op_binary(octx);
536536

537+
case HTP_OP_NORM:
537538
case HTP_OP_RMS_NORM:
538539
case HTP_OP_SCALE:
539540
case HTP_OP_SQR:

ggml/src/ggml-hexagon/htp/unary-ops.c

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,79 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
158158
}
159159
}
160160

161+
static void hvx_fast_norm_f32(const uint8_t * restrict src,
162+
uint8_t * restrict dst,
163+
uint8_t * restrict pad,
164+
const int num_elems,
165+
float epsilon) {
166+
(void)pad;
167+
168+
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
169+
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
170+
171+
const int nvec = num_elems / VLEN_FP32; // number of full vectors
172+
const int nloe = num_elems % VLEN_FP32; // leftover elements
173+
174+
// Compute sum of squares and sum of values for full vectors
175+
HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000);
176+
HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000);
177+
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
178+
179+
#pragma unroll(4)
180+
for (int i = 0; i < nvec; i++) {
181+
HVX_Vector v1 = v_src[i];
182+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
183+
sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
184+
sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
185+
}
186+
187+
// Handle tail elements using vectorized ops with masking
188+
if (nloe > 0) {
189+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
190+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
191+
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
192+
sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2);
193+
sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero()));
194+
}
195+
196+
// Reduce HVX sums
197+
sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v));
198+
sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v));
199+
200+
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
201+
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
202+
HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v);
203+
HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v);
204+
HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v));
205+
HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v);
206+
HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v);
207+
208+
// scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction
209+
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v));
210+
HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v)));
211+
212+
#pragma unroll(4)
213+
for (int i = 0; i < nvec; i++) {
214+
HVX_Vector v1 = v_src[i];
215+
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
216+
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
217+
v_dst[i] = Q6_Vsf_equals_Vqf32(v3);
218+
}
219+
220+
// Handle tail elements using vectorized ops with masking
221+
if (nloe > 0) {
222+
223+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
224+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
225+
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b);
226+
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v);
227+
HVX_Vector result = Q6_Vsf_equals_Vqf32(v3);
228+
229+
// Store with masking to avoid overwriting memory beyond the tensor
230+
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
231+
}
232+
}
233+
161234
static void scale_f32(const float * restrict src,
162235
float * restrict dst,
163236
uint8_t * restrict spad,
@@ -196,6 +269,24 @@ static void rms_norm_f32(const float * restrict src,
196269
}
197270
}
198271

272+
static void norm_f32(const float * restrict src,
273+
float * restrict dst,
274+
uint8_t * restrict spad,
275+
const uint32_t num_rows,
276+
const uint32_t row_elems,
277+
const size_t row_size,
278+
int32_t * op_params) {
279+
float epsilon = 0.f;
280+
memcpy(&epsilon, op_params, sizeof(float));
281+
282+
for (uint32_t ir = 0; ir < num_rows; ir++) {
283+
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
284+
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
285+
286+
hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
287+
}
288+
}
289+
199290
static void sqr_f32(const float * restrict src,
200291
float * restrict dst,
201292
uint8_t * restrict spad,
@@ -556,6 +647,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
556647

557648
// Process block in VTCM
558649
switch (htp_op) {
650+
case HTP_OP_NORM:
651+
norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
652+
break;
559653
case HTP_OP_RMS_NORM:
560654
rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
561655
break;
@@ -632,6 +726,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
632726
const char * op_type = NULL;
633727

634728
switch (octx->op) {
729+
case HTP_OP_NORM:
730+
op_type = "norm-f32";
731+
break;
635732
case HTP_OP_RMS_NORM:
636733
op_type = "rmsnorm-f32";
637734
break;

0 commit comments

Comments
 (0)