From da44953329daec5f16b7b19d7ebfdc6552415362 Mon Sep 17 00:00:00 2001 From: guyfischman <138163913+guyfischman@users.noreply.github.com> Date: Tue, 12 May 2026 07:15:02 +0200 Subject: [PATCH 1/5] metal : promote mul_mv/mul_mm batch divisors to function constants (#22711) * metal : promote mul_mv/mul_mm batch divisors to function constants * metal : take op directly in get_pipeline_mul_mv_ext --- ggml/src/ggml-metal/ggml-metal-device.cpp | 46 +++++- ggml/src/ggml-metal/ggml-metal-device.h | 2 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 165 +++++++++++----------- 4 files changed, 127 insertions(+), 88 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index d211bf79f14..f0147af84c1 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -647,19 +647,30 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_m return res; } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, const ggml_tensor * op, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + const int ne12 = op->src[1]->ne[2]; + const int r2 = ne12 / op->src[0]->ne[2]; + const int r3 = op->src[1]->ne[3] / op->src[0]->ne[3]; + + GGML_ASSERT(ne12 <= INT16_MAX && r2 <= INT16_MAX && r3 <= INT16_MAX); + snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); - snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); + snprintf(name, 256, "%s_nsg=%d_nxpsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, nxpsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, (int16_t) r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, (int16_t) r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -687,8 +698,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0) : (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0); + GGML_ASSERT(op->src[1]->ne[2] <= INT16_MAX && op->src[1]->ne[3] <= INT16_MAX); + const int16_t ne12 = (int16_t) op->src[1]->ne[2]; + const int16_t ne13 = (int16_t) op->src[1]->ne[3]; + const int16_t r2 = (int16_t) (ne12 / op->src[0]->ne[2]); + const int16_t r3 = (int16_t) (ne13 / op->src[0]->ne[3]); + snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); - snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); + snprintf(name, 256, "%s_bci=%d_bco=%d_ne12=%d_ne13=%d_r2=%d_r3=%d", + base, bc_inp, bc_out, ne12, ne13, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -696,6 +714,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + ggml_metal_cv_set_int16(cv, ne12, FC_MUL_MM + 2); + ggml_metal_cv_set_int16(cv, ne13, FC_MUL_MM + 3); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MM + 4); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MM + 5); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -877,14 +899,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta } }; + GGML_ASSERT(ne12 <= INT16_MAX && ne13 <= INT16_MAX); + const int16_t r2 = (int16_t) (ne12 / ne02); + const int16_t r3 = (int16_t) (ne13 / ne03); + snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); - snprintf(name, 256, "%s_nsg=%d", base, nsg); + snprintf(name, 256, "%s_nsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -1102,6 +1131,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m ggml_metal_cv_t cv = ggml_metal_cv_init(); ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 4718ca083b0..1f212a92f98 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -129,7 +129,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, const struct ggml_tensor * op, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5fa162c875c..a114391c2e8 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2120,7 +2120,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported ne11"); }; - auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op, nsg, nxpsg, r1ptg); ggml_metal_kargs_mul_mv_ext args = { /*.ne00 =*/ ne00, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c372eaedeae..3882b955847 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3353,6 +3353,9 @@ static inline void helper_mv_reduce_and_write( constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]]; constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]]; +constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]]; +constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]]; +constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]]; template void mul_vec_q_n_f32_impl( @@ -3376,10 +3379,10 @@ void mul_vec_q_n_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); @@ -3388,7 +3391,7 @@ void mul_vec_q_n_f32_impl( // pointers to src0 rows device const block_q_type * ax[NR0]; FOR_UNROLL (int row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -3462,8 +3465,8 @@ void kernel_mul_mv_q1_0_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; @@ -3471,7 +3474,7 @@ void kernel_mul_mv_q1_0_f32_impl( device const block_q1_0 * ax[nr0]; for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); } @@ -3590,10 +3593,10 @@ void kernel_mul_mv_q8_0_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); @@ -3602,7 +3605,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -3682,10 +3685,10 @@ void kernel_mul_mv_ext_q4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -3785,10 +3788,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -4000,10 +4003,10 @@ void kernel_mul_mv_t_t_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const T0 * x = (device const T0 *) (src0 + offset0); @@ -4012,7 +4015,7 @@ void kernel_mul_mv_t_t_impl( // pointers to src0 rows device const T0 * ax [NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const T0 *) ((device char *) src0 + offset0); } @@ -4122,10 +4125,10 @@ void kernel_mul_mv_t_t_4_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); @@ -4135,7 +4138,7 @@ void kernel_mul_mv_t_t_4_impl( device const T0 * ax [NR0]; device const T04 * ax4[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax [row] = (device const T0 *) ((device char *) src0 + offset0); ax4[row] = (device const T04 *) ((device char *) src0 + offset0); @@ -4239,10 +4242,10 @@ void kernel_mul_mv_t_t_short_impl( return; } - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); @@ -7462,10 +7465,10 @@ void kernel_mul_mv_q2_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); @@ -7567,10 +7570,10 @@ void kernel_mul_mv_q3_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); @@ -7741,10 +7744,10 @@ void kernel_mul_mv_q4_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); @@ -7853,10 +7856,10 @@ void kernel_mul_mv_q5_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); @@ -7989,10 +7992,10 @@ void kernel_mul_mv_q6_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); @@ -8094,10 +8097,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); @@ -8202,10 +8205,10 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); @@ -8321,10 +8324,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); @@ -8433,10 +8436,10 @@ void kernel_mul_mv_iq3_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); @@ -8545,10 +8548,10 @@ void kernel_mul_mv_iq2_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); @@ -8658,10 +8661,10 @@ void kernel_mul_mv_iq1_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); @@ -8757,10 +8760,10 @@ void kernel_mul_mv_iq1_m_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); @@ -8866,10 +8869,10 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); @@ -8975,10 +8978,10 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); @@ -9086,10 +9089,10 @@ void kernel_mul_mv_mxfp4_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); @@ -9304,6 +9307,10 @@ kernel void kernel_diag_f32( constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; +constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]]; +constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]]; +constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]]; +constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]]; // each block_q contains 16*nl weights #ifdef GGML_METAL_HAS_TENSOR @@ -9330,11 +9337,11 @@ kernel void kernel_mul_mm( // Batch dimension handling const int im = tgpig.z; - const int i12 = im % args.ne12; - const int i13 = im / args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; // Batch offsets for srcA and srcB - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; // Tile dimensions constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; @@ -9473,10 +9480,10 @@ kernel void kernel_mul_mm( short il = il0; - const int i12 = im%args.ne12; - const int i13 = im/args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; const short offset1 = il0/nl; device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; From 78fbbc2c0788efc8857a2c0dc9802ec689fa12c1 Mon Sep 17 00:00:00 2001 From: Jesus Talavera <145992175+jesus-talavera-ibm@users.noreply.github.com> Date: Tue, 12 May 2026 07:17:04 +0200 Subject: [PATCH 2/5] convert : add split() to LoraTorchTensor in LoRA converter (#22832) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * convert : add split() method to LoraTorchTensor * Fix python type-check * Fix flake8 Lint * fix: handle positional dim arg in torch.split dispatch * Fix type-check again * Fix type-checks * Remove unit test per reviewers feedback * work around ty deficiency --------- Co-authored-by: Sigbjørn Skjæret --- convert_lora_to_gguf.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index d5833420560..ad4751bb96e 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -188,6 +188,24 @@ def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor: def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor: return self.transpose(axis0, axis1) + def split(self, split_size: int | Sequence[int], dim: int = 0) -> tuple[LoraTorchTensor, ...]: + shape = self.shape + ndim = len(shape) + if dim < 0: + dim += ndim + if dim == ndim - 1: + A_chunks = self._lora_A.split(split_size, dim=-1) + return tuple(LoraTorchTensor(a, self._lora_B) for a in A_chunks) + elif dim == ndim - 2: + B_chunks = self._lora_B.split(split_size, dim=-2) + return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks) + else: + B_chunks = self._lora_B.split(split_size, dim=dim) + if self._lora_A.shape[dim] == 1: + return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks) + A_chunks = self._lora_A.split(split_size, dim=dim) + return tuple(LoraTorchTensor(a, b) for a, b in zip(A_chunks, B_chunks)) + def to(self, *args, **kwargs): return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) @@ -230,6 +248,11 @@ def __torch_function__(cls, func: Callable, types, args=(), kwargs=None): ) else: raise NotImplementedError + elif func is torch.split: + assert len(args) and len(args) >= 2 + tensor, split_size = args[0], args[1] + dim = args[2] if len(args) > 2 else kwargs.get("dim", 0) + return tensor.split(split_size, dim=dim) else: raise NotImplementedError From 41782591303ada159467fc75b232b2c9110f45aa Mon Sep 17 00:00:00 2001 From: AesSedai <7980540+AesSedai@users.noreply.github.com> Date: Tue, 12 May 2026 02:11:14 -0700 Subject: [PATCH 3/5] mtmd: add MiMo v2.5 vision (#22883) * mimo-v2.5: vision support * mimo-v2.5: use fused qkv for vision * mimi-v2.5: fix f16 vision overflow * mimo-v2.5: comment cleanups * mimo-v2.5: Flash doesn't have mmproj more cleanup remember to use filter_tensors * mimo-v2.5: fix trailing whitespace --- convert_hf_to_gguf.py | 67 +++++++++++ gguf-py/gguf/constants.py | 48 ++++---- gguf-py/gguf/gguf_writer.py | 6 + gguf-py/gguf/tensor_mapping.py | 4 + tools/mtmd/CMakeLists.txt | 1 + tools/mtmd/clip-graph.h | 3 +- tools/mtmd/clip-impl.h | 5 + tools/mtmd/clip-model.h | 4 + tools/mtmd/clip.cpp | 126 +++++++++++++++++++- tools/mtmd/models/mimovl.cpp | 209 +++++++++++++++++++++++++++++++++ tools/mtmd/models/models.h | 9 ++ tools/mtmd/mtmd.cpp | 1 + 12 files changed, 460 insertions(+), 23 deletions(-) create mode 100644 tools/mtmd/models/mimovl.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bf76fa40614..d79372ceac7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9760,6 +9760,73 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("MiMoV2ForCausalLM") +class MiMoV2VisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + hp = self.hparams_vision + + hp["image_size"] = hp.get("image_size", 560) + hp["num_attention_heads"] = hp.get("num_heads", 32) + hp["num_hidden_layers"] = hp.get("depth", 28) + + self.n_q_heads = int(hp["num_heads"]) + self.num_kv_heads = int(hp.get("num_key_value_heads", 8)) + self.head_dim = int(hp.get("qk_channels", 64)) + self.spatial_merge_size = int(hp["spatial_merge_size"]) + # MiMoV2 vision RMSNorm: HF uses getattr(config, "rms_norm_eps", 1e-6) and the + # field is absent from MiMo-V2.5's vision_config + self.rms_norm_eps = float(hp.get("rms_norm_eps", 1e-6)) + + # fullatt_block_indexes are also reflected in vit_window_attn_types as -1 + self.fullatt_block_indexes = list(hp.get("fullatt_block_indexes") or []) + self.vit_window_attn_types = list(hp.get("vit_window_attn_types") or []) + self.visual_token_window_size = int(hp.get("visual_token_window_size", -1)) + self.use_sink = bool(hp.get("use_sink", False)) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.MIMOVL) + self.gguf_writer.add_vision_use_silu(True) + self.gguf_writer.add_vision_head_count_kv(self.num_kv_heads) + self.gguf_writer.add_vision_spatial_merge_size(self.spatial_merge_size) + self.gguf_writer.add_uint32(gguf.Keys.ClipVision.WINDOW_SIZE, self.visual_token_window_size) + self.gguf_writer.add_vision_wa_pattern_mode(self.vit_window_attn_types) + self.gguf_writer.add_vision_attention_layernorm_eps(self.rms_norm_eps) + self.gguf_writer.add_vision_min_pixels(int(self.preprocessor_config["min_pixels"])) + self.gguf_writer.add_vision_max_pixels(int(self.preprocessor_config["max_pixels"])) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + # Sinks must be F32: any sink-style softmax/mask add in ggml requires + # F32, and we fold sinks into a host-built F32 mask at encode time. + if new_name.endswith(".attn_sinks"): + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + @classmethod + def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: + name, _ = item + if not name.startswith("visual."): + return None + return super().filter_tensors(item) + + def modify_tensors(self, data_torch, name, bid): + # Conv3D patch embed: split along the temporal axis (kt=2) into two Conv2D + # weights that the existing qwen2vl-style two-Conv2D path consumes. + if name == "visual.patch_embed.proj.weight": + _, _, kt, _, _ = data_torch.shape + if kt != 2: + raise ValueError(f"unexpected temporal_patch_size: {kt}") + embd_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + yield (embd_name + ".weight", data_torch[:, :, 0, ...]) + yield (embd_name + ".weight.1", data_torch[:, :, 1, ...]) + return + + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Step3p5ForCausalLM") class Step35Model(TextModel): model_arch = gguf.MODEL_ARCH.STEP35 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 617cbc49d19..4055ec2873a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -299,30 +299,32 @@ class Clip: HAS_LLAVA_PROJECTOR = "clip.has_llava_projector" class ClipVision: - PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modality models - IMAGE_SIZE = "clip.vision.image_size" - IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" - IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" - PREPROC_MIN_TILES = "clip.vision.preproc_min_tiles" - PREPROC_MAX_TILES = "clip.vision.preproc_max_tiles" - PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" - PATCH_SIZE = "clip.vision.patch_size" - EMBEDDING_LENGTH = "clip.vision.embedding_length" - FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length" - PROJECTION_DIM = "clip.vision.projection_dim" - BLOCK_COUNT = "clip.vision.block_count" - IMAGE_MEAN = "clip.vision.image_mean" - IMAGE_STD = "clip.vision.image_std" - SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" - USE_GELU = "clip.use_gelu" - USE_SILU = "clip.use_silu" - N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl - WA_LAYER_INDEXES = "clip.vision.wa_layer_indexes" # used by youtuvl - IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" - WINDOW_SIZE = "clip.vision.window_size" + PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modality models + IMAGE_SIZE = "clip.vision.image_size" + IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" + IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" + PREPROC_MIN_TILES = "clip.vision.preproc_min_tiles" + PREPROC_MAX_TILES = "clip.vision.preproc_max_tiles" + PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" + PATCH_SIZE = "clip.vision.patch_size" + EMBEDDING_LENGTH = "clip.vision.embedding_length" + FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length" + PROJECTION_DIM = "clip.vision.projection_dim" + BLOCK_COUNT = "clip.vision.block_count" + IMAGE_MEAN = "clip.vision.image_mean" + IMAGE_STD = "clip.vision.image_std" + SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" + USE_GELU = "clip.use_gelu" + USE_SILU = "clip.use_silu" + N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl + WA_LAYER_INDEXES = "clip.vision.wa_layer_indexes" # used by youtuvl + WA_PATTERN_MODE = "clip.vision.wa_pattern_mode" # used by mimovl, per-layer -1/0/1 + IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" + WINDOW_SIZE = "clip.vision.window_size" class Attention: HEAD_COUNT = "clip.vision.attention.head_count" + HEAD_COUNT_KV = "clip.vision.attention.head_count_kv" # used by mimovl (GQA) LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon" class Projector: @@ -733,6 +735,7 @@ class MODEL_TENSOR(IntEnum): V_ENC_ATTN_V = auto() V_ENC_ATTN_O = auto() V_ENC_ATTN_O_NORM = auto() + V_ENC_ATTN_SINKS = auto() # mimovl V_ENC_POST_ATTN_NORM = auto() V_ENC_FFN_UP = auto() V_ENC_FFN_GATE = auto() @@ -1246,6 +1249,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1", MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out", MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm", + MODEL_TENSOR.V_ENC_ATTN_SINKS: "v.blk.{bid}.attn_sinks", MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2", MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", @@ -1426,6 +1430,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_ENC_ATTN_V, MODEL_TENSOR.V_ENC_ATTN_O, MODEL_TENSOR.V_ENC_ATTN_O_NORM, + MODEL_TENSOR.V_ENC_ATTN_SINKS, MODEL_TENSOR.V_ENC_POST_ATTN_NORM, MODEL_TENSOR.V_ENC_FFN_UP, MODEL_TENSOR.V_ENC_FFN_GATE, @@ -4258,6 +4263,7 @@ class VisionProjectorType: HUNYUANVL = "hunyuanvl" MINICPMV4_6 = "minicpmv4_6" GRANITE_SPEECH = "granite_speech" # audio + MIMOVL = "mimovl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 35fb01470c4..a101382719d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1151,6 +1151,9 @@ def add_vision_block_count(self, value: int) -> None: def add_vision_head_count(self, value: int) -> None: self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value) + def add_vision_head_count_kv(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT_KV, value) + def add_vision_attention_layernorm_eps(self, value: float) -> None: self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value) @@ -1222,6 +1225,9 @@ def add_vision_wa_layer_indexes(self, layers: Sequence[int]) -> None: def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) + def add_vision_wa_pattern_mode(self, modes: Sequence[int]) -> None: + self.add_array(Keys.ClipVision.WA_PATTERN_MODE, modes) + def add_vision_window_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index f27f0e4c997..f40cb828201 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1569,6 +1569,10 @@ class TensorNameMap: "vision_model.transformer.resblocks.{bid}.attn.out_proj", # Step3-VL ), + MODEL_TENSOR.V_ENC_ATTN_SINKS: ( + "visual.blocks.{bid}.attn.sinks", # mimovl + ), + MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", "model.vision_tower.encoder.layers.{bid}.layer_norm2", # minicpmv4_6 diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 21d17dbaa41..a76adc9b80b 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -34,6 +34,7 @@ add_library(mtmd models/pixtral.cpp models/qwen2vl.cpp models/qwen3vl.cpp + models/mimovl.cpp models/qwen3a.cpp models/step3vl.cpp models/siglip.cpp diff --git a/tools/mtmd/clip-graph.h b/tools/mtmd/clip-graph.h index d3e7b1ed044..39f069501bc 100644 --- a/tools/mtmd/clip-graph.h +++ b/tools/mtmd/clip-graph.h @@ -98,7 +98,8 @@ struct clip_graph { ggml_tensor * v_cur, ggml_tensor * kq_mask, float kq_scale, - int il) const; + int il, + ggml_tensor * sinks = nullptr) const; // implementation of the 2D RoPE without adding a new op in ggml // this is not efficient (use double the memory), but works on all backends diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 817bf26b217..8e09f26e988 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -31,6 +31,7 @@ #define KEY_N_BLOCK "clip.%s.block_count" #define KEY_PROJ_DIM "clip.%s.projection_dim" #define KEY_N_HEAD "clip.%s.attention.head_count" +#define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv" #define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" // vision-specific @@ -53,6 +54,7 @@ #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" #define KEY_WIN_ATTN_LAYER_INDEXES "clip.vision.wa_layer_indexes" +#define KEY_WA_PATTERN_MODE "clip.vision.wa_pattern_mode" #define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" #define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" @@ -86,6 +88,7 @@ #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" #define TN_ATTN_V "%s.blk.%d.attn_v.%s" #define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" +#define TN_ATTN_SINKS "%s.blk.%d.attn_sinks" #define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s" #define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s" #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" @@ -344,6 +347,7 @@ enum projector_type { PROJECTOR_TYPE_HUNYUANVL, PROJECTOR_TYPE_MINICPMV4_6, PROJECTOR_TYPE_GRANITE_SPEECH, + PROJECTOR_TYPE_MIMOVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -393,6 +397,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"}, { PROJECTOR_TYPE_MINICPMV4_6, "minicpmv4_6"}, { PROJECTOR_TYPE_GRANITE_SPEECH, "granite_speech"}, + { PROJECTOR_TYPE_MIMOVL, "mimovl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 48f8b1a193f..ce15dbcd11e 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -42,6 +42,7 @@ struct clip_hparams { int32_t n_ff = 0; int32_t projection_dim = 0; int32_t n_head = 0; + int32_t n_head_kv = 0; int32_t n_layer = 0; // idefics3 int32_t n_merge = 0; // number of patch merges **per-side** @@ -83,6 +84,7 @@ struct clip_hparams { int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; std::unordered_set wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL) + std::vector wa_pattern_mode; // mimovl: per-layer window-attention mode // deepseek-ocr (sam) int32_t sam_n_layer = 0; @@ -166,6 +168,8 @@ struct clip_layer { ggml_tensor * o_w = nullptr; ggml_tensor * o_b = nullptr; + ggml_tensor * attn_sinks = nullptr; + ggml_tensor * k_norm = nullptr; ggml_tensor * q_norm = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 513b94f2a65..f0c63d375a3 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -642,7 +642,8 @@ ggml_tensor * clip_graph::build_attn( ggml_tensor * v_cur, ggml_tensor * kq_mask, float kq_scale, - int il) const { + int il, + ggml_tensor * sinks) const { // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); @@ -665,6 +666,9 @@ ggml_tensor * clip_graph::build_attn( cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + if (sinks != nullptr) { + ggml_flash_attn_ext_add_sinks(cur, sinks); + } cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); @@ -677,6 +681,9 @@ ggml_tensor * clip_graph::build_attn( // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); + if (sinks != nullptr) { + ggml_soft_max_add_sinks(kq, sinks); + } ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); @@ -866,6 +873,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_MIMOVL: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_STEP3VL: { builder = std::make_unique(ctx, img); @@ -1389,6 +1400,22 @@ struct clip_model_loader { LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); } } break; + case PROJECTOR_TYPE_MIMOVL: + { + hparams.n_merge = 2; // spatial_merge_size + hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + get_u32(string_format(KEY_N_HEAD_KV, "vision"), hparams.n_head_kv); + // 1D banded sliding-window radius (visual_token_window_size); required + get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size); + std::vector pat; + get_arr_int(KEY_WA_PATTERN_MODE, pat, true); + GGML_ASSERT((int) pat.size() == hparams.n_layer && "mimovl wa_pattern_mode length must equal n_layer"); + hparams.wa_pattern_mode.assign(pat.begin(), pat.end()); + get_u32(KEY_IMAGE_MIN_PIXELS, hparams.image_min_pixels); + get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels); + hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_STEP3VL: { hparams.n_merge = 4; // two stride-2 downsamplers after patching @@ -1729,6 +1756,8 @@ struct clip_model_loader { layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight")); layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false); + // mimovl per-head attention sink bias + layer.attn_sinks = get_tensor(string_format(TN_ATTN_SINKS, prefix, il), false); // qwen3vl deepstack layer layer.deepstack_norm_w = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "weight"), false); @@ -1913,6 +1942,13 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_MIMOVL: + { + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + } break; case PROJECTOR_TYPE_STEP3VL: { model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); @@ -3011,6 +3047,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_HUNYUANOCR: @@ -3032,6 +3069,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_HUNYUANVL: @@ -3110,6 +3148,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_YOUTUVL: { @@ -3681,6 +3720,89 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("positions", positions); } break; + case PROJECTOR_TYPE_MIMOVL: + { + const int merge = hparams.n_merge; // 2 + const int merge_unit = merge * merge; // 4 + const int patch = hparams.patch_size; // 16 + const int H = image_size_height / patch; + const int W = image_size_width / patch; + const int n_pos_full = H * W; + const int llm_h = H / merge; + const int llm_w = W / merge; + const int n_units = llm_h * llm_w; // n_pos / merge_unit + + // Row-major merge-tile-ordered (h, w) positions + std::vector pos_h_row(n_pos_full); + std::vector pos_w_row(n_pos_full); + { + int idx = 0; + for (int ty = 0; ty < llm_h; ty++) { + for (int tx = 0; tx < llm_w; tx++) { + for (int dy = 0; dy < merge; dy++) { + for (int dx = 0; dx < merge; dx++) { + pos_h_row[idx] = ty * merge + dy; + pos_w_row[idx] = tx * merge + dx; + idx++; + } + } + } + } + } + + // Col-major merge-unit permutation + std::vector idx_col(n_units); + for (int r = 0; r < llm_h; r++) { + for (int c = 0; c < llm_w; c++) { + int u_row = r * llm_w + c; + int u_col = c * llm_h + r; + idx_col[u_col] = (float) u_row; + } + } + + // Col-mode positions: permute pos_*_row by idx_col + std::vector pos_h_col(n_pos_full); + std::vector pos_w_col(n_pos_full); + for (int u = 0; u < n_units; u++) { + int src = (int) idx_col[u]; + for (int k = 0; k < merge_unit; k++) { + pos_h_col[u * merge_unit + k] = pos_h_row[src * merge_unit + k]; + pos_w_col[u * merge_unit + k] = pos_w_row[src * merge_unit + k]; + } + } + + // Pack into ggml_rope_multi VISION-mode layout. The non-CPU kernels + // only read slots 0 and 1, so pack h in slot 0, w in slot 1: + // positions[0..n_pos) = h + // positions[n_pos..2*n_pos) = w + // positions[2*n_pos..3*n_pos) = 0 + // positions[3*n_pos..4*n_pos) = 0 + std::vector positions_row(static_cast(n_pos_full) * 4, 0); + std::vector positions_col(static_cast(n_pos_full) * 4, 0); + for (int i = 0; i < n_pos_full; i++) { + positions_row[0 * n_pos_full + i] = pos_h_row[i]; + positions_row[1 * n_pos_full + i] = pos_w_row[i]; + positions_col[0 * n_pos_full + i] = pos_h_col[i]; + positions_col[1 * n_pos_full + i] = pos_w_col[i]; + } + + // Banded 1D sliding-window mask + const int window = hparams.attn_window_size; + GGML_ASSERT(window > 0); + std::vector mask(static_cast(n_pos_full) * n_pos_full, std::numeric_limits::lowest()); + for (int q = 0; q < n_pos_full; q++) { + int lo = std::max(0, q - window); + int hi = std::min(n_pos_full - 1, q + window); + for (int k = lo; k <= hi; k++) { + mask[static_cast(q) * n_pos_full + k] = 0.0f; + } + } + + set_input_i32("mimovl_positions_row", positions_row); + set_input_i32("mimovl_positions_col", positions_col); + set_input_f32("mimovl_idx_col", idx_col); + set_input_f32("mimovl_window_mask", mask); + } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_KIMIVL: case PROJECTOR_TYPE_KIMIK25: @@ -4081,6 +4203,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN3VL: // main path + deepstack paths return ctx->model.mm_1_b->ne[0] * (1 + ctx->model.n_deepstack_layers); + case PROJECTOR_TYPE_MIMOVL: + return ctx->model.mm_1_w->ne[1]; case PROJECTOR_TYPE_STEP3VL: return ctx->model.mm_model_proj->ne[1]; case PROJECTOR_TYPE_GEMMA3: diff --git a/tools/mtmd/models/mimovl.cpp b/tools/mtmd/models/mimovl.cpp new file mode 100644 index 00000000000..19db88f132a --- /dev/null +++ b/tools/mtmd/models/mimovl.cpp @@ -0,0 +1,209 @@ +#include "models.h" + +ggml_tensor * clip_graph_mimovl::build_mm(ggml_tensor * w, ggml_tensor * x) const { + ggml_tensor * cur = ggml_mul_mat(ctx0, w, x); + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + return cur; +} + +// MiMoVL vision tower for MiMo-V2.5 (non-Pro). Qwen2.5-VL-shaped ViT, except: +// 1. GQA in attention (32 Q / 8 KV heads, head_dim 64). +// 2. Per-head attention sinks on every windowed layer. The sinks adjust +// the softmax denominator (equivalently, a virtual extra K column with V=0), +// so they decay attention weight without contributing to the output. +// 3. Per-layer window-attention mode in hparams.wa_pattern_mode: +// -1 -> full, 0 -> row-window+sinks, 1 -> col-window+sinks. +// Col mode transposes the merge-unit grid on entry and restores +// it on exit. Both patch and rotary orderings are pre-computed +// host-side. +// 4. 1D banded sliding window (|q-k| > window_size -> -inf) as a +// single 2D mask broadcast across heads. +// 5. Per-block MLP biases. +ggml_cgraph * clip_graph_mimovl::build() { + GGML_ASSERT(model.patch_embeddings_0 != nullptr); + GGML_ASSERT(model.patch_embeddings_1 != nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + GGML_ASSERT(hparams.n_head_kv > 0); + GGML_ASSERT(n_head % hparams.n_head_kv == 0); + GGML_ASSERT((int) hparams.wa_pattern_mode.size() == n_layer); + + const int batch_size = 1; + const int n_pos = n_patches; + const int n_head_kv = hparams.n_head_kv; + const int merge = hparams.n_merge > 0 ? hparams.n_merge : 2; + const int merge_unit = merge * merge; + const int n_units = n_pos / merge_unit; + GGML_ASSERT(n_units * merge_unit == n_pos); + + // MiMoVL has head_dim=64 with n_embd=1280, so n_embd is NOT n_head*head_dim + // (the base class's d_head = n_embd/n_head = 40 is wrong here). Derive + // head_dim from the fused QKV projection: rows = (n_head + 2*n_head_kv)*head_dim. + GGML_ASSERT(model.layers[0].qkv_w != nullptr); + const int qkv_rows = model.layers[0].qkv_w->ne[1]; + const int head_dim = qkv_rows / (n_head + 2 * n_head_kv); + GGML_ASSERT(head_dim * (n_head + 2 * n_head_kv) == qkv_rows); + const float attn_scale = 1.0f / std::sqrt((float) head_dim); + const int rope_n_dims = head_dim / 2; + int mrope_sections[4] = {rope_n_dims/2, rope_n_dims/2, 0, 0}; + + // Patch embed: Conv3D(kt=2) split into two Conv2D, then interleave-merge + // along the height axis to match the merge-tile token order. + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, + patch_size, patch_size, 0, 0, 1, 1); + { + ggml_tensor * inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, + patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + + GGML_ASSERT(img.nx % (patch_size * 2) == 0); + GGML_ASSERT(img.ny % (patch_size * 2) == 0); + + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w,h,c,b] -> [c,w,h,b] + inp = ggml_cont_4d(ctx0, inp, n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + inp = ggml_reshape_4d(ctx0, inp, n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d(ctx0, inp, n_embd, n_patches_x * n_patches_y, batch_size); + } + cb(inp, "patch_embed", -1); + + ggml_tensor * positions_row = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos * 4); + ggml_set_name(positions_row, "mimovl_positions_row"); + ggml_set_input(positions_row); + + ggml_tensor * positions_col = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos * 4); + ggml_set_name(positions_col, "mimovl_positions_col"); + ggml_set_input(positions_col); + + // idx_col is the col-major merge-unit permutation. Take it as F32 so we can + // derive the inverse permutation in-graph via ggml_argsort; + // ggml_get_rows requires its index tensor to be I32, so cast back as well. + ggml_tensor * idx_col_f = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_units); + ggml_set_name(idx_col_f, "mimovl_idx_col"); + ggml_set_input(idx_col_f); + ggml_tensor * idx_col = ggml_cast(ctx0, idx_col_f, GGML_TYPE_I32); + ggml_tensor * idx_col_inv = ggml_argsort(ctx0, idx_col_f, GGML_SORT_ORDER_ASC); + + ggml_tensor * window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos); + ggml_set_name(window_mask, "mimovl_window_mask"); + ggml_set_input(window_mask); + + ggml_tensor * window_mask_attn = (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) + ? ggml_cast(ctx0, window_mask, GGML_TYPE_F16) + : window_mask; + + // Reorder helper: permute patches at merge-unit granularity. The patch + // sequence is laid out as n_units groups of merge_unit (=4) consecutive + // patches; the row<->col transpose only permutes whole groups. We keep + // the per-group (h,w) ordering intact by reshaping to + // [n_embd*merge_unit, n_units] before ggml_get_rows. + auto reorder = [&](ggml_tensor * x, ggml_tensor * idx) { + ggml_tensor * y = ggml_reshape_2d(ctx0, x, n_embd * merge_unit, n_units); + y = ggml_get_rows(ctx0, y, idx); + return ggml_reshape_3d(ctx0, y, n_embd, n_pos, batch_size); + }; + + ggml_tensor * inpL = inp; + int prev_mode = -1; + + for (int il = 0; il < n_layer; il++) { + const auto & layer = model.layers[il]; + const int mode = hparams.wa_pattern_mode[il]; + const bool is_full = (mode == -1); + const bool is_col = (mode == 1); + + // Reorder transitions on entry/exit of a col-mode run. + if (is_col && prev_mode != 1) { + inpL = reorder(inpL, idx_col); + cb(inpL, "reorder_to_col", il); + } else if (!is_col && prev_mode == 1) { + inpL = reorder(inpL, idx_col_inv); + cb(inpL, "reorder_to_row", il); + } + + ggml_tensor * cur = inpL; + + // Pre-attention RMSNorm. + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_RMS, eps, il); + cb(cur, "ln1", il); + + // Fused QKV with GQA. + ggml_tensor * qkv = build_mm(layer.qkv_w, cur); + qkv = ggml_add(ctx0, qkv, layer.qkv_b); + + const size_t row = ggml_row_size(qkv->type, head_dim); + const size_t off_k = ggml_row_size(qkv->type, n_head * head_dim); + const size_t off_v = ggml_row_size(qkv->type, (n_head + n_head_kv) * head_dim); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, head_dim, n_head, n_pos, row, qkv->nb[1], 0); + ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, head_dim, n_head_kv, n_pos, row, qkv->nb[1], off_k); + ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, head_dim, n_head_kv, n_pos, row, qkv->nb[1], off_v); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // 2D RoPE + ggml_tensor * pos = is_col ? positions_col : positions_row; + Qcur = ggml_rope_multi(ctx0, Qcur, pos, nullptr, rope_n_dims, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000.0f, 1.0f, 0.0f, 1.0f, 32.0f, 1.0f); + Kcur = ggml_rope_multi(ctx0, Kcur, pos, nullptr, rope_n_dims, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000.0f, 1.0f, 0.0f, 1.0f, 32.0f, 1.0f); + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + // Full layers: plain attention. Windowed layers: banded mask and per-head sinks. + ggml_tensor * mask = is_full ? nullptr : window_mask_attn; + ggml_tensor * sinks = is_full ? nullptr : layer.attn_sinks; + if (!is_full) { + GGML_ASSERT(layer.attn_sinks != nullptr); + } + ggml_tensor * attn_out = build_attn(layer.o_w, layer.o_b, Qcur, Kcur, Vcur, mask, attn_scale, il, sinks); + cb(attn_out, "attn_out", il); + + // Residual 1. + cur = ggml_add(ctx0, attn_out, inpL); + inpL = cur; + cb(cur, "ffn_inp", il); + + // Pre-FFN RMSNorm. + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_RMS, eps, il); + cb(cur, "ffn_inp_normed", il); + + // SwiGLU MLP with biases + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + layer.ff_gate_w, layer.ff_gate_b, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + cb(cur, "ffn_out", il); + + // Residual 2. + cur = ggml_add(ctx0, inpL, cur); + cb(cur, "layer_out", il); + + inpL = cur; + prev_mode = mode; + } + + // If the last block was col-mode, undo the transpose so the merger sees patches in row order. + if (prev_mode == 1) { + inpL = reorder(inpL, idx_col_inv); + cb(inpL, "reorder_to_row_final", -1); + } + + // Merger: post-LayerNorm + inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, 1e-6f, n_layer); + cb(inpL, "post_ln", -1); + + // Spatial merge: pack each merge_unit (=4) of patches into a single + // (n_embd*merge_unit)-wide row, then run the 2-layer MLP. + ggml_tensor * embeddings = ggml_reshape_3d(ctx0, inpL, n_embd * merge_unit, n_units, batch_size); + embeddings = build_ffn(embeddings, + model.mm_0_w, nullptr, + nullptr, nullptr, + model.mm_1_w, nullptr, + FFN_GELU, -1); + cb(embeddings, "vit_out", -1); + + ggml_build_forward_expand(gf, embeddings); + return gf; +} diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index dbba233b16f..955daa6d6d3 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -33,6 +33,15 @@ struct clip_graph_qwen3vl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_mimovl : clip_graph { + clip_graph_mimovl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + // Force F32 mat-mul accumulation to avoid F16 overflow in the FFN down-proj + // when the mmproj is stored in F16 (the source weights are BF16; downcasting + // to F16 reduces dynamic range below the SwiGLU output magnitude on the last few layers). + ggml_tensor * build_mm(ggml_tensor * w, ggml_tensor * x) const override; +}; + struct clip_graph_step3vl : clip_graph { clip_graph_step3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 87da6876f7c..22092f6a606 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -325,6 +325,7 @@ struct mtmd_context { case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: { // <|vision_start|> ... (image embeddings) ... <|vision_end|> img_beg = "<|vision_start|>"; From fa62042af9cc858832b613ec51cce98f9884ce01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Tue, 12 May 2026 11:34:10 +0200 Subject: [PATCH 4/5] ci : bump ty to 0.0.35 (#22961) --- .github/workflows/python-type-check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-type-check.yml b/.github/workflows/python-type-check.yml index 2d3fa163d23..cbeeb39d05b 100644 --- a/.github/workflows/python-type-check.yml +++ b/.github/workflows/python-type-check.yml @@ -31,7 +31,7 @@ jobs: uses: actions/setup-python@v6 with: python-version: "3.11" - pip-install: -r requirements/requirements-all.txt ty==0.0.33 + pip-install: -r requirements/requirements-all.txt ty==0.0.35 # - name: Type-check with Pyright # uses: jakebailey/pyright-action@v2 # with: From 706fbd8ab62797bd3702a3b20312efb801db49a6 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 12 May 2026 04:41:58 -0500 Subject: [PATCH 5/5] vulkan: Check shared memory size for mmq shaders (#22693) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 168 ++++++++++++++++++++++++--- 1 file changed, 149 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7e450a559dd..90ea7cc1a9b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -681,6 +681,15 @@ struct vk_device_struct { bool mul_mat_id_m[GGML_TYPE_COUNT]; bool mul_mat_id_s[GGML_TYPE_COUNT]; + // Separate flags for the q8_1 (integer dot) mmq path, whose shader uses + // a different shared-memory layout than the float matmul shaders. + bool mul_mat_l_int[GGML_TYPE_COUNT]; + bool mul_mat_m_int[GGML_TYPE_COUNT]; + bool mul_mat_s_int[GGML_TYPE_COUNT]; + bool mul_mat_id_l_int[GGML_TYPE_COUNT]; + bool mul_mat_id_m_int[GGML_TYPE_COUNT]; + bool mul_mat_id_s_int[GGML_TYPE_COUNT]; + vk::DescriptorSetLayout dsl; vk_matmul_pipeline pipeline_matmul_f32 {}; @@ -3207,6 +3216,70 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec return supported; } +// Shmem usage for the q8_1 mmq shader (mul_mmq.comp), which uses +// block_a_cache / block_b_cache layouts (see mul_mmq_shmem_types.glsl) rather +// than the float load buffers checked by ggml_vk_matmul_shmem_support. +// Sizes follow std430 rules. Returns false for types without a q8_1 pipeline. +static bool ggml_vk_matmul_int_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { + + // FLOAT_TYPE in the shader is float16_t with fp16 support, otherwise float. + const uint32_t fp_size = device->fp16 ? 2u : 4u; + const uint32_t fp_align = fp_size; + const uint32_t fp2_size = 2u * fp_size; + const uint32_t fp2_align = device->fp16 ? 4u : 8u; + + struct member { uint32_t size, align; }; + auto std430_size = [](std::initializer_list members) { + uint32_t off = 0, struct_align = 1; + for (const auto &m : members) { + off = (off + m.align - 1) & ~(m.align - 1); + off += m.size; + struct_align = std::max(struct_align, m.align); + } + return (off + struct_align - 1) & ~(struct_align - 1); + }; + + uint32_t block_a_size = 0; + switch (src0_type) { + case GGML_TYPE_Q4_0: block_a_size = std430_size({{16, 4}, {fp_size, fp_align}}); break; // qs[16/4] + dm + case GGML_TYPE_Q4_1: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + dm(vec2) + case GGML_TYPE_Q5_0: block_a_size = std430_size({{16, 4}, {4, 4}, {fp_size, fp_align}}); break; // qs[16/4] + qh + dm + case GGML_TYPE_Q5_1: block_a_size = std430_size({{16, 4}, {4, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + qh + dm(vec2) + case GGML_TYPE_Q8_0: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + dm + case GGML_TYPE_MXFP4: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + d + case GGML_TYPE_Q2_K: block_a_size = std430_size({{ 8, 4}, {2, 2}, {fp2_size, fp2_align}}); break; // qs[2] + scales(u8vec2) + dm(vec2) + case GGML_TYPE_Q3_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + d_scales(vec2) + case GGML_TYPE_Q4_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + dm(vec2) + case GGML_TYPE_Q5_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + dm(vec2) + case GGML_TYPE_Q6_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + d_scales(vec2) + default: + return false; + } + + // block_b_cache: { int32_t qs[8]; FLOAT_TYPEV2 ds; } + const uint32_t block_b_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); + + const uint32_t BM = warptile[1]; + const uint32_t BN = warptile[2]; + // mul_mmq.comp: BK_STEP=1 for MUL_MAT_ID, 4 otherwise. + const uint32_t BK_STEP = mul_mat_id ? 1u : 4u; + + const uint32_t buf_a_size = BM * BK_STEP * block_a_size; + const uint32_t buf_b_size = BN * BK_STEP * block_b_size; + const uint32_t mmid_row_ids = mul_mat_id ? (BN * 2u * (uint32_t)sizeof(uint16_t)) : 0u; + + const uint32_t warps = warptile[0] / warptile[10]; + const uint32_t ballots_sh = mul_mat_id ? (warps * 4u * (uint32_t)sizeof(uint32_t)) : 0u; + + const uint32_t total_size = buf_a_size + buf_b_size + mmid_row_ids + ballots_sh; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_matmul_int_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " + "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", total=" << total_size << ", supported=" << supported); + + return supported; +} + struct GpuPipelineConfig { // GPU architecture identifier. // Example: vk_device_architecture::AMD_GCN @@ -3453,6 +3526,40 @@ static void ggml_vk_load_shaders(vk_device& device) { } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) { device->mul_mat_id_l[i] = false; } + + // The q8_1 mmq path has its own (larger) shmem layout, check it separately. + // K-quants use the _int_k warptiles, others use _int. + const bool is_k_quant = (t == GGML_TYPE_Q2_K || t == GGML_TYPE_Q3_K || + t == GGML_TYPE_Q4_K || t == GGML_TYPE_Q5_K || + t == GGML_TYPE_Q6_K); + const auto & s_int = is_k_quant ? s_warptile_mmq_int_k : s_warptile_mmq_int; + const auto & m_int = is_k_quant ? m_warptile_mmq_int_k : m_warptile_mmq_int; + const auto & l_int = is_k_quant ? l_warptile_mmq_int_k : l_warptile_mmq_int; + const auto & s_intid = is_k_quant ? s_warptile_mmqid_int_k : s_warptile_mmqid_int; + const auto & m_intid = is_k_quant ? m_warptile_mmqid_int_k : m_warptile_mmqid_int; + const auto & l_intid = is_k_quant ? l_warptile_mmqid_int_k : l_warptile_mmqid_int; + + if (!ggml_vk_matmul_int_shmem_support(device, s_int, false, t)) { + device->mul_mat_s_int[i] = false; + device->mul_mat_m_int[i] = false; + device->mul_mat_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, m_int, false, t)) { + device->mul_mat_m_int[i] = false; + device->mul_mat_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, l_int, false, t)) { + device->mul_mat_l_int[i] = false; + } + + if (!ggml_vk_matmul_int_shmem_support(device, s_intid, true, t)) { + device->mul_mat_id_s_int[i] = false; + device->mul_mat_id_m_int[i] = false; + device->mul_mat_id_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, m_intid, true, t)) { + device->mul_mat_id_m_int[i] = false; + device->mul_mat_id_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, l_intid, true, t)) { + device->mul_mat_id_l_int[i] = false; + } } } @@ -5613,6 +5720,13 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_s[i] = true; break; } + + device->mul_mat_l_int[i] = true; + device->mul_mat_m_int[i] = true; + device->mul_mat_s_int[i] = true; + device->mul_mat_id_l_int[i] = true; + device->mul_mat_id_m_int[i] = true; + device->mul_mat_id_s_int[i] = true; } @@ -7220,6 +7334,13 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + // The q8_1 (integer dot) mmq path uses a different shader with its own + // shared-memory layout, so use the int-specific availability flags. + const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1); + const bool mm_l = is_q8_1 ? ctx->device->mul_mat_l_int[src0_type] : ctx->device->mul_mat_l[src0_type]; + const bool mm_m = is_q8_1 ? ctx->device->mul_mat_m_int[src0_type] : ctx->device->mul_mat_m[src0_type]; + const bool mm_s = is_q8_1 ? ctx->device->mul_mat_s_int[src0_type] : ctx->device->mul_mat_s[src0_type]; + if (ctx->device->coopmat2) { const uint32_t shader_core_count = ctx->device->shader_core_count; const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]); @@ -7236,26 +7357,24 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, // split_k==3 with large tiles likely better than medium tiles with no split_k. (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2); - if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + if ((mm_l && (n > crossover_large && prefer_large)) || (!mm_m && !mm_s)) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size uint32_t crossover_medium = mmp->s->wg_denoms[1]; - if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { + if ((mm_m && (n > crossover_medium)) || !mm_s) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { + if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { + if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; - - GGML_UNUSED(src1_type); } static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { @@ -7312,35 +7431,42 @@ static void ggml_vk_matmul( ctx->prealloc_split_k_need_sync = true; } -static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + + // The q8_1 (integer dot) mmq path uses a different shader with its own + // shared-memory layout, so use the int-specific availability flags. + const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1); + const bool mm_l = is_q8_1 ? ctx->device->mul_mat_id_l_int[src0_type] : ctx->device->mul_mat_id_l[src0_type]; + const bool mm_m = is_q8_1 ? ctx->device->mul_mat_id_m_int[src0_type] : ctx->device->mul_mat_id_m[src0_type]; + const bool mm_s = is_q8_1 ? ctx->device->mul_mat_id_s_int[src0_type] : ctx->device->mul_mat_id_s[src0_type]; if (ctx->device->coopmat2) { // Use large shader when the N dimension is greater than the medium shader's tile size uint32_t crossover_large = mmp->m->wg_denoms[1]; - if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { + if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s)) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size uint32_t crossover_medium = mmp->s->wg_denoms[1]; - if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { + if ((mm_m && (n > crossover_medium)) || !mm_s) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { + if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { + if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); - return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align; +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align; } static void ggml_vk_matmul_id( @@ -7636,10 +7762,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type); + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, effective_src1_type)); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type); if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); @@ -8471,10 +8599,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); + const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type); + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type, effective_src1_type)); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type); if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);