Skip to content

Commit f03d331

Browse files
TitaniumtownTheTom
authored andcommitted
vulkan: TQ4_1s support for model weights (#69)
* vulkan: add TQ4_1S weight compression support Adds Vulkan shader support for TQ4_1S (4-bit WHT-rotated weight compression with 16 Lloyd-Max centroids, 32-element blocks). Shaders: - dequant_tq4_1s.comp: standalone dequant with WHT inverse via subgroupShuffleXor (32-thread workgroup, 5-stage butterfly) - mul_mat_vec_tq4_1s.comp: specialized MUL_MAT_VEC with inline activation pre-rotation (forward RHT on activation, centroid*scale dequant without inverse RHT) - copy_from_quant.comp: TQ4_1S dequant path with full WHT inverse - copy_to_quant.comp: TQ4_1S SET_ROWS quantization path with forward RHT, dual half-block RMS scales, 16-centroid quantization - types.glsl: block_tq4_1s struct (d0, d1, qs[16]) - dequant_funcs.glsl: TQ4_1S centroid*scale dequant (no RHT) Pipeline wiring (ggml-vulkan.cpp): - MUL_MAT, SET_ROWS, CPY supports_op - pipeline_dequant, pipeline_set_rows, pipeline_cpy_quant_f32 - Specialized MUL_MAT_VEC with forced subgroup workgroup size Tests: - test_set_rows_tq4_1s: SET_ROWS round-trip validation * vulkan: add fused mul_mat_vec kernel for TQ4_1S Adds a specialised MUL_MAT_VEC shader for TQ4_1S weights so the per-decode-step matrix-vector product no longer has to dequant the full weight tensor to f16 and then go through the generic matmul path. The kernel pre-rotates the activation via a forward Walsh-Hadamard Transform in shared memory and dot-products against the raw centroid*scale stored weights, folding the inverse-WHT on the weight side into the activation by the symmetry H = H^T. Math: w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k] sum_k w[k] * a[k] = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j] Portability choices: - Workgroup size is pinned to 32 threads regardless of the DMMV_WG_SIZE bucket the rest of the mul_mat_vec family picks for the current architecture. The butterfly operates on 32-element blocks with one element per thread; that contract is fixed by the quantization format, not by the GPU. Earlier revisions used `gl_WorkGroupSize.x` as the stride unit, which silently skipped half the work on Intel drivers that force the subgroup to 16 (tests passed via NMSE tolerance while real inference output was garbage). - Butterfly implementation is shared memory only. A subgroup-shuffle variant (`subgroupShuffleXor`) was prototyped and measured on Intel Arc A380 with Mesa Xe HPG: it ran ~60-85 %% slower than the explicit shared-memory butterfly, because Mesa emulates subgroup shuffles via LDS and ends up doing the same LDS traffic with extra driver overhead. The shared-memory butterfly is correct on every device regardless of subgroup-op support, is the fastest path on every device we can actually measure, and leaves the `pipeline_dequant_mul_mat_vec_f32_f32[w][TQ4_1S]` slot uniform across all DMMV_WG_SIZE buckets. - Reduction is the shared-memory tree reduction (no subgroupAdd), for the same reason: on Intel Arc the subgroupAdd is also LDS-backed and the hybrid reduction path was measurably slower. Future vendor-specific heuristics can switch to the hybrid or pure-subgroup reduction variants on NVIDIA / AMD RDNA if hardware subgroup ops turn out to beat the LDS roundtrip there; the existing reduction modes in `mul_mat_vec_base.glsl` already provide the necessary variants. - NUM_ROWS is 8 so the butterfly cost amortises across 8 output rows per workgroup. Each thread holds one position of each of the 8 weight blocks and pairs them with the shared rotated activation. - `mul_mm` and `flash_attn_cm2` shader generation is skipped for TQ4_1S because it is a weight-only format that never reaches the coopmat2 matmul or the KV cache flash-attention paths. Tests: - `test-backend-ops` MUL_MAT tolerance tightened from 2.0 to 0.01 NMSE so real defects can't hide behind a loose check. - Added Gemma-4 E2B, Qwen, Phi and Llama dimensional coverage (k in {1536, 2048, 2304, 3072, 4096}, m in {256, 1152, 1536, 2048, 5120, 6144}, n in {1..8, 16, 64, 256}). 148 MUL_MAT test cases total. Verification (Intel Arc A380, 6 GB VRAM, Vulkan ANV / Mesa Xe HPG, `llama-bench -p 512 -n 128 -r 3` and `llama-perplexity -c 512 --chunks 20 wiki.test.raw`): | Model | Config | Size | Reduction | PPL Δ | pp512/Q8 | tg128/Q8 | |---------------|---------|----------:|----------:|-------:|---------:|---------:| | Qwen2.5-1.5B | I | 1570→1082 | -31.1% | +4.66% | 53.9% | 107.5% | | Phi-3.5-mini | I | 3873→2839 | -26.7% | +5.36% | 57.6% | 52.8% | | Llama-3.2-3B | hybrid | 3263→2147 | -34.2% | +2.03% | 82.4% | 84.2% | | Llama-3.2-3B | premium | 3263→2577 | -21.0% | +0.98% | 71.3% | 67.3% | Qwen2.5-1.5B is faster than its own Q8_0 baseline with Config I: the compressed model fits in less VRAM, and on a small model the TQ4_1S compute cost is offset by the reduced memory traffic. All four models produce coherent output end-to-end and the reductions line up with the TurboQuant paper's validation matrix (§5.8). The remaining gap to Q8_0 on the bigger models is compute-bound on the A380; it closes further on GPUs with more raw throughput. * vulkan: restructure TQ4_1S inner loop for cross-row smem reuse Splits the dequant+accumulate phase into two sub-loops: 1. Pre-compute w_vals[n] for all NUM_ROWS rows (centroid lookup + scale multiply, reads from weight buffer only). 2. Read the rotated activation from shared memory ONCE per column, then FMA across all rows in a tight register loop. This is the Vulkan analogue of the 'hot loop load dedup' from the CUDA kernel (PR #57 optimisation #2). It makes the shared memory read explicitly loop-invariant across rows, which helps compilers that don't auto-hoist LDS loads out of unrolled loops. Measured effect on Intel Arc A380 (Llama-3.2-3B premium, llama-bench tg128, r=5): 15.50 -> 15.78 t/s (+1.8%, within noise but not a regression). The structure is cleaner regardless and should benefit architectures with higher LDS latency.
1 parent 037047e commit f03d331

9 files changed

Lines changed: 598 additions & 3 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4155,6 +4155,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
41554155

41564156
const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0;
41574157
const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;
4158+
4159+
// TQ4_1S uses a dedicated pipeline whose workgroup size is always 32 and
4160+
// whose reduction path is always the shared-memory variant.
4161+
//
4162+
// The Walsh-Hadamard butterfly inside the shader operates on 32-element
4163+
// blocks with one element per thread, so the workgroup contract is fixed
4164+
// regardless of what the rest of the mul_mat_vec family picks for the
4165+
// current DMMV_WG_SIZE bucket. We always use 32 threads per workgroup.
4166+
//
4167+
// Reduction choice: the shader uses the SHMEM tree reduction even when
4168+
// subgroup arithmetic is available. A subgroup-shuffle butterfly + pure
4169+
// subgroupAdd reduction variant was tried and measured ~70 %% slower on
4170+
// Intel Arc (Mesa Xe HPG), where subgroup shuffles and subgroup adds are
4171+
// emulated over LDS and end up doing the same amount of LDS traffic as
4172+
// the explicit shared-memory path but with extra driver overhead. Going
4173+
// through SHMEM directly is always correct and is fastest on the devices
4174+
// we can actually measure. Future vendor-specific heuristics can switch
4175+
// to the hybrid reduction variant on NVIDIA / AMD RDNA if hardware
4176+
// subgroup shuffles beat the LDS roundtrip there.
4177+
const uint32_t tq4_1s_wg_size = 32u;
4178+
const uint32_t tq4_1s_force_sg_size = 0u;
4179+
const bool tq4_1s_use_subgroups = false;
4180+
const shader_reduction_mode tq4_1s_reduc = SHADER_REDUCTION_MODE_SHMEM;
4181+
41584182
static constexpr uint32_t mul_mat_vec_num_bindings = 5;
41594183
static constexpr uint32_t mul_mat_vec_id_num_bindings = 6;
41604184

@@ -4196,6 +4220,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
41964220
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
41974221
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
41984222
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
4223+
// TQ4_1S: fixed 32-thread workgroup, shared-memory WHT butterfly,
4224+
// shared-memory reduction. NUM_ROWS=8 amortises the butterfly cost
4225+
// across 8 output rows per workgroup.
4226+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_TQ4_1S][i], "mul_mat_vec_tq4_1s_f32_f32", arr_dmmv_tq4_1s_f32_f32_len[tq4_1s_reduc], arr_dmmv_tq4_1s_f32_f32_data[tq4_1s_reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {8, 1, 1}, {tq4_1s_wg_size, 8, i+1}, 1, true, tq4_1s_use_subgroups, tq4_1s_force_sg_size);
41994227

42004228
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
42014229
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
@@ -4222,6 +4250,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
42224250
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
42234251
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
42244252
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
4253+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_TQ4_1S][i], "mul_mat_vec_tq4_1s_f16_f32", arr_dmmv_tq4_1s_f16_f32_len[tq4_1s_reduc], arr_dmmv_tq4_1s_f16_f32_data[tq4_1s_reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {8, 1, 1}, {tq4_1s_wg_size, 8, i+1}, 1, true, tq4_1s_use_subgroups, tq4_1s_force_sg_size);
42254254

42264255
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
42274256
if (device->integer_dot_product) {
@@ -4331,6 +4360,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
43314360
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
43324361
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
43334362
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TURBO3_0], "dequant_turbo3_0", dequant_turbo3_0_len, dequant_turbo3_0_data, "main", 2, 5 * sizeof(uint32_t), {128, 1, 1}, {}, 1);
4363+
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TQ4_1S], "dequant_tq4_1s", dequant_tq4_1s_len, dequant_tq4_1s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
43344364

43354365
// TurboQuant WHT
43364366
ggml_vk_create_pipeline(device, device->pipeline_turbo_wht, "turbo_wht", turbo_wht_len, turbo_wht_data, "main", 2, 3 * sizeof(uint32_t), {128, 1, 1}, {}, 1);
@@ -4471,7 +4501,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
44714501
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
44724502
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
44734503
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4474-
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO3_0], "set_rows_turbo3_0" #itype, set_rows_turbo3_0 ## itype ## _len, set_rows_turbo3_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
4504+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO3_0], "set_rows_turbo3_0" #itype, set_rows_turbo3_0 ## itype ## _len, set_rows_turbo3_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4505+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TQ4_1S], "set_rows_tq4_1s" #itype, set_rows_tq4_1s ## itype ## _len, set_rows_tq4_1s ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
44754506

44764507
SET_ROWS(_i32)
44774508
SET_ROWS(_i64)
@@ -4486,6 +4517,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
44864517
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
44874518
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
44884519
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_TURBO3_0], "cpy_turbo3_0_f32", cpy_turbo3_0_f32_len, cpy_turbo3_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_TURBO3_0), 1, 1}, {}, 1);
4520+
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_TQ4_1S], "cpy_tq4_1s_f32", cpy_tq4_1s_f32_len, cpy_tq4_1s_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_TQ4_1S), 1, 1}, {}, 1);
44894521

44904522
auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
44914523
std::string s;
@@ -6141,6 +6173,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
61416173
case GGML_TYPE_IQ4_NL:
61426174
case GGML_TYPE_MXFP4:
61436175
case GGML_TYPE_NVFP4:
6176+
case GGML_TYPE_TQ4_1S:
61446177
break;
61456178
default:
61466179
return nullptr;
@@ -6281,6 +6314,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
62816314
case GGML_TYPE_IQ4_NL:
62826315
case GGML_TYPE_MXFP4:
62836316
case GGML_TYPE_NVFP4:
6317+
case GGML_TYPE_TQ4_1S:
62846318
break;
62856319
default:
62866320
return nullptr;
@@ -6296,6 +6330,10 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
62966330
if (m < 4096 && k >= 1024) {
62976331
dmmv_wg = DMMV_WG_SIZE_LARGE;
62986332
}
6333+
} else if (a_type == GGML_TYPE_TQ4_1S) {
6334+
// TQ4_1S needs exactly 32 threads (one subgroup) to cooperate on the
6335+
// 32-element WHT butterfly in shared memory. Force SUBGROUP-sized wg.
6336+
dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
62996337
} else {
63006338
if (m <= 8192 && k >= 1024) {
63016339
dmmv_wg = DMMV_WG_SIZE_LARGE;
@@ -7393,6 +7431,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
73937431
case GGML_TYPE_Q8_0:
73947432
case GGML_TYPE_IQ4_NL:
73957433
case GGML_TYPE_TURBO3_0:
7434+
case GGML_TYPE_TQ4_1S:
73967435
return ctx->device->pipeline_cpy_quant_f32[src->type];
73977436
default:
73987437
break;
@@ -10216,6 +10255,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
1021610255
uint32_t ne = ggml_nelements(src0);
1021710256
if (dst->type == GGML_TYPE_TURBO3_0) {
1021810257
ne = ne / 128;
10258+
} else if (dst->type == GGML_TYPE_TQ4_1S) {
10259+
ne = ne / 32;
1021910260
} else if (ggml_is_quantized(dst->type)) {
1022010261
// quants run 32 threads each doing QUANT_K elements
1022110262
ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
@@ -15467,6 +15508,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1546715508
case GGML_TYPE_IQ4_NL:
1546815509
case GGML_TYPE_MXFP4:
1546915510
case GGML_TYPE_NVFP4:
15511+
case GGML_TYPE_TQ4_1S:
1547015512
break;
1547115513
default:
1547215514
return false;
@@ -15607,6 +15649,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1560715649
case GGML_TYPE_Q8_0:
1560815650
case GGML_TYPE_IQ4_NL:
1560915651
case GGML_TYPE_TURBO3_0:
15652+
case GGML_TYPE_TQ4_1S:
1561015653
return true;
1561115654
default:
1561215655
return false;
@@ -15647,6 +15690,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1564715690
case GGML_TYPE_Q8_0:
1564815691
case GGML_TYPE_IQ4_NL:
1564915692
case GGML_TYPE_TURBO3_0:
15693+
case GGML_TYPE_TQ4_1S:
1565015694
return true;
1565115695
default:
1565215696
break;

ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,41 @@ void main() {
3030

3131
const uint a_offset = 0;
3232
const uint ib = src_idx;
33+
34+
#if defined(DATA_A_TQ4_1S)
35+
// TQ4_1S requires full inverse WHT after centroid*scale dequant.
36+
// Dequant all 32 elements into a buffer, apply butterfly, then write.
37+
const float tq4_signs[32] = float[32](
38+
+1.0, -1.0, +1.0, -1.0, +1.0, +1.0, -1.0, +1.0,
39+
-1.0, -1.0, +1.0, -1.0, +1.0, +1.0, -1.0, +1.0,
40+
-1.0, -1.0, +1.0, -1.0, +1.0, -1.0, -1.0, +1.0,
41+
-1.0, +1.0, +1.0, -1.0, +1.0, -1.0, -1.0, +1.0
42+
);
43+
const float TQ4_INV_SQRT32 = 0.17677669529663688;
44+
45+
float buf[32];
46+
for (int j = 0; j < 32; j += 2) {
47+
vec2 v = dequantize(ib, j, a_offset);
48+
buf[j] = v.x;
49+
buf[j+1] = v.y;
50+
}
51+
52+
// Inverse WHT butterfly (5 stages for 32 elements)
53+
for (uint step = 1u; step < 32u; step <<= 1u) {
54+
for (uint i = 0u; i < 32u; i += step * 2u) {
55+
for (uint j2 = i; j2 < i + step; j2++) {
56+
float a2 = buf[j2], b2 = buf[j2 + step];
57+
buf[j2] = a2 + b2;
58+
buf[j2 + step] = a2 - b2;
59+
}
60+
}
61+
}
62+
63+
// Normalize and apply sign pattern
64+
for (int j = 0; j < 32; j++) {
65+
data_d[dst_idx + j] = buf[j] * TQ4_INV_SQRT32 * tq4_signs[j];
66+
}
67+
#else
3368
const vec2 dm = get_dm(ib, a_offset);
3469

3570
[[unroll]] for (int j = 0; j < QUANT_K; j += 4) {
@@ -48,4 +83,5 @@ void main() {
4883
data_d[dst_idx + j + 3] = v[3];
4984
#endif
5085
}
86+
#endif
5187
}

0 commit comments

Comments
 (0)