|
| 1 | +#pragma clang diagnostic ignored "-Wunused-variable" |
| 2 | +#pragma clang diagnostic ignored "-Wunused-function" |
| 3 | +#pragma clang diagnostic ignored "-Wunused-but-set-variable" |
| 4 | + |
| 5 | +#include <HAP_farf.h> |
| 6 | +#include <HAP_perf.h> |
| 7 | + |
| 8 | +#include <string.h> |
| 9 | + |
| 10 | +#include "hvx-copy.h" |
| 11 | +#include "hvx-utils.h" |
| 12 | + |
| 13 | +#define GGML_COMMON_DECL_C |
| 14 | +#include "ggml-common.h" |
| 15 | +#include "htp-ctx.h" |
| 16 | +#include "htp-ops.h" |
| 17 | + |
| 18 | +// ggml op_params layout for FILL: |
| 19 | +// op_params[0] (as float) - the scalar fill value |
| 20 | + |
| 21 | +#define fill_preamble \ |
| 22 | + const struct htp_tensor * dst = octx->dst; \ |
| 23 | + \ |
| 24 | + const uint32_t ne0 = dst->ne[0]; \ |
| 25 | + const uint32_t ne1 = dst->ne[1]; \ |
| 26 | + const uint32_t ne2 = dst->ne[2]; \ |
| 27 | + const uint32_t ne3 = dst->ne[3]; \ |
| 28 | + \ |
| 29 | + const uint32_t nb1 = dst->nb[1]; \ |
| 30 | + const uint32_t nb2 = dst->nb[2]; \ |
| 31 | + const uint32_t nb3 = dst->nb[3]; \ |
| 32 | + \ |
| 33 | + const uint32_t nr = ne1 * ne2 * ne3; |
| 34 | + |
| 35 | +struct htp_fill_context { |
| 36 | + struct htp_ops_context * octx; |
| 37 | + uint32_t nrows_per_thread; |
| 38 | + uint32_t total_rows; // ne1 * ne2 * ne3 |
| 39 | + bool opt_path; |
| 40 | + HVX_Vector splat_vec; |
| 41 | + uint32_t elem_size; |
| 42 | +}; |
| 43 | + |
| 44 | +static void fill_thread(unsigned int nth, unsigned int ith, void * data) { |
| 45 | + const struct htp_fill_context * fctx = (const struct htp_fill_context *) data; |
| 46 | + struct htp_ops_context * octx = fctx->octx; |
| 47 | + fill_preamble; |
| 48 | + |
| 49 | + // Parallelise over the flat row index spanning ne1*ne2*ne3 |
| 50 | + const uint32_t ir0 = fctx->nrows_per_thread * ith; |
| 51 | + const uint32_t ir1 = MIN(ir0 + fctx->nrows_per_thread, fctx->total_rows); |
| 52 | + |
| 53 | + uint64_t t1 = HAP_perf_get_qtimer_count(); |
| 54 | + |
| 55 | + if (fctx->opt_path) { |
| 56 | + // Opt path: tensor is fully contiguous, treat as flat array |
| 57 | + const uint32_t elem_start = ir0 * ne0; |
| 58 | + const uint32_t elem_end = ir1 * ne0; |
| 59 | + uint8_t * dst_ptr = (uint8_t *) dst->data + elem_start * fctx->elem_size; |
| 60 | + hvx_splat_u(dst_ptr, fctx->splat_vec, elem_end - elem_start, fctx->elem_size); |
| 61 | + } else { |
| 62 | + // Non-contiguous path: must respect strides |
| 63 | + for (uint32_t ir = ir0; ir < ir1; ++ir) { |
| 64 | + const uint32_t i1 = ir % ne1; |
| 65 | + const uint32_t i2 = (ir / ne1) % ne2; |
| 66 | + const uint32_t i3 = ir / (ne1 * ne2); |
| 67 | + uint8_t * dst_ptr = (uint8_t *) dst->data + i1*nb1 + i2*nb2 + i3*nb3; |
| 68 | + hvx_splat_u(dst_ptr, fctx->splat_vec, ne0, fctx->elem_size); |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + uint64_t t2 = HAP_perf_get_qtimer_count(); |
| 73 | + FARF(HIGH, "fill %u/%u: rows %u:%u usec %u\n", |
| 74 | + ith, nth, ir0, ir1, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); |
| 75 | +} |
| 76 | + |
| 77 | +int op_fill(struct htp_ops_context * octx) { |
| 78 | + fill_preamble; |
| 79 | + |
| 80 | + if (dst->type != HTP_TYPE_F32 && dst->type != HTP_TYPE_F16) { |
| 81 | + return HTP_STATUS_NO_SUPPORT; |
| 82 | + } |
| 83 | + |
| 84 | + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { |
| 85 | + return HTP_STATUS_OK; |
| 86 | + } |
| 87 | + |
| 88 | + // nr = ne1*ne2*ne3 (flat row count across all outer dims); parallelise over it. |
| 89 | + const uint32_t n_threads = MIN(nr, octx->n_threads); |
| 90 | + |
| 91 | + // Optimize if fully contiguous: skip stride arithmetic, treat as flat array |
| 92 | + const bool opt_path = (nb2 == nb1 * ne1) && (nb3 == nb2 * ne2); |
| 93 | + |
| 94 | + FARF(HIGH, "fill: (%ux%ux%ux%u) type=%u opt=%d\n", |
| 95 | + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->type, (int) opt_path); |
| 96 | + |
| 97 | + float val_f32 = 0.f; |
| 98 | + memcpy(&val_f32, &octx->op_params[0], sizeof(float)); |
| 99 | + |
| 100 | + struct htp_fill_context fctx = { |
| 101 | + .octx = octx, |
| 102 | + .nrows_per_thread = (nr + n_threads - 1) / n_threads, |
| 103 | + .total_rows = nr, |
| 104 | + .opt_path = opt_path, |
| 105 | + }; |
| 106 | + |
| 107 | + switch (dst->type) { |
| 108 | + case HTP_TYPE_F32: |
| 109 | + fctx.splat_vec = hvx_vec_splat_f32(val_f32); |
| 110 | + fctx.elem_size = sizeof(float); |
| 111 | + break; |
| 112 | + case HTP_TYPE_F16: |
| 113 | + fctx.splat_vec = hvx_vec_splat_f16((_Float16) val_f32); |
| 114 | + fctx.elem_size = sizeof(_Float16); |
| 115 | + break; |
| 116 | + default: |
| 117 | + return HTP_STATUS_NO_SUPPORT; |
| 118 | + } |
| 119 | + |
| 120 | + worker_pool_run_func(octx->ctx->worker_pool, fill_thread, &fctx, n_threads); |
| 121 | + |
| 122 | + return HTP_STATUS_OK; |
| 123 | +} |
0 commit comments