Skip to content

Commit bf83733

Browse files
committed
ggml-ve: batched VEBP matmul for N>1 (one weight pass across n_tok columns)
The N>1 VEBP prompt matmul looped the matvec _inner per activation column, so each column re-traversed the whole ternary weight from HBM. ve_vebp_matmul_ptr_ inner instead quantises all n_tok activation columns in parallel, then runs ONE rowblock loop in which each rowblock's interleaved weight is loaded once and reused (from cache) across all n_tok columns. The graph-compiler MUL_MAT_VEBP codegen calls it for the scales_n (N>1) case; decode / the n_out tail keep the per-column matvec. Standalone test (test_vebp_matmul.c, random planes vs the per-column matvec): bit-identical across clean / tail / large-N cases. Integrated output is token-for-token identical to the interpreter. Bonsai-VEBP prompt: 12.65 -> 13.2 tok/s (warm, run twice; ~5%). The gain is small because VEBP prompt is compute-bound: V.OP is already 92% and the ternary vpcnt is inherently one dot product per (row, column) — batching recovers the weight HBM traffic + per-column barriers + serial quant, not the vpcnt work. Net VEBP prompt this session: 9.8 (interpreter) -> 13.2 (1.35x). Decode and Llama paths unchanged. A weight-word-reuse intrinsics kernel (load each weight word once, tile N columns in registers) could shave a little more but the inner loop is vpcnt-bound, so the headroom is small — left as a possible follow-up.
1 parent 113ba08 commit bf83733

3 files changed

Lines changed: 174 additions & 7 deletions

File tree

ggml/src/ggml-ve/graph_compiler.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,13 +1002,26 @@ std::string GraphCompiler::gen_op_code(const TracedOp & op, int idx) const {
10021002
std::string ws_p = "p[" + std::to_string(op.vebp_ws_idx) + "]";
10031003
std::string wn_p = "p[" + std::to_string(op.vebp_wn_idx) + "]";
10041004
std::string wsc_p = "p[" + std::to_string(op.vebp_wsc_idx) + "]";
1005-
ss << " for (int64_t col = 0; col < " << col_n << "; col++)\n";
1006-
ss << " ve_vebp_matvec_ptr_inner((float*)" << dst << " + col*" << M
1007-
<< ", (const unsigned long*)" << ws_p
1008-
<< ", (const unsigned long*)" << wn_p
1009-
<< ", (const float*)" << wsc_p
1010-
<< ", (const float*)" << src1 << " + col*" << K
1011-
<< ", " << M << ", " << K << ");\n";
1005+
if (scales_n) {
1006+
// N>1: one batched call — the rowblock weight is read once and
1007+
// reused across all n_tok columns (vs the col-loop re-traversing
1008+
// the whole weight from HBM per column).
1009+
ss << " ve_vebp_matmul_ptr_inner((float*)" << dst
1010+
<< ", (const unsigned long*)" << ws_p
1011+
<< ", (const unsigned long*)" << wn_p
1012+
<< ", (const float*)" << wsc_p
1013+
<< ", (const float*)" << src1
1014+
<< ", " << M << ", " << K << ", (int)n_tok);\n";
1015+
} else {
1016+
// decode / n_out tail: per-column matvec (col_n is 1).
1017+
ss << " for (int64_t col = 0; col < " << col_n << "; col++)\n";
1018+
ss << " ve_vebp_matvec_ptr_inner((float*)" << dst << " + col*" << M
1019+
<< ", (const unsigned long*)" << ws_p
1020+
<< ", (const unsigned long*)" << wn_p
1021+
<< ", (const float*)" << wsc_p
1022+
<< ", (const float*)" << src1 << " + col*" << K
1023+
<< ", " << M << ", " << K << ");\n";
1024+
}
10121025
break;
10131026
}
10141027

@@ -1314,6 +1327,7 @@ std::string GraphCompiler::generate_source(const std::string & func_name) const
13141327
ss << "extern void ve_bf16_matvec_rowmajor_ptr_inner(float* y, const uint16_t* W, const float* x, int M, int K);\n";
13151328
ss << "extern void ve_q4k_matvec_rowmajor_ptr_inner(float* y, const unsigned char* qs, const unsigned char* hdr, const float* x, int M, int K);\n";
13161329
ss << "extern void ve_vebp_matvec_ptr_inner(float* y, const unsigned long* ws, const unsigned long* wn, const float* wsc, const float* x, int M, int K);\n";
1330+
ss << "extern void ve_vebp_matmul_ptr_inner(float* y, const unsigned long* ws, const unsigned long* wn, const float* wsc, const float* x, int M, int K, int N);\n";
13171331
ss << "extern void ve_f32_matvec_ptr(float* y, const float* W, const float* x, int M, int K);\n";
13181332
ss << "extern void attention_f32_raw_gqa_stride_omp(float* out, const float* q, const void* k, const void* v,"
13191333
<< " int head_dim, int n_q_heads, int n_kv_heads, int seq_len, float scale,"
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/* Standalone correctness test for ve_vebp_matmul_ptr_inner (batched VEBP matmul)
2+
* vs the per-column ve_vebp_matvec_ptr_inner loop the graph compiler used.
3+
* Random interleaved planes are fine — both kernels run the SAME vpcnt math, so
4+
* the batched output must match the per-column reference column-for-column.
5+
*
6+
* Build: ncc -O4 -fopenmp test_vebp_matmul.c libve_sgemv.so -o /tmp/test_vebp_matmul
7+
*/
8+
#include <stdint.h>
9+
#include <stdio.h>
10+
#include <stdlib.h>
11+
#include <math.h>
12+
#include <omp.h>
13+
14+
extern void ve_vebp_matvec_ptr_inner(float *y, const uint64_t *ws, const uint64_t *wn,
15+
const float *wsc, const float *x, int M, int K);
16+
extern void ve_vebp_matmul_ptr_inner(float *y, const uint64_t *ws, const uint64_t *wn,
17+
const float *wsc, const float *x, int M, int K, int N);
18+
19+
/* Stubs for the .so's device-runtime / cblas / ftrace deps — none are reached
20+
* from the matvec/matmul _inner code paths this test exercises. */
21+
int vedaMemPtr(void **p, uint64_t v) { (void) p; (void) v; return 1; }
22+
void cblas_sgemv(void) {}
23+
void cblas_sgemm(void) {}
24+
void __ftrace_func_enter(void) {}
25+
void __ftrace_func_exit(void) {}
26+
27+
#define RB 256
28+
29+
static int run_case(int M, int K, int N) {
30+
long wpr = K / 64, ng = K / 128;
31+
long Mblk = (M + RB - 1) / RB; /* padded full blocks */
32+
long ws_n = Mblk * wpr * RB, wsc_n = Mblk * ng * RB;
33+
uint64_t *ws = aligned_alloc(64, ws_n * 8), *wn = aligned_alloc(64, ws_n * 8);
34+
float *wsc = aligned_alloc(64, wsc_n * 4);
35+
float *x = aligned_alloc(64, (long) K * N * 4);
36+
float *yr = aligned_alloc(64, (long) M * N * 4);
37+
float *yb = aligned_alloc(64, (long) M * N * 4);
38+
39+
for (long i = 0; i < ws_n; i++) { ws[i] = ((uint64_t) rand() << 32) ^ rand();
40+
wn[i] = ((uint64_t) rand() << 32) ^ rand(); }
41+
for (long i = 0; i < wsc_n; i++) wsc[i] = (float) (rand() % 1000) / 1000.0f;
42+
for (long i = 0; i < (long) K * N; i++) x[i] = (float) (rand() % 2000 - 1000) / 100.0f;
43+
44+
for (int j = 0; j < N; j++) {
45+
#pragma omp parallel
46+
{ ve_vebp_matvec_ptr_inner(yr + (long) j * M, ws, wn, wsc, x + (long) j * K, M, K); }
47+
}
48+
#pragma omp parallel
49+
{ ve_vebp_matmul_ptr_inner(yb, ws, wn, wsc, x, M, K, N); }
50+
51+
double maxd = 0; long nbad = 0;
52+
for (long i = 0; i < (long) M * N; i++) {
53+
double d = fabs((double) yr[i] - (double) yb[i]);
54+
if (d > maxd) maxd = d;
55+
if (d > 1e-3) nbad++;
56+
}
57+
printf("M=%d K=%d N=%d max|ref-batch|=%.3e nbad=%ld/%ld %s\n",
58+
M, K, N, maxd, nbad, (long) M * N, nbad ? "FAIL" : "PASS");
59+
free(ws); free(wn); free(wsc); free(x); free(yr); free(yb);
60+
return nbad != 0;
61+
}
62+
63+
int main(void) {
64+
srand(12345);
65+
int bad = 0;
66+
bad |= run_case(768, 512, 5); /* clean: M=3 blocks, no tail */
67+
bad |= run_case(800, 512, 5); /* tail: 3 full blocks + 32 rows */
68+
bad |= run_case(4096, 1024, 8); /* bigger, prompt-ish N */
69+
return bad;
70+
}

ggml/src/ggml-ve/kernels-veda/ve_vebp_dispatch.c

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ static uint64_t *g_asign = NULL, *g_amag = NULL;
3333
static size_t g_acap = 0; /* in uint64 units of wpr */
3434
static float g_ptr_ax = 0.0f; /* shares ax single->for in ptr_inner */
3535

36+
/* batched (N-column) activation scratch for ve_vebp_matmul_ptr_inner */
37+
static uint64_t *g_asignN = NULL, *g_amagN = NULL;
38+
static float *g_axN = NULL;
39+
static long g_Ncap = 0, g_wprcapN = 0;
40+
3641
/* Pointer-arg variant for the graph compiler's generated kernel. Called
3742
* from INSIDE the kernel's `#pragma omp parallel` region (all threads enter).
3843
* Builds the activation planes once via `#pragma omp single` (implicit
@@ -83,6 +88,84 @@ void ve_vebp_matvec_ptr_inner(float *y, const uint64_t *ws, const uint64_t *wn,
8388
}
8489
}
8590

91+
/* Batched matmul variant: y[M,N] = W[M,K] @ X[K,N], one fused call instead of
92+
* the graph compiler's per-column matvec loop. Called from INSIDE the kernel's
93+
* `#pragma omp parallel`. Quantises all N activation columns in parallel, then
94+
* shares ONE rowblock loop across the team — each rowblock's interleaved weight
95+
* is loaded from HBM once and reused (from cache) across all N columns, instead
96+
* of the col-loop re-traversing the whole weight from HBM N times. The vpcnt
97+
* work itself is inherently N x (one dot product per column), so this recovers
98+
* the weight-traffic + barrier + serial-quant overhead, not the compute.
99+
* y column j is at y + j*M (matches the col-major dst the codegen used). */
100+
void ve_vebp_matmul_ptr_inner(float *y, const uint64_t *ws, const uint64_t *wn,
101+
const float *wsc, const float *x,
102+
int M, int K, int N) {
103+
const long wpr = (long) K / 64;
104+
const long ng = (long) K / 128;
105+
const long Mfull = (long) M / VEBP_RB;
106+
const long Mtail = (long) M % VEBP_RB;
107+
108+
#pragma omp single
109+
{
110+
if ((long) N > g_Ncap || wpr > g_wprcapN) {
111+
if (g_asignN) free(g_asignN);
112+
if (g_amagN) free(g_amagN);
113+
if (g_axN) free(g_axN);
114+
g_asignN = (uint64_t *) aligned_alloc(64, (size_t) N * wpr * sizeof(uint64_t));
115+
g_amagN = (uint64_t *) aligned_alloc(64, (size_t) N * VEBP_NB * wpr * sizeof(uint64_t));
116+
g_axN = (float *) aligned_alloc(64, (size_t) N * sizeof(float));
117+
g_Ncap = N; g_wprcapN = wpr;
118+
}
119+
} /* implicit barrier: scratch published */
120+
121+
/* Quantise each of the N activation columns (independent -> omp for). */
122+
#pragma omp for
123+
for (long j = 0; j < (long) N; j++) {
124+
const float *xj = x + j * (long) K;
125+
float amax = 0.0f;
126+
for (long k = 0; k < (long) K; k++) { float a = fabsf(xj[k]); if (a > amax) amax = a; }
127+
float axj = amax / 127.0f + 1e-12f;
128+
g_axN[j] = axj;
129+
vebp_build_act_planes(xj, VEBP_NB, wpr,
130+
g_asignN + j * wpr,
131+
g_amagN + j * VEBP_NB * wpr,
132+
1.0f / axj);
133+
} /* implicit barrier: all planes ready */
134+
135+
/* One rowblock loop; inner N columns reuse the cached weight block. */
136+
#pragma omp for
137+
for (long rb = 0; rb < Mfull; rb++) {
138+
const uint64_t *wsb = ws + rb * wpr * VEBP_RB;
139+
const uint64_t *wnb = wn + rb * wpr * VEBP_RB;
140+
const float *wscb = wsc + rb * ng * VEBP_RB;
141+
for (long j = 0; j < (long) N; j++) {
142+
vebp_block_vpcnt_scaled(wsb, wnb, wscb,
143+
g_asignN + j * wpr,
144+
g_amagN + j * VEBP_NB * wpr,
145+
VEBP_NB, wpr, g_axN[j],
146+
y + j * (long) M + rb * VEBP_RB);
147+
}
148+
}
149+
if (Mtail) {
150+
#pragma omp single
151+
{
152+
float ytail[VEBP_RB];
153+
const long rb = Mfull;
154+
const uint64_t *wsb = ws + rb * wpr * VEBP_RB;
155+
const uint64_t *wnb = wn + rb * wpr * VEBP_RB;
156+
const float *wscb = wsc + rb * ng * VEBP_RB;
157+
for (long j = 0; j < (long) N; j++) {
158+
vebp_block_vpcnt_scaled(wsb, wnb, wscb,
159+
g_asignN + j * wpr,
160+
g_amagN + j * VEBP_NB * wpr,
161+
VEBP_NB, wpr, g_axN[j], ytail);
162+
for (long r = 0; r < Mtail; r++)
163+
y[j * (long) M + rb * VEBP_RB + r] = ytail[r];
164+
}
165+
}
166+
}
167+
}
168+
86169
uint64_t ve_vebp_matvec_hbm(uint64_t y_vptr, uint64_t Ws_vptr, uint64_t Wn_vptr,
87170
uint64_t wscale_vptr, uint64_t x_vptr,
88171
uint64_t M, uint64_t K) {

0 commit comments

Comments
 (0)