Skip to content

Commit b5be42e

Browse files
committed
vulkan: restructure TQ4_1S inner loop for cross-row smem reuse
Splits the dequant+accumulate phase into two sub-loops: 1. Pre-compute w_vals[n] for all NUM_ROWS rows (centroid lookup + scale multiply, reads from weight buffer only). 2. Read the rotated activation from shared memory ONCE per column, then FMA across all rows in a tight register loop. This is the Vulkan analogue of the 'hot loop load dedup' from the CUDA kernel (PR #57 optimisation #2). It makes the shared memory read explicitly loop-invariant across rows, which helps compilers that don't auto-hoist LDS loads out of unrolled loops. Measured effect on Intel Arc A380 (Llama-3.2-3B premium, llama-bench tg128, r=5): 15.50 -> 15.78 t/s (+1.8%, within noise but not a regression). The structure is cleaner regardless and should benefit architectures with higher LDS latency.
1 parent ffc7128 commit b5be42e

1 file changed

Lines changed: 23 additions & 33 deletions

File tree

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq4_1s.comp

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,12 @@ const float TQ4_SIGNS[32] = float[32](
2424

2525
const float TQ4_INV_SQRT32 = 0.17677669529663688;
2626

27-
// Math: the stored weights satisfy w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k]
28-
// where H is the 32x32 symmetric Hadamard matrix and stored[j] = centroid[qs[j]] * d[j].
27+
// See the commit message on a850ccc for the full derivation and portability
28+
// rationale. Short version: pre-rotate the activation block via forward WHT
29+
// in shared memory, then dot-product against the raw centroid*scale weights.
2930
//
30-
// sum_k w[k] * a[k]
31-
// = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j]
32-
//
33-
// So we pre-rotate the activation once per block via forward RHT, then each
34-
// thread dot-products against the raw centroid*scale weights at its own
35-
// position of the block.
36-
//
37-
// Workgroup contract: local_size_x (spec constant 0) is always 32, and every
38-
// thread owns exactly one element of the 32-element block. The butterfly is
39-
// performed in shared memory. A subgroup-shuffle variant was tried but it
40-
// was measurably slower on Intel Arc / Mesa (where shuffles are emulated over
41-
// shared memory anyway) and the shared-memory path is correct on every
42-
// device regardless of whether subgroup shuffles are supported.
43-
//
44-
// Shared memory budget: NUM_COLS * 32 floats (128 bytes per column, max 1 KiB
45-
// at NUM_COLS=8), plus whatever tmpsh the reduction helper allocates.
31+
// Shared memory budget: NUM_COLS * 32 floats (max 1 KiB at NUM_COLS=8)
32+
// plus whatever tmpsh the reduction helper allocates.
4633

4734
shared float tq4_smem[8 * 32];
4835

@@ -65,18 +52,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
6552
const float sign_tid = TQ4_SIGNS[tid];
6653

6754
for (uint blk = 0; blk < num_blocks_per_row; blk++) {
68-
// Load the activation slice for each column, sign-flipped, into shared
69-
// memory. Each of the 32 threads handles one element position.
55+
// --- Stage 1: load activation, sign-flip, write to shared memory ---
7056
[[unroll]] for (uint c = 0; c < NUM_COLS; ++c) {
7157
const uint b_base = c * p.batch_stride_b + b_offset + blk * 32u;
7258
tq4_smem[c * 32u + tid] = float(data_b[b_base + tid]) * sign_tid;
7359
}
7460
barrier();
7561

76-
// Forward WHT butterfly in shared memory (5 stages, log2(32)). At
77-
// each stage the threads with the low bit of `step` clear take both
78-
// slots of the pair and write back (sum, diff) so that only 16 threads
79-
// are active per stage and no two threads touch the same slot.
62+
// --- Stage 2: forward WHT butterfly in shared memory (5 stages) ---
8063
[[unroll]] for (uint step = 1u; step < 32u; step <<= 1u) {
8164
if ((tid & step) == 0u) {
8265
const uint partner = tid + step;
@@ -91,24 +74,31 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
9174
barrier();
9275
}
9376

94-
// Dequant weight(s) for the current block and accumulate. The
95-
// INV_SQRT32 normalisation of the inverse WHT is folded into w so
96-
// the inner accumulate is just one multiply-add per (col, row).
77+
// --- Stage 3: dequant all rows' weights for this block position ---
78+
// Pre-computing the weight for every row before touching the column
79+
// accumulator lets the compiler treat the smem read in stage 4 as
80+
// loop-invariant across rows, which is the Vulkan analogue of the
81+
// "hot loop load dedup" optimisation in the CUDA kernel (PR #57).
82+
float w_vals[NUM_ROWS];
9783
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
98-
const uint ib = (first_row + n) * num_blocks_per_row + blk;
84+
const uint ib = (first_row + n) * num_blocks_per_row + blk;
9985
const uint idx = (uint(data_a[a_offset + ib].qs[byte_idx]) >> nibble_shift) & 0xFu;
10086
const float d = (tid < 16u)
10187
? float(data_a[a_offset + ib].d0)
10288
: float(data_a[a_offset + ib].d1);
103-
const float w = TQ4_CENTROIDS[idx] * d * TQ4_INV_SQRT32;
89+
w_vals[n] = TQ4_CENTROIDS[idx] * d * TQ4_INV_SQRT32;
90+
}
10491

105-
[[unroll]] for (uint c = 0; c < NUM_COLS; ++c) {
106-
temp[c][n] += FLOAT_TYPE(w * tq4_smem[c * 32u + tid]);
92+
// --- Stage 4: accumulate dot products ---
93+
// Read the rotated activation once per column; reuse across all rows.
94+
[[unroll]] for (uint c = 0; c < NUM_COLS; ++c) {
95+
const float b_rotated = tq4_smem[c * 32u + tid];
96+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
97+
temp[c][n] += FLOAT_TYPE(w_vals[n] * b_rotated);
10798
}
10899
}
109100

110-
// Ensure every thread is done reading the current block's rotated
111-
// activation before the next iteration overwrites it.
101+
// Ensure every thread is done reading before the next block's store.
112102
barrier();
113103
}
114104

0 commit comments

Comments
 (0)