Commit f03d331
vulkan: TQ4_1s support for model weights (#69)
* vulkan: add TQ4_1S weight compression support
Adds Vulkan shader support for TQ4_1S (4-bit WHT-rotated weight
compression with 16 Lloyd-Max centroids, 32-element blocks).
Shaders:
- dequant_tq4_1s.comp: standalone dequant with WHT inverse via
subgroupShuffleXor (32-thread workgroup, 5-stage butterfly)
- mul_mat_vec_tq4_1s.comp: specialized MUL_MAT_VEC with inline
activation pre-rotation (forward RHT on activation, centroid*scale
dequant without inverse RHT)
- copy_from_quant.comp: TQ4_1S dequant path with full WHT inverse
- copy_to_quant.comp: TQ4_1S SET_ROWS quantization path with forward
RHT, dual half-block RMS scales, 16-centroid quantization
- types.glsl: block_tq4_1s struct (d0, d1, qs[16])
- dequant_funcs.glsl: TQ4_1S centroid*scale dequant (no RHT)
Pipeline wiring (ggml-vulkan.cpp):
- MUL_MAT, SET_ROWS, CPY supports_op
- pipeline_dequant, pipeline_set_rows, pipeline_cpy_quant_f32
- Specialized MUL_MAT_VEC with forced subgroup workgroup size
Tests:
- test_set_rows_tq4_1s: SET_ROWS round-trip validation
* vulkan: add fused mul_mat_vec kernel for TQ4_1S
Adds a specialised MUL_MAT_VEC shader for TQ4_1S weights so the
per-decode-step matrix-vector product no longer has to dequant the
full weight tensor to f16 and then go through the generic matmul
path. The kernel pre-rotates the activation via a forward
Walsh-Hadamard Transform in shared memory and dot-products against
the raw centroid*scale stored weights, folding the inverse-WHT on
the weight side into the activation by the symmetry H = H^T.
Math:
w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k]
sum_k w[k] * a[k] = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j]
Portability choices:
- Workgroup size is pinned to 32 threads regardless of the
DMMV_WG_SIZE bucket the rest of the mul_mat_vec family picks for
the current architecture. The butterfly operates on 32-element
blocks with one element per thread; that contract is fixed by the
quantization format, not by the GPU. Earlier revisions used
`gl_WorkGroupSize.x` as the stride unit, which silently skipped
half the work on Intel drivers that force the subgroup to 16
(tests passed via NMSE tolerance while real inference output was
garbage).
- Butterfly implementation is shared memory only. A subgroup-shuffle
variant (`subgroupShuffleXor`) was prototyped and measured on Intel
Arc A380 with Mesa Xe HPG: it ran ~60-85 %% slower than the
explicit shared-memory butterfly, because Mesa emulates subgroup
shuffles via LDS and ends up doing the same LDS traffic with extra
driver overhead. The shared-memory butterfly is correct on every
device regardless of subgroup-op support, is the fastest path on
every device we can actually measure, and leaves the
`pipeline_dequant_mul_mat_vec_f32_f32[w][TQ4_1S]` slot uniform
across all DMMV_WG_SIZE buckets.
- Reduction is the shared-memory tree reduction (no subgroupAdd), for
the same reason: on Intel Arc the subgroupAdd is also LDS-backed
and the hybrid reduction path was measurably slower. Future
vendor-specific heuristics can switch to the hybrid or pure-subgroup
reduction variants on NVIDIA / AMD RDNA if hardware subgroup ops
turn out to beat the LDS roundtrip there; the existing reduction
modes in `mul_mat_vec_base.glsl` already provide the necessary
variants.
- NUM_ROWS is 8 so the butterfly cost amortises across 8 output rows
per workgroup. Each thread holds one position of each of the 8
weight blocks and pairs them with the shared rotated activation.
- `mul_mm` and `flash_attn_cm2` shader generation is skipped for
TQ4_1S because it is a weight-only format that never reaches the
coopmat2 matmul or the KV cache flash-attention paths.
Tests:
- `test-backend-ops` MUL_MAT tolerance tightened from 2.0 to 0.01
NMSE so real defects can't hide behind a loose check.
- Added Gemma-4 E2B, Qwen, Phi and Llama dimensional coverage
(k in {1536, 2048, 2304, 3072, 4096}, m in {256, 1152, 1536,
2048, 5120, 6144}, n in {1..8, 16, 64, 256}). 148 MUL_MAT test
cases total.
Verification (Intel Arc A380, 6 GB VRAM, Vulkan ANV / Mesa Xe HPG,
`llama-bench -p 512 -n 128 -r 3` and `llama-perplexity -c 512
--chunks 20 wiki.test.raw`):
| Model | Config | Size | Reduction | PPL Δ | pp512/Q8 | tg128/Q8 |
|---------------|---------|----------:|----------:|-------:|---------:|---------:|
| Qwen2.5-1.5B | I | 1570→1082 | -31.1% | +4.66% | 53.9% | 107.5% |
| Phi-3.5-mini | I | 3873→2839 | -26.7% | +5.36% | 57.6% | 52.8% |
| Llama-3.2-3B | hybrid | 3263→2147 | -34.2% | +2.03% | 82.4% | 84.2% |
| Llama-3.2-3B | premium | 3263→2577 | -21.0% | +0.98% | 71.3% | 67.3% |
Qwen2.5-1.5B is faster than its own Q8_0 baseline with Config I:
the compressed model fits in less VRAM, and on a small model the
TQ4_1S compute cost is offset by the reduced memory traffic.
All four models produce coherent output end-to-end and the
reductions line up with the TurboQuant paper's validation matrix
(§5.8). The remaining gap to Q8_0 on the bigger models is
compute-bound on the A380; it closes further on GPUs with more raw
throughput.
* vulkan: restructure TQ4_1S inner loop for cross-row smem reuse
Splits the dequant+accumulate phase into two sub-loops:
1. Pre-compute w_vals[n] for all NUM_ROWS rows (centroid lookup +
scale multiply, reads from weight buffer only).
2. Read the rotated activation from shared memory ONCE per column,
then FMA across all rows in a tight register loop.
This is the Vulkan analogue of the 'hot loop load dedup' from the
CUDA kernel (PR #57 optimisation #2). It makes the shared memory
read explicitly loop-invariant across rows, which helps compilers
that don't auto-hoist LDS loads out of unrolled loops.
Measured effect on Intel Arc A380 (Llama-3.2-3B premium,
llama-bench tg128, r=5): 15.50 -> 15.78 t/s (+1.8%, within noise
but not a regression). The structure is cleaner regardless and
should benefit architectures with higher LDS latency.1 parent 037047e commit f03d331
9 files changed
Lines changed: 598 additions & 3 deletions
File tree
- ggml/src/ggml-vulkan
- vulkan-shaders
- tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4155 | 4155 | | |
4156 | 4156 | | |
4157 | 4157 | | |
| 4158 | + | |
| 4159 | + | |
| 4160 | + | |
| 4161 | + | |
| 4162 | + | |
| 4163 | + | |
| 4164 | + | |
| 4165 | + | |
| 4166 | + | |
| 4167 | + | |
| 4168 | + | |
| 4169 | + | |
| 4170 | + | |
| 4171 | + | |
| 4172 | + | |
| 4173 | + | |
| 4174 | + | |
| 4175 | + | |
| 4176 | + | |
| 4177 | + | |
| 4178 | + | |
| 4179 | + | |
| 4180 | + | |
| 4181 | + | |
4158 | 4182 | | |
4159 | 4183 | | |
4160 | 4184 | | |
| |||
4196 | 4220 | | |
4197 | 4221 | | |
4198 | 4222 | | |
| 4223 | + | |
| 4224 | + | |
| 4225 | + | |
| 4226 | + | |
4199 | 4227 | | |
4200 | 4228 | | |
4201 | 4229 | | |
| |||
4222 | 4250 | | |
4223 | 4251 | | |
4224 | 4252 | | |
| 4253 | + | |
4225 | 4254 | | |
4226 | 4255 | | |
4227 | 4256 | | |
| |||
4331 | 4360 | | |
4332 | 4361 | | |
4333 | 4362 | | |
| 4363 | + | |
4334 | 4364 | | |
4335 | 4365 | | |
4336 | 4366 | | |
| |||
4471 | 4501 | | |
4472 | 4502 | | |
4473 | 4503 | | |
4474 | | - | |
| 4504 | + | |
| 4505 | + | |
4475 | 4506 | | |
4476 | 4507 | | |
4477 | 4508 | | |
| |||
4486 | 4517 | | |
4487 | 4518 | | |
4488 | 4519 | | |
| 4520 | + | |
4489 | 4521 | | |
4490 | 4522 | | |
4491 | 4523 | | |
| |||
6141 | 6173 | | |
6142 | 6174 | | |
6143 | 6175 | | |
| 6176 | + | |
6144 | 6177 | | |
6145 | 6178 | | |
6146 | 6179 | | |
| |||
6281 | 6314 | | |
6282 | 6315 | | |
6283 | 6316 | | |
| 6317 | + | |
6284 | 6318 | | |
6285 | 6319 | | |
6286 | 6320 | | |
| |||
6296 | 6330 | | |
6297 | 6331 | | |
6298 | 6332 | | |
| 6333 | + | |
| 6334 | + | |
| 6335 | + | |
| 6336 | + | |
6299 | 6337 | | |
6300 | 6338 | | |
6301 | 6339 | | |
| |||
7393 | 7431 | | |
7394 | 7432 | | |
7395 | 7433 | | |
| 7434 | + | |
7396 | 7435 | | |
7397 | 7436 | | |
7398 | 7437 | | |
| |||
10216 | 10255 | | |
10217 | 10256 | | |
10218 | 10257 | | |
| 10258 | + | |
| 10259 | + | |
10219 | 10260 | | |
10220 | 10261 | | |
10221 | 10262 | | |
| |||
15467 | 15508 | | |
15468 | 15509 | | |
15469 | 15510 | | |
| 15511 | + | |
15470 | 15512 | | |
15471 | 15513 | | |
15472 | 15514 | | |
| |||
15607 | 15649 | | |
15608 | 15650 | | |
15609 | 15651 | | |
| 15652 | + | |
15610 | 15653 | | |
15611 | 15654 | | |
15612 | 15655 | | |
| |||
15647 | 15690 | | |
15648 | 15691 | | |
15649 | 15692 | | |
| 15693 | + | |
15650 | 15694 | | |
15651 | 15695 | | |
15652 | 15696 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
33 | 68 | | |
34 | 69 | | |
35 | 70 | | |
| |||
48 | 83 | | |
49 | 84 | | |
50 | 85 | | |
| 86 | + | |
51 | 87 | | |
0 commit comments