Skip to content

Commit 8bc492e

Browse files
hexagon: add SOLVE_TRI op (#21974)
* hexagon: add SOLVE_TRI op * ggml: fix TODO description for solve_tri * hexagon: rm unused variable/function warnings * hexagon: chunk vs batch processingfor better thread utilization * hexagon: vectorize partial f32 loads * hexagon: move HVX f32 add/sub/mul wrappers to hvx-base.h --------- Co-authored-by: Todor Boinovski <todorb@qti.qualcomm.com>
1 parent e5f070a commit 8bc492e

7 files changed

Lines changed: 335 additions & 2 deletions

File tree

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2693,6 +2693,39 @@ static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess
26932693
return true;
26942694
}
26952695

2696+
static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2697+
const struct ggml_tensor * src0 = op->src[0]; // A
2698+
const struct ggml_tensor * src1 = op->src[1]; // B
2699+
const struct ggml_tensor * dst = op; // X
2700+
2701+
if (!src0 || !src1) {
2702+
return false;
2703+
}
2704+
2705+
if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
2706+
return false;
2707+
}
2708+
2709+
if (src0->ne[0] != src0->ne[1]) {
2710+
return false;
2711+
}
2712+
2713+
if (src0->ne[1] != src1->ne[1]) {
2714+
return false;
2715+
}
2716+
2717+
if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) {
2718+
return false;
2719+
}
2720+
2721+
if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || dst->ne[3] != src1->ne[3]) {
2722+
return false;
2723+
}
2724+
2725+
GGML_UNUSED(sess);
2726+
return true;
2727+
}
2728+
26962729
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
26972730
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
26982731
return sess->c_name();
@@ -2731,7 +2764,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
27312764
case GGML_OP_CUMSUM: return HTP_OP_CUMSUM;
27322765
case GGML_OP_FILL: return HTP_OP_FILL;
27332766
case GGML_OP_DIAG: return HTP_OP_DIAG;
2734-
2767+
case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI;
27352768
case GGML_OP_UNARY:
27362769
switch (ggml_get_unary_op(t)) {
27372770
case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU;
@@ -3277,6 +3310,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
32773310
supp = ggml_hexagon_supported_diag(sess, op);
32783311
break;
32793312

3313+
case GGML_OP_SOLVE_TRI:
3314+
supp = ggml_hexagon_supported_solve_tri(sess, op);
3315+
break;
3316+
32803317
default:
32813318
break;
32823319
}

ggml/src/ggml-hexagon/htp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ add_library(${HTP_LIB} SHARED
3636
cumsum-ops.c
3737
fill-ops.c
3838
diag-ops.c
39+
solve-tri-ops.c
3940
)
4041

4142
target_compile_definitions(${HTP_LIB} PRIVATE

ggml/src/ggml-hexagon/htp/htp-ctx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,5 +103,6 @@ int op_ssm_conv(struct htp_ops_context * octx);
103103
int op_cumsum(struct htp_ops_context * octx);
104104
int op_fill(struct htp_ops_context * octx);
105105
int op_diag(struct htp_ops_context * octx);
106+
int op_solve_tri(struct htp_ops_context * octx);
106107

107108
#endif /* HTP_CTX_H */

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ enum htp_op_code {
8282
HTP_OP_CUMSUM,
8383
HTP_OP_FILL,
8484
HTP_OP_DIAG,
85-
85+
HTP_OP_SOLVE_TRI,
8686
HTP_OP_INVALID
8787
};
8888

ggml/src/ggml-hexagon/htp/hvx-base.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
256256
return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b));
257257
}
258258

259+
static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) {
260+
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b));
261+
}
262+
263+
static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) {
264+
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b));
265+
}
266+
267+
static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) {
268+
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b));
269+
}
270+
259271
#else
260272

261273
static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)
@@ -273,6 +285,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
273285
return Q6_Vhf_vmpy_VhfVhf(a, b);
274286
}
275287

288+
static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) {
289+
return Q6_Vsf_vadd_VsfVsf(a, b);
290+
}
291+
292+
static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) {
293+
return Q6_Vsf_vsub_VsfVsf(a, b);
294+
}
295+
296+
static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) {
297+
return Q6_Vsf_vmpy_VsfVsf(a, b);
298+
}
299+
276300
#endif // __HVX_ARCH__ < 79
277301

278302
#endif /* HVX_BASE_H */

ggml/src/ggml-hexagon/htp/main.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,9 @@ static int execute_op(struct htp_ops_context * octx) {
573573
case HTP_OP_DIAG:
574574
return op_diag(octx);
575575

576+
case HTP_OP_SOLVE_TRI:
577+
return op_solve_tri(octx);
578+
576579
case HTP_OP_INVALID:
577580
break;
578581

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
2+
3+
#include <HAP_farf.h>
4+
#include <HAP_perf.h>
5+
#include <string.h>
6+
7+
#define GGML_COMMON_DECL_C
8+
#include "ggml-common.h"
9+
#include "htp-ctx.h"
10+
#include "htp-ops.h"
11+
#include "hvx-types.h"
12+
#include "hvx-utils.h"
13+
14+
struct htp_solve_tri_context {
15+
struct htp_ops_context * octx;
16+
uint32_t jobs_per_thread;
17+
uint32_t total_jobs;
18+
uint32_t k_chunks;
19+
uint32_t col_block;
20+
};
21+
22+
static inline void solve_tri_row_scalar(const float * A_row,
23+
const float * B_row,
24+
float * X,
25+
uint32_t row,
26+
uint32_t k,
27+
uint32_t col0,
28+
uint32_t coln,
29+
float inv_diag) {
30+
for (uint32_t col = col0; col < col0 + coln; ++col) {
31+
float sum = 0.0f;
32+
for (uint32_t t = 0; t < row; ++t) {
33+
sum += A_row[t] * X[t * k + col];
34+
}
35+
X[row * k + col] = (B_row[col] - sum) * inv_diag;
36+
}
37+
}
38+
39+
static inline HVX_Vector hvx_load_partial_f32(const float * src, uint32_t n) {
40+
HVX_Vector v = *((const HVX_UVector *) src);
41+
HVX_VectorPred mask = Q6_Q_vsetq2_R(n * sizeof(float));
42+
return Q6_V_vmux_QVV(mask, v, Q6_V_vzero());
43+
}
44+
45+
static inline void solve_tri_row_hvx(const float * A_row,
46+
const float * B_row,
47+
float * X,
48+
uint32_t row,
49+
uint32_t k,
50+
uint32_t col0,
51+
uint32_t coln,
52+
float inv_diag) {
53+
const bool full = (coln == VLEN_FP32);
54+
55+
HVX_Vector sum_v = Q6_V_vzero();
56+
for (uint32_t t = 0; t < row; ++t) {
57+
const float a = A_row[t];
58+
const float * x_row_col = X + t * k + col0;
59+
60+
HVX_Vector x_v = full ? *((const HVX_UVector *) x_row_col) : hvx_load_partial_f32(x_row_col, coln);
61+
HVX_Vector a_v = hvx_vec_splat_f32(a);
62+
sum_v = hvx_vec_add_f32_f32(sum_v, hvx_vec_mul_f32_f32(x_v, a_v));
63+
}
64+
65+
const float * b_row_col = B_row + col0;
66+
float * x_out_col = X + row * k + col0;
67+
68+
HVX_Vector b_v = full ? *((const HVX_UVector *) b_row_col) : hvx_load_partial_f32(b_row_col, coln);
69+
HVX_Vector inv_diag_v = hvx_vec_splat_f32(inv_diag);
70+
71+
HVX_Vector out_v = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(b_v, sum_v), inv_diag_v);
72+
hvx_vec_store_u((void *) x_out_col, coln * sizeof(float), out_v);
73+
}
74+
75+
// Batch-level thread: each job is one full batch.
76+
static void solve_tri_batch_thread_f32(unsigned int nth, unsigned int ith, void * data) {
77+
struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data;
78+
struct htp_ops_context * octx = sctx->octx;
79+
80+
const struct htp_tensor * src0 = octx->src[0]; // A
81+
const struct htp_tensor * src1 = octx->src[1]; // B
82+
const struct htp_tensor * dst = octx->dst; // X
83+
84+
const uint32_t n = src0->ne[0];
85+
const uint32_t k = src1->ne[0];
86+
87+
const uint32_t ne02 = src0->ne[2];
88+
89+
const uint32_t col_block = VLEN_FP32;
90+
const uint32_t k_full = (k / col_block) * col_block;
91+
92+
const uint32_t start_batch = sctx->jobs_per_thread * ith;
93+
const uint32_t end_batch = MIN(start_batch + sctx->jobs_per_thread, sctx->total_jobs);
94+
95+
uint64_t t1, t2;
96+
t1 = HAP_perf_get_qtimer_count();
97+
98+
for (uint32_t batch = start_batch; batch < end_batch; ++batch) {
99+
const uint32_t i03 = batch / ne02;
100+
const uint32_t i02 = batch - i03 * ne02;
101+
102+
const float * A_batch =
103+
(const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]);
104+
const float * B_batch =
105+
(const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]);
106+
float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]);
107+
108+
for (uint32_t row = 0; row < n; ++row) {
109+
const float diag = A_batch[row * n + row];
110+
const float inv_diag = 1.0f / diag;
111+
const float * A_row = A_batch + row * n;
112+
const float * B_row = B_batch + row * k;
113+
114+
uint32_t col0 = 0;
115+
for (; col0 < k_full; col0 += col_block) {
116+
solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, col_block, inv_diag);
117+
}
118+
119+
if (col0 < k) {
120+
const uint32_t coln = k - col0;
121+
if (coln >= 8) {
122+
solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag);
123+
} else {
124+
solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag);
125+
}
126+
}
127+
}
128+
}
129+
130+
t2 = HAP_perf_get_qtimer_count();
131+
132+
FARF(HIGH, "solve-tri-batch %d/%d: A=(%ux%u) B=(%ux%u) batch %u:%u usec %u\n",
133+
ith, nth, n, n, k, n, start_batch, end_batch,
134+
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
135+
}
136+
137+
// Chunk-level thread: each job is one (batch, col_chunk) pair.
138+
static void solve_tri_chunk_thread_f32(unsigned int nth, unsigned int ith, void * data) {
139+
struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data;
140+
struct htp_ops_context * octx = sctx->octx;
141+
142+
const struct htp_tensor * src0 = octx->src[0]; // A
143+
const struct htp_tensor * src1 = octx->src[1]; // B
144+
const struct htp_tensor * dst = octx->dst; // X
145+
146+
const uint32_t n = src0->ne[0];
147+
const uint32_t k = src1->ne[0];
148+
149+
const uint32_t ne02 = src0->ne[2];
150+
151+
const uint32_t start_job = sctx->jobs_per_thread * ith;
152+
const uint32_t end_job = MIN(start_job + sctx->jobs_per_thread, sctx->total_jobs);
153+
154+
uint64_t t1, t2;
155+
t1 = HAP_perf_get_qtimer_count();
156+
157+
for (uint32_t job = start_job; job < end_job; ++job) {
158+
const uint32_t batch = job / sctx->k_chunks;
159+
const uint32_t chunk = job - batch * sctx->k_chunks;
160+
161+
const uint32_t i03 = batch / ne02;
162+
const uint32_t i02 = batch - i03 * ne02;
163+
164+
const uint32_t col0 = chunk * sctx->col_block;
165+
const uint32_t coln = MIN(sctx->col_block, k - col0);
166+
167+
const float * A_batch =
168+
(const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]);
169+
const float * B_batch =
170+
(const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]);
171+
float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]);
172+
173+
const bool use_hvx = (coln >= 8);
174+
175+
for (uint32_t row = 0; row < n; ++row) {
176+
const float diag = A_batch[row * n + row];
177+
const float inv_diag = 1.0f / diag;
178+
179+
const float * A_row = A_batch + row * n;
180+
const float * B_row = B_batch + row * k;
181+
182+
if (use_hvx) {
183+
solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag);
184+
} else {
185+
solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag);
186+
}
187+
}
188+
}
189+
190+
t2 = HAP_perf_get_qtimer_count();
191+
192+
FARF(HIGH, "solve-tri-chunk %d/%d: A=(%ux%u) B=(%ux%u) job %u:%u usec %u\n",
193+
ith, nth, n, n, k, n, start_job, end_job,
194+
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
195+
}
196+
197+
int op_solve_tri(struct htp_ops_context * octx) {
198+
const struct htp_tensor * src0 = octx->src[0]; // A
199+
const struct htp_tensor * src1 = octx->src[1]; // B
200+
const struct htp_tensor * dst = octx->dst; // X
201+
202+
if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) {
203+
return HTP_STATUS_NO_SUPPORT;
204+
}
205+
206+
// left=true, lower=true, uni=false only
207+
if (src0->ne[0] != src0->ne[1]) {
208+
return HTP_STATUS_INVAL_PARAMS;
209+
}
210+
if (src0->ne[1] != src1->ne[1]) {
211+
return HTP_STATUS_INVAL_PARAMS;
212+
}
213+
if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) {
214+
return HTP_STATUS_INVAL_PARAMS;
215+
}
216+
if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] ||
217+
dst->ne[3] != src1->ne[3]) {
218+
return HTP_STATUS_INVAL_PARAMS;
219+
}
220+
221+
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
222+
return HTP_STATUS_OK;
223+
}
224+
225+
const uint32_t k = src1->ne[0];
226+
227+
const uint32_t col_block = VLEN_FP32;
228+
const uint32_t k_chunks = (k + col_block - 1) / col_block;
229+
const uint32_t total_batches = src0->ne[2] * src0->ne[3];
230+
const bool batched = total_batches >= (uint32_t) octx->n_threads;
231+
232+
FARF(HIGH, "solve-tri: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : batched %d\n",
233+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
234+
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
235+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], batched);
236+
237+
if (batched) {
238+
// Batch-level parallelism
239+
const uint32_t n_threads = MIN((uint32_t) octx->n_threads, total_batches);
240+
241+
struct htp_solve_tri_context sctx = {
242+
.octx = octx,
243+
.jobs_per_thread = (total_batches + n_threads - 1) / n_threads,
244+
.total_jobs = total_batches,
245+
.k_chunks = k_chunks,
246+
.col_block = col_block,
247+
};
248+
249+
worker_pool_run_func(octx->ctx->worker_pool, solve_tri_batch_thread_f32, &sctx, n_threads);
250+
} else {
251+
// Chunk-level parallelism
252+
const uint32_t total_jobs = total_batches * k_chunks;
253+
const uint32_t n_threads = MIN((uint32_t) octx->n_threads, MAX(total_jobs, 1));
254+
255+
struct htp_solve_tri_context sctx = {
256+
.octx = octx,
257+
.jobs_per_thread = (total_jobs + n_threads - 1) / n_threads,
258+
.total_jobs = total_jobs,
259+
.k_chunks = k_chunks,
260+
.col_block = col_block,
261+
};
262+
263+
worker_pool_run_func(octx->ctx->worker_pool, solve_tri_chunk_thread_f32, &sctx, n_threads);
264+
}
265+
266+
return HTP_STATUS_OK;
267+
}

0 commit comments

Comments
 (0)