Skip to content

Commit dbe7901

Browse files
authored
vulkan: fix matmul integer pipeline selection (#23005)
* vulkan: fix matmul integer pipeline selection * gate pipeline creation with the right bools
1 parent 320a6a4 commit dbe7901

1 file changed

Lines changed: 12 additions & 12 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3954,13 +3954,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
39543954
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
39553955

39563956
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3957-
if (device->mul_mat ## ID ## _l[TYPE]) { \
3957+
if (device->mul_mat ## ID ## _l_int[TYPE]) { \
39583958
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
39593959
} \
3960-
if (device->mul_mat ## ID ## _m[TYPE]) { \
3960+
if (device->mul_mat ## ID ## _m_int[TYPE]) { \
39613961
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
39623962
} \
3963-
if (device->mul_mat ## ID ## _s[TYPE]) { \
3963+
if (device->mul_mat ## ID ## _s_int[TYPE]) { \
39643964
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
39653965
} \
39663966

@@ -4131,11 +4131,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
41314131
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
41324132

41334133
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
4134-
if (device->mul_mat ## ID ## _l[TYPE]) \
4134+
if (device->mul_mat ## ID ## _l_int[TYPE]) \
41354135
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
4136-
if (device->mul_mat ## ID ## _m[TYPE]) \
4136+
if (device->mul_mat ## ID ## _m_int[TYPE]) \
41374137
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
4138-
if (device->mul_mat ## ID ## _s[TYPE]) \
4138+
if (device->mul_mat ## ID ## _s_int[TYPE]) \
41394139
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
41404140

41414141
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
@@ -5716,12 +5716,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
57165716
break;
57175717
}
57185718

5719-
device->mul_mat_l_int[i] = true;
5720-
device->mul_mat_m_int[i] = true;
5721-
device->mul_mat_s_int[i] = true;
5722-
device->mul_mat_id_l_int[i] = true;
5723-
device->mul_mat_id_m_int[i] = true;
5724-
device->mul_mat_id_s_int[i] = true;
5719+
device->mul_mat_l_int[i] = device->mul_mat_l[i];
5720+
device->mul_mat_m_int[i] = device->mul_mat_m[i];
5721+
device->mul_mat_s_int[i] = device->mul_mat_s[i];
5722+
device->mul_mat_id_l_int[i] = device->mul_mat_id_l[i];
5723+
device->mul_mat_id_m_int[i] = device->mul_mat_id_m[i];
5724+
device->mul_mat_id_s_int[i] = device->mul_mat_id_s[i];
57255725
}
57265726

57275727

0 commit comments

Comments
 (0)