Skip to content

vulkan: TQ4_1s support for model weights#69

Merged
TheTom merged 3 commits intoTheTom:feature/turboquant-kv-cachefrom
Titaniumtown:pr/vulkan-tq4-1s
Apr 20, 2026
Merged

vulkan: TQ4_1s support for model weights#69
TheTom merged 3 commits intoTheTom:feature/turboquant-kv-cachefrom
Titaniumtown:pr/vulkan-tq4-1s

Conversation

@Titaniumtown
Copy link
Copy Markdown

@Titaniumtown Titaniumtown commented Apr 11, 2026

Overview

Model weight quantization TQ4_1S support for vulkan backends.

| Model         | Config  |     Size  | Reduction | PPL Δ  | pp512/Q8 | tg128/Q8 |
|---------------|---------|----------:|----------:|-------:|---------:|---------:|
| Qwen2.5-1.5B  | I       | 1570→1082 |   -31.1%  | +4.66% |    53.9% |   107.5% |
| Phi-3.5-mini  | I       | 3873→2839 |   -26.7%  | +5.36% |    57.6% |    52.8% |
| Llama-3.2-3B  | hybrid  | 3263→2147 |   -34.2%  | +2.03% |    82.4% |    84.2% |
| Llama-3.2-3B  | premium | 3263→2577 |   -21.0%  | +0.98% |    71.3% |    67.3% |

Additional information

Port of #45 and #57 to vulkan

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: Claude Opus 4.6 via oh-my-pi was used in a loop with access to real hardware to implement and verify the code based on the Metal Implementation

@twobombs
Copy link
Copy Markdown

twobombs commented Apr 11, 2026

on qwen3.5-35B turbo3 is functional albeit (s)low(er) on tps I/O - turbo2 and turbo4 fail to start for me

build with cmake -B build -DGGML_VULKAN=1 && cmake --build build --config Release -j $(grep -c ^processor /proc/cpuinfo)

ran with ./build/bin/llama-server -m /media/aryan/nvme/models/llama.cpp/Qwen3.5-35B-A3B-Q4_K_S.gguf --jinja --device Vulkan2,Vulkan0 --host 0.0.0.0 -np 1 --port 8033 -c 128000 -ctk turbo3 -ctv turbo3 -fa 1 -ngl 99

Turbo3 log

system info: n_threads = 24, n_threads_batch = 24, total_threads = 48

system_info: n_threads = 24 (n_threads_batch = 24) / 48 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 |

Running without SSL
init: using 47 threads for HTTP server
start: binding port with default address family
main: loading model
srv load_model: loading model '/media/aryan/nvme/models/llama.cpp/Qwen3.5-35B-A3B-Q4_K_S.gguf'
common_init_result: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on
llama_params_fit_impl: projected memory use with initial parameters [MiB]:
llama_params_fit_impl: - Vulkan2 (AMD Radeon (TM) Pro VII (RADV VEGA20)): 16384 total, 13773 used, 2502 free vs. target of 1024
llama_params_fit_impl: - Vulkan0 (NVIDIA CMP 50HX) : 10294 total, 8274 used, 1470 free vs. target of 1024
llama_params_fit_impl: projected to use 22047 MiB of device memory vs. 26020 MiB of free device memory
llama_params_fit_impl: targets for free memory can be met on all devices, no changes needed
llama_params_fit: successfully fit params to free device memory
llama_params_fit: fitting params to free memory took 1.42 seconds
llama_model_load_from_file_impl: using device Vulkan2 (AMD Radeon (TM) Pro VII (RADV VEGA20)) (0000:44:00.0) - 16359 MiB free
llama_model_load_from_file_impl: using device Vulkan0 (NVIDIA CMP 50HX) (0000:03:00.0) - 9812 MiB free
llama_model_loader: loaded meta data with 52 key-value pairs and 733 tensors from /media/aryan/nvme/models/llama.cpp/Qwen3.5-35B-A3B-Q4_K_S.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = qwen35moe
llama_model_loader: - kv 1: general.type str = model
llama_model_loader: - kv 2: general.sampling.top_k i32 = 20
llama_model_loader: - kv 3: general.sampling.top_p f32 = 0.950000
llama_model_loader: - kv 4: general.sampling.temp f32 = 1.000000
llama_model_loader: - kv 5: general.name str = Qwen3.5-35B-A3B
llama_model_loader: - kv 6: general.basename str = Qwen3.5-35B-A3B
llama_model_loader: - kv 7: general.quantized_by str = Unsloth
llama_model_loader: - kv 8: general.size_label str = 35B-A3B
llama_model_loader: - kv 9: general.license str = apache-2.0
llama_model_loader: - kv 10: general.license.link str = https://huggingface.co/Qwen/Qwen3.5-3...
llama_model_loader: - kv 11: general.repo_url str = https://huggingface.co/unsloth
llama_model_loader: - kv 12: general.base_model.count u32 = 1
llama_model_loader: - kv 13: general.base_model.0.name str = Qwen3.5 35B A3B
llama_model_loader: - kv 14: general.base_model.0.organization str = Qwen
llama_model_loader: - kv 15: general.base_model.0.repo_url str = https://huggingface.co/Qwen/Qwen3.5-3...
llama_model_loader: - kv 16: general.tags arr[str,2] = ["unsloth", "image-text-to-text"]
llama_model_loader: - kv 17: qwen35moe.block_count u32 = 40
llama_model_loader: - kv 18: qwen35moe.context_length u32 = 262144
llama_model_loader: - kv 19: qwen35moe.embedding_length u32 = 2048
llama_model_loader: - kv 20: qwen35moe.attention.head_count u32 = 16
llama_model_loader: - kv 21: qwen35moe.attention.head_count_kv u32 = 2
llama_model_loader: - kv 22: qwen35moe.rope.dimension_sections arr[i32,4] = [11, 11, 10, 0]
llama_model_loader: - kv 23: qwen35moe.rope.freq_base f32 = 10000000.000000
llama_model_loader: - kv 24: qwen35moe.attention.layer_norm_rms_epsilon f32 = 0.000001
llama_model_loader: - kv 25: qwen35moe.expert_count u32 = 256
llama_model_loader: - kv 26: qwen35moe.expert_used_count u32 = 8
llama_model_loader: - kv 27: qwen35moe.attention.key_length u32 = 256
llama_model_loader: - kv 28: qwen35moe.attention.value_length u32 = 256
llama_model_loader: - kv 29: qwen35moe.expert_feed_forward_length u32 = 512
llama_model_loader: - kv 30: qwen35moe.expert_shared_feed_forward_length u32 = 512
llama_model_loader: - kv 31: qwen35moe.ssm.conv_kernel u32 = 4
llama_model_loader: - kv 32: qwen35moe.ssm.state_size u32 = 128
llama_model_loader: - kv 33: qwen35moe.ssm.group_count u32 = 16
llama_model_loader: - kv 34: qwen35moe.ssm.time_step_rank u32 = 32
llama_model_loader: - kv 35: qwen35moe.ssm.inner_size u32 = 4096
llama_model_loader: - kv 36: qwen35moe.full_attention_interval u32 = 4
llama_model_loader: - kv 37: qwen35moe.rope.dimension_count u32 = 64
llama_model_loader: - kv 38: tokenizer.ggml.model str = gpt2
llama_model_loader: - kv 39: tokenizer.ggml.pre str = qwen35
llama_model_loader: - kv 40: tokenizer.ggml.tokens arr[str,248320] = ["!", """, "#", "$", "%", "&", "'", ...
llama_model_loader: - kv 41: tokenizer.ggml.token_type arr[i32,248320] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv 42: tokenizer.ggml.merges arr[str,247587] = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv 43: tokenizer.ggml.eos_token_id u32 = 248046
llama_model_loader: - kv 44: tokenizer.ggml.padding_token_id u32 = 248055
llama_model_loader: - kv 45: tokenizer.chat_template str = {%- set image_count = namespace(value...
llama_model_loader: - kv 46: general.quantization_version u32 = 2
llama_model_loader: - kv 47: general.file_type u32 = 14
llama_model_loader: - kv 48: quantize.imatrix.file str = Qwen3.5-35B-A3B-GGUF/imatrix_unsloth....
llama_model_loader: - kv 49: quantize.imatrix.dataset str = unsloth_calibration_Qwen3.5-35B-A3B.txt
llama_model_loader: - kv 50: quantize.imatrix.entries_count u32 = 510
llama_model_loader: - kv 51: quantize.imatrix.chunks_count u32 = 76
llama_model_loader: - type f32: 301 tensors
llama_model_loader: - type q8_0: 311 tensors
llama_model_loader: - type q4_K: 120 tensors
llama_model_loader: - type q6_K: 1 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type = Q4_K - Small
print_info: file size = 19.24 GiB (4.77 BPW)
load: 0 unused tokens
load: printing all EOG tokens:
load: - 248044 ('<|endoftext|>')
load: - 248046 ('<|im_end|>')
load: - 248063 ('<|fim_pad|>')
load: - 248064 ('<|repo_name|>')
load: - 248065 ('<|file_sep|>')
load: special tokens cache size = 33
load: token to piece cache size = 1.7581 MB
print_info: arch = qwen35moe
print_info: vocab_only = 0
print_info: no_alloc = 0
print_info: n_ctx_train = 262144
print_info: n_embd = 2048
print_info: n_embd_inp = 2048
print_info: n_layer = 40
print_info: n_head = 16
print_info: n_head_kv = 2
print_info: n_rot = 64
print_info: n_swa = 0
print_info: is_swa_any = 0
print_info: n_embd_head_k = 256
print_info: n_embd_head_v = 256
print_info: n_gqa = 8
print_info: n_embd_k_gqa = 512
print_info: n_embd_v_gqa = 512
print_info: f_norm_eps = 0.0e+00
print_info: f_norm_rms_eps = 1.0e-06
print_info: f_clamp_kqv = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale = 0.0e+00
print_info: f_attn_scale = 0.0e+00
print_info: n_ff = 0
print_info: n_expert = 256
print_info: n_expert_used = 8
print_info: n_expert_groups = 0
print_info: n_group_used = 0
print_info: causal attn = 1
print_info: pooling type = -1
print_info: rope type = 40
print_info: rope scaling = linear
print_info: freq_base_train = 10000000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn = 262144
print_info: rope_yarn_log_mul = 0.0000
print_info: rope_finetuned = unknown
print_info: mrope sections = [11, 11, 10, 0]
print_info: ssm_d_conv = 4
print_info: ssm_d_inner = 4096
print_info: ssm_d_state = 128
print_info: ssm_dt_rank = 32
print_info: ssm_n_group = 16
print_info: ssm_dt_b_c_rms = 0
print_info: model type = 35B.A3B
print_info: model params = 34.66 B
print_info: general.name = Qwen3.5-35B-A3B
print_info: vocab type = BPE
print_info: n_vocab = 248320
print_info: n_merges = 247587
print_info: BOS token = 11 ','
print_info: EOS token = 248046 '<|im_end|>'
print_info: EOT token = 248046 '<|im_end|>'
print_info: PAD token = 248055 '<|vision_pad|>'
print_info: LF token = 198 'Ċ'
print_info: FIM PRE token = 248060 '<|fim_prefix|>'
print_info: FIM SUF token = 248062 '<|fim_suffix|>'
print_info: FIM MID token = 248061 '<|fim_middle|>'
print_info: FIM PAD token = 248063 '<|fim_pad|>'
print_info: FIM REP token = 248064 '<|repo_name|>'
print_info: FIM SEP token = 248065 '<|file_sep|>'
print_info: EOG token = 248044 '<|endoftext|>'
print_info: EOG token = 248046 '<|im_end|>'
print_info: EOG token = 248063 '<|fim_pad|>'
print_info: EOG token = 248064 '<|repo_name|>'
print_info: EOG token = 248065 '<|file_sep|>'
print_info: max token length = 256
load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
load_tensors: offloading output layer to GPU
load_tensors: offloading 39 repeating layers to GPU
load_tensors: offloaded 41/41 layers to GPU
load_tensors: CPU_Mapped model buffer size = 515.31 MiB
load_tensors: Vulkan0 model buffer size = 6971.91 MiB
load_tensors: Vulkan2 model buffer size = 12218.42 MiB
..................................................................................................
common_init_result: added <|endoftext|> logit bias = -inf
common_init_result: added <|im_end|> logit bias = -inf
common_init_result: added <|fim_pad|> logit bias = -inf
common_init_result: added <|repo_name|> logit bias = -inf
common_init_result: added <|file_sep|> logit bias = -inf
llama_context: constructing llama_context
llama_context: n_seq_max = 1
llama_context: n_ctx = 128000
llama_context: n_ctx_seq = 128000
llama_context: n_batch = 2048
llama_context: n_ubatch = 512
llama_context: causal_attn = 1
llama_context: flash_attn = enabled
llama_context: kv_unified = false
llama_context: freq_base = 10000000.0
llama_context: freq_scale = 1
llama_context: n_ctx_seq (128000) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context: Vulkan_Host output buffer size = 0.95 MiB
llama_kv_cache: Vulkan0 KV buffer size = 195.31 MiB
llama_kv_cache: Vulkan2 KV buffer size = 293.09 MiB
llama_kv_cache: TurboQuant rotation matrices initialized (128x128)
llama_kv_cache: size = 488.28 MiB (128000 cells, 10 layers, 1/1 seqs), K (turbo3): 244.14 MiB, V (turbo3): 244.14 MiB
llama_kv_cache: upstream attention rotation disabled (TurboQuant uses kernel-level WHT)
llama_kv_cache: attn_rot_k = 0
llama_kv_cache: attn_rot_v = 0
llama_memory_recurrent: Vulkan0 RS buffer size = 20.94 MiB
llama_memory_recurrent: Vulkan2 RS buffer size = 41.88 MiB
llama_memory_recurrent: size = 62.81 MiB ( 1 cells, 40 layers, 1 seqs), R (f32): 2.81 MiB, S (f32): 60.00 MiB
llama_context: pipeline parallelism enabled
sched_reserve: reserving ...
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
sched_reserve: Vulkan2 compute buffer size = 1220.16 MiB
sched_reserve: Vulkan0 compute buffer size = 1014.07 MiB
sched_reserve: Vulkan_Host compute buffer size = 1008.08 MiB
sched_reserve: graph nodes = 3749
sched_reserve: graph splits = 3
sched_reserve: reserve took 1736.76 ms, sched copies = 4
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
srv load_model: initializing slots, n_slots = 1
common_speculative_is_compat: the target context does not support partial sequence removal
srv load_model: speculative decoding not supported by this context
slot load_model: id 0 | task -1 | new slot, n_ctx = 128000
srv load_model: prompt cache is enabled, size limit: 8192 MiB
srv load_model: use --cache-ram 0 to disable the prompt cache
srv load_model: for more info see https://github.com/ggml-org/pull/16391
srv init: init: --clear-idle requires --kv-unified, disabling
init: chat template, example_format: '<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there<|im_end|>
<|im_start|>user
How are you?<|im_end|>
<|im_start|>assistant

'
srv init: init: chat template, thinking = 1
main: model loaded
main: server is listening on http://0.0.0.0:8033
main: starting the main loop...
srv update_slots: all slots are idle

Details

however: turbo2 and turbo4 result in a coredump on a Ubuntu 24.04 instance

./build/bin/llama-server -m /media/aryan/nvme/models/llama.cpp/Qwen3.5-35B-A3B-Q4_K_S.gguf --jinja --device Vulkan2,Vulkan0 --host 0.0.0.0 -np 1 --port 8033 -c 128000 -ctk turbo2 -ctv turbo2 -fa 1 -ngl 99

Turbo2 log

system info: n_threads = 24, n_threads_batch = 24, total_threads = 48

system_info: n_threads = 24 (n_threads_batch = 24) / 48 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 |

Running without SSL
init: using 47 threads for HTTP server
start: binding port with default address family
main: loading model
srv load_model: loading model '/media/aryan/nvme/models/llama.cpp/Qwen3.5-35B-A3B-Q4_K_S.gguf'
common_init_result: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on
/media/aryan/nvme/llama-cpp-turboquant/ggml/src/ggml-backend.cpp:809: pre-allocated tensor (cache_k_l3 (view)) in a buffer (Vulkan2) that cannot run the operation (SET_ROWS)
[New LWP 81708]
[New LWP 81705]
[New LWP 81703]
[New LWP 81702]
[New LWP 81701]
[New LWP 81679]
[New LWP 81678]
[New LWP 81677]
[New LWP 81676]
[New LWP 81675]
[New LWP 81674]
[New LWP 81673]
[New LWP 81672]
[New LWP 81671]
[New LWP 81670]
[New LWP 81669]
[New LWP 81668]
[New LWP 81667]
[New LWP 81666]
[New LWP 81665]
[New LWP 81664]
[New LWP 81663]
[New LWP 81662]
[New LWP 81661]
[New LWP 81660]
[New LWP 81659]
[New LWP 81658]
[New LWP 81657]
[New LWP 81656]
[New LWP 81655]
[New LWP 81654]
[New LWP 81653]
[New LWP 81652]
[New LWP 81651]
[New LWP 81650]
[New LWP 81649]
[New LWP 81648]
[New LWP 81647]
[New LWP 81646]
[New LWP 81645]
[New LWP 81644]
[New LWP 81643]
[New LWP 81642]
[New LWP 81641]
[New LWP 81640]
[New LWP 81639]
[New LWP 81638]
[New LWP 81637]
[New LWP 81636]
[New LWP 81635]
[New LWP 81634]
[New LWP 81633]
[New LWP 81632]
[New LWP 81631]
[New LWP 81629]
[New LWP 81628]

This GDB supports auto-downloading debuginfo from the following URLs:
https://debuginfod.ubuntu.com
Enable debuginfod for this session? (y or [n]) [answered N; input not from terminal]
Debuginfod has been disabled.
To make this setting permanent, add 'set debuginfod enabled off' to .gdbinit.
warning: could not find '.gnu_debugaltlink' file for /lib/x86_64-linux-gnu/libvulkan_virtio.so
warning: could not find '.gnu_debugaltlink' file for /lib/x86_64-linux-gnu/libvulkan_intel.so
warning: could not find '.gnu_debugaltlink' file for /lib/x86_64-linux-gnu/libvulkan_asahi.so
warning: could not find '.gnu_debugaltlink' file for /lib/x86_64-linux-gnu/libvulkan_lvp.so
warning: could not find '.gnu_debugaltlink' file for /lib/x86_64-linux-gnu/libtinfo.so.6
warning: could not find '.gnu_debugaltlink' file for /lib/x86_64-linux-gnu/libvulkan_intel_hasvk.so
warning: could not find '.gnu_debugaltlink' file for /lib/x86_64-linux-gnu/libvulkan_gfxstream.so
warning: could not find '.gnu_debugaltlink' file for /lib/x86_64-linux-gnu/libvulkan_radeon.so
warning: could not find '.gnu_debugaltlink' file for /lib/x86_64-linux-gnu/libvulkan_nouveau.so
warning: could not find '.gnu_debugaltlink' file for /lib/x86_64-linux-gnu/libVkLayer_MESA_device_select.so
warning: could not find '.gnu_debugaltlink' file for /lib/x86_64-linux-gnu/libcap.so.2
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
0x00007d3025510813 in __GI___wait4 (pid=81709, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
warning: 30 ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory
#0 0x00007d3025510813 in __GI___wait4 (pid=81709, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30 in ../sysdeps/unix/sysv/linux/wait4.c
#1 0x00007d3026050933 in ggml_print_backtrace () from /media/aryan/nvme/llama-cpp-turboquant/build/bin/libggml-base.so.0
#2 0x00007d3026050adb in ggml_abort () from /media/aryan/nvme/llama-cpp-turboquant/build/bin/libggml-base.so.0
#3 0x00007d3026068c38 in ggml_backend_sched_backend_id_from_cur(ggml_backend_sched*, ggml_tensor*) () from /media/aryan/nvme/llama-cpp-turboquant/build/bin/libggml-base.so.0
#4 0x00007d302606aca7 in ggml_backend_sched_split_graph () from /media/aryan/nvme/llama-cpp-turboquant/build/bin/libggml-base.so.0
#5 0x00007d3025ccccf7 in llama_context::graph_reserve(unsigned int, unsigned int, unsigned int, llama_memory_context_i const*, bool, unsigned long*) () from /media/aryan/nvme/llama-cpp-turboquant/build/bin/libllama.so.0
#6 0x00007d3025ccdd80 in llama_context::sched_reserve() () from /media/aryan/nvme/llama-cpp-turboquant/build/bin/libllama.so.0
#7 0x00007d3025cd18c9 in llama_context::llama_context(llama_model const&, llama_context_params) () from /media/aryan/nvme/llama-cpp-turboquant/build/bin/libllama.so.0
#8 0x00007d3025cd243b in llama_init_from_model () from /media/aryan/nvme/llama-cpp-turboquant/build/bin/libllama.so.0
#9 0x00007d3025ca30a0 in llama_get_device_memory_data(char const*, llama_model_params const*, llama_context_params const*, std::vector<ggml_backend_device*, std::allocator<ggml_backend_device*> >&, unsigned int&, unsigned int&, unsigned int&, ggml_log_level) () from /media/aryan/nvme/llama-cpp-turboquant/build/bin/libllama.so.0
#10 0x00007d3025ca3e0b in llama_params_fit_impl(char const*, llama_model_params*, llama_context_params*, float*, llama_model_tensor_buft_override*, unsigned long*, unsigned int, ggml_log_level) () from /media/aryan/nvme/llama-cpp-turboquant/build/bin/libllama.so.0
#11 0x00007d3025ca74b2 in llama_params_fit () from /media/aryan/nvme/llama-cpp-turboquant/build/bin/libllama.so.0
#12 0x00005e5ce19fe967 in common_init_result::common_init_result(common_params&) ()
#13 0x00005e5ce19ffbdc in common_init_from_params(common_params&) ()
#14 0x00005e5ce191724e in server_context_impl::load_model(common_params const&) ()
#15 0x00005e5ce1863239 in main ()
[Inferior 1 (process 81608) detached]
Aborted (core dumped)

@Titaniumtown
Copy link
Copy Markdown
Author

@twobombs you posted a long log, is there something you're trying to say? Please put it in a collapsible block.

winkay2000 pushed a commit to winkay2000/llama-cpp-turboquant that referenced this pull request Apr 14, 2026
Head-to-head benchmarks vs TheTom and Duster. Key finding: TBQ is
accidentally 1-bit quantization with temperature scaling. 16 new
experiment action items from analysis.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
winkay2000 pushed a commit to winkay2000/llama-cpp-turboquant that referenced this pull request Apr 14, 2026
5-14% PPL improvement at ALL context lengths for both 3-bit and 2-bit TCQ.
Multiplies stored norm by 1.2 to sharpen attention logits.
Beats every competitor at every context length at both bit rates.
Override via TURBO_TCQ_ALPHA env var.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@TheTom TheTom force-pushed the feature/turboquant-kv-cache branch from 45f8a06 to 1073622 Compare April 16, 2026 01:14
@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Apr 20, 2026

Hey @Titaniumtown — we went through the open PRs today and want to merge this, but it needs a rebase onto feature/turboquant-kv-cache current tip. There are conflicts in ggml-vulkan.cpp (SET_ROWS macro signature changed upstream, plus new types like Q1_0/NVFP4 on HEAD) and vulkan-shaders-gen.cpp.

Could you rebase onto the latest feature/turboquant-kv-cache? Once it's clean we'll pull, build, and merge. Thanks!

Adds Vulkan shader support for TQ4_1S (4-bit WHT-rotated weight
compression with 16 Lloyd-Max centroids, 32-element blocks).

Shaders:
- dequant_tq4_1s.comp: standalone dequant with WHT inverse via
  subgroupShuffleXor (32-thread workgroup, 5-stage butterfly)
- mul_mat_vec_tq4_1s.comp: specialized MUL_MAT_VEC with inline
  activation pre-rotation (forward RHT on activation, centroid*scale
  dequant without inverse RHT)
- copy_from_quant.comp: TQ4_1S dequant path with full WHT inverse
- copy_to_quant.comp: TQ4_1S SET_ROWS quantization path with forward
  RHT, dual half-block RMS scales, 16-centroid quantization
- types.glsl: block_tq4_1s struct (d0, d1, qs[16])
- dequant_funcs.glsl: TQ4_1S centroid*scale dequant (no RHT)

Pipeline wiring (ggml-vulkan.cpp):
- MUL_MAT, SET_ROWS, CPY supports_op
- pipeline_dequant, pipeline_set_rows, pipeline_cpy_quant_f32
- Specialized MUL_MAT_VEC with forced subgroup workgroup size

Tests:
- test_set_rows_tq4_1s: SET_ROWS round-trip validation
Adds a specialised MUL_MAT_VEC shader for TQ4_1S weights so the
per-decode-step matrix-vector product no longer has to dequant the
full weight tensor to f16 and then go through the generic matmul
path.  The kernel pre-rotates the activation via a forward
Walsh-Hadamard Transform in shared memory and dot-products against
the raw centroid*scale stored weights, folding the inverse-WHT on
the weight side into the activation by the symmetry H = H^T.

Math:
  w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k]
  sum_k w[k] * a[k] = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j]

Portability choices:

- Workgroup size is pinned to 32 threads regardless of the
  DMMV_WG_SIZE bucket the rest of the mul_mat_vec family picks for
  the current architecture.  The butterfly operates on 32-element
  blocks with one element per thread; that contract is fixed by the
  quantization format, not by the GPU.  Earlier revisions used
  `gl_WorkGroupSize.x` as the stride unit, which silently skipped
  half the work on Intel drivers that force the subgroup to 16
  (tests passed via NMSE tolerance while real inference output was
  garbage).

- Butterfly implementation is shared memory only.  A subgroup-shuffle
  variant (`subgroupShuffleXor`) was prototyped and measured on Intel
  Arc A380 with Mesa Xe HPG: it ran ~60-85 %% slower than the
  explicit shared-memory butterfly, because Mesa emulates subgroup
  shuffles via LDS and ends up doing the same LDS traffic with extra
  driver overhead.  The shared-memory butterfly is correct on every
  device regardless of subgroup-op support, is the fastest path on
  every device we can actually measure, and leaves the
  `pipeline_dequant_mul_mat_vec_f32_f32[w][TQ4_1S]` slot uniform
  across all DMMV_WG_SIZE buckets.

- Reduction is the shared-memory tree reduction (no subgroupAdd), for
  the same reason: on Intel Arc the subgroupAdd is also LDS-backed
  and the hybrid reduction path was measurably slower.  Future
  vendor-specific heuristics can switch to the hybrid or pure-subgroup
  reduction variants on NVIDIA / AMD RDNA if hardware subgroup ops
  turn out to beat the LDS roundtrip there; the existing reduction
  modes in `mul_mat_vec_base.glsl` already provide the necessary
  variants.

- NUM_ROWS is 8 so the butterfly cost amortises across 8 output rows
  per workgroup.  Each thread holds one position of each of the 8
  weight blocks and pairs them with the shared rotated activation.

- `mul_mm` and `flash_attn_cm2` shader generation is skipped for
  TQ4_1S because it is a weight-only format that never reaches the
  coopmat2 matmul or the KV cache flash-attention paths.

Tests:

- `test-backend-ops` MUL_MAT tolerance tightened from 2.0 to 0.01
  NMSE so real defects can't hide behind a loose check.
- Added Gemma-4 E2B, Qwen, Phi and Llama dimensional coverage
  (k in {1536, 2048, 2304, 3072, 4096}, m in {256, 1152, 1536,
  2048, 5120, 6144}, n in {1..8, 16, 64, 256}).  148 MUL_MAT test
  cases total.

Verification (Intel Arc A380, 6 GB VRAM, Vulkan ANV / Mesa Xe HPG,
`llama-bench -p 512 -n 128 -r 3` and `llama-perplexity -c 512
--chunks 20 wiki.test.raw`):

| Model         | Config  |     Size  | Reduction | PPL Δ  | pp512/Q8 | tg128/Q8 |
|---------------|---------|----------:|----------:|-------:|---------:|---------:|
| Qwen2.5-1.5B  | I       | 1570→1082 |   -31.1%  | +4.66% |    53.9% |   107.5% |
| Phi-3.5-mini  | I       | 3873→2839 |   -26.7%  | +5.36% |    57.6% |    52.8% |
| Llama-3.2-3B  | hybrid  | 3263→2147 |   -34.2%  | +2.03% |    82.4% |    84.2% |
| Llama-3.2-3B  | premium | 3263→2577 |   -21.0%  | +0.98% |    71.3% |    67.3% |

Qwen2.5-1.5B is faster than its own Q8_0 baseline with Config I:
the compressed model fits in less VRAM, and on a small model the
TQ4_1S compute cost is offset by the reduced memory traffic.

All four models produce coherent output end-to-end and the
reductions line up with the TurboQuant paper's validation matrix
(§5.8).  The remaining gap to Q8_0 on the bigger models is
compute-bound on the A380; it closes further on GPUs with more raw
throughput.
Splits the dequant+accumulate phase into two sub-loops:

  1. Pre-compute w_vals[n] for all NUM_ROWS rows (centroid lookup +
     scale multiply, reads from weight buffer only).
  2. Read the rotated activation from shared memory ONCE per column,
     then FMA across all rows in a tight register loop.

This is the Vulkan analogue of the 'hot loop load dedup' from the
CUDA kernel (PR TheTom#57 optimisation TheTom#2).  It makes the shared memory
read explicitly loop-invariant across rows, which helps compilers
that don't auto-hoist LDS loads out of unrolled loops.

Measured effect on Intel Arc A380 (Llama-3.2-3B premium,
llama-bench tg128, r=5): 15.50 -> 15.78 t/s (+1.8%, within noise
but not a regression).  The structure is cleaner regardless and
should benefit architectures with higher LDS latency.
@Titaniumtown
Copy link
Copy Markdown
Author

@TheTom done!

@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Apr 20, 2026

will add it to my test queue thank you

@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Apr 20, 2026

Tested — Metal builds clean, turbo3 KV passes sanity check, CUDA changes are additive only (f16+turbo mixed FA instances). No regressions. Merging, thank you @Titaniumtown!

@TheTom TheTom merged commit 8ba9f12 into TheTom:feature/turboquant-kv-cache Apr 20, 2026
22 of 50 checks passed
TheTom pushed a commit that referenced this pull request Apr 22, 2026
* vulkan: add TQ4_1S weight compression support

Adds Vulkan shader support for TQ4_1S (4-bit WHT-rotated weight
compression with 16 Lloyd-Max centroids, 32-element blocks).

Shaders:
- dequant_tq4_1s.comp: standalone dequant with WHT inverse via
  subgroupShuffleXor (32-thread workgroup, 5-stage butterfly)
- mul_mat_vec_tq4_1s.comp: specialized MUL_MAT_VEC with inline
  activation pre-rotation (forward RHT on activation, centroid*scale
  dequant without inverse RHT)
- copy_from_quant.comp: TQ4_1S dequant path with full WHT inverse
- copy_to_quant.comp: TQ4_1S SET_ROWS quantization path with forward
  RHT, dual half-block RMS scales, 16-centroid quantization
- types.glsl: block_tq4_1s struct (d0, d1, qs[16])
- dequant_funcs.glsl: TQ4_1S centroid*scale dequant (no RHT)

Pipeline wiring (ggml-vulkan.cpp):
- MUL_MAT, SET_ROWS, CPY supports_op
- pipeline_dequant, pipeline_set_rows, pipeline_cpy_quant_f32
- Specialized MUL_MAT_VEC with forced subgroup workgroup size

Tests:
- test_set_rows_tq4_1s: SET_ROWS round-trip validation

* vulkan: add fused mul_mat_vec kernel for TQ4_1S

Adds a specialised MUL_MAT_VEC shader for TQ4_1S weights so the
per-decode-step matrix-vector product no longer has to dequant the
full weight tensor to f16 and then go through the generic matmul
path.  The kernel pre-rotates the activation via a forward
Walsh-Hadamard Transform in shared memory and dot-products against
the raw centroid*scale stored weights, folding the inverse-WHT on
the weight side into the activation by the symmetry H = H^T.

Math:
  w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k]
  sum_k w[k] * a[k] = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j]

Portability choices:

- Workgroup size is pinned to 32 threads regardless of the
  DMMV_WG_SIZE bucket the rest of the mul_mat_vec family picks for
  the current architecture.  The butterfly operates on 32-element
  blocks with one element per thread; that contract is fixed by the
  quantization format, not by the GPU.  Earlier revisions used
  `gl_WorkGroupSize.x` as the stride unit, which silently skipped
  half the work on Intel drivers that force the subgroup to 16
  (tests passed via NMSE tolerance while real inference output was
  garbage).

- Butterfly implementation is shared memory only.  A subgroup-shuffle
  variant (`subgroupShuffleXor`) was prototyped and measured on Intel
  Arc A380 with Mesa Xe HPG: it ran ~60-85 %% slower than the
  explicit shared-memory butterfly, because Mesa emulates subgroup
  shuffles via LDS and ends up doing the same LDS traffic with extra
  driver overhead.  The shared-memory butterfly is correct on every
  device regardless of subgroup-op support, is the fastest path on
  every device we can actually measure, and leaves the
  `pipeline_dequant_mul_mat_vec_f32_f32[w][TQ4_1S]` slot uniform
  across all DMMV_WG_SIZE buckets.

- Reduction is the shared-memory tree reduction (no subgroupAdd), for
  the same reason: on Intel Arc the subgroupAdd is also LDS-backed
  and the hybrid reduction path was measurably slower.  Future
  vendor-specific heuristics can switch to the hybrid or pure-subgroup
  reduction variants on NVIDIA / AMD RDNA if hardware subgroup ops
  turn out to beat the LDS roundtrip there; the existing reduction
  modes in `mul_mat_vec_base.glsl` already provide the necessary
  variants.

- NUM_ROWS is 8 so the butterfly cost amortises across 8 output rows
  per workgroup.  Each thread holds one position of each of the 8
  weight blocks and pairs them with the shared rotated activation.

- `mul_mm` and `flash_attn_cm2` shader generation is skipped for
  TQ4_1S because it is a weight-only format that never reaches the
  coopmat2 matmul or the KV cache flash-attention paths.

Tests:

- `test-backend-ops` MUL_MAT tolerance tightened from 2.0 to 0.01
  NMSE so real defects can't hide behind a loose check.
- Added Gemma-4 E2B, Qwen, Phi and Llama dimensional coverage
  (k in {1536, 2048, 2304, 3072, 4096}, m in {256, 1152, 1536,
  2048, 5120, 6144}, n in {1..8, 16, 64, 256}).  148 MUL_MAT test
  cases total.

Verification (Intel Arc A380, 6 GB VRAM, Vulkan ANV / Mesa Xe HPG,
`llama-bench -p 512 -n 128 -r 3` and `llama-perplexity -c 512
--chunks 20 wiki.test.raw`):

| Model         | Config  |     Size  | Reduction | PPL Δ  | pp512/Q8 | tg128/Q8 |
|---------------|---------|----------:|----------:|-------:|---------:|---------:|
| Qwen2.5-1.5B  | I       | 1570→1082 |   -31.1%  | +4.66% |    53.9% |   107.5% |
| Phi-3.5-mini  | I       | 3873→2839 |   -26.7%  | +5.36% |    57.6% |    52.8% |
| Llama-3.2-3B  | hybrid  | 3263→2147 |   -34.2%  | +2.03% |    82.4% |    84.2% |
| Llama-3.2-3B  | premium | 3263→2577 |   -21.0%  | +0.98% |    71.3% |    67.3% |

Qwen2.5-1.5B is faster than its own Q8_0 baseline with Config I:
the compressed model fits in less VRAM, and on a small model the
TQ4_1S compute cost is offset by the reduced memory traffic.

All four models produce coherent output end-to-end and the
reductions line up with the TurboQuant paper's validation matrix
(§5.8).  The remaining gap to Q8_0 on the bigger models is
compute-bound on the A380; it closes further on GPUs with more raw
throughput.

* vulkan: restructure TQ4_1S inner loop for cross-row smem reuse

Splits the dequant+accumulate phase into two sub-loops:

  1. Pre-compute w_vals[n] for all NUM_ROWS rows (centroid lookup +
     scale multiply, reads from weight buffer only).
  2. Read the rotated activation from shared memory ONCE per column,
     then FMA across all rows in a tight register loop.

This is the Vulkan analogue of the 'hot loop load dedup' from the
CUDA kernel (PR #57 optimisation #2).  It makes the shared memory
read explicitly loop-invariant across rows, which helps compilers
that don't auto-hoist LDS loads out of unrolled loops.

Measured effect on Intel Arc A380 (Llama-3.2-3B premium,
llama-bench tg128, r=5): 15.50 -> 15.78 t/s (+1.8%, within noise
but not a regression).  The structure is cleaner regardless and
should benefit architectures with higher LDS latency.
jimbothigpen pushed a commit to jimbothigpen/frankenturbo2 that referenced this pull request May 2, 2026
* vulkan: add TQ4_1S weight compression support

Adds Vulkan shader support for TQ4_1S (4-bit WHT-rotated weight
compression with 16 Lloyd-Max centroids, 32-element blocks).

Shaders:
- dequant_tq4_1s.comp: standalone dequant with WHT inverse via
  subgroupShuffleXor (32-thread workgroup, 5-stage butterfly)
- mul_mat_vec_tq4_1s.comp: specialized MUL_MAT_VEC with inline
  activation pre-rotation (forward RHT on activation, centroid*scale
  dequant without inverse RHT)
- copy_from_quant.comp: TQ4_1S dequant path with full WHT inverse
- copy_to_quant.comp: TQ4_1S SET_ROWS quantization path with forward
  RHT, dual half-block RMS scales, 16-centroid quantization
- types.glsl: block_tq4_1s struct (d0, d1, qs[16])
- dequant_funcs.glsl: TQ4_1S centroid*scale dequant (no RHT)

Pipeline wiring (ggml-vulkan.cpp):
- MUL_MAT, SET_ROWS, CPY supports_op
- pipeline_dequant, pipeline_set_rows, pipeline_cpy_quant_f32
- Specialized MUL_MAT_VEC with forced subgroup workgroup size

Tests:
- test_set_rows_tq4_1s: SET_ROWS round-trip validation

* vulkan: add fused mul_mat_vec kernel for TQ4_1S

Adds a specialised MUL_MAT_VEC shader for TQ4_1S weights so the
per-decode-step matrix-vector product no longer has to dequant the
full weight tensor to f16 and then go through the generic matmul
path.  The kernel pre-rotates the activation via a forward
Walsh-Hadamard Transform in shared memory and dot-products against
the raw centroid*scale stored weights, folding the inverse-WHT on
the weight side into the activation by the symmetry H = H^T.

Math:
  w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k]
  sum_k w[k] * a[k] = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j]

Portability choices:

- Workgroup size is pinned to 32 threads regardless of the
  DMMV_WG_SIZE bucket the rest of the mul_mat_vec family picks for
  the current architecture.  The butterfly operates on 32-element
  blocks with one element per thread; that contract is fixed by the
  quantization format, not by the GPU.  Earlier revisions used
  `gl_WorkGroupSize.x` as the stride unit, which silently skipped
  half the work on Intel drivers that force the subgroup to 16
  (tests passed via NMSE tolerance while real inference output was
  garbage).

- Butterfly implementation is shared memory only.  A subgroup-shuffle
  variant (`subgroupShuffleXor`) was prototyped and measured on Intel
  Arc A380 with Mesa Xe HPG: it ran ~60-85 %% slower than the
  explicit shared-memory butterfly, because Mesa emulates subgroup
  shuffles via LDS and ends up doing the same LDS traffic with extra
  driver overhead.  The shared-memory butterfly is correct on every
  device regardless of subgroup-op support, is the fastest path on
  every device we can actually measure, and leaves the
  `pipeline_dequant_mul_mat_vec_f32_f32[w][TQ4_1S]` slot uniform
  across all DMMV_WG_SIZE buckets.

- Reduction is the shared-memory tree reduction (no subgroupAdd), for
  the same reason: on Intel Arc the subgroupAdd is also LDS-backed
  and the hybrid reduction path was measurably slower.  Future
  vendor-specific heuristics can switch to the hybrid or pure-subgroup
  reduction variants on NVIDIA / AMD RDNA if hardware subgroup ops
  turn out to beat the LDS roundtrip there; the existing reduction
  modes in `mul_mat_vec_base.glsl` already provide the necessary
  variants.

- NUM_ROWS is 8 so the butterfly cost amortises across 8 output rows
  per workgroup.  Each thread holds one position of each of the 8
  weight blocks and pairs them with the shared rotated activation.

- `mul_mm` and `flash_attn_cm2` shader generation is skipped for
  TQ4_1S because it is a weight-only format that never reaches the
  coopmat2 matmul or the KV cache flash-attention paths.

Tests:

- `test-backend-ops` MUL_MAT tolerance tightened from 2.0 to 0.01
  NMSE so real defects can't hide behind a loose check.
- Added Gemma-4 E2B, Qwen, Phi and Llama dimensional coverage
  (k in {1536, 2048, 2304, 3072, 4096}, m in {256, 1152, 1536,
  2048, 5120, 6144}, n in {1..8, 16, 64, 256}).  148 MUL_MAT test
  cases total.

Verification (Intel Arc A380, 6 GB VRAM, Vulkan ANV / Mesa Xe HPG,
`llama-bench -p 512 -n 128 -r 3` and `llama-perplexity -c 512
--chunks 20 wiki.test.raw`):

| Model         | Config  |     Size  | Reduction | PPL Δ  | pp512/Q8 | tg128/Q8 |
|---------------|---------|----------:|----------:|-------:|---------:|---------:|
| Qwen2.5-1.5B  | I       | 1570→1082 |   -31.1%  | +4.66% |    53.9% |   107.5% |
| Phi-3.5-mini  | I       | 3873→2839 |   -26.7%  | +5.36% |    57.6% |    52.8% |
| Llama-3.2-3B  | hybrid  | 3263→2147 |   -34.2%  | +2.03% |    82.4% |    84.2% |
| Llama-3.2-3B  | premium | 3263→2577 |   -21.0%  | +0.98% |    71.3% |    67.3% |

Qwen2.5-1.5B is faster than its own Q8_0 baseline with Config I:
the compressed model fits in less VRAM, and on a small model the
TQ4_1S compute cost is offset by the reduced memory traffic.

All four models produce coherent output end-to-end and the
reductions line up with the TurboQuant paper's validation matrix
(§5.8).  The remaining gap to Q8_0 on the bigger models is
compute-bound on the A380; it closes further on GPUs with more raw
throughput.

* vulkan: restructure TQ4_1S inner loop for cross-row smem reuse

Splits the dequant+accumulate phase into two sub-loops:

  1. Pre-compute w_vals[n] for all NUM_ROWS rows (centroid lookup +
     scale multiply, reads from weight buffer only).
  2. Read the rotated activation from shared memory ONCE per column,
     then FMA across all rows in a tight register loop.

This is the Vulkan analogue of the 'hot loop load dedup' from the
CUDA kernel (PR TheTom#57 optimisation TheTom#2).  It makes the shared memory
read explicitly loop-invariant across rows, which helps compilers
that don't auto-hoist LDS loads out of unrolled loops.

Measured effect on Intel Arc A380 (Llama-3.2-3B premium,
llama-bench tg128, r=5): 15.50 -> 15.78 t/s (+1.8%, within noise
but not a regression).  The structure is cleaner regardless and
should benefit architectures with higher LDS latency.
jimbothigpen pushed a commit to jimbothigpen/frankenturbo2 that referenced this pull request May 2, 2026
* vulkan: add TQ4_1S weight compression support

Adds Vulkan shader support for TQ4_1S (4-bit WHT-rotated weight
compression with 16 Lloyd-Max centroids, 32-element blocks).

Shaders:
- dequant_tq4_1s.comp: standalone dequant with WHT inverse via
  subgroupShuffleXor (32-thread workgroup, 5-stage butterfly)
- mul_mat_vec_tq4_1s.comp: specialized MUL_MAT_VEC with inline
  activation pre-rotation (forward RHT on activation, centroid*scale
  dequant without inverse RHT)
- copy_from_quant.comp: TQ4_1S dequant path with full WHT inverse
- copy_to_quant.comp: TQ4_1S SET_ROWS quantization path with forward
  RHT, dual half-block RMS scales, 16-centroid quantization
- types.glsl: block_tq4_1s struct (d0, d1, qs[16])
- dequant_funcs.glsl: TQ4_1S centroid*scale dequant (no RHT)

Pipeline wiring (ggml-vulkan.cpp):
- MUL_MAT, SET_ROWS, CPY supports_op
- pipeline_dequant, pipeline_set_rows, pipeline_cpy_quant_f32
- Specialized MUL_MAT_VEC with forced subgroup workgroup size

Tests:
- test_set_rows_tq4_1s: SET_ROWS round-trip validation

* vulkan: add fused mul_mat_vec kernel for TQ4_1S

Adds a specialised MUL_MAT_VEC shader for TQ4_1S weights so the
per-decode-step matrix-vector product no longer has to dequant the
full weight tensor to f16 and then go through the generic matmul
path.  The kernel pre-rotates the activation via a forward
Walsh-Hadamard Transform in shared memory and dot-products against
the raw centroid*scale stored weights, folding the inverse-WHT on
the weight side into the activation by the symmetry H = H^T.

Math:
  w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k]
  sum_k w[k] * a[k] = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j]

Portability choices:

- Workgroup size is pinned to 32 threads regardless of the
  DMMV_WG_SIZE bucket the rest of the mul_mat_vec family picks for
  the current architecture.  The butterfly operates on 32-element
  blocks with one element per thread; that contract is fixed by the
  quantization format, not by the GPU.  Earlier revisions used
  `gl_WorkGroupSize.x` as the stride unit, which silently skipped
  half the work on Intel drivers that force the subgroup to 16
  (tests passed via NMSE tolerance while real inference output was
  garbage).

- Butterfly implementation is shared memory only.  A subgroup-shuffle
  variant (`subgroupShuffleXor`) was prototyped and measured on Intel
  Arc A380 with Mesa Xe HPG: it ran ~60-85 %% slower than the
  explicit shared-memory butterfly, because Mesa emulates subgroup
  shuffles via LDS and ends up doing the same LDS traffic with extra
  driver overhead.  The shared-memory butterfly is correct on every
  device regardless of subgroup-op support, is the fastest path on
  every device we can actually measure, and leaves the
  `pipeline_dequant_mul_mat_vec_f32_f32[w][TQ4_1S]` slot uniform
  across all DMMV_WG_SIZE buckets.

- Reduction is the shared-memory tree reduction (no subgroupAdd), for
  the same reason: on Intel Arc the subgroupAdd is also LDS-backed
  and the hybrid reduction path was measurably slower.  Future
  vendor-specific heuristics can switch to the hybrid or pure-subgroup
  reduction variants on NVIDIA / AMD RDNA if hardware subgroup ops
  turn out to beat the LDS roundtrip there; the existing reduction
  modes in `mul_mat_vec_base.glsl` already provide the necessary
  variants.

- NUM_ROWS is 8 so the butterfly cost amortises across 8 output rows
  per workgroup.  Each thread holds one position of each of the 8
  weight blocks and pairs them with the shared rotated activation.

- `mul_mm` and `flash_attn_cm2` shader generation is skipped for
  TQ4_1S because it is a weight-only format that never reaches the
  coopmat2 matmul or the KV cache flash-attention paths.

Tests:

- `test-backend-ops` MUL_MAT tolerance tightened from 2.0 to 0.01
  NMSE so real defects can't hide behind a loose check.
- Added Gemma-4 E2B, Qwen, Phi and Llama dimensional coverage
  (k in {1536, 2048, 2304, 3072, 4096}, m in {256, 1152, 1536,
  2048, 5120, 6144}, n in {1..8, 16, 64, 256}).  148 MUL_MAT test
  cases total.

Verification (Intel Arc A380, 6 GB VRAM, Vulkan ANV / Mesa Xe HPG,
`llama-bench -p 512 -n 128 -r 3` and `llama-perplexity -c 512
--chunks 20 wiki.test.raw`):

| Model         | Config  |     Size  | Reduction | PPL Δ  | pp512/Q8 | tg128/Q8 |
|---------------|---------|----------:|----------:|-------:|---------:|---------:|
| Qwen2.5-1.5B  | I       | 1570→1082 |   -31.1%  | +4.66% |    53.9% |   107.5% |
| Phi-3.5-mini  | I       | 3873→2839 |   -26.7%  | +5.36% |    57.6% |    52.8% |
| Llama-3.2-3B  | hybrid  | 3263→2147 |   -34.2%  | +2.03% |    82.4% |    84.2% |
| Llama-3.2-3B  | premium | 3263→2577 |   -21.0%  | +0.98% |    71.3% |    67.3% |

Qwen2.5-1.5B is faster than its own Q8_0 baseline with Config I:
the compressed model fits in less VRAM, and on a small model the
TQ4_1S compute cost is offset by the reduced memory traffic.

All four models produce coherent output end-to-end and the
reductions line up with the TurboQuant paper's validation matrix
(§5.8).  The remaining gap to Q8_0 on the bigger models is
compute-bound on the A380; it closes further on GPUs with more raw
throughput.

* vulkan: restructure TQ4_1S inner loop for cross-row smem reuse

Splits the dequant+accumulate phase into two sub-loops:

  1. Pre-compute w_vals[n] for all NUM_ROWS rows (centroid lookup +
     scale multiply, reads from weight buffer only).
  2. Read the rotated activation from shared memory ONCE per column,
     then FMA across all rows in a tight register loop.

This is the Vulkan analogue of the 'hot loop load dedup' from the
CUDA kernel (PR TheTom#57 optimisation TheTom#2).  It makes the shared memory
read explicitly loop-invariant across rows, which helps compilers
that don't auto-hoist LDS loads out of unrolled loops.

Measured effect on Intel Arc A380 (Llama-3.2-3B premium,
llama-bench tg128, r=5): 15.50 -> 15.78 t/s (+1.8%, within noise
but not a regression).  The structure is cleaner regardless and
should benefit architectures with higher LDS latency.
sbaier1 pushed a commit to sbaier1/llama-cpp-turboquant that referenced this pull request May 8, 2026
* vulkan: add TQ4_1S weight compression support

Adds Vulkan shader support for TQ4_1S (4-bit WHT-rotated weight
compression with 16 Lloyd-Max centroids, 32-element blocks).

Shaders:
- dequant_tq4_1s.comp: standalone dequant with WHT inverse via
  subgroupShuffleXor (32-thread workgroup, 5-stage butterfly)
- mul_mat_vec_tq4_1s.comp: specialized MUL_MAT_VEC with inline
  activation pre-rotation (forward RHT on activation, centroid*scale
  dequant without inverse RHT)
- copy_from_quant.comp: TQ4_1S dequant path with full WHT inverse
- copy_to_quant.comp: TQ4_1S SET_ROWS quantization path with forward
  RHT, dual half-block RMS scales, 16-centroid quantization
- types.glsl: block_tq4_1s struct (d0, d1, qs[16])
- dequant_funcs.glsl: TQ4_1S centroid*scale dequant (no RHT)

Pipeline wiring (ggml-vulkan.cpp):
- MUL_MAT, SET_ROWS, CPY supports_op
- pipeline_dequant, pipeline_set_rows, pipeline_cpy_quant_f32
- Specialized MUL_MAT_VEC with forced subgroup workgroup size

Tests:
- test_set_rows_tq4_1s: SET_ROWS round-trip validation

* vulkan: add fused mul_mat_vec kernel for TQ4_1S

Adds a specialised MUL_MAT_VEC shader for TQ4_1S weights so the
per-decode-step matrix-vector product no longer has to dequant the
full weight tensor to f16 and then go through the generic matmul
path.  The kernel pre-rotates the activation via a forward
Walsh-Hadamard Transform in shared memory and dot-products against
the raw centroid*scale stored weights, folding the inverse-WHT on
the weight side into the activation by the symmetry H = H^T.

Math:
  w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k]
  sum_k w[k] * a[k] = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j]

Portability choices:

- Workgroup size is pinned to 32 threads regardless of the
  DMMV_WG_SIZE bucket the rest of the mul_mat_vec family picks for
  the current architecture.  The butterfly operates on 32-element
  blocks with one element per thread; that contract is fixed by the
  quantization format, not by the GPU.  Earlier revisions used
  `gl_WorkGroupSize.x` as the stride unit, which silently skipped
  half the work on Intel drivers that force the subgroup to 16
  (tests passed via NMSE tolerance while real inference output was
  garbage).

- Butterfly implementation is shared memory only.  A subgroup-shuffle
  variant (`subgroupShuffleXor`) was prototyped and measured on Intel
  Arc A380 with Mesa Xe HPG: it ran ~60-85 %% slower than the
  explicit shared-memory butterfly, because Mesa emulates subgroup
  shuffles via LDS and ends up doing the same LDS traffic with extra
  driver overhead.  The shared-memory butterfly is correct on every
  device regardless of subgroup-op support, is the fastest path on
  every device we can actually measure, and leaves the
  `pipeline_dequant_mul_mat_vec_f32_f32[w][TQ4_1S]` slot uniform
  across all DMMV_WG_SIZE buckets.

- Reduction is the shared-memory tree reduction (no subgroupAdd), for
  the same reason: on Intel Arc the subgroupAdd is also LDS-backed
  and the hybrid reduction path was measurably slower.  Future
  vendor-specific heuristics can switch to the hybrid or pure-subgroup
  reduction variants on NVIDIA / AMD RDNA if hardware subgroup ops
  turn out to beat the LDS roundtrip there; the existing reduction
  modes in `mul_mat_vec_base.glsl` already provide the necessary
  variants.

- NUM_ROWS is 8 so the butterfly cost amortises across 8 output rows
  per workgroup.  Each thread holds one position of each of the 8
  weight blocks and pairs them with the shared rotated activation.

- `mul_mm` and `flash_attn_cm2` shader generation is skipped for
  TQ4_1S because it is a weight-only format that never reaches the
  coopmat2 matmul or the KV cache flash-attention paths.

Tests:

- `test-backend-ops` MUL_MAT tolerance tightened from 2.0 to 0.01
  NMSE so real defects can't hide behind a loose check.
- Added Gemma-4 E2B, Qwen, Phi and Llama dimensional coverage
  (k in {1536, 2048, 2304, 3072, 4096}, m in {256, 1152, 1536,
  2048, 5120, 6144}, n in {1..8, 16, 64, 256}).  148 MUL_MAT test
  cases total.

Verification (Intel Arc A380, 6 GB VRAM, Vulkan ANV / Mesa Xe HPG,
`llama-bench -p 512 -n 128 -r 3` and `llama-perplexity -c 512
--chunks 20 wiki.test.raw`):

| Model         | Config  |     Size  | Reduction | PPL Δ  | pp512/Q8 | tg128/Q8 |
|---------------|---------|----------:|----------:|-------:|---------:|---------:|
| Qwen2.5-1.5B  | I       | 1570→1082 |   -31.1%  | +4.66% |    53.9% |   107.5% |
| Phi-3.5-mini  | I       | 3873→2839 |   -26.7%  | +5.36% |    57.6% |    52.8% |
| Llama-3.2-3B  | hybrid  | 3263→2147 |   -34.2%  | +2.03% |    82.4% |    84.2% |
| Llama-3.2-3B  | premium | 3263→2577 |   -21.0%  | +0.98% |    71.3% |    67.3% |

Qwen2.5-1.5B is faster than its own Q8_0 baseline with Config I:
the compressed model fits in less VRAM, and on a small model the
TQ4_1S compute cost is offset by the reduced memory traffic.

All four models produce coherent output end-to-end and the
reductions line up with the TurboQuant paper's validation matrix
(§5.8).  The remaining gap to Q8_0 on the bigger models is
compute-bound on the A380; it closes further on GPUs with more raw
throughput.

* vulkan: restructure TQ4_1S inner loop for cross-row smem reuse

Splits the dequant+accumulate phase into two sub-loops:

  1. Pre-compute w_vals[n] for all NUM_ROWS rows (centroid lookup +
     scale multiply, reads from weight buffer only).
  2. Read the rotated activation from shared memory ONCE per column,
     then FMA across all rows in a tight register loop.

This is the Vulkan analogue of the 'hot loop load dedup' from the
CUDA kernel (PR TheTom#57 optimisation TheTom#2).  It makes the shared memory
read explicitly loop-invariant across rows, which helps compilers
that don't auto-hoist LDS loads out of unrolled loops.

Measured effect on Intel Arc A380 (Llama-3.2-3B premium,
llama-bench tg128, r=5): 15.50 -> 15.78 t/s (+1.8%, within noise
but not a regression).  The structure is cleaner regardless and
should benefit architectures with higher LDS latency.
sbaier1 pushed a commit to sbaier1/llama-cpp-turboquant that referenced this pull request May 8, 2026
* vulkan: add TQ4_1S weight compression support

Adds Vulkan shader support for TQ4_1S (4-bit WHT-rotated weight
compression with 16 Lloyd-Max centroids, 32-element blocks).

Shaders:
- dequant_tq4_1s.comp: standalone dequant with WHT inverse via
  subgroupShuffleXor (32-thread workgroup, 5-stage butterfly)
- mul_mat_vec_tq4_1s.comp: specialized MUL_MAT_VEC with inline
  activation pre-rotation (forward RHT on activation, centroid*scale
  dequant without inverse RHT)
- copy_from_quant.comp: TQ4_1S dequant path with full WHT inverse
- copy_to_quant.comp: TQ4_1S SET_ROWS quantization path with forward
  RHT, dual half-block RMS scales, 16-centroid quantization
- types.glsl: block_tq4_1s struct (d0, d1, qs[16])
- dequant_funcs.glsl: TQ4_1S centroid*scale dequant (no RHT)

Pipeline wiring (ggml-vulkan.cpp):
- MUL_MAT, SET_ROWS, CPY supports_op
- pipeline_dequant, pipeline_set_rows, pipeline_cpy_quant_f32
- Specialized MUL_MAT_VEC with forced subgroup workgroup size

Tests:
- test_set_rows_tq4_1s: SET_ROWS round-trip validation

* vulkan: add fused mul_mat_vec kernel for TQ4_1S

Adds a specialised MUL_MAT_VEC shader for TQ4_1S weights so the
per-decode-step matrix-vector product no longer has to dequant the
full weight tensor to f16 and then go through the generic matmul
path.  The kernel pre-rotates the activation via a forward
Walsh-Hadamard Transform in shared memory and dot-products against
the raw centroid*scale stored weights, folding the inverse-WHT on
the weight side into the activation by the symmetry H = H^T.

Math:
  w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k]
  sum_k w[k] * a[k] = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j]

Portability choices:

- Workgroup size is pinned to 32 threads regardless of the
  DMMV_WG_SIZE bucket the rest of the mul_mat_vec family picks for
  the current architecture.  The butterfly operates on 32-element
  blocks with one element per thread; that contract is fixed by the
  quantization format, not by the GPU.  Earlier revisions used
  `gl_WorkGroupSize.x` as the stride unit, which silently skipped
  half the work on Intel drivers that force the subgroup to 16
  (tests passed via NMSE tolerance while real inference output was
  garbage).

- Butterfly implementation is shared memory only.  A subgroup-shuffle
  variant (`subgroupShuffleXor`) was prototyped and measured on Intel
  Arc A380 with Mesa Xe HPG: it ran ~60-85 %% slower than the
  explicit shared-memory butterfly, because Mesa emulates subgroup
  shuffles via LDS and ends up doing the same LDS traffic with extra
  driver overhead.  The shared-memory butterfly is correct on every
  device regardless of subgroup-op support, is the fastest path on
  every device we can actually measure, and leaves the
  `pipeline_dequant_mul_mat_vec_f32_f32[w][TQ4_1S]` slot uniform
  across all DMMV_WG_SIZE buckets.

- Reduction is the shared-memory tree reduction (no subgroupAdd), for
  the same reason: on Intel Arc the subgroupAdd is also LDS-backed
  and the hybrid reduction path was measurably slower.  Future
  vendor-specific heuristics can switch to the hybrid or pure-subgroup
  reduction variants on NVIDIA / AMD RDNA if hardware subgroup ops
  turn out to beat the LDS roundtrip there; the existing reduction
  modes in `mul_mat_vec_base.glsl` already provide the necessary
  variants.

- NUM_ROWS is 8 so the butterfly cost amortises across 8 output rows
  per workgroup.  Each thread holds one position of each of the 8
  weight blocks and pairs them with the shared rotated activation.

- `mul_mm` and `flash_attn_cm2` shader generation is skipped for
  TQ4_1S because it is a weight-only format that never reaches the
  coopmat2 matmul or the KV cache flash-attention paths.

Tests:

- `test-backend-ops` MUL_MAT tolerance tightened from 2.0 to 0.01
  NMSE so real defects can't hide behind a loose check.
- Added Gemma-4 E2B, Qwen, Phi and Llama dimensional coverage
  (k in {1536, 2048, 2304, 3072, 4096}, m in {256, 1152, 1536,
  2048, 5120, 6144}, n in {1..8, 16, 64, 256}).  148 MUL_MAT test
  cases total.

Verification (Intel Arc A380, 6 GB VRAM, Vulkan ANV / Mesa Xe HPG,
`llama-bench -p 512 -n 128 -r 3` and `llama-perplexity -c 512
--chunks 20 wiki.test.raw`):

| Model         | Config  |     Size  | Reduction | PPL Δ  | pp512/Q8 | tg128/Q8 |
|---------------|---------|----------:|----------:|-------:|---------:|---------:|
| Qwen2.5-1.5B  | I       | 1570→1082 |   -31.1%  | +4.66% |    53.9% |   107.5% |
| Phi-3.5-mini  | I       | 3873→2839 |   -26.7%  | +5.36% |    57.6% |    52.8% |
| Llama-3.2-3B  | hybrid  | 3263→2147 |   -34.2%  | +2.03% |    82.4% |    84.2% |
| Llama-3.2-3B  | premium | 3263→2577 |   -21.0%  | +0.98% |    71.3% |    67.3% |

Qwen2.5-1.5B is faster than its own Q8_0 baseline with Config I:
the compressed model fits in less VRAM, and on a small model the
TQ4_1S compute cost is offset by the reduced memory traffic.

All four models produce coherent output end-to-end and the
reductions line up with the TurboQuant paper's validation matrix
(§5.8).  The remaining gap to Q8_0 on the bigger models is
compute-bound on the A380; it closes further on GPUs with more raw
throughput.

* vulkan: restructure TQ4_1S inner loop for cross-row smem reuse

Splits the dequant+accumulate phase into two sub-loops:

  1. Pre-compute w_vals[n] for all NUM_ROWS rows (centroid lookup +
     scale multiply, reads from weight buffer only).
  2. Read the rotated activation from shared memory ONCE per column,
     then FMA across all rows in a tight register loop.

This is the Vulkan analogue of the 'hot loop load dedup' from the
CUDA kernel (PR TheTom#57 optimisation TheTom#2).  It makes the shared memory
read explicitly loop-invariant across rows, which helps compilers
that don't auto-hoist LDS loads out of unrolled loops.

Measured effect on Intel Arc A380 (Llama-3.2-3B premium,
llama-bench tg128, r=5): 15.50 -> 15.78 t/s (+1.8%, within noise
but not a regression).  The structure is cleaner regardless and
should benefit architectures with higher LDS latency.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants