Skip to content

Commit 5a4cd67

Browse files
authored
Hexagon: DAIG op (ggml-org#22195)
* hexagon: Add DIAG op * hexagon: add HVX support and DMA double buffering * hexagon: fix fatal error * hexagon: remove as many pragma(s) as possible
1 parent 2248799 commit 5a4cd67

6 files changed

Lines changed: 250 additions & 0 deletions

File tree

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2596,6 +2596,29 @@ static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * se
25962596
return true;
25972597
}
25982598

2599+
static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2600+
const struct ggml_tensor * src0 = op->src[0];
2601+
const struct ggml_tensor * dst = op;
2602+
2603+
// diag only supports F32 currently
2604+
if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
2605+
return false;
2606+
}
2607+
2608+
// Input must have ne[1] == 1 (vector input)
2609+
if (src0->ne[1] != 1) {
2610+
return false;
2611+
}
2612+
2613+
// Output must be square in first two dimensions
2614+
if (dst->ne[0] != dst->ne[1] || dst->ne[0] != src0->ne[0]) {
2615+
return false;
2616+
}
2617+
2618+
GGML_UNUSED(sess);
2619+
return true;
2620+
}
2621+
25992622
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
26002623
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
26012624
return sess->c_name();
@@ -2632,6 +2655,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
26322655
case GGML_OP_ROPE: return HTP_OP_ROPE;
26332656
case GGML_OP_REPEAT: return HTP_OP_REPEAT;
26342657
case GGML_OP_CUMSUM: return HTP_OP_CUMSUM;
2658+
case GGML_OP_DIAG: return HTP_OP_DIAG;
26352659

26362660
case GGML_OP_UNARY:
26372661
switch (ggml_get_unary_op(t)) {
@@ -3159,6 +3183,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
31593183
supp = ggml_hexagon_supported_cumsum(sess, op);
31603184
break;
31613185

3186+
case GGML_OP_DIAG:
3187+
supp = ggml_hexagon_supported_diag(sess, op);
3188+
break;
3189+
31623190
default:
31633191
break;
31643192
}

ggml/src/ggml-hexagon/htp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_library(${HTP_LIB} SHARED
3434
argsort-ops.c
3535
ssm-conv.c
3636
cumsum-ops.c
37+
diag-ops.c
3738
)
3839

3940
target_compile_definitions(${HTP_LIB} PRIVATE
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
2+
3+
#include <HAP_farf.h>
4+
#include <HAP_perf.h>
5+
6+
#define GGML_COMMON_DECL_C
7+
#include "ggml-common.h"
8+
#include "htp-ctx.h"
9+
#include "htp-ops.h"
10+
#include "hvx-types.h"
11+
#include "hex-utils.h"
12+
#include "hvx-copy.h"
13+
#include "hex-dma.h"
14+
15+
#define htp_diag_tensors_preamble \
16+
const struct htp_tensor * restrict src0 = octx->src[0]; \
17+
const struct htp_tensor * restrict dst = octx->dst; \
18+
\
19+
const uint32_t ne02 = src0->ne[2]; \
20+
\
21+
const uint32_t ne0 = dst->ne[0]; \
22+
const uint32_t ne1 = dst->ne[1]; \
23+
\
24+
const uint32_t nb02 = src0->nb[2]; \
25+
const uint32_t nb03 = src0->nb[3]; \
26+
\
27+
const uint32_t nb1 = dst->nb[1]; \
28+
const uint32_t nb2 = dst->nb[2]; \
29+
const uint32_t nb3 = dst->nb[3];
30+
31+
struct htp_diag_context {
32+
struct htp_ops_context * octx;
33+
size_t src_batch_size;
34+
size_t dst_row_size;
35+
size_t src_batch_size_aligned;
36+
size_t dst_row_size_aligned;
37+
uint32_t batches_per_thread;
38+
uint32_t total_batches;
39+
};
40+
41+
#define htp_diag_preamble \
42+
struct htp_diag_context * dctx = (struct htp_diag_context *) data; \
43+
struct htp_ops_context * octx = dctx->octx; \
44+
htp_diag_tensors_preamble;
45+
46+
static inline void hvx_diag_row_f32(const float * restrict src, float * restrict dst,
47+
uint32_t row_idx, uint32_t n) {
48+
hvx_splat_f32_a((uint8_t *) dst, 0.0f, n);
49+
dst[row_idx] = src[row_idx];
50+
}
51+
52+
// ---------------------------------------------------------------------------
53+
// Per thread worker: DMA src fetch, compute in VTCM, DMA dst writeback
54+
// ---------------------------------------------------------------------------
55+
56+
static void diag_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) {
57+
htp_diag_preamble;
58+
dma_queue * dma_queue = octx->ctx->dma[ith];
59+
60+
uint64_t t1, t2;
61+
t1 = HAP_perf_get_qtimer_count();
62+
63+
const uint32_t ib0 = dctx->batches_per_thread * ith;
64+
const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches);
65+
66+
if (ib0 >= ib1) {
67+
return;
68+
}
69+
70+
const size_t src_batch_size = dctx->src_batch_size;
71+
const size_t dst_row_size = dctx->dst_row_size;
72+
const size_t src_batch_size_aligned = dctx->src_batch_size_aligned;
73+
const size_t dst_row_size_aligned = dctx->dst_row_size_aligned;
74+
75+
const uint8_t * src_data = (const uint8_t *) src0->data;
76+
uint8_t * dst_data = (uint8_t *) dst->data;
77+
78+
// 1 src buffer + 1 dst row buffer per thread in VTCM
79+
uint8_t * src_spad = octx->src0_spad.data + (ith * src_batch_size_aligned);
80+
uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned);
81+
82+
for (uint32_t ib = ib0; ib < ib1; ib++) {
83+
const uint32_t i3 = ib / ne02;
84+
const uint32_t i2 = ib % ne02;
85+
86+
const uint8_t * src_batch = src_data + i3 * nb03 + i2 * nb02;
87+
88+
// Fetch source vector into VTCM
89+
dma_queue_push_ddr_to_vtcm(dma_queue,
90+
dma_make_ptr(src_spad, src_batch),
91+
src_batch_size_aligned, src_batch_size, 1);
92+
dma_queue_flush(dma_queue);
93+
94+
const float * src_spad_f32 = (const float *) src_spad;
95+
float * dst_spad_f32 = (float *) dst_spad;
96+
97+
for (uint32_t i1 = 0; i1 < ne1; i1++) {
98+
// Compute row in VTCM
99+
hvx_diag_row_f32(src_spad_f32, dst_spad_f32, i1, ne0);
100+
101+
// Write completed row back to DDR
102+
uint8_t * dst_row = dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1;
103+
dma_queue_push_vtcm_to_ddr(dma_queue,
104+
dma_make_ptr(dst_row, dst_spad),
105+
dst_row_size, dst_row_size_aligned, 1);
106+
dma_queue_flush(dma_queue);
107+
}
108+
}
109+
110+
t2 = HAP_perf_get_qtimer_count();
111+
112+
FARF(HIGH, "diag-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n",
113+
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1,
114+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
115+
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
116+
}
117+
118+
// ---------------------------------------------------------------------------
119+
// Per thread worker: Direct HVX (no DMA)
120+
// ---------------------------------------------------------------------------
121+
122+
static void diag_thread_f32(unsigned int nth, unsigned int ith, void * data) {
123+
htp_diag_preamble;
124+
125+
uint64_t t1, t2;
126+
t1 = HAP_perf_get_qtimer_count();
127+
128+
const uint8_t * src_data = (const uint8_t *) src0->data;
129+
uint8_t * dst_data = (uint8_t *) dst->data;
130+
131+
const uint32_t ib0 = dctx->batches_per_thread * ith;
132+
const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches);
133+
134+
for (uint32_t ib = ib0; ib < ib1; ib++) {
135+
const uint32_t i3 = ib / ne02;
136+
const uint32_t i2 = ib % ne02;
137+
138+
const float * restrict src_batch = (const float *)(src_data + i3 * nb03 + i2 * nb02);
139+
140+
for (uint32_t i1 = 0; i1 < ne1; i1++) {
141+
float * restrict dst_row = (float *)(dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1);
142+
hvx_diag_row_f32(src_batch, dst_row, i1, ne0);
143+
}
144+
}
145+
146+
t2 = HAP_perf_get_qtimer_count();
147+
148+
FARF(HIGH, "diag-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n",
149+
ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1,
150+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
151+
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
152+
}
153+
154+
int op_diag_f32(struct htp_ops_context * octx) {
155+
const struct htp_tensor * src0 = octx->src[0];
156+
const struct htp_tensor * dst = octx->dst;
157+
158+
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
159+
return HTP_STATUS_OK;
160+
}
161+
162+
const uint32_t total_batches = src0->ne[2] * src0->ne[3];
163+
const uint32_t n_threads = MIN(octx->n_threads, total_batches);
164+
165+
const size_t src_batch_size = src0->ne[0] * sizeof(float);
166+
const size_t dst_row_size = dst->ne[0] * sizeof(float);
167+
const size_t src_batch_size_aligned = hex_round_up(src_batch_size, VLEN);
168+
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
169+
170+
// 1 src buffer + 1 dst row buffer per thread
171+
const size_t spad_per_thread = src_batch_size_aligned + dst_row_size_aligned;
172+
173+
octx->src0_spad.size_per_thread = src_batch_size_aligned;
174+
octx->dst_spad.size_per_thread = dst_row_size_aligned;
175+
176+
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
177+
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
178+
179+
octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
180+
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL;
181+
182+
struct htp_diag_context dctx = {
183+
.octx = octx,
184+
.src_batch_size = src_batch_size,
185+
.dst_row_size = dst_row_size,
186+
.src_batch_size_aligned = src_batch_size_aligned,
187+
.dst_row_size_aligned = dst_row_size_aligned,
188+
.batches_per_thread = (total_batches + n_threads - 1) / n_threads,
189+
.total_batches = total_batches,
190+
};
191+
192+
if (octx->ctx->vtcm_size < spad_per_thread * n_threads) {
193+
worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32, &dctx, n_threads);
194+
} else {
195+
worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32_dma, &dctx, n_threads);
196+
}
197+
198+
return HTP_STATUS_OK;
199+
}
200+
201+
int op_diag(struct htp_ops_context * octx) {
202+
const struct htp_tensor * dst = octx->dst;
203+
204+
int err = HTP_STATUS_OK;
205+
206+
switch (dst->type) {
207+
case HTP_TYPE_F32:
208+
err = op_diag_f32(octx);
209+
break;
210+
default:
211+
err = HTP_STATUS_NO_SUPPORT;
212+
break;
213+
}
214+
215+
return err;
216+
}

ggml/src/ggml-hexagon/htp/htp-ctx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,6 @@ int op_repeat(struct htp_ops_context * octx);
9898
int op_argsort(struct htp_ops_context * octx);
9999
int op_ssm_conv(struct htp_ops_context * octx);
100100
int op_cumsum(struct htp_ops_context * octx);
101+
int op_diag(struct htp_ops_context * octx);
101102

102103
#endif /* HTP_CTX_H */

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ enum htp_op_code {
8080
HTP_OP_SSM_CONV,
8181
HTP_OP_REPEAT,
8282
HTP_OP_CUMSUM,
83+
HTP_OP_DIAG,
8384

8485
HTP_OP_INVALID
8586
};

ggml/src/ggml-hexagon/htp/main.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,9 @@ static int execute_op(struct htp_ops_context * octx) {
514514
case HTP_OP_CUMSUM:
515515
return op_cumsum(octx);
516516

517+
case HTP_OP_DIAG:
518+
return op_diag(octx);
519+
517520
case HTP_OP_INVALID:
518521
break;
519522

0 commit comments

Comments
 (0)