Skip to content

Commit b46812d

Browse files
Feature hexagon l2 norm (ggml-org#22816)
* L2_NORM Updates * Addressed PR Comments * ggml-hexagon: add L2_NORM HVX kernel for Hexagon backend * hex-unary: remove supported_unary_nc since the outer loop is the same for all unary ops --------- Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
1 parent 4995604 commit b46812d

4 files changed

Lines changed: 91 additions & 2 deletions

File tree

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2420,8 +2420,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
24202420
return false;
24212421
}
24222422

2423-
// TODO: add support for non-contiguous elements within a row
2424-
if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) {
2423+
// dst must be contiguous; src0 may be non-contiguous
2424+
if (!ggml_is_contiguous(dst)) {
24252425
return false;
24262426
}
24272427

@@ -2791,6 +2791,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
27912791
case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS;
27922792
case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS;
27932793
case GGML_OP_ARGSORT: return HTP_OP_ARGSORT;
2794+
case GGML_OP_L2_NORM: return HTP_OP_L2_NORM;
27942795
case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM;
27952796
case GGML_OP_SCALE: return HTP_OP_SCALE;
27962797
case GGML_OP_SQR: return HTP_OP_SQR;
@@ -3253,6 +3254,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
32533254
supp = ggml_hexagon_supported_add_id(sess, op);
32543255
break;
32553256

3257+
case GGML_OP_L2_NORM:
3258+
supp = ggml_hexagon_supported_unary(sess, op);
3259+
break;
3260+
32563261
case GGML_OP_RMS_NORM:
32573262
case GGML_OP_SCALE:
32583263
supp = ggml_hexagon_supported_unary(sess, op);

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ enum htp_op_code {
8383
HTP_OP_FILL,
8484
HTP_OP_DIAG,
8585
HTP_OP_SOLVE_TRI,
86+
HTP_OP_L2_NORM,
87+
8688
HTP_OP_INVALID
8789
};
8890

ggml/src/ggml-hexagon/htp/main.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ static int execute_op(struct htp_ops_context * octx) {
542542
case HTP_OP_UNARY_SIGMOID:
543543
case HTP_OP_UNARY_NEG:
544544
case HTP_OP_UNARY_EXP:
545+
case HTP_OP_L2_NORM:
545546
return op_unary(octx);
546547

547548
case HTP_OP_UNARY_SILU:

ggml/src/ggml-hexagon/htp/unary-ops.c

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,81 @@ static void softplus_f32(const float * restrict src,
298298
}
299299
}
300300

301+
// --- L2_NORM HVX kernel ---
302+
// Computes y[i] = x[i] / fmax(sqrt(sum(x[j]^2)), epsilon) for each row.
303+
// scale = 1/fmax(sqrt(sum), epsilon) is computed entirely in HVX registers
304+
// using rsqrt + inverse to avoid scalar extraction.
305+
static void hvx_fast_l2_norm_f32(const uint8_t * restrict src,
306+
uint8_t * restrict dst,
307+
uint8_t * restrict pad,
308+
const int num_elems,
309+
float epsilon) {
310+
(void)pad;
311+
312+
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
313+
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
314+
315+
HVX_Vector sum_v = hvx_vec_splat_f32(0.0f);
316+
317+
const int nvec = num_elems / VLEN_FP32;
318+
const int nloe = num_elems % VLEN_FP32;
319+
320+
#pragma unroll(4)
321+
for (int i = 0; i < nvec; i++) {
322+
HVX_Vector v1 = v_src[i];
323+
HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
324+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
325+
}
326+
327+
// Include tail elements in the sum-of-squares using a predicate mask
328+
if (nloe > 0) {
329+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
330+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
331+
HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
332+
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq);
333+
}
334+
335+
// Compute scale = 1/fmax(sqrt(sum), epsilon) entirely in HVX registers.
336+
// hvx_vec_rsqrt_f32 + hvx_vec_inverse_f32 avoids scalar extraction.
337+
HVX_Vector sum_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
338+
HVX_Vector rsqrt_v = hvx_vec_rsqrt_f32(sum_sf); // 1/sqrt(sum)
339+
HVX_Vector sqrt_v = hvx_vec_inverse_f32(rsqrt_v); // sqrt(sum)
340+
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
341+
HVX_Vector denom_v = Q6_Vsf_vmax_VsfVsf(sqrt_v, epsilon_v); // fmax(sqrt(sum), epsilon)
342+
HVX_Vector scale_v = hvx_vec_inverse_f32(denom_v); // 1/fmax(sqrt(sum), epsilon)
343+
344+
#pragma unroll(4)
345+
for (int i = 0; i < nvec; i++) {
346+
HVX_Vector v1 = v_src[i];
347+
v_dst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
348+
}
349+
350+
if (nloe > 0) {
351+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
352+
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
353+
HVX_Vector result = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v));
354+
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
355+
}
356+
}
357+
358+
static void l2_norm_f32(const float * restrict src,
359+
float * restrict dst,
360+
uint8_t * restrict spad,
361+
const uint32_t num_rows,
362+
const uint32_t row_elems,
363+
const size_t row_size,
364+
int32_t * op_params) {
365+
float epsilon = 0.f;
366+
memcpy(&epsilon, op_params, sizeof(float));
367+
368+
for (uint32_t ir = 0; ir < num_rows; ir++) {
369+
const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
370+
float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
371+
372+
hvx_fast_l2_norm_f32((const uint8_t *)src_f, (uint8_t *)dst_f, spad, row_elems, epsilon);
373+
}
374+
}
375+
301376
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
302377
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
303378
struct htp_ops_context * octx = uctx->octx;
@@ -402,6 +477,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
402477
case HTP_OP_UNARY_SOFTPLUS:
403478
softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
404479
break;
480+
case HTP_OP_L2_NORM:
481+
l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
482+
break;
405483
default:
406484
break;
407485
}
@@ -469,6 +547,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
469547
case HTP_OP_UNARY_SOFTPLUS:
470548
op_type = "softplus-f32";
471549
break;
550+
case HTP_OP_L2_NORM:
551+
op_type = "l2norm-f32";
552+
break;
472553

473554
default:
474555
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);

0 commit comments

Comments
 (0)