Skip to content

Commit 67855db

Browse files
committed
vulkan: add Q1_0_g128 (1-bit ternary) shader support
Add Vulkan compute shader support for the GGML_TYPE_Q1_0_g128 quantization format (1-bit sign / binary quantization, group size 128). New files: - dequant_q1_0_g128.comp: standalone dequantization shader - mul_mat_vec_q1_0_g128.comp: fused matrix-vector multiply shader (4 threads/block, 32 elements/thread, 8x dot(vec4,vec4)) Modified files: - types.glsl: block_q1_0_g128 struct, QUANT_K=128, QUANT_R=1 - dequant_funcs.glsl: dequantize/dequantize4 + single-scale get_dm - mul_mm_funcs.glsl: branchless FMA load path for batched matmul - vulkan-shaders-gen.cpp: type registration, LOAD_VEC_A=4, excluded from coopmat2 flash attention and integer dot product Q8_1 paths - ggml-vulkan.cpp: pipeline registration for dequant, get_rows, mul_mat_vec (f32/f16/id), mul_mat_mat, mul_mat_mat_id, supports_op - test-backend-ops.cpp: Q1_0_g128 test cases for get_rows, mul_mat, mul_mat_id Performance on AMD Radeon 680M (RDNA2 iGPU): eval: 0.28 -> 23.4 t/s (84x), prompt: 0.31 -> 38.3 t/s (124x) graph splits: 291 -> 2
1 parent f5dda72 commit 67855db

File tree

8 files changed

+272
-10
lines changed

8 files changed

+272
-10
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 38 additions & 1 deletion
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,41 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
2222
}
2323
#endif
2424

25+
#if defined(DATA_A_Q1_0_G128)
26+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
27+
// iqs is the element index within the block (0..127)
28+
const uint byte_idx = iqs / 8;
29+
const uint bit_idx = iqs % 8;
30+
const uint bits = uint(data_a[a_offset + ib].qs[byte_idx]);
31+
const float sign0 = ((bits >> bit_idx) & 1) == 1 ? 1.0f : -1.0f;
32+
// Second element
33+
const uint byte_idx2 = (iqs + 1) / 8;
34+
const uint bit_idx2 = (iqs + 1) % 8;
35+
const uint bits2 = uint(data_a[a_offset + ib].qs[byte_idx2]);
36+
const float sign1 = ((bits2 >> bit_idx2) & 1) == 1 ? 1.0f : -1.0f;
37+
return vec2(sign0, sign1);
38+
}
39+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
40+
const uint byte_idx0 = iqs / 8;
41+
const uint bit_idx0 = iqs % 8;
42+
const uint bits0 = uint(data_a[a_offset + ib].qs[byte_idx0]);
43+
const float s0 = ((bits0 >> bit_idx0) & 1) == 1 ? 1.0f : -1.0f;
44+
const uint byte_idx1 = (iqs + 1) / 8;
45+
const uint bit_idx1 = (iqs + 1) % 8;
46+
const uint bits1 = uint(data_a[a_offset + ib].qs[byte_idx1]);
47+
const float s1 = ((bits1 >> bit_idx1) & 1) == 1 ? 1.0f : -1.0f;
48+
const uint byte_idx2 = (iqs + 2) / 8;
49+
const uint bit_idx2 = (iqs + 2) % 8;
50+
const uint bits2 = uint(data_a[a_offset + ib].qs[byte_idx2]);
51+
const float s2 = ((bits2 >> bit_idx2) & 1) == 1 ? 1.0f : -1.0f;
52+
const uint byte_idx3 = (iqs + 3) / 8;
53+
const uint bit_idx3 = (iqs + 3) % 8;
54+
const uint bits3 = uint(data_a[a_offset + ib].qs[byte_idx3]);
55+
const float s3 = ((bits3 >> bit_idx3) & 1) == 1 ? 1.0f : -1.0f;
56+
return vec4(s0, s1, s2, s3);
57+
}
58+
#endif
59+
2560
#if defined(DATA_A_Q4_0)
2661
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
2762
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -448,7 +483,7 @@ vec2 get_dm(uint ib, uint a_offset) {
448483
}
449484
#endif
450485

451-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
486+
#if defined(DATA_A_Q1_0_G128) || defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
452487
vec2 get_dm(uint ib, uint a_offset) {
453488
return vec2(float(data_a[a_offset + ib].d), 0);
454489
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#version 450
2+
3+
#include "dequant_head.glsl"
4+
5+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
6+
7+
layout (binding = 0) readonly buffer A {block_q1_0_g128 data_a[];};
8+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9+
10+
void main() {
11+
// Each thread handles one 128-element block
12+
const uint ib = gl_WorkGroupID.x * 256 + gl_LocalInvocationID.x;
13+
14+
if (ib >= p.nel / 128) {
15+
return;
16+
}
17+
18+
const uint b_idx = ib * 128;
19+
const float d = float(data_a[ib].d);
20+
21+
// Each block has 16 bytes = 128 bits = 128 elements
22+
[[unroll]] for (uint byte_idx = 0; byte_idx < 16; ++byte_idx) {
23+
const uint bits = uint(data_a[ib].qs[byte_idx]);
24+
[[unroll]] for (uint bit_idx = 0; bit_idx < 8; ++bit_idx) {
25+
const float sign = ((bits >> bit_idx) & 1) == 1 ? 1.0f : -1.0f;
26+
data_b[b_idx + byte_idx * 8 + bit_idx] = D_TYPE(d * sign);
27+
}
28+
}
29+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#version 450
2+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
3+
#extension GL_EXT_shader_8bit_storage : require
4+
5+
#include "mul_mat_vec_base.glsl"
6+
7+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8+
9+
// Fused 1-bit matrix-vector multiply for Q1_0_g128.
10+
// 4 threads per block, each handles 32 elements (one uint32 of packed bits).
11+
// Uses simple ternary sign selection which compiles to v_cndmask on RDNA.
12+
13+
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
14+
15+
void calc_block(const uint a_offset, const uint b_offset, const uint itid, const uint i,
16+
const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
17+
18+
const uint y_idx_base = i * 128 + itid * 32;
19+
20+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
21+
const uint base_b = (j * p.batch_stride_b + b_offset + y_idx_base) / 4;
22+
const vec4 bv0 = vec4(data_b_v4[base_b]);
23+
const vec4 bv1 = vec4(data_b_v4[base_b + 1]);
24+
const vec4 bv2 = vec4(data_b_v4[base_b + 2]);
25+
const vec4 bv3 = vec4(data_b_v4[base_b + 3]);
26+
const vec4 bv4 = vec4(data_b_v4[base_b + 4]);
27+
const vec4 bv5 = vec4(data_b_v4[base_b + 5]);
28+
const vec4 bv6 = vec4(data_b_v4[base_b + 6]);
29+
const vec4 bv7 = vec4(data_b_v4[base_b + 7]);
30+
31+
uint ibi = a_offset + first_row * num_blocks_per_row + i;
32+
33+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
34+
const float d = float(data_a[ibi].d);
35+
36+
const uint byte_base = itid * 4;
37+
const uint bits = uint(data_a[ibi].qs[byte_base])
38+
| (uint(data_a[ibi].qs[byte_base + 1]) << 8)
39+
| (uint(data_a[ibi].qs[byte_base + 2]) << 16)
40+
| (uint(data_a[ibi].qs[byte_base + 3]) << 24);
41+
42+
FLOAT_TYPE partial = FLOAT_TYPE(0);
43+
44+
partial += FLOAT_TYPE(dot(vec4(
45+
(bits & 0x1u) != 0 ? 1.0 : -1.0, (bits & 0x2u) != 0 ? 1.0 : -1.0,
46+
(bits & 0x4u) != 0 ? 1.0 : -1.0, (bits & 0x8u) != 0 ? 1.0 : -1.0), bv0));
47+
partial += FLOAT_TYPE(dot(vec4(
48+
(bits & 0x10u) != 0 ? 1.0 : -1.0, (bits & 0x20u) != 0 ? 1.0 : -1.0,
49+
(bits & 0x40u) != 0 ? 1.0 : -1.0, (bits & 0x80u) != 0 ? 1.0 : -1.0), bv1));
50+
partial += FLOAT_TYPE(dot(vec4(
51+
(bits & 0x100u) != 0 ? 1.0 : -1.0, (bits & 0x200u) != 0 ? 1.0 : -1.0,
52+
(bits & 0x400u) != 0 ? 1.0 : -1.0, (bits & 0x800u) != 0 ? 1.0 : -1.0), bv2));
53+
partial += FLOAT_TYPE(dot(vec4(
54+
(bits & 0x1000u) != 0 ? 1.0 : -1.0, (bits & 0x2000u) != 0 ? 1.0 : -1.0,
55+
(bits & 0x4000u) != 0 ? 1.0 : -1.0, (bits & 0x8000u) != 0 ? 1.0 : -1.0), bv3));
56+
partial += FLOAT_TYPE(dot(vec4(
57+
(bits & 0x10000u) != 0 ? 1.0 : -1.0, (bits & 0x20000u) != 0 ? 1.0 : -1.0,
58+
(bits & 0x40000u) != 0 ? 1.0 : -1.0, (bits & 0x80000u) != 0 ? 1.0 : -1.0), bv4));
59+
partial += FLOAT_TYPE(dot(vec4(
60+
(bits & 0x100000u) != 0 ? 1.0 : -1.0, (bits & 0x200000u) != 0 ? 1.0 : -1.0,
61+
(bits & 0x400000u) != 0 ? 1.0 : -1.0, (bits & 0x800000u) != 0 ? 1.0 : -1.0), bv5));
62+
partial += FLOAT_TYPE(dot(vec4(
63+
(bits & 0x1000000u) != 0 ? 1.0 : -1.0, (bits & 0x2000000u) != 0 ? 1.0 : -1.0,
64+
(bits & 0x4000000u) != 0 ? 1.0 : -1.0, (bits & 0x8000000u) != 0 ? 1.0 : -1.0), bv6));
65+
partial += FLOAT_TYPE(dot(vec4(
66+
(bits & 0x10000000u) != 0 ? 1.0 : -1.0, (bits & 0x20000000u) != 0 ? 1.0 : -1.0,
67+
(bits & 0x40000000u) != 0 ? 1.0 : -1.0, (bits & 0x80000000u) != 0 ? 1.0 : -1.0), bv7));
68+
69+
temp[j][n] = fma(FLOAT_TYPE(d), partial, temp[j][n]);
70+
ibi += num_blocks_per_row;
71+
}
72+
}
73+
}
74+
75+
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
76+
uint a_offset, b_offset, d_offset;
77+
get_offsets(a_offset, b_offset, d_offset);
78+
79+
const uint num_blocks_per_row = p.ncols / 128;
80+
const uint blocks_per_wg = gl_WorkGroupSize.x / 4;
81+
const uint tid = gl_LocalInvocationID.x;
82+
const uint itid = tid % 4;
83+
const uint ix = tid / 4;
84+
85+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
86+
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
87+
temp[j][i] = FLOAT_TYPE(0);
88+
}
89+
}
90+
91+
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
92+
calc_block(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
93+
94+
reduce_result(temp, d_offset, first_row, num_rows, tid);
95+
}
96+
97+
void main() {
98+
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
99+
100+
if (first_row + NUM_ROWS <= p.stride_d) {
101+
compute_outputs(first_row, NUM_ROWS);
102+
} else {
103+
if (first_row >= p.stride_d) {
104+
return;
105+
}
106+
compute_outputs(first_row, p.stride_d - first_row);
107+
}
108+
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,37 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
128128
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
129129
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
130130

131+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
132+
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
133+
#elif defined(DATA_A_Q1_0_G128)
134+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
135+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
136+
137+
// LOAD_VEC_A = 4, so each load processes 4 elements.
138+
// 128 elements per block / 4 = 32 loads per block.
139+
const uint ib = idx / 32; // block index
140+
const uint iel = (idx % 32) * 4; // element offset within block (0,4,8,...124)
141+
142+
const float d = float(data_a[ib].d);
143+
const float d2 = d + d;
144+
const float neg_d = -d;
145+
146+
// Mirror Metal's chunking more directly: q1_0_g128 is 8 chunks of 16 sign bits.
147+
// Decode the containing 16-bit chunk, then select the 4-bit sub-group for this load.
148+
const uint chunk16 = iel / 16;
149+
const uint chunk_bit = iel % 16;
150+
const uint byte_offset = chunk16 * 2;
151+
const uint bits16 = uint(data_a[ib].qs[byte_offset])
152+
| (uint(data_a[ib].qs[byte_offset + 1]) << 8);
153+
const uint bits = (bits16 >> chunk_bit) & 0xFu;
154+
155+
// Branchless FMA: d*(2*bit-1) = fma(2d, bit_float, -d)
156+
const vec4 bit_floats = vec4(
157+
float(bits & 1u), float((bits >> 1) & 1u),
158+
float((bits >> 2) & 1u), float((bits >> 3) & 1u)
159+
);
160+
const vec4 v = fma(vec4(d2), bit_floats, vec4(neg_d));
161+
131162
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
132163
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
133164
#elif defined(DATA_A_Q2_K)

ggml/src/ggml-vulkan/vulkan-shaders/types.glsl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
66
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
77
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
8+
#extension GL_EXT_shader_8bit_storage : require
89
#extension GL_EXT_shader_16bit_storage : require
910

1011
#if defined(DATA_A_F32)
@@ -46,6 +47,23 @@
4647
#endif
4748
#endif
4849

50+
#define QUANT_K_Q1_0_G128 128
51+
#define QUANT_R_Q1_0_G128 1
52+
53+
struct block_q1_0_g128
54+
{
55+
float16_t d;
56+
uint8_t qs[16];
57+
};
58+
59+
#if defined(DATA_A_Q1_0_G128)
60+
#define QUANT_K QUANT_K_Q1_0_G128
61+
#define QUANT_R QUANT_R_Q1_0_G128
62+
#define QUANT_AUXF 1
63+
#define A_TYPE block_q1_0_g128
64+
#define DATA_A_QUANT_LEGACY
65+
#endif
66+
4967
#define QUANT_K_Q4_0 32
5068
#define QUANT_R_Q4_0 2
5169

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ const std::vector<std::string> type_names = {
5050
"q5_0",
5151
"q5_1",
5252
"q8_0",
53+
"q1_0_g128",
5354
"q2_k",
5455
"q3_k",
5556
"q4_k",
@@ -220,7 +221,7 @@ bool is_quantized_type(const std::string& type_name) {
220221
}
221222

222223
bool is_legacy_quant(const std::string& type_name) {
223-
return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0";
224+
return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0" || type_name == "q1_0_g128";
224225
}
225226

226227
bool is_k_quant(const std::string& type_name) {
@@ -554,7 +555,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
554555
std::string load_vec_quant = "2";
555556
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
556557
load_vec_quant = "8";
557-
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
558+
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q1_0_g128") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
558559
load_vec_quant = "4";
559560

560561
if (tname == "bf16") {
@@ -580,14 +581,14 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
580581
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
581582
}
582583

583-
if (tname != "f16" && tname != "f32") {
584+
if (tname != "f16" && tname != "f32" && !(coopmat2 && tname == "q1_0_g128")) {
584585
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
585586
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
586587
}
587588

588589
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
589590
// Integer dot mmq performs better with f32 accumulators
590-
if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
591+
if (!f16acc && !coopmat && !coopmat2 && tname != "q1_0_g128" && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
591592
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
592593
}
593594
#endif
@@ -645,7 +646,7 @@ void process_shaders() {
645646
if (tname == "f16") {
646647
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
647648
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
648-
} else {
649+
} else if (tname != "q1_0_g128") {
649650
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
650651
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
651652
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
@@ -680,7 +681,7 @@ void process_shaders() {
680681
for (const auto& tname : type_names) {
681682
// mul mat vec
682683
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
683-
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
684+
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_") || tname == "q1_0_g128") ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
684685

685686
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
686687
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
@@ -697,7 +698,7 @@ void process_shaders() {
697698

698699
// mul mat vec with integer dot product
699700
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
700-
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") {
701+
if (tname != "q1_0_g128" && (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m")) {
701702
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
702703
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
703704
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
@@ -1139,7 +1140,7 @@ void write_output_files() {
11391140

11401141
for (const std::string& btype : btypes) {
11411142
for (const auto& tname : type_names) {
1142-
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") {
1143+
if (btype == "q8_1" && (!is_legacy_quant(tname) || tname == "q1_0_g128") && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") {
11431144
continue;
11441145
}
11451146
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";

tests/test-backend-ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7094,6 +7094,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
70947094
}
70957095

70967096
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, 1, false));
7097+
test_cases.emplace_back(new test_get_rows(GGML_TYPE_Q1_0_g128, 256, 5, 4, 1, 1, false));
70977098
for (ggml_type type : all_types) {
70987099
for (int b : {1, 7}) {
70997100
for (bool v : {false, true}) {
@@ -7796,6 +7797,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
77967797
}
77977798
#endif
77987799

7800+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q1_0_g128, GGML_TYPE_F32, 16, 16, 256, {1, 1}, {1, 1}));
7801+
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q1_0_g128, GGML_TYPE_F32, 8, 2, false, 16, 16, 256));
77997802
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 128, { 8, 1}, {1, 1}));
78007803
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 128, { 8, 1}, {4, 1}));
78017804
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 64, { 8, 1}, {4, 1}));

0 commit comments

Comments
 (0)