Skip to content

Commit 0dedb9e

Browse files
hexagon: add support for FILL op (ggml-org#22198)
Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
1 parent 2799d93 commit 0dedb9e

6 files changed

Lines changed: 145 additions & 0 deletions

File tree

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2655,6 +2655,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
26552655
case GGML_OP_ROPE: return HTP_OP_ROPE;
26562656
case GGML_OP_REPEAT: return HTP_OP_REPEAT;
26572657
case GGML_OP_CUMSUM: return HTP_OP_CUMSUM;
2658+
case GGML_OP_FILL: return HTP_OP_FILL;
26582659
case GGML_OP_DIAG: return HTP_OP_DIAG;
26592660

26602661
case GGML_OP_UNARY:
@@ -3053,6 +3054,17 @@ static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * se
30533054
return true;
30543055
}
30553056

3057+
static bool ggml_hexagon_supported_fill(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
3058+
const struct ggml_tensor * dst = op;
3059+
3060+
if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
3061+
return false;
3062+
}
3063+
3064+
GGML_UNUSED(sess);
3065+
return true;
3066+
}
3067+
30563068
static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
30573069
auto sess = static_cast<ggml_hexagon_session *>(dev->context);
30583070

@@ -3183,6 +3195,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
31833195
supp = ggml_hexagon_supported_cumsum(sess, op);
31843196
break;
31853197

3198+
case GGML_OP_FILL:
3199+
supp = ggml_hexagon_supported_fill(sess, op);
3200+
break;
3201+
31863202
case GGML_OP_DIAG:
31873203
supp = ggml_hexagon_supported_diag(sess, op);
31883204
break;

ggml/src/ggml-hexagon/htp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_library(${HTP_LIB} SHARED
3434
argsort-ops.c
3535
ssm-conv.c
3636
cumsum-ops.c
37+
fill-ops.c
3738
diag-ops.c
3839
)
3940

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#pragma clang diagnostic ignored "-Wunused-variable"
2+
#pragma clang diagnostic ignored "-Wunused-function"
3+
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
4+
5+
#include <HAP_farf.h>
6+
#include <HAP_perf.h>
7+
8+
#include <string.h>
9+
10+
#include "hvx-copy.h"
11+
#include "hvx-utils.h"
12+
13+
#define GGML_COMMON_DECL_C
14+
#include "ggml-common.h"
15+
#include "htp-ctx.h"
16+
#include "htp-ops.h"
17+
18+
// ggml op_params layout for FILL:
19+
// op_params[0] (as float) - the scalar fill value
20+
21+
#define fill_preamble \
22+
const struct htp_tensor * dst = octx->dst; \
23+
\
24+
const uint32_t ne0 = dst->ne[0]; \
25+
const uint32_t ne1 = dst->ne[1]; \
26+
const uint32_t ne2 = dst->ne[2]; \
27+
const uint32_t ne3 = dst->ne[3]; \
28+
\
29+
const uint32_t nb1 = dst->nb[1]; \
30+
const uint32_t nb2 = dst->nb[2]; \
31+
const uint32_t nb3 = dst->nb[3]; \
32+
\
33+
const uint32_t nr = ne1 * ne2 * ne3;
34+
35+
struct htp_fill_context {
36+
struct htp_ops_context * octx;
37+
uint32_t nrows_per_thread;
38+
uint32_t total_rows; // ne1 * ne2 * ne3
39+
bool opt_path;
40+
HVX_Vector splat_vec;
41+
uint32_t elem_size;
42+
};
43+
44+
static void fill_thread(unsigned int nth, unsigned int ith, void * data) {
45+
const struct htp_fill_context * fctx = (const struct htp_fill_context *) data;
46+
struct htp_ops_context * octx = fctx->octx;
47+
fill_preamble;
48+
49+
// Parallelise over the flat row index spanning ne1*ne2*ne3
50+
const uint32_t ir0 = fctx->nrows_per_thread * ith;
51+
const uint32_t ir1 = MIN(ir0 + fctx->nrows_per_thread, fctx->total_rows);
52+
53+
uint64_t t1 = HAP_perf_get_qtimer_count();
54+
55+
if (fctx->opt_path) {
56+
// Opt path: tensor is fully contiguous, treat as flat array
57+
const uint32_t elem_start = ir0 * ne0;
58+
const uint32_t elem_end = ir1 * ne0;
59+
uint8_t * dst_ptr = (uint8_t *) dst->data + elem_start * fctx->elem_size;
60+
hvx_splat_u(dst_ptr, fctx->splat_vec, elem_end - elem_start, fctx->elem_size);
61+
} else {
62+
// Non-contiguous path: must respect strides
63+
for (uint32_t ir = ir0; ir < ir1; ++ir) {
64+
const uint32_t i1 = ir % ne1;
65+
const uint32_t i2 = (ir / ne1) % ne2;
66+
const uint32_t i3 = ir / (ne1 * ne2);
67+
uint8_t * dst_ptr = (uint8_t *) dst->data + i1*nb1 + i2*nb2 + i3*nb3;
68+
hvx_splat_u(dst_ptr, fctx->splat_vec, ne0, fctx->elem_size);
69+
}
70+
}
71+
72+
uint64_t t2 = HAP_perf_get_qtimer_count();
73+
FARF(HIGH, "fill %u/%u: rows %u:%u usec %u\n",
74+
ith, nth, ir0, ir1, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
75+
}
76+
77+
int op_fill(struct htp_ops_context * octx) {
78+
fill_preamble;
79+
80+
if (dst->type != HTP_TYPE_F32 && dst->type != HTP_TYPE_F16) {
81+
return HTP_STATUS_NO_SUPPORT;
82+
}
83+
84+
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
85+
return HTP_STATUS_OK;
86+
}
87+
88+
// nr = ne1*ne2*ne3 (flat row count across all outer dims); parallelise over it.
89+
const uint32_t n_threads = MIN(nr, octx->n_threads);
90+
91+
// Optimize if fully contiguous: skip stride arithmetic, treat as flat array
92+
const bool opt_path = (nb2 == nb1 * ne1) && (nb3 == nb2 * ne2);
93+
94+
FARF(HIGH, "fill: (%ux%ux%ux%u) type=%u opt=%d\n",
95+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->type, (int) opt_path);
96+
97+
float val_f32 = 0.f;
98+
memcpy(&val_f32, &octx->op_params[0], sizeof(float));
99+
100+
struct htp_fill_context fctx = {
101+
.octx = octx,
102+
.nrows_per_thread = (nr + n_threads - 1) / n_threads,
103+
.total_rows = nr,
104+
.opt_path = opt_path,
105+
};
106+
107+
switch (dst->type) {
108+
case HTP_TYPE_F32:
109+
fctx.splat_vec = hvx_vec_splat_f32(val_f32);
110+
fctx.elem_size = sizeof(float);
111+
break;
112+
case HTP_TYPE_F16:
113+
fctx.splat_vec = hvx_vec_splat_f16((_Float16) val_f32);
114+
fctx.elem_size = sizeof(_Float16);
115+
break;
116+
default:
117+
return HTP_STATUS_NO_SUPPORT;
118+
}
119+
120+
worker_pool_run_func(octx->ctx->worker_pool, fill_thread, &fctx, n_threads);
121+
122+
return HTP_STATUS_OK;
123+
}

ggml/src/ggml-hexagon/htp/htp-ctx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ int op_repeat(struct htp_ops_context * octx);
9898
int op_argsort(struct htp_ops_context * octx);
9999
int op_ssm_conv(struct htp_ops_context * octx);
100100
int op_cumsum(struct htp_ops_context * octx);
101+
int op_fill(struct htp_ops_context * octx);
101102
int op_diag(struct htp_ops_context * octx);
102103

103104
#endif /* HTP_CTX_H */

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ enum htp_op_code {
8080
HTP_OP_SSM_CONV,
8181
HTP_OP_REPEAT,
8282
HTP_OP_CUMSUM,
83+
HTP_OP_FILL,
8384
HTP_OP_DIAG,
8485

8586
HTP_OP_INVALID

ggml/src/ggml-hexagon/htp/main.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,9 @@ static int execute_op(struct htp_ops_context * octx) {
514514
case HTP_OP_CUMSUM:
515515
return op_cumsum(octx);
516516

517+
case HTP_OP_FILL:
518+
return op_fill(octx);
519+
517520
case HTP_OP_DIAG:
518521
return op_diag(octx);
519522

0 commit comments

Comments
 (0)