Commit 885d5fe
committed
vulkan: add fused mul_mat_vec kernel for TQ4_1S
Adds a specialised MUL_MAT_VEC shader for TQ4_1S weights so the
per-decode-step matrix-vector product no longer has to dequant the
full weight tensor to f16 and then go through the generic matmul
path. The kernel pre-rotates the activation via a forward
Walsh-Hadamard Transform in shared memory and dot-products against
the raw centroid*scale stored weights, folding the inverse-WHT on
the weight side into the activation by the symmetry H = H^T.
Math:
w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k]
sum_k w[k] * a[k] = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j]
Portability choices:
- Workgroup size is pinned to 32 threads regardless of the
DMMV_WG_SIZE bucket the rest of the mul_mat_vec family picks for
the current architecture. The butterfly operates on 32-element
blocks with one element per thread; that contract is fixed by the
quantization format, not by the GPU. Earlier revisions used
`gl_WorkGroupSize.x` as the stride unit, which silently skipped
half the work on Intel drivers that force the subgroup to 16
(tests passed via NMSE tolerance while real inference output was
garbage).
- Butterfly implementation is shared memory only. A subgroup-shuffle
variant (`subgroupShuffleXor`) was prototyped and measured on Intel
Arc A380 with Mesa Xe HPG: it ran ~60-85 %% slower than the
explicit shared-memory butterfly, because Mesa emulates subgroup
shuffles via LDS and ends up doing the same LDS traffic with extra
driver overhead. The shared-memory butterfly is correct on every
device regardless of subgroup-op support, is the fastest path on
every device we can actually measure, and leaves the
`pipeline_dequant_mul_mat_vec_f32_f32[w][TQ4_1S]` slot uniform
across all DMMV_WG_SIZE buckets.
- Reduction is the shared-memory tree reduction (no subgroupAdd), for
the same reason: on Intel Arc the subgroupAdd is also LDS-backed
and the hybrid reduction path was measurably slower. Future
vendor-specific heuristics can switch to the hybrid or pure-subgroup
reduction variants on NVIDIA / AMD RDNA if hardware subgroup ops
turn out to beat the LDS roundtrip there; the existing reduction
modes in `mul_mat_vec_base.glsl` already provide the necessary
variants.
- NUM_ROWS is 8 so the butterfly cost amortises across 8 output rows
per workgroup. Each thread holds one position of each of the 8
weight blocks and pairs them with the shared rotated activation.
- `mul_mm` and `flash_attn_cm2` shader generation is skipped for
TQ4_1S because it is a weight-only format that never reaches the
coopmat2 matmul or the KV cache flash-attention paths.
Tests:
- `test-backend-ops` MUL_MAT tolerance tightened from 2.0 to 0.01
NMSE so real defects can't hide behind a loose check.
- Added Gemma-4 E2B, Qwen, Phi and Llama dimensional coverage
(k in {1536, 2048, 2304, 3072, 4096}, m in {256, 1152, 1536,
2048, 5120, 6144}, n in {1..8, 16, 64, 256}). 148 MUL_MAT test
cases total.
Verification (Intel Arc A380, 6 GB VRAM, Vulkan ANV / Mesa Xe HPG,
`llama-bench -p 512 -n 128 -r 3` and `llama-perplexity -c 512
--chunks 20 wiki.test.raw`):
| Model | Config | Size | Reduction | PPL Δ | pp512/Q8 | tg128/Q8 |
|---------------|---------|----------:|----------:|-------:|---------:|---------:|
| Qwen2.5-1.5B | I | 1570→1082 | -31.1% | +4.66% | 53.9% | 107.5% |
| Phi-3.5-mini | I | 3873→2839 | -26.7% | +5.36% | 57.6% | 52.8% |
| Llama-3.2-3B | hybrid | 3263→2147 | -34.2% | +2.03% | 82.4% | 84.2% |
| Llama-3.2-3B | premium | 3263→2577 | -21.0% | +0.98% | 71.3% | 67.3% |
Qwen2.5-1.5B is faster than its own Q8_0 baseline with Config I:
the compressed model fits in less VRAM, and on a small model the
TQ4_1S compute cost is offset by the reduced memory traffic.
All four models produce coherent output end-to-end and the
reductions line up with the TurboQuant paper's validation matrix
(§5.8). The remaining gap to Q8_0 on the bigger models is
compute-bound on the A380; it closes further on GPUs with more raw
throughput.1 parent 9424395 commit 885d5fe
4 files changed
Lines changed: 200 additions & 7 deletions
File tree
- ggml/src/ggml-vulkan
- vulkan-shaders
- tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4023 | 4023 | | |
4024 | 4024 | | |
4025 | 4025 | | |
| 4026 | + | |
| 4027 | + | |
| 4028 | + | |
| 4029 | + | |
| 4030 | + | |
| 4031 | + | |
| 4032 | + | |
| 4033 | + | |
| 4034 | + | |
| 4035 | + | |
| 4036 | + | |
| 4037 | + | |
| 4038 | + | |
| 4039 | + | |
| 4040 | + | |
| 4041 | + | |
| 4042 | + | |
| 4043 | + | |
| 4044 | + | |
| 4045 | + | |
| 4046 | + | |
| 4047 | + | |
| 4048 | + | |
| 4049 | + | |
4026 | 4050 | | |
4027 | 4051 | | |
4028 | 4052 | | |
| |||
4062 | 4086 | | |
4063 | 4087 | | |
4064 | 4088 | | |
| 4089 | + | |
| 4090 | + | |
| 4091 | + | |
| 4092 | + | |
4065 | 4093 | | |
4066 | 4094 | | |
4067 | 4095 | | |
| |||
4086 | 4114 | | |
4087 | 4115 | | |
4088 | 4116 | | |
| 4117 | + | |
4089 | 4118 | | |
4090 | 4119 | | |
4091 | 4120 | | |
| |||
6181 | 6210 | | |
6182 | 6211 | | |
6183 | 6212 | | |
| 6213 | + | |
6184 | 6214 | | |
6185 | 6215 | | |
6186 | 6216 | | |
| |||
6196 | 6226 | | |
6197 | 6227 | | |
6198 | 6228 | | |
| 6229 | + | |
| 6230 | + | |
| 6231 | + | |
| 6232 | + | |
6199 | 6233 | | |
6200 | 6234 | | |
6201 | 6235 | | |
| |||
8206 | 8240 | | |
8207 | 8241 | | |
8208 | 8242 | | |
8209 | | - | |
8210 | | - | |
| 8243 | + | |
8211 | 8244 | | |
8212 | 8245 | | |
8213 | 8246 | | |
| |||
Lines changed: 129 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
562 | 562 | | |
563 | 563 | | |
564 | 564 | | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
565 | 570 | | |
566 | 571 | | |
567 | 572 | | |
| |||
641 | 646 | | |
642 | 647 | | |
643 | 648 | | |
| 649 | + | |
| 650 | + | |
644 | 651 | | |
645 | 652 | | |
646 | 653 | | |
| |||
682 | 689 | | |
683 | 690 | | |
684 | 691 | | |
685 | | - | |
| 692 | + | |
686 | 693 | | |
687 | 694 | | |
688 | 695 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2376 | 2376 | | |
2377 | 2377 | | |
2378 | 2378 | | |
2379 | | - | |
2380 | | - | |
2381 | | - | |
2382 | | - | |
| 2379 | + | |
| 2380 | + | |
| 2381 | + | |
2383 | 2382 | | |
2384 | 2383 | | |
2385 | 2384 | | |
| |||
8187 | 8186 | | |
8188 | 8187 | | |
8189 | 8188 | | |
| 8189 | + | |
| 8190 | + | |
| 8191 | + | |
| 8192 | + | |
| 8193 | + | |
| 8194 | + | |
| 8195 | + | |
| 8196 | + | |
| 8197 | + | |
| 8198 | + | |
| 8199 | + | |
| 8200 | + | |
| 8201 | + | |
| 8202 | + | |
| 8203 | + | |
| 8204 | + | |
| 8205 | + | |
| 8206 | + | |
| 8207 | + | |
| 8208 | + | |
| 8209 | + | |
| 8210 | + | |
| 8211 | + | |
| 8212 | + | |
| 8213 | + | |
8190 | 8214 | | |
8191 | 8215 | | |
8192 | 8216 | | |
| |||
0 commit comments