@@ -681,6 +681,15 @@ struct vk_device_struct {
681681 bool mul_mat_id_m[GGML_TYPE_COUNT];
682682 bool mul_mat_id_s[GGML_TYPE_COUNT];
683683
684+ // Separate flags for the q8_1 (integer dot) mmq path, whose shader uses
685+ // a different shared-memory layout than the float matmul shaders.
686+ bool mul_mat_l_int[GGML_TYPE_COUNT];
687+ bool mul_mat_m_int[GGML_TYPE_COUNT];
688+ bool mul_mat_s_int[GGML_TYPE_COUNT];
689+ bool mul_mat_id_l_int[GGML_TYPE_COUNT];
690+ bool mul_mat_id_m_int[GGML_TYPE_COUNT];
691+ bool mul_mat_id_s_int[GGML_TYPE_COUNT];
692+
684693 vk::DescriptorSetLayout dsl;
685694
686695 vk_matmul_pipeline pipeline_matmul_f32 {};
@@ -3207,6 +3216,70 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
32073216 return supported;
32083217}
32093218
3219+ // Shmem usage for the q8_1 mmq shader (mul_mmq.comp), which uses
3220+ // block_a_cache / block_b_cache layouts (see mul_mmq_shmem_types.glsl) rather
3221+ // than the float load buffers checked by ggml_vk_matmul_shmem_support.
3222+ // Sizes follow std430 rules. Returns false for types without a q8_1 pipeline.
3223+ static bool ggml_vk_matmul_int_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
3224+
3225+ // FLOAT_TYPE in the shader is float16_t with fp16 support, otherwise float.
3226+ const uint32_t fp_size = device->fp16 ? 2u : 4u;
3227+ const uint32_t fp_align = fp_size;
3228+ const uint32_t fp2_size = 2u * fp_size;
3229+ const uint32_t fp2_align = device->fp16 ? 4u : 8u;
3230+
3231+ struct member { uint32_t size, align; };
3232+ auto std430_size = [](std::initializer_list<member> members) {
3233+ uint32_t off = 0, struct_align = 1;
3234+ for (const auto &m : members) {
3235+ off = (off + m.align - 1) & ~(m.align - 1);
3236+ off += m.size;
3237+ struct_align = std::max(struct_align, m.align);
3238+ }
3239+ return (off + struct_align - 1) & ~(struct_align - 1);
3240+ };
3241+
3242+ uint32_t block_a_size = 0;
3243+ switch (src0_type) {
3244+ case GGML_TYPE_Q4_0: block_a_size = std430_size({{16, 4}, {fp_size, fp_align}}); break; // qs[16/4] + dm
3245+ case GGML_TYPE_Q4_1: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + dm(vec2)
3246+ case GGML_TYPE_Q5_0: block_a_size = std430_size({{16, 4}, {4, 4}, {fp_size, fp_align}}); break; // qs[16/4] + qh + dm
3247+ case GGML_TYPE_Q5_1: block_a_size = std430_size({{16, 4}, {4, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + qh + dm(vec2)
3248+ case GGML_TYPE_Q8_0: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + dm
3249+ case GGML_TYPE_MXFP4: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + d
3250+ case GGML_TYPE_Q2_K: block_a_size = std430_size({{ 8, 4}, {2, 2}, {fp2_size, fp2_align}}); break; // qs[2] + scales(u8vec2) + dm(vec2)
3251+ case GGML_TYPE_Q3_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + d_scales(vec2)
3252+ case GGML_TYPE_Q4_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + dm(vec2)
3253+ case GGML_TYPE_Q5_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + dm(vec2)
3254+ case GGML_TYPE_Q6_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + d_scales(vec2)
3255+ default:
3256+ return false;
3257+ }
3258+
3259+ // block_b_cache: { int32_t qs[8]; FLOAT_TYPEV2 ds; }
3260+ const uint32_t block_b_size = std430_size({{32, 4}, {fp2_size, fp2_align}});
3261+
3262+ const uint32_t BM = warptile[1];
3263+ const uint32_t BN = warptile[2];
3264+ // mul_mmq.comp: BK_STEP=1 for MUL_MAT_ID, 4 otherwise.
3265+ const uint32_t BK_STEP = mul_mat_id ? 1u : 4u;
3266+
3267+ const uint32_t buf_a_size = BM * BK_STEP * block_a_size;
3268+ const uint32_t buf_b_size = BN * BK_STEP * block_b_size;
3269+ const uint32_t mmid_row_ids = mul_mat_id ? (BN * 2u * (uint32_t)sizeof(uint16_t)) : 0u;
3270+
3271+ const uint32_t warps = warptile[0] / warptile[10];
3272+ const uint32_t ballots_sh = mul_mat_id ? (warps * 4u * (uint32_t)sizeof(uint32_t)) : 0u;
3273+
3274+ const uint32_t total_size = buf_a_size + buf_b_size + mmid_row_ids + ballots_sh;
3275+ const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
3276+
3277+ VK_LOG_DEBUG("ggml_vk_matmul_int_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
3278+ "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", total=" << total_size << ", supported=" << supported);
3279+
3280+ return supported;
3281+ }
3282+
32103283struct GpuPipelineConfig {
32113284 // GPU architecture identifier.
32123285 // Example: vk_device_architecture::AMD_GCN
@@ -3453,6 +3526,40 @@ static void ggml_vk_load_shaders(vk_device& device) {
34533526 } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
34543527 device->mul_mat_id_l[i] = false;
34553528 }
3529+
3530+ // The q8_1 mmq path has its own (larger) shmem layout, check it separately.
3531+ // K-quants use the _int_k warptiles, others use _int.
3532+ const bool is_k_quant = (t == GGML_TYPE_Q2_K || t == GGML_TYPE_Q3_K ||
3533+ t == GGML_TYPE_Q4_K || t == GGML_TYPE_Q5_K ||
3534+ t == GGML_TYPE_Q6_K);
3535+ const auto & s_int = is_k_quant ? s_warptile_mmq_int_k : s_warptile_mmq_int;
3536+ const auto & m_int = is_k_quant ? m_warptile_mmq_int_k : m_warptile_mmq_int;
3537+ const auto & l_int = is_k_quant ? l_warptile_mmq_int_k : l_warptile_mmq_int;
3538+ const auto & s_intid = is_k_quant ? s_warptile_mmqid_int_k : s_warptile_mmqid_int;
3539+ const auto & m_intid = is_k_quant ? m_warptile_mmqid_int_k : m_warptile_mmqid_int;
3540+ const auto & l_intid = is_k_quant ? l_warptile_mmqid_int_k : l_warptile_mmqid_int;
3541+
3542+ if (!ggml_vk_matmul_int_shmem_support(device, s_int, false, t)) {
3543+ device->mul_mat_s_int[i] = false;
3544+ device->mul_mat_m_int[i] = false;
3545+ device->mul_mat_l_int[i] = false;
3546+ } else if (!ggml_vk_matmul_int_shmem_support(device, m_int, false, t)) {
3547+ device->mul_mat_m_int[i] = false;
3548+ device->mul_mat_l_int[i] = false;
3549+ } else if (!ggml_vk_matmul_int_shmem_support(device, l_int, false, t)) {
3550+ device->mul_mat_l_int[i] = false;
3551+ }
3552+
3553+ if (!ggml_vk_matmul_int_shmem_support(device, s_intid, true, t)) {
3554+ device->mul_mat_id_s_int[i] = false;
3555+ device->mul_mat_id_m_int[i] = false;
3556+ device->mul_mat_id_l_int[i] = false;
3557+ } else if (!ggml_vk_matmul_int_shmem_support(device, m_intid, true, t)) {
3558+ device->mul_mat_id_m_int[i] = false;
3559+ device->mul_mat_id_l_int[i] = false;
3560+ } else if (!ggml_vk_matmul_int_shmem_support(device, l_intid, true, t)) {
3561+ device->mul_mat_id_l_int[i] = false;
3562+ }
34563563 }
34573564 }
34583565
@@ -5613,6 +5720,13 @@ static vk_device ggml_vk_get_device(size_t idx) {
56135720 device->mul_mat_id_s[i] = true;
56145721 break;
56155722 }
5723+
5724+ device->mul_mat_l_int[i] = true;
5725+ device->mul_mat_m_int[i] = true;
5726+ device->mul_mat_s_int[i] = true;
5727+ device->mul_mat_id_l_int[i] = true;
5728+ device->mul_mat_id_m_int[i] = true;
5729+ device->mul_mat_id_s_int[i] = true;
56165730 }
56175731
56185732
@@ -7220,6 +7334,13 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m,
72207334static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
72217335 VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
72227336
7337+ // The q8_1 (integer dot) mmq path uses a different shader with its own
7338+ // shared-memory layout, so use the int-specific availability flags.
7339+ const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1);
7340+ const bool mm_l = is_q8_1 ? ctx->device->mul_mat_l_int[src0_type] : ctx->device->mul_mat_l[src0_type];
7341+ const bool mm_m = is_q8_1 ? ctx->device->mul_mat_m_int[src0_type] : ctx->device->mul_mat_m[src0_type];
7342+ const bool mm_s = is_q8_1 ? ctx->device->mul_mat_s_int[src0_type] : ctx->device->mul_mat_s[src0_type];
7343+
72237344 if (ctx->device->coopmat2) {
72247345 const uint32_t shader_core_count = ctx->device->shader_core_count;
72257346 const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
@@ -7236,26 +7357,24 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
72367357 // split_k==3 with large tiles likely better than medium tiles with no split_k.
72377358 (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
72387359
7239- if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type] )) {
7360+ if ((mm_l && (n > crossover_large && prefer_large)) || (!mm_m && !mm_s )) {
72407361 return aligned ? mmp->a_l : mmp->l;
72417362 }
72427363 // Use medium shader when the N dimension is greater than the small shader's tile size
72437364 uint32_t crossover_medium = mmp->s->wg_denoms[1];
7244- if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type] ) {
7365+ if ((mm_m && (n > crossover_medium)) || !mm_s ) {
72457366 return aligned ? mmp->a_m : mmp->m;
72467367 }
72477368 return aligned ? mmp->a_s : mmp->s;
72487369 }
72497370
7250- if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type] )) {
7371+ if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l )) {
72517372 return aligned ? mmp->a_s : mmp->s;
72527373 }
7253- if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type] ) {
7374+ if ((mm_m && (m <= 64 || n <= 64)) || !mm_l ) {
72547375 return aligned ? mmp->a_m : mmp->m;
72557376 }
72567377 return aligned ? mmp->a_l : mmp->l;
7257-
7258- GGML_UNUSED(src1_type);
72597378}
72607379
72617380static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
@@ -7312,35 +7431,42 @@ static void ggml_vk_matmul(
73127431 ctx->prealloc_split_k_need_sync = true;
73137432}
73147433
7315- static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
7316- VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
7434+ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
7435+ VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
7436+
7437+ // The q8_1 (integer dot) mmq path uses a different shader with its own
7438+ // shared-memory layout, so use the int-specific availability flags.
7439+ const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1);
7440+ const bool mm_l = is_q8_1 ? ctx->device->mul_mat_id_l_int[src0_type] : ctx->device->mul_mat_id_l[src0_type];
7441+ const bool mm_m = is_q8_1 ? ctx->device->mul_mat_id_m_int[src0_type] : ctx->device->mul_mat_id_m[src0_type];
7442+ const bool mm_s = is_q8_1 ? ctx->device->mul_mat_id_s_int[src0_type] : ctx->device->mul_mat_id_s[src0_type];
73177443
73187444 if (ctx->device->coopmat2) {
73197445 // Use large shader when the N dimension is greater than the medium shader's tile size
73207446 uint32_t crossover_large = mmp->m->wg_denoms[1];
7321- if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type] )) {
7447+ if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s )) {
73227448 return aligned ? mmp->a_l : mmp->l;
73237449 }
73247450 // Use medium shader when the N dimension is greater than the small shader's tile size
73257451 uint32_t crossover_medium = mmp->s->wg_denoms[1];
7326- if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type] ) {
7452+ if ((mm_m && (n > crossover_medium)) || !mm_s ) {
73277453 return aligned ? mmp->a_m : mmp->m;
73287454 }
73297455 return aligned ? mmp->a_s : mmp->s;
73307456 }
73317457
7332- if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type] )) {
7458+ if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l )) {
73337459 return aligned ? mmp->a_s : mmp->s;
73347460 }
7335- if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type] ) {
7461+ if ((mm_m && (m <= 64 || n <= 64)) || !mm_l ) {
73367462 return aligned ? mmp->a_m : mmp->m;
73377463 }
73387464 return aligned ? mmp->a_l : mmp->l;
73397465}
73407466
7341- static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
7342- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
7343- return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
7467+ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type ) {
7468+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << " )");
7469+ return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type, src1_type )->align;
73447470}
73457471
73467472static void ggml_vk_matmul_id(
@@ -7636,10 +7762,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
76367762 // Not implemented
76377763 GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
76387764
7639- const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
7765+ const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type);
7766+
7767+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, effective_src1_type));
76407768 const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
76417769
7642- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type) );
7770+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type );
76437771
76447772 if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
76457773 pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
@@ -8471,10 +8599,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
84718599 // Not implemented
84728600 GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
84738601
8474- const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
8602+ const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type);
8603+
8604+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type, effective_src1_type));
84758605 const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
84768606
8477- vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
8607+ vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type );
84788608
84798609 if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
84808610 pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
0 commit comments