Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4155,6 +4155,30 @@ static void ggml_vk_load_shaders(vk_device& device) {

const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0;
const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;

// TQ4_1S uses a dedicated pipeline whose workgroup size is always 32 and
// whose reduction path is always the shared-memory variant.
//
// The Walsh-Hadamard butterfly inside the shader operates on 32-element
// blocks with one element per thread, so the workgroup contract is fixed
// regardless of what the rest of the mul_mat_vec family picks for the
// current DMMV_WG_SIZE bucket. We always use 32 threads per workgroup.
//
// Reduction choice: the shader uses the SHMEM tree reduction even when
// subgroup arithmetic is available. A subgroup-shuffle butterfly + pure
// subgroupAdd reduction variant was tried and measured ~70 %% slower on
// Intel Arc (Mesa Xe HPG), where subgroup shuffles and subgroup adds are
// emulated over LDS and end up doing the same amount of LDS traffic as
// the explicit shared-memory path but with extra driver overhead. Going
// through SHMEM directly is always correct and is fastest on the devices
// we can actually measure. Future vendor-specific heuristics can switch
// to the hybrid reduction variant on NVIDIA / AMD RDNA if hardware
// subgroup shuffles beat the LDS roundtrip there.
const uint32_t tq4_1s_wg_size = 32u;
const uint32_t tq4_1s_force_sg_size = 0u;
const bool tq4_1s_use_subgroups = false;
const shader_reduction_mode tq4_1s_reduc = SHADER_REDUCTION_MODE_SHMEM;

static constexpr uint32_t mul_mat_vec_num_bindings = 5;
static constexpr uint32_t mul_mat_vec_id_num_bindings = 6;

Expand Down Expand Up @@ -4196,6 +4220,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
// TQ4_1S: fixed 32-thread workgroup, shared-memory WHT butterfly,
// shared-memory reduction. NUM_ROWS=8 amortises the butterfly cost
// across 8 output rows per workgroup.
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_TQ4_1S][i], "mul_mat_vec_tq4_1s_f32_f32", arr_dmmv_tq4_1s_f32_f32_len[tq4_1s_reduc], arr_dmmv_tq4_1s_f32_f32_data[tq4_1s_reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {8, 1, 1}, {tq4_1s_wg_size, 8, i+1}, 1, true, tq4_1s_use_subgroups, tq4_1s_force_sg_size);

ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
Expand All @@ -4222,6 +4250,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_TQ4_1S][i], "mul_mat_vec_tq4_1s_f16_f32", arr_dmmv_tq4_1s_f16_f32_len[tq4_1s_reduc], arr_dmmv_tq4_1s_f16_f32_data[tq4_1s_reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {8, 1, 1}, {tq4_1s_wg_size, 8, i+1}, 1, true, tq4_1s_use_subgroups, tq4_1s_force_sg_size);

#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
Expand Down Expand Up @@ -4331,6 +4360,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TURBO3_0], "dequant_turbo3_0", dequant_turbo3_0_len, dequant_turbo3_0_data, "main", 2, 5 * sizeof(uint32_t), {128, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TQ4_1S], "dequant_tq4_1s", dequant_tq4_1s_len, dequant_tq4_1s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);

// TurboQuant WHT
ggml_vk_create_pipeline(device, device->pipeline_turbo_wht, "turbo_wht", turbo_wht_len, turbo_wht_data, "main", 2, 3 * sizeof(uint32_t), {128, 1, 1}, {}, 1);
Expand Down Expand Up @@ -4471,7 +4501,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO3_0], "set_rows_turbo3_0" #itype, set_rows_turbo3_0 ## itype ## _len, set_rows_turbo3_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO3_0], "set_rows_turbo3_0" #itype, set_rows_turbo3_0 ## itype ## _len, set_rows_turbo3_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TQ4_1S], "set_rows_tq4_1s" #itype, set_rows_tq4_1s ## itype ## _len, set_rows_tq4_1s ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);

SET_ROWS(_i32)
SET_ROWS(_i64)
Expand All @@ -4486,6 +4517,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_TURBO3_0], "cpy_turbo3_0_f32", cpy_turbo3_0_f32_len, cpy_turbo3_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_TURBO3_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_TQ4_1S], "cpy_tq4_1s_f32", cpy_tq4_1s_f32_len, cpy_tq4_1s_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_TQ4_1S), 1, 1}, {}, 1);

auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
std::string s;
Expand Down Expand Up @@ -6141,6 +6173,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_TQ4_1S:
break;
default:
return nullptr;
Expand Down Expand Up @@ -6281,6 +6314,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_TQ4_1S:
break;
default:
return nullptr;
Expand All @@ -6296,6 +6330,10 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
if (m < 4096 && k >= 1024) {
dmmv_wg = DMMV_WG_SIZE_LARGE;
}
} else if (a_type == GGML_TYPE_TQ4_1S) {
// TQ4_1S needs exactly 32 threads (one subgroup) to cooperate on the
// 32-element WHT butterfly in shared memory. Force SUBGROUP-sized wg.
dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
} else {
if (m <= 8192 && k >= 1024) {
dmmv_wg = DMMV_WG_SIZE_LARGE;
Expand Down Expand Up @@ -7393,6 +7431,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_TURBO3_0:
case GGML_TYPE_TQ4_1S:
return ctx->device->pipeline_cpy_quant_f32[src->type];
default:
break;
Expand Down Expand Up @@ -10216,6 +10255,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
uint32_t ne = ggml_nelements(src0);
if (dst->type == GGML_TYPE_TURBO3_0) {
ne = ne / 128;
} else if (dst->type == GGML_TYPE_TQ4_1S) {
ne = ne / 32;
} else if (ggml_is_quantized(dst->type)) {
// quants run 32 threads each doing QUANT_K elements
ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
Expand Down Expand Up @@ -15467,6 +15508,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_TQ4_1S:
break;
default:
return false;
Expand Down Expand Up @@ -15607,6 +15649,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_TURBO3_0:
case GGML_TYPE_TQ4_1S:
return true;
default:
return false;
Expand Down Expand Up @@ -15647,6 +15690,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_TURBO3_0:
case GGML_TYPE_TQ4_1S:
return true;
default:
break;
Expand Down
36 changes: 36 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,41 @@ void main() {

const uint a_offset = 0;
const uint ib = src_idx;

#if defined(DATA_A_TQ4_1S)
// TQ4_1S requires full inverse WHT after centroid*scale dequant.
// Dequant all 32 elements into a buffer, apply butterfly, then write.
const float tq4_signs[32] = float[32](
+1.0, -1.0, +1.0, -1.0, +1.0, +1.0, -1.0, +1.0,
-1.0, -1.0, +1.0, -1.0, +1.0, +1.0, -1.0, +1.0,
-1.0, -1.0, +1.0, -1.0, +1.0, -1.0, -1.0, +1.0,
-1.0, +1.0, +1.0, -1.0, +1.0, -1.0, -1.0, +1.0
);
const float TQ4_INV_SQRT32 = 0.17677669529663688;

float buf[32];
for (int j = 0; j < 32; j += 2) {
vec2 v = dequantize(ib, j, a_offset);
buf[j] = v.x;
buf[j+1] = v.y;
}

// Inverse WHT butterfly (5 stages for 32 elements)
for (uint step = 1u; step < 32u; step <<= 1u) {
for (uint i = 0u; i < 32u; i += step * 2u) {
for (uint j2 = i; j2 < i + step; j2++) {
float a2 = buf[j2], b2 = buf[j2 + step];
buf[j2] = a2 + b2;
buf[j2 + step] = a2 - b2;
}
}
}

// Normalize and apply sign pattern
for (int j = 0; j < 32; j++) {
data_d[dst_idx + j] = buf[j] * TQ4_INV_SQRT32 * tq4_signs[j];
}
#else
const vec2 dm = get_dm(ib, a_offset);

[[unroll]] for (int j = 0; j < QUANT_K; j += 4) {
Expand All @@ -48,4 +83,5 @@ void main() {
data_d[dst_idx + j + 3] = v[3];
#endif
}
#endif
}
Loading
Loading