Skip to content

Commit 81edfc7

Browse files
TitaniumtownLogicDaemon
authored andcommitted
vulkan: TQ4_1s support for model weights (TheTom#69)
* vulkan: add TQ4_1S weight compression support Adds Vulkan shader support for TQ4_1S (4-bit WHT-rotated weight compression with 16 Lloyd-Max centroids, 32-element blocks). Shaders: - dequant_tq4_1s.comp: standalone dequant with WHT inverse via subgroupShuffleXor (32-thread workgroup, 5-stage butterfly) - mul_mat_vec_tq4_1s.comp: specialized MUL_MAT_VEC with inline activation pre-rotation (forward RHT on activation, centroid*scale dequant without inverse RHT) - copy_from_quant.comp: TQ4_1S dequant path with full WHT inverse - copy_to_quant.comp: TQ4_1S SET_ROWS quantization path with forward RHT, dual half-block RMS scales, 16-centroid quantization - types.glsl: block_tq4_1s struct (d0, d1, qs[16]) - dequant_funcs.glsl: TQ4_1S centroid*scale dequant (no RHT) Pipeline wiring (ggml-vulkan.cpp): - MUL_MAT, SET_ROWS, CPY supports_op - pipeline_dequant, pipeline_set_rows, pipeline_cpy_quant_f32 - Specialized MUL_MAT_VEC with forced subgroup workgroup size Tests: - test_set_rows_tq4_1s: SET_ROWS round-trip validation * vulkan: add fused mul_mat_vec kernel for TQ4_1S Adds a specialised MUL_MAT_VEC shader for TQ4_1S weights so the per-decode-step matrix-vector product no longer has to dequant the full weight tensor to f16 and then go through the generic matmul path. The kernel pre-rotates the activation via a forward Walsh-Hadamard Transform in shared memory and dot-products against the raw centroid*scale stored weights, folding the inverse-WHT on the weight side into the activation by the symmetry H = H^T. Math: w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k] sum_k w[k] * a[k] = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j] Portability choices: - Workgroup size is pinned to 32 threads regardless of the DMMV_WG_SIZE bucket the rest of the mul_mat_vec family picks for the current architecture. The butterfly operates on 32-element blocks with one element per thread; that contract is fixed by the quantization format, not by the GPU. Earlier revisions used `gl_WorkGroupSize.x` as the stride unit, which silently skipped half the work on Intel drivers that force the subgroup to 16 (tests passed via NMSE tolerance while real inference output was garbage). - Butterfly implementation is shared memory only. A subgroup-shuffle variant (`subgroupShuffleXor`) was prototyped and measured on Intel Arc A380 with Mesa Xe HPG: it ran ~60-85 %% slower than the explicit shared-memory butterfly, because Mesa emulates subgroup shuffles via LDS and ends up doing the same LDS traffic with extra driver overhead. The shared-memory butterfly is correct on every device regardless of subgroup-op support, is the fastest path on every device we can actually measure, and leaves the `pipeline_dequant_mul_mat_vec_f32_f32[w][TQ4_1S]` slot uniform across all DMMV_WG_SIZE buckets. - Reduction is the shared-memory tree reduction (no subgroupAdd), for the same reason: on Intel Arc the subgroupAdd is also LDS-backed and the hybrid reduction path was measurably slower. Future vendor-specific heuristics can switch to the hybrid or pure-subgroup reduction variants on NVIDIA / AMD RDNA if hardware subgroup ops turn out to beat the LDS roundtrip there; the existing reduction modes in `mul_mat_vec_base.glsl` already provide the necessary variants. - NUM_ROWS is 8 so the butterfly cost amortises across 8 output rows per workgroup. Each thread holds one position of each of the 8 weight blocks and pairs them with the shared rotated activation. - `mul_mm` and `flash_attn_cm2` shader generation is skipped for TQ4_1S because it is a weight-only format that never reaches the coopmat2 matmul or the KV cache flash-attention paths. Tests: - `test-backend-ops` MUL_MAT tolerance tightened from 2.0 to 0.01 NMSE so real defects can't hide behind a loose check. - Added Gemma-4 E2B, Qwen, Phi and Llama dimensional coverage (k in {1536, 2048, 2304, 3072, 4096}, m in {256, 1152, 1536, 2048, 5120, 6144}, n in {1..8, 16, 64, 256}). 148 MUL_MAT test cases total. Verification (Intel Arc A380, 6 GB VRAM, Vulkan ANV / Mesa Xe HPG, `llama-bench -p 512 -n 128 -r 3` and `llama-perplexity -c 512 --chunks 20 wiki.test.raw`): | Model | Config | Size | Reduction | PPL Δ | pp512/Q8 | tg128/Q8 | |---------------|---------|----------:|----------:|-------:|---------:|---------:| | Qwen2.5-1.5B | I | 1570→1082 | -31.1% | +4.66% | 53.9% | 107.5% | | Phi-3.5-mini | I | 3873→2839 | -26.7% | +5.36% | 57.6% | 52.8% | | Llama-3.2-3B | hybrid | 3263→2147 | -34.2% | +2.03% | 82.4% | 84.2% | | Llama-3.2-3B | premium | 3263→2577 | -21.0% | +0.98% | 71.3% | 67.3% | Qwen2.5-1.5B is faster than its own Q8_0 baseline with Config I: the compressed model fits in less VRAM, and on a small model the TQ4_1S compute cost is offset by the reduced memory traffic. All four models produce coherent output end-to-end and the reductions line up with the TurboQuant paper's validation matrix (§5.8). The remaining gap to Q8_0 on the bigger models is compute-bound on the A380; it closes further on GPUs with more raw throughput. * vulkan: restructure TQ4_1S inner loop for cross-row smem reuse Splits the dequant+accumulate phase into two sub-loops: 1. Pre-compute w_vals[n] for all NUM_ROWS rows (centroid lookup + scale multiply, reads from weight buffer only). 2. Read the rotated activation from shared memory ONCE per column, then FMA across all rows in a tight register loop. This is the Vulkan analogue of the 'hot loop load dedup' from the CUDA kernel (PR TheTom#57 optimisation TheTom#2). It makes the shared memory read explicitly loop-invariant across rows, which helps compilers that don't auto-hoist LDS loads out of unrolled loops. Measured effect on Intel Arc A380 (Llama-3.2-3B premium, llama-bench tg128, r=5): 15.50 -> 15.78 t/s (+1.8%, within noise but not a regression). The structure is cleaner regardless and should benefit architectures with higher LDS latency.
1 parent fbd0a4e commit 81edfc7

2 files changed

Lines changed: 57 additions & 3 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4155,6 +4155,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
41554155

41564156
const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0;
41574157
const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;
4158+
4159+
// TQ4_1S uses a dedicated pipeline whose workgroup size is always 32 and
4160+
// whose reduction path is always the shared-memory variant.
4161+
//
4162+
// The Walsh-Hadamard butterfly inside the shader operates on 32-element
4163+
// blocks with one element per thread, so the workgroup contract is fixed
4164+
// regardless of what the rest of the mul_mat_vec family picks for the
4165+
// current DMMV_WG_SIZE bucket. We always use 32 threads per workgroup.
4166+
//
4167+
// Reduction choice: the shader uses the SHMEM tree reduction even when
4168+
// subgroup arithmetic is available. A subgroup-shuffle butterfly + pure
4169+
// subgroupAdd reduction variant was tried and measured ~70 %% slower on
4170+
// Intel Arc (Mesa Xe HPG), where subgroup shuffles and subgroup adds are
4171+
// emulated over LDS and end up doing the same amount of LDS traffic as
4172+
// the explicit shared-memory path but with extra driver overhead. Going
4173+
// through SHMEM directly is always correct and is fastest on the devices
4174+
// we can actually measure. Future vendor-specific heuristics can switch
4175+
// to the hybrid reduction variant on NVIDIA / AMD RDNA if hardware
4176+
// subgroup shuffles beat the LDS roundtrip there.
4177+
const uint32_t tq4_1s_wg_size = 32u;
4178+
const uint32_t tq4_1s_force_sg_size = 0u;
4179+
const bool tq4_1s_use_subgroups = false;
4180+
const shader_reduction_mode tq4_1s_reduc = SHADER_REDUCTION_MODE_SHMEM;
4181+
41584182
static constexpr uint32_t mul_mat_vec_num_bindings = 5;
41594183
static constexpr uint32_t mul_mat_vec_id_num_bindings = 6;
41604184

@@ -4196,6 +4220,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
41964220
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
41974221
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
41984222
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
4223+
// TQ4_1S: fixed 32-thread workgroup, shared-memory WHT butterfly,
4224+
// shared-memory reduction. NUM_ROWS=8 amortises the butterfly cost
4225+
// across 8 output rows per workgroup.
4226+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_TQ4_1S][i], "mul_mat_vec_tq4_1s_f32_f32", arr_dmmv_tq4_1s_f32_f32_len[tq4_1s_reduc], arr_dmmv_tq4_1s_f32_f32_data[tq4_1s_reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {8, 1, 1}, {tq4_1s_wg_size, 8, i+1}, 1, true, tq4_1s_use_subgroups, tq4_1s_force_sg_size);
41994227

42004228
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
42014229
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
@@ -4222,6 +4250,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
42224250
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
42234251
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
42244252
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
4253+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_TQ4_1S][i], "mul_mat_vec_tq4_1s_f16_f32", arr_dmmv_tq4_1s_f16_f32_len[tq4_1s_reduc], arr_dmmv_tq4_1s_f16_f32_data[tq4_1s_reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {8, 1, 1}, {tq4_1s_wg_size, 8, i+1}, 1, true, tq4_1s_use_subgroups, tq4_1s_force_sg_size);
42254254

42264255
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
42274256
if (device->integer_dot_product) {
@@ -4331,6 +4360,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
43314360
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
43324361
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
43334362
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TURBO3_0], "dequant_turbo3_0", dequant_turbo3_0_len, dequant_turbo3_0_data, "main", 2, 5 * sizeof(uint32_t), {128, 1, 1}, {}, 1);
4363+
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TQ4_1S], "dequant_tq4_1s", dequant_tq4_1s_len, dequant_tq4_1s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
43344364

43354365
// TurboQuant WHT
43364366
ggml_vk_create_pipeline(device, device->pipeline_turbo_wht, "turbo_wht", turbo_wht_len, turbo_wht_data, "main", 2, 3 * sizeof(uint32_t), {128, 1, 1}, {}, 1);
@@ -4471,7 +4501,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
44714501
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
44724502
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
44734503
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4474-
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO3_0], "set_rows_turbo3_0" #itype, set_rows_turbo3_0 ## itype ## _len, set_rows_turbo3_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
4504+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO3_0], "set_rows_turbo3_0" #itype, set_rows_turbo3_0 ## itype ## _len, set_rows_turbo3_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4505+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TQ4_1S], "set_rows_tq4_1s" #itype, set_rows_tq4_1s ## itype ## _len, set_rows_tq4_1s ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
44754506

44764507
SET_ROWS(_i32)
44774508
SET_ROWS(_i64)
@@ -4486,6 +4517,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
44864517
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
44874518
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
44884519
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_TURBO3_0], "cpy_turbo3_0_f32", cpy_turbo3_0_f32_len, cpy_turbo3_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_TURBO3_0), 1, 1}, {}, 1);
4520+
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_TQ4_1S], "cpy_tq4_1s_f32", cpy_tq4_1s_f32_len, cpy_tq4_1s_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_TQ4_1S), 1, 1}, {}, 1);
44894521

44904522
auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
44914523
std::string s;
@@ -6141,6 +6173,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
61416173
case GGML_TYPE_IQ4_NL:
61426174
case GGML_TYPE_MXFP4:
61436175
case GGML_TYPE_NVFP4:
6176+
case GGML_TYPE_TQ4_1S:
61446177
break;
61456178
default:
61466179
return nullptr;
@@ -6281,6 +6314,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
62816314
case GGML_TYPE_IQ4_NL:
62826315
case GGML_TYPE_MXFP4:
62836316
case GGML_TYPE_NVFP4:
6317+
case GGML_TYPE_TQ4_1S:
62846318
break;
62856319
default:
62866320
return nullptr;
@@ -6296,6 +6330,10 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
62966330
if (m < 4096 && k >= 1024) {
62976331
dmmv_wg = DMMV_WG_SIZE_LARGE;
62986332
}
6333+
} else if (a_type == GGML_TYPE_TQ4_1S) {
6334+
// TQ4_1S needs exactly 32 threads (one subgroup) to cooperate on the
6335+
// 32-element WHT butterfly in shared memory. Force SUBGROUP-sized wg.
6336+
dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
62996337
} else {
63006338
if (m <= 8192 && k >= 1024) {
63016339
dmmv_wg = DMMV_WG_SIZE_LARGE;
@@ -7393,6 +7431,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
73937431
case GGML_TYPE_Q8_0:
73947432
case GGML_TYPE_IQ4_NL:
73957433
case GGML_TYPE_TURBO3_0:
7434+
case GGML_TYPE_TQ4_1S:
73967435
return ctx->device->pipeline_cpy_quant_f32[src->type];
73977436
default:
73987437
break;
@@ -10216,6 +10255,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
1021610255
uint32_t ne = ggml_nelements(src0);
1021710256
if (dst->type == GGML_TYPE_TURBO3_0) {
1021810257
ne = ne / 128;
10258+
} else if (dst->type == GGML_TYPE_TQ4_1S) {
10259+
ne = ne / 32;
1021910260
} else if (ggml_is_quantized(dst->type)) {
1022010261
// quants run 32 threads each doing QUANT_K elements
1022110262
ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
@@ -15467,6 +15508,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1546715508
case GGML_TYPE_IQ4_NL:
1546815509
case GGML_TYPE_MXFP4:
1546915510
case GGML_TYPE_NVFP4:
15511+
case GGML_TYPE_TQ4_1S:
1547015512
break;
1547115513
default:
1547215514
return false;
@@ -15607,6 +15649,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1560715649
case GGML_TYPE_Q8_0:
1560815650
case GGML_TYPE_IQ4_NL:
1560915651
case GGML_TYPE_TURBO3_0:
15652+
case GGML_TYPE_TQ4_1S:
1561015653
return true;
1561115654
default:
1561215655
return false;
@@ -15647,6 +15690,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1564715690
case GGML_TYPE_Q8_0:
1564815691
case GGML_TYPE_IQ4_NL:
1564915692
case GGML_TYPE_TURBO3_0:
15693+
case GGML_TYPE_TQ4_1S:
1565015694
return true;
1565115695
default:
1565215696
break;

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ const std::vector<std::string> type_names = {
6969
"nvfp4",
7070
"bf16",
7171
"turbo3_0",
72+
"tq4_1s",
7273
};
7374

7475
enum MatMulIdType {
@@ -564,6 +565,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
564565
if (tname == "bf16") {
565566
continue;
566567
}
568+
// TQ4_1S uses a specialized mul_mat_vec shader for small N and
569+
// the dequant+f16 matmul fallback for large N. No dedicated mul_mm needed.
570+
if (tname == "tq4_1s") {
571+
continue;
572+
}
567573

568574
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
569575
// For unaligned, load one at a time for f32/f16, or two at a time for quants
@@ -644,6 +650,8 @@ void process_shaders() {
644650

645651
for (const auto& tname : type_names) {
646652
if (tname == "bf16") continue;
653+
// TQ4_1S is a weight-only format; flash attention isn't defined for it.
654+
if (tname == "tq4_1s") continue;
647655

648656
if (fp16) {
649657
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -692,7 +700,7 @@ void process_shaders() {
692700
for (const auto& tname : type_names) {
693701
// mul mat vec
694702
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
695-
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
703+
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_") || tname == "tq4_1s") ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
696704

697705
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}));
698706
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}));
@@ -775,8 +783,10 @@ void process_shaders() {
775783
}
776784
// turbo3_0 copy-from-quant only; copy-to-quant (cpy_f32_turbo3_0) omitted because the non-SET_ROWS quantize() path lacks the WHT transform
777785
string_to_spv("cpy_turbo3_0_f32", "copy_from_quant.comp", {{"DATA_A_TURBO3_0", "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
786+
// tq4_1s copy-from-quant only; copy-to-quant requires WHT forward (handled in SET_ROWS path)
787+
string_to_spv("cpy_tq4_1s_f32", "copy_from_quant.comp", {{"DATA_A_TQ4_1S", "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
778788

779-
for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl", "turbo3_0"}) {
789+
for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl", "turbo3_0", "tq4_1s"}) {
780790
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
781791
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
782792
}

0 commit comments

Comments
 (0)