Skip to content

Commit 706e3f9

Browse files
authored
vulkan: Implement mmvq for iq1_s/iq1_m (ggml-org#18450)
1 parent 5755e52 commit 706e3f9

5 files changed

Lines changed: 433 additions & 4 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,7 +2702,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
27022702
switch (src0_type) {
27032703
case GGML_TYPE_IQ1_S:
27042704
case GGML_TYPE_IQ1_M:
2705-
lut_size = 2*2048;
2705+
lut_size = 2*2048 + 4*2048;
27062706
break;
27072707
case GGML_TYPE_IQ2_XXS:
27082708
lut_size = 8*256;
@@ -3627,6 +3627,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
36273627
uint32_t rm_kq = 2;
36283628
uint32_t rm_stdq_int = 1;
36293629
uint32_t rm_kq_int = 1;
3630+
auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; };
36303631
if (device->vendor_id == VK_VENDOR_ID_AMD) {
36313632
if (device->architecture == AMD_GCN) {
36323633
rm_stdq = 2;
@@ -3730,6 +3731,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
37303731
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
37313732
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
37323733
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3734+
3735+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
3736+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
3737+
37333738
}
37343739
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
37353740
}
@@ -3776,13 +3781,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
37763781
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
37773782
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
37783783
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
3784+
3785+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
3786+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
37793787
}
37803788
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
37813789
}
37823790

37833791
#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
37843792
GGML_UNUSED(rm_stdq_int);
37853793
GGML_UNUSED(rm_kq_int);
3794+
GGML_UNUSED(rm_iq_int);
37863795
#endif
37873796

37883797
// dequant shaders
@@ -5616,6 +5625,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
56165625
case GGML_TYPE_Q4_K:
56175626
case GGML_TYPE_Q5_K:
56185627
case GGML_TYPE_Q6_K:
5628+
case GGML_TYPE_IQ1_S:
5629+
case GGML_TYPE_IQ1_M:
56195630
break;
56205631
default:
56215632
return nullptr;
@@ -5772,6 +5783,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
57725783
case GGML_TYPE_Q4_K:
57735784
case GGML_TYPE_Q5_K:
57745785
case GGML_TYPE_Q6_K:
5786+
case GGML_TYPE_IQ1_S:
5787+
case GGML_TYPE_IQ1_M:
57755788
break;
57765789
default:
57775790
return nullptr;
@@ -7037,7 +7050,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
70377050
// Quantization overhead is not worth it for small k
70387051
switch (device->vendor_id) {
70397052
case VK_VENDOR_ID_NVIDIA:
7040-
if (src0_type == GGML_TYPE_Q2_K) {
7053+
if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
70417054
return true;
70427055
}
70437056

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1414
#define K_PER_ITER 8
1515
#elif defined(DATA_A_QUANT_K)
1616
#define K_PER_ITER 16
17+
#elif defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
18+
#define K_PER_ITER 32
1719
#else
1820
#error unimplemented
1921
#endif
@@ -49,6 +51,15 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
4951
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
5052
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
5153
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
54+
#elif K_PER_ITER == 32
55+
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 ];
56+
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 1];
57+
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 2];
58+
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 3];
59+
cache_b_qs[4] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 4];
60+
cache_b_qs[5] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 5];
61+
cache_b_qs[6] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 6];
62+
cache_b_qs[7] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 7];
5263
#else
5364
#error unimplemented
5465
#endif

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,118 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
377377
return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
378378
}
379379
#endif
380+
381+
#if defined(DATA_A_IQ1_S)
382+
void repack8(uint ib, uint iqs, out i32vec4 out0, out i32vec4 out1) {
383+
const uint ib32 = iqs / 32;
384+
385+
const uint qh = data_a[ib].qh[ib32];
386+
387+
const uint qs16_0 = data_a_packed16[ib].qs[(4 * ib32 + 0) / 2];
388+
const uint qs16_1 = data_a_packed16[ib].qs[(4 * ib32 + 2) / 2];
389+
390+
const uint qs0 = qs16_0 & 0xFF;
391+
const uint qs1 = qs16_0 >> 8;
392+
const uint qs2 = qs16_1 & 0xFF;
393+
const uint qs3 = qs16_1 >> 8;
394+
395+
const uint hi0 = bitfieldExtract(qh, 3 * int(0), 3);
396+
const uint hi1 = bitfieldExtract(qh, 3 * int(1), 3);
397+
const uint hi2 = bitfieldExtract(qh, 3 * int(2), 3);
398+
const uint hi3 = bitfieldExtract(qh, 3 * int(3), 3);
399+
400+
const int32_t grid0 = int32_t(iq1s_grid_gpu[qs0 | (hi0 << 8)]);
401+
const int32_t grid1 = int32_t(iq1s_grid_gpu[qs1 | (hi1 << 8)]);
402+
const int32_t grid2 = int32_t(iq1s_grid_gpu[qs2 | (hi2 << 8)]);
403+
const int32_t grid3 = int32_t(iq1s_grid_gpu[qs3 | (hi3 << 8)]);
404+
405+
out0 = i32vec4((grid0 >> 0) & 0x0F0F0F0F,
406+
(grid0 >> 4) & 0x0F0F0F0F,
407+
(grid1 >> 0) & 0x0F0F0F0F,
408+
(grid1 >> 4) & 0x0F0F0F0F);
409+
out1 = i32vec4((grid2 >> 0) & 0x0F0F0F0F,
410+
(grid2 >> 4) & 0x0F0F0F0F,
411+
(grid3 >> 0) & 0x0F0F0F0F,
412+
(grid3 >> 4) & 0x0F0F0F0F);
413+
}
414+
415+
vec2 get_dm(uint ib, uint iqs) {
416+
const uint ib32 = iqs / 32;
417+
418+
const uint qh = data_a[ib].qh[ib32];
419+
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
420+
421+
const float d = float(data_a[ib].d);
422+
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
423+
424+
// the -1 cancels out the bias in iq1s_grid_gpu
425+
return FLOAT_TYPE_VEC2(dl, dl * (delta - 1));
426+
}
427+
428+
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
429+
int32_t q_sum = 0;
430+
431+
const uint ib_k = ib_a / 8;
432+
const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;
433+
434+
i32vec4 qs_a0;
435+
i32vec4 qs_a1;
436+
repack8(ib_k, iqs_k, qs_a0, qs_a1);
437+
438+
const vec2 dm = get_dm(ib_k, iqs_k);
439+
440+
q_sum += dotPacked4x8EXT(qs_a0.x, cache_b_qs[0]);
441+
q_sum += dotPacked4x8EXT(qs_a0.y, cache_b_qs[1]);
442+
q_sum += dotPacked4x8EXT(qs_a0.z, cache_b_qs[2]);
443+
q_sum += dotPacked4x8EXT(qs_a0.w, cache_b_qs[3]);
444+
q_sum += dotPacked4x8EXT(qs_a1.x, cache_b_qs[4]);
445+
q_sum += dotPacked4x8EXT(qs_a1.y, cache_b_qs[5]);
446+
q_sum += dotPacked4x8EXT(qs_a1.z, cache_b_qs[6]);
447+
q_sum += dotPacked4x8EXT(qs_a1.w, cache_b_qs[7]);
448+
449+
return FLOAT_TYPE(float(cache_b_ds.x) * float(dm.x) * float(q_sum) + float(dm.y) * float(cache_b_ds.y));
450+
}
451+
#endif
452+
453+
#if defined(DATA_A_IQ1_M)
454+
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
455+
const uint ib_k = ib_a / 8;
456+
const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;
457+
458+
const uint ib32 = iqs_k / 32;
459+
const uint ib64 = ib32 / 2;
460+
461+
const uint16_t[4] scales = data_a[ib_k].scales;
462+
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
463+
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
464+
465+
const uint qs32 = data_a_packed32[ib_k].qs[ib32];
466+
const uint qh16 = data_a_packed16[ib_k].qh[ib32];
467+
468+
float sum = 0;
469+
const uint sc = data_a[ib_k].scales[ib64];
470+
[[unroll]] for (int l = 0; l < 4; ++l) {
471+
const uint ib16 = 2 * ib32 + l / 2;
472+
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
473+
const uint qh = qh16 >> (4 * l);
474+
const uint qs = (qs32 >> (8 * l)) & 0xFF;
475+
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
476+
477+
const int32_t grid = int32_t(iq1s_grid_gpu[qs | ((qh & 7) << 8)]);
478+
479+
int32_t q_sum = 0;
480+
q_sum += dotPacked4x8EXT((grid >> 0) & 0x0F0F0F0F, cache_b_qs[2 * l + 0]);
481+
q_sum += dotPacked4x8EXT((grid >> 4) & 0x0F0F0F0F, cache_b_qs[2 * l + 1]);
482+
483+
int32_t y_sum = 0;
484+
y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 0]);
485+
y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 1]);
486+
487+
// the -1 cancels out the bias in iq1s_grid_gpu
488+
sum += dl * (q_sum + y_sum * (delta - 1));
489+
}
490+
sum *= float(cache_b_ds.x);
491+
492+
return sum;
493+
}
494+
#endif

0 commit comments

Comments
 (0)