Skip to content

Commit 8c77a04

Browse files
authored
vulkan: more mul mat optimizations (#18533)
* q4_k * q5_k * q2_k * q4_1 * q5_1 * better buf index
1 parent ffba4f2 commit 8c77a04

3 files changed

Lines changed: 49 additions & 44 deletions

File tree

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,8 @@ vec2 get_dm(uint ib, uint a_offset) {
462462

463463
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
464464
vec2 get_dm(uint ib, uint a_offset) {
465-
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
465+
const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm);
466+
return dm;
466467
}
467468
#endif
468469

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
4747
#endif
4848
#elif defined(DATA_A_Q4_0)
4949
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
50-
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
50+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
5151

5252
const uint ib = idx / 4;
5353
const uint iqs = idx & 0x03;
@@ -63,24 +63,23 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
6363
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
6464
#elif defined(DATA_A_Q4_1)
6565
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
66-
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
66+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
6767

6868
const uint ib = idx / 4;
6969
const uint iqs = idx & 0x03;
7070

71-
const float d = float(data_a_packed16[ib].d);
72-
const float m = float(data_a_packed16[ib].m);
73-
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
74-
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
75-
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
71+
const vec2 dm = vec2(data_a_packed32[ib].dm);
72+
const uint vui = data_a_packed32[ib].qs[iqs];
73+
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;
74+
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y;
7675

7776
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
7877
buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
7978
buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy);
8079
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
8180
#elif defined(DATA_A_Q5_0)
8281
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
83-
const uint buf_idx = col * SHMEM_STRIDE + row;
82+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
8483

8584
const uint ib = idx / 8;
8685
const uint iqs = idx & 0x07;
@@ -97,22 +96,26 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
9796
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
9897
#elif defined(DATA_A_Q5_1)
9998
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
100-
const uint buf_idx = col * SHMEM_STRIDE + row;
101-
102-
const uint ib = idx / 8;
103-
const uint iqs = idx & 0x07;
104-
105-
const float d = float(data_a_packed16[ib].d);
106-
const float m = float(data_a_packed16[ib].m);
107-
const uint uint_qh = data_a_packed16[ib].qh;
108-
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
109-
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
99+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
110100

111-
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
112-
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
101+
const uint ib = idx / 4;
102+
const uint iqs = idx & 0x03;
113103

114-
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
115-
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
104+
const vec2 dm = vec2(data_a_packed32[ib].dm);
105+
const uint uint_qh = data_a_packed32[ib].qh;
106+
const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10);
107+
const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10);
108+
const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10);
109+
const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10);
110+
111+
const uint vui = data_a_packed32[ib].qs[iqs];
112+
const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y;
113+
const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y;
114+
115+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz);
116+
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz);
117+
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw);
118+
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw);
116119
#elif defined(DATA_A_Q8_0)
117120
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
118121
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -131,20 +134,21 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
131134
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
132135
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
133136

134-
const uint ib = idx / 128; // 2 values per idx
135-
const uint iqs = idx % 128; // 0..127
137+
const uint ib = idx / 64; // 4 values per idx
138+
const uint iqs = (idx % 64) * 2; // 0,2,4..126
136139

137140
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
138141
const uint scalesi = iqs / 8; // 0..15
139142
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
140143

141-
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
144+
const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303));
142145
const uint scales = data_a[ib].scales[scalesi];
143146
const vec2 dm = vec2(data_a[ib].dm);
144147

145-
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
148+
const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4);
146149

147-
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
150+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
151+
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
148152
#elif defined(DATA_A_Q3_K)
149153
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
150154
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -173,8 +177,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
173177
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
174178
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
175179

176-
const uint ib = idx / 128; // 2 values per idx
177-
const uint iqs = idx % 128; // 0..127
180+
const uint ib = idx / 64; // 4 values per idx
181+
const uint iqs = (idx % 64) * 2; // 0,2,4..126
178182

179183
const uint n = iqs / 32; // 0,1,2,3
180184
const uint b = (iqs % 32) / 16; // 0,1
@@ -200,16 +204,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
200204
const float d = loadd.x * sc;
201205
const float m = -loadd.y * mbyte;
202206

203-
const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy);
207+
const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F));
204208

205-
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
206-
fma(d, q.y, m));
209+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
210+
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
207211
#elif defined(DATA_A_Q5_K)
208212
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
209213
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
210214

211-
const uint ib = idx / 128; // 2 values per idx
212-
const uint iqs = idx % 128; // 0..127
215+
const uint ib = idx / 64; // 4 values per idx
216+
const uint iqs = (idx % 64) * 2; // 0,2,4..126
213217

214218
const uint n = iqs / 32; // 0,1,2,3
215219
const uint b = (iqs % 32) / 16; // 0,1
@@ -236,12 +240,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
236240
const float d = loadd.x * sc;
237241
const float m = -loadd.y * mbyte;
238242

239-
const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F;
240-
const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4;
241-
const vec2 q = vec2(unpack8(qs | qh).xy);
243+
const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F;
244+
const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4;
245+
const vec4 q = vec4(unpack8(qs | qh));
242246

243-
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
244-
fma(d, q.y, m));
247+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
248+
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
245249
#elif defined(DATA_A_Q6_K)
246250
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
247251
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -455,7 +459,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
455459
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
456460
#elif defined(DATA_A_IQ4_NL)
457461
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
458-
const uint buf_idx = col * SHMEM_STRIDE + row;
462+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
459463

460464
const uint ib = idx / 8;
461465
const uint iqs = idx & 0x07;
@@ -469,7 +473,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
469473
kvalues_iq4nl[vui >> 12]);
470474
#elif defined(DATA_A_MXFP4)
471475
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
472-
const uint buf_idx = col * SHMEM_STRIDE + row;
476+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
473477

474478
const uint ib = idx / 8;
475479
const uint iqs = (idx & 0x07) * 2;

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
552552

553553
for (const auto& tname : type_names) {
554554
std::string load_vec_quant = "2";
555-
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
555+
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
556556
load_vec_quant = "8";
557-
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
557+
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
558558
load_vec_quant = "4";
559559

560560
if (tname == "bf16") {

0 commit comments

Comments
 (0)