@@ -24,25 +24,12 @@ const float TQ4_SIGNS[32] = float[32](
2424
2525const float TQ4_INV_SQRT32 = 0.17677669529663688;
2626
27- // Math: the stored weights satisfy w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k]
28- // where H is the 32x32 symmetric Hadamard matrix and stored[j] = centroid[qs[j]] * d[j].
27+ // See the commit message on a850ccc for the full derivation and portability
28+ // rationale. Short version: pre-rotate the activation block via forward WHT
29+ // in shared memory, then dot-product against the raw centroid*scale weights.
2930//
30- // sum_k w[k] * a[k]
31- // = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j]
32- //
33- // So we pre-rotate the activation once per block via forward RHT, then each
34- // thread dot-products against the raw centroid*scale weights at its own
35- // position of the block.
36- //
37- // Workgroup contract: local_size_x (spec constant 0) is always 32, and every
38- // thread owns exactly one element of the 32-element block. The butterfly is
39- // performed in shared memory. A subgroup-shuffle variant was tried but it
40- // was measurably slower on Intel Arc / Mesa (where shuffles are emulated over
41- // shared memory anyway) and the shared-memory path is correct on every
42- // device regardless of whether subgroup shuffles are supported.
43- //
44- // Shared memory budget: NUM_COLS * 32 floats (128 bytes per column, max 1 KiB
45- // at NUM_COLS=8), plus whatever tmpsh the reduction helper allocates.
31+ // Shared memory budget: NUM_COLS * 32 floats (max 1 KiB at NUM_COLS=8)
32+ // plus whatever tmpsh the reduction helper allocates.
4633
4734shared float tq4_smem[8 * 32];
4835
@@ -65,18 +52,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
6552 const float sign_tid = TQ4_SIGNS[tid];
6653
6754 for (uint blk = 0; blk < num_blocks_per_row; blk++) {
68- // Load the activation slice for each column, sign-flipped, into shared
69- // memory. Each of the 32 threads handles one element position.
55+ // --- Stage 1: load activation, sign-flip, write to shared memory ---
7056 [[unroll]] for (uint c = 0; c < NUM_COLS; ++c) {
7157 const uint b_base = c * p.batch_stride_b + b_offset + blk * 32u;
7258 tq4_smem[c * 32u + tid] = float(data_b[b_base + tid]) * sign_tid;
7359 }
7460 barrier();
7561
76- // Forward WHT butterfly in shared memory (5 stages, log2(32)). At
77- // each stage the threads with the low bit of `step` clear take both
78- // slots of the pair and write back (sum, diff) so that only 16 threads
79- // are active per stage and no two threads touch the same slot.
62+ // --- Stage 2: forward WHT butterfly in shared memory (5 stages) ---
8063 [[unroll]] for (uint step = 1u; step < 32u; step <<= 1u) {
8164 if ((tid & step) == 0u) {
8265 const uint partner = tid + step;
@@ -91,24 +74,31 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
9174 barrier();
9275 }
9376
94- // Dequant weight(s) for the current block and accumulate. The
95- // INV_SQRT32 normalisation of the inverse WHT is folded into w so
96- // the inner accumulate is just one multiply-add per (col, row).
77+ // --- Stage 3: dequant all rows' weights for this block position ---
78+ // Pre-computing the weight for every row before touching the column
79+ // accumulator lets the compiler treat the smem read in stage 4 as
80+ // loop-invariant across rows, which is the Vulkan analogue of the
81+ // "hot loop load dedup" optimisation in the CUDA kernel (PR #57).
82+ float w_vals[NUM_ROWS];
9783 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
98- const uint ib = (first_row + n) * num_blocks_per_row + blk;
84+ const uint ib = (first_row + n) * num_blocks_per_row + blk;
9985 const uint idx = (uint(data_a[a_offset + ib].qs[byte_idx]) >> nibble_shift) & 0xFu;
10086 const float d = (tid < 16u)
10187 ? float(data_a[a_offset + ib].d0)
10288 : float(data_a[a_offset + ib].d1);
103- const float w = TQ4_CENTROIDS[idx] * d * TQ4_INV_SQRT32;
89+ w_vals[n] = TQ4_CENTROIDS[idx] * d * TQ4_INV_SQRT32;
90+ }
10491
105- [[unroll]] for (uint c = 0; c < NUM_COLS; ++c) {
106- temp[c][n] += FLOAT_TYPE(w * tq4_smem[c * 32u + tid]);
92+ // --- Stage 4: accumulate dot products ---
93+ // Read the rotated activation once per column; reuse across all rows.
94+ [[unroll]] for (uint c = 0; c < NUM_COLS; ++c) {
95+ const float b_rotated = tq4_smem[c * 32u + tid];
96+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
97+ temp[c][n] += FLOAT_TYPE(w_vals[n] * b_rotated);
10798 }
10899 }
109100
110- // Ensure every thread is done reading the current block's rotated
111- // activation before the next iteration overwrites it.
101+ // Ensure every thread is done reading before the next block's store.
112102 barrier();
113103 }
114104
0 commit comments