|
17 | 17 | #include "ggml-common.h" |
18 | 18 | #include "htp-ctx.h" |
19 | 19 | #include "htp-ops.h" |
20 | | -#include "htp-ops.h" |
21 | 20 |
|
22 | 21 | struct htp_unary_context { |
23 | 22 | struct htp_ops_context * octx; |
@@ -277,6 +276,95 @@ static void sigmoid_f32(const float * restrict src, |
277 | 276 | } |
278 | 277 | } |
279 | 278 |
|
| 279 | +static void tri_f32(const float * restrict src, |
| 280 | + float * restrict dst, |
| 281 | + uint8_t * restrict spad, |
| 282 | + const uint32_t num_rows, |
| 283 | + const uint32_t row_elems, |
| 284 | + const size_t row_size, |
| 285 | + int32_t * op_params, |
| 286 | + const uint32_t ir, |
| 287 | + const struct htp_unary_context * uctx) { |
| 288 | + |
| 289 | + const int32_t ttype = op_params[0]; |
| 290 | + const HVX_Vector zero = hvx_vec_splat_f32(0.0f); |
| 291 | + const uint32_t nvec = row_elems / VLEN_FP32; |
| 292 | + const uint32_t nloe = row_elems % VLEN_FP32; |
| 293 | + |
| 294 | + const uint32_t ne01 = uctx->octx->src[0]->ne[1]; |
| 295 | + |
| 296 | + for (uint32_t b = 0; b < num_rows; b++) { |
| 297 | + const uint32_t abs_row = ir + b; |
| 298 | + const uint32_t i01 = abs_row % ne01; |
| 299 | + |
| 300 | + const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size); |
| 301 | + HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size); |
| 302 | + |
| 303 | + uint32_t boundary; |
| 304 | + int keep_left; |
| 305 | + switch (ttype) { |
| 306 | + case 0: boundary = i01; keep_left = 0; break; // keep col >= row |
| 307 | + case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row |
| 308 | + case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row |
| 309 | + case 3: boundary = i01; keep_left = 1; break; // keep col < row |
| 310 | + default: boundary = 0; keep_left = 0; break; |
| 311 | + } |
| 312 | + if (boundary > row_elems) boundary = row_elems; |
| 313 | + |
| 314 | + // Full HVX vectors — each starts at a 128-byte aligned offset |
| 315 | + for (uint32_t i = 0; i < nvec; i++) { |
| 316 | + const uint32_t vec_start = i * VLEN_FP32; |
| 317 | + const uint32_t vec_end = vec_start + VLEN_FP32; |
| 318 | + if (keep_left) { |
| 319 | + if (vec_end <= boundary) { |
| 320 | + v_dst[i] = v_src[i]; |
| 321 | + } else if (vec_start >= boundary) { |
| 322 | + v_dst[i] = zero; |
| 323 | + } else { |
| 324 | + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); |
| 325 | + v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero); |
| 326 | + } |
| 327 | + } else { |
| 328 | + if (vec_end <= boundary) { |
| 329 | + v_dst[i] = zero; |
| 330 | + } else if (vec_start >= boundary) { |
| 331 | + v_dst[i] = v_src[i]; |
| 332 | + } else { |
| 333 | + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); |
| 334 | + v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]); |
| 335 | + } |
| 336 | + } |
| 337 | + } |
| 338 | + |
| 339 | + // Tail elements (row_elems not a multiple of VLEN_FP32) |
| 340 | + if (nloe > 0) { |
| 341 | + const uint32_t vec_start = nvec * VLEN_FP32; |
| 342 | + const uint32_t vec_end = vec_start + nloe; |
| 343 | + HVX_Vector tail_val; |
| 344 | + if (keep_left) { |
| 345 | + if (vec_end <= boundary) { |
| 346 | + tail_val = v_src[nvec]; |
| 347 | + } else if (vec_start >= boundary) { |
| 348 | + tail_val = zero; |
| 349 | + } else { |
| 350 | + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); |
| 351 | + tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero); |
| 352 | + } |
| 353 | + } else { |
| 354 | + if (vec_end <= boundary) { |
| 355 | + tail_val = zero; |
| 356 | + } else if (vec_start >= boundary) { |
| 357 | + tail_val = v_src[nvec]; |
| 358 | + } else { |
| 359 | + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); |
| 360 | + tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]); |
| 361 | + } |
| 362 | + } |
| 363 | + hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val); |
| 364 | + } |
| 365 | + } |
| 366 | +} |
| 367 | + |
280 | 368 | static void softplus_f32(const float * restrict src, |
281 | 369 | float * restrict dst, |
282 | 370 | uint8_t * restrict spad, |
@@ -498,6 +586,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * |
498 | 586 | case HTP_OP_L2_NORM: |
499 | 587 | l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); |
500 | 588 | break; |
| 589 | + case HTP_OP_TRI: |
| 590 | + tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx); |
| 591 | + break; |
501 | 592 | default: |
502 | 593 | break; |
503 | 594 | } |
@@ -571,6 +662,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { |
571 | 662 | case HTP_OP_L2_NORM: |
572 | 663 | op_type = "l2norm-f32"; |
573 | 664 | break; |
| 665 | + case HTP_OP_TRI: |
| 666 | + op_type = "tri-f32"; |
| 667 | + break; |
| 668 | + |
574 | 669 | default: |
575 | 670 | FARF(ERROR, "Unsupported unary Op %u\n", octx->op); |
576 | 671 | return HTP_STATUS_NO_SUPPORT; |
@@ -640,6 +735,22 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { |
640 | 735 | return err; |
641 | 736 | } |
642 | 737 |
|
| 738 | +int op_tri(struct htp_ops_context * octx) { |
| 739 | + int err = HTP_STATUS_OK; |
| 740 | + |
| 741 | + switch (octx->src[0]->type) { |
| 742 | + case HTP_TYPE_F32: |
| 743 | + err = execute_op_unary_f32(octx); |
| 744 | + break; |
| 745 | + |
| 746 | + default: |
| 747 | + err = HTP_STATUS_NO_SUPPORT; |
| 748 | + break; |
| 749 | + } |
| 750 | + |
| 751 | + return err; |
| 752 | +} |
| 753 | + |
643 | 754 | int op_unary(struct htp_ops_context * octx) { |
644 | 755 | int err = HTP_STATUS_OK; |
645 | 756 |
|
|
0 commit comments