Skip to content

Commit 706fbd8

Browse files
authored
vulkan: Check shared memory size for mmq shaders (ggml-org#22693)
1 parent fa62042 commit 706fbd8

1 file changed

Lines changed: 149 additions & 19 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 149 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,15 @@ struct vk_device_struct {
681681
bool mul_mat_id_m[GGML_TYPE_COUNT];
682682
bool mul_mat_id_s[GGML_TYPE_COUNT];
683683

684+
// Separate flags for the q8_1 (integer dot) mmq path, whose shader uses
685+
// a different shared-memory layout than the float matmul shaders.
686+
bool mul_mat_l_int[GGML_TYPE_COUNT];
687+
bool mul_mat_m_int[GGML_TYPE_COUNT];
688+
bool mul_mat_s_int[GGML_TYPE_COUNT];
689+
bool mul_mat_id_l_int[GGML_TYPE_COUNT];
690+
bool mul_mat_id_m_int[GGML_TYPE_COUNT];
691+
bool mul_mat_id_s_int[GGML_TYPE_COUNT];
692+
684693
vk::DescriptorSetLayout dsl;
685694

686695
vk_matmul_pipeline pipeline_matmul_f32 {};
@@ -3207,6 +3216,70 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
32073216
return supported;
32083217
}
32093218

3219+
// Shmem usage for the q8_1 mmq shader (mul_mmq.comp), which uses
3220+
// block_a_cache / block_b_cache layouts (see mul_mmq_shmem_types.glsl) rather
3221+
// than the float load buffers checked by ggml_vk_matmul_shmem_support.
3222+
// Sizes follow std430 rules. Returns false for types without a q8_1 pipeline.
3223+
static bool ggml_vk_matmul_int_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
3224+
3225+
// FLOAT_TYPE in the shader is float16_t with fp16 support, otherwise float.
3226+
const uint32_t fp_size = device->fp16 ? 2u : 4u;
3227+
const uint32_t fp_align = fp_size;
3228+
const uint32_t fp2_size = 2u * fp_size;
3229+
const uint32_t fp2_align = device->fp16 ? 4u : 8u;
3230+
3231+
struct member { uint32_t size, align; };
3232+
auto std430_size = [](std::initializer_list<member> members) {
3233+
uint32_t off = 0, struct_align = 1;
3234+
for (const auto &m : members) {
3235+
off = (off + m.align - 1) & ~(m.align - 1);
3236+
off += m.size;
3237+
struct_align = std::max(struct_align, m.align);
3238+
}
3239+
return (off + struct_align - 1) & ~(struct_align - 1);
3240+
};
3241+
3242+
uint32_t block_a_size = 0;
3243+
switch (src0_type) {
3244+
case GGML_TYPE_Q4_0: block_a_size = std430_size({{16, 4}, {fp_size, fp_align}}); break; // qs[16/4] + dm
3245+
case GGML_TYPE_Q4_1: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + dm(vec2)
3246+
case GGML_TYPE_Q5_0: block_a_size = std430_size({{16, 4}, {4, 4}, {fp_size, fp_align}}); break; // qs[16/4] + qh + dm
3247+
case GGML_TYPE_Q5_1: block_a_size = std430_size({{16, 4}, {4, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + qh + dm(vec2)
3248+
case GGML_TYPE_Q8_0: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + dm
3249+
case GGML_TYPE_MXFP4: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + d
3250+
case GGML_TYPE_Q2_K: block_a_size = std430_size({{ 8, 4}, {2, 2}, {fp2_size, fp2_align}}); break; // qs[2] + scales(u8vec2) + dm(vec2)
3251+
case GGML_TYPE_Q3_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + d_scales(vec2)
3252+
case GGML_TYPE_Q4_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + dm(vec2)
3253+
case GGML_TYPE_Q5_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + dm(vec2)
3254+
case GGML_TYPE_Q6_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + d_scales(vec2)
3255+
default:
3256+
return false;
3257+
}
3258+
3259+
// block_b_cache: { int32_t qs[8]; FLOAT_TYPEV2 ds; }
3260+
const uint32_t block_b_size = std430_size({{32, 4}, {fp2_size, fp2_align}});
3261+
3262+
const uint32_t BM = warptile[1];
3263+
const uint32_t BN = warptile[2];
3264+
// mul_mmq.comp: BK_STEP=1 for MUL_MAT_ID, 4 otherwise.
3265+
const uint32_t BK_STEP = mul_mat_id ? 1u : 4u;
3266+
3267+
const uint32_t buf_a_size = BM * BK_STEP * block_a_size;
3268+
const uint32_t buf_b_size = BN * BK_STEP * block_b_size;
3269+
const uint32_t mmid_row_ids = mul_mat_id ? (BN * 2u * (uint32_t)sizeof(uint16_t)) : 0u;
3270+
3271+
const uint32_t warps = warptile[0] / warptile[10];
3272+
const uint32_t ballots_sh = mul_mat_id ? (warps * 4u * (uint32_t)sizeof(uint32_t)) : 0u;
3273+
3274+
const uint32_t total_size = buf_a_size + buf_b_size + mmid_row_ids + ballots_sh;
3275+
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
3276+
3277+
VK_LOG_DEBUG("ggml_vk_matmul_int_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
3278+
"mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", total=" << total_size << ", supported=" << supported);
3279+
3280+
return supported;
3281+
}
3282+
32103283
struct GpuPipelineConfig {
32113284
// GPU architecture identifier.
32123285
// Example: vk_device_architecture::AMD_GCN
@@ -3453,6 +3526,40 @@ static void ggml_vk_load_shaders(vk_device& device) {
34533526
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
34543527
device->mul_mat_id_l[i] = false;
34553528
}
3529+
3530+
// The q8_1 mmq path has its own (larger) shmem layout, check it separately.
3531+
// K-quants use the _int_k warptiles, others use _int.
3532+
const bool is_k_quant = (t == GGML_TYPE_Q2_K || t == GGML_TYPE_Q3_K ||
3533+
t == GGML_TYPE_Q4_K || t == GGML_TYPE_Q5_K ||
3534+
t == GGML_TYPE_Q6_K);
3535+
const auto & s_int = is_k_quant ? s_warptile_mmq_int_k : s_warptile_mmq_int;
3536+
const auto & m_int = is_k_quant ? m_warptile_mmq_int_k : m_warptile_mmq_int;
3537+
const auto & l_int = is_k_quant ? l_warptile_mmq_int_k : l_warptile_mmq_int;
3538+
const auto & s_intid = is_k_quant ? s_warptile_mmqid_int_k : s_warptile_mmqid_int;
3539+
const auto & m_intid = is_k_quant ? m_warptile_mmqid_int_k : m_warptile_mmqid_int;
3540+
const auto & l_intid = is_k_quant ? l_warptile_mmqid_int_k : l_warptile_mmqid_int;
3541+
3542+
if (!ggml_vk_matmul_int_shmem_support(device, s_int, false, t)) {
3543+
device->mul_mat_s_int[i] = false;
3544+
device->mul_mat_m_int[i] = false;
3545+
device->mul_mat_l_int[i] = false;
3546+
} else if (!ggml_vk_matmul_int_shmem_support(device, m_int, false, t)) {
3547+
device->mul_mat_m_int[i] = false;
3548+
device->mul_mat_l_int[i] = false;
3549+
} else if (!ggml_vk_matmul_int_shmem_support(device, l_int, false, t)) {
3550+
device->mul_mat_l_int[i] = false;
3551+
}
3552+
3553+
if (!ggml_vk_matmul_int_shmem_support(device, s_intid, true, t)) {
3554+
device->mul_mat_id_s_int[i] = false;
3555+
device->mul_mat_id_m_int[i] = false;
3556+
device->mul_mat_id_l_int[i] = false;
3557+
} else if (!ggml_vk_matmul_int_shmem_support(device, m_intid, true, t)) {
3558+
device->mul_mat_id_m_int[i] = false;
3559+
device->mul_mat_id_l_int[i] = false;
3560+
} else if (!ggml_vk_matmul_int_shmem_support(device, l_intid, true, t)) {
3561+
device->mul_mat_id_l_int[i] = false;
3562+
}
34563563
}
34573564
}
34583565

@@ -5613,6 +5720,13 @@ static vk_device ggml_vk_get_device(size_t idx) {
56135720
device->mul_mat_id_s[i] = true;
56145721
break;
56155722
}
5723+
5724+
device->mul_mat_l_int[i] = true;
5725+
device->mul_mat_m_int[i] = true;
5726+
device->mul_mat_s_int[i] = true;
5727+
device->mul_mat_id_l_int[i] = true;
5728+
device->mul_mat_id_m_int[i] = true;
5729+
device->mul_mat_id_s_int[i] = true;
56165730
}
56175731

56185732

@@ -7220,6 +7334,13 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m,
72207334
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
72217335
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
72227336

7337+
// The q8_1 (integer dot) mmq path uses a different shader with its own
7338+
// shared-memory layout, so use the int-specific availability flags.
7339+
const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1);
7340+
const bool mm_l = is_q8_1 ? ctx->device->mul_mat_l_int[src0_type] : ctx->device->mul_mat_l[src0_type];
7341+
const bool mm_m = is_q8_1 ? ctx->device->mul_mat_m_int[src0_type] : ctx->device->mul_mat_m[src0_type];
7342+
const bool mm_s = is_q8_1 ? ctx->device->mul_mat_s_int[src0_type] : ctx->device->mul_mat_s[src0_type];
7343+
72237344
if (ctx->device->coopmat2) {
72247345
const uint32_t shader_core_count = ctx->device->shader_core_count;
72257346
const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
@@ -7236,26 +7357,24 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
72367357
// split_k==3 with large tiles likely better than medium tiles with no split_k.
72377358
(tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
72387359

7239-
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
7360+
if ((mm_l && (n > crossover_large && prefer_large)) || (!mm_m && !mm_s)) {
72407361
return aligned ? mmp->a_l : mmp->l;
72417362
}
72427363
// Use medium shader when the N dimension is greater than the small shader's tile size
72437364
uint32_t crossover_medium = mmp->s->wg_denoms[1];
7244-
if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
7365+
if ((mm_m && (n > crossover_medium)) || !mm_s) {
72457366
return aligned ? mmp->a_m : mmp->m;
72467367
}
72477368
return aligned ? mmp->a_s : mmp->s;
72487369
}
72497370

7250-
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
7371+
if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) {
72517372
return aligned ? mmp->a_s : mmp->s;
72527373
}
7253-
if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
7374+
if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) {
72547375
return aligned ? mmp->a_m : mmp->m;
72557376
}
72567377
return aligned ? mmp->a_l : mmp->l;
7257-
7258-
GGML_UNUSED(src1_type);
72597378
}
72607379

72617380
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
@@ -7312,35 +7431,42 @@ static void ggml_vk_matmul(
73127431
ctx->prealloc_split_k_need_sync = true;
73137432
}
73147433

7315-
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
7316-
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
7434+
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
7435+
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
7436+
7437+
// The q8_1 (integer dot) mmq path uses a different shader with its own
7438+
// shared-memory layout, so use the int-specific availability flags.
7439+
const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1);
7440+
const bool mm_l = is_q8_1 ? ctx->device->mul_mat_id_l_int[src0_type] : ctx->device->mul_mat_id_l[src0_type];
7441+
const bool mm_m = is_q8_1 ? ctx->device->mul_mat_id_m_int[src0_type] : ctx->device->mul_mat_id_m[src0_type];
7442+
const bool mm_s = is_q8_1 ? ctx->device->mul_mat_id_s_int[src0_type] : ctx->device->mul_mat_id_s[src0_type];
73177443

73187444
if (ctx->device->coopmat2) {
73197445
// Use large shader when the N dimension is greater than the medium shader's tile size
73207446
uint32_t crossover_large = mmp->m->wg_denoms[1];
7321-
if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
7447+
if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s)) {
73227448
return aligned ? mmp->a_l : mmp->l;
73237449
}
73247450
// Use medium shader when the N dimension is greater than the small shader's tile size
73257451
uint32_t crossover_medium = mmp->s->wg_denoms[1];
7326-
if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
7452+
if ((mm_m && (n > crossover_medium)) || !mm_s) {
73277453
return aligned ? mmp->a_m : mmp->m;
73287454
}
73297455
return aligned ? mmp->a_s : mmp->s;
73307456
}
73317457

7332-
if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
7458+
if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) {
73337459
return aligned ? mmp->a_s : mmp->s;
73347460
}
7335-
if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
7461+
if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) {
73367462
return aligned ? mmp->a_m : mmp->m;
73377463
}
73387464
return aligned ? mmp->a_l : mmp->l;
73397465
}
73407466

7341-
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
7342-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
7343-
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
7467+
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
7468+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
7469+
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
73447470
}
73457471

73467472
static void ggml_vk_matmul_id(
@@ -7636,10 +7762,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
76367762
// Not implemented
76377763
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
76387764

7639-
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
7765+
const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type);
7766+
7767+
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, effective_src1_type));
76407768
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
76417769

7642-
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
7770+
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type);
76437771

76447772
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
76457773
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
@@ -8471,10 +8599,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
84718599
// Not implemented
84728600
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
84738601

8474-
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
8602+
const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type);
8603+
8604+
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type, effective_src1_type));
84758605
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
84768606

8477-
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
8607+
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type);
84788608

84798609
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
84808610
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);

0 commit comments

Comments
 (0)