Skip to content

Commit 43fe657

Browse files
committed
vulkan: add Q1_0 (1-bit ternary) shader support
Adds Vulkan shader support for the GGML_TYPE_Q1_0 (1-bit sign quantization, group size 128) format. Without these shaders, Q1_0 models fall back to CPU dequantization on Vulkan devices, resulting in ~291 graph splits and extremely poor performance. New files: - dequant_q1_0.comp: standalone dequantization shader - mul_mat_vec_q1_0.comp: fused matrix-vector multiply (4 threads/block) Changes: - types.glsl: block_q1_0 struct + GL_EXT_shader_8bit_storage - dequant_funcs.glsl: dequantize/dequantize4/get_dm for Q1_0 - mul_mm_funcs.glsl: load_a_to_shmem with branchless FMA decoding - vulkan-shaders-gen.cpp: Q1_0 type registration, exclusions for coopmat2/flash_attn/integer_dot (no dequantFunc/Q8_1 mapping) - ggml-vulkan.cpp: pipeline registration, supports_op, f32acc forced - test-backend-ops.cpp: get_rows/mul_mat/mul_mat_id tests Does NOT force Q1_0 through dequant fallback in mul_mat_q_f16 - the fused matmul path is used when available, giving 2x+ prefill speedup on AMD RDNA2 Vulkan vs the dequant path. Tested on: AMD Radeon 680M (RDNA2), AMD RX 470, AMD Vega 56 Rebased onto PrismML master (Q1_0_g128 renamed to Q1_0)
1 parent d0a6dfe commit 43fe657

File tree

8 files changed

+234
-10
lines changed

8 files changed

+234
-10
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 48 additions & 1 deletion
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
2222
}
2323
#endif
2424

25+
#if defined(DATA_A_Q1_0)vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint byte_idx = iqs / 8; const uint bit_idx = iqs % 8; const uint bits = uint(data_a[a_offset + ib].qs[byte_idx]); const float sign0 = ((bits >> bit_idx) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx2 = (iqs + 1) / 8; const uint bit_idx2 = (iqs + 1) % 8; const uint bits2 = uint(data_a[a_offset + ib].qs[byte_idx2]); const float sign1 = ((bits2 >> bit_idx2) & 1) == 1 ? 1.0f : -1.0f; return vec2(sign0, sign1);}vec4 dequantize4(uint ib, uint iqs, uint a_offset) { const uint byte_idx0 = iqs / 8; const uint bit_idx0 = iqs % 8; const uint bits0 = uint(data_a[a_offset + ib].qs[byte_idx0]); const float s0 = ((bits0 >> bit_idx0) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx1 = (iqs + 1) / 8; const uint bit_idx1 = (iqs + 1) % 8; const uint bits1 = uint(data_a[a_offset + ib].qs[byte_idx1]); const float s1 = ((bits1 >> bit_idx1) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx2 = (iqs + 2) / 8; const uint bit_idx2 = (iqs + 2) % 8; const uint bits2 = uint(data_a[a_offset + ib].qs[byte_idx2]); const float s2 = ((bits2 >> bit_idx2) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx3 = (iqs + 3) / 8; const uint bit_idx3 = (iqs + 3) % 8; const uint bits3 = uint(data_a[a_offset + ib].qs[byte_idx3]); const float s3 = ((bits3 >> bit_idx3) & 1) == 1 ? 1.0f : -1.0f; return vec4(s0, s1, s2, s3);}#endif
2526
#if defined(DATA_A_Q4_0)
2627
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
2728
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -448,7 +449,8 @@ vec2 get_dm(uint ib, uint a_offset) {
448449
}
449450
#endif
450451

451-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
452+
#if defined(DATA_A_Q1_0)vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint byte_idx = iqs / 8; const uint bit_idx = iqs % 8; const uint bits = uint(data_a[a_offset + ib].qs[byte_idx]); const float sign0 = ((bits >> bit_idx) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx2 = (iqs + 1) / 8; const uint bit_idx2 = (iqs + 1) % 8; const uint bits2 = uint(data_a[a_offset + ib].qs[byte_idx2]); const float sign1 = ((bits2 >> bit_idx2) & 1) == 1 ? 1.0f : -1.0f; return vec2(sign0, sign1);}vec4 dequantize4(uint ib, uint iqs, uint a_offset) { const uint byte_idx0 = iqs / 8; const uint bit_idx0 = iqs % 8; const uint bits0 = uint(data_a[a_offset + ib].qs[byte_idx0]); const float s0 = ((bits0 >> bit_idx0) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx1 = (iqs + 1) / 8; const uint bit_idx1 = (iqs + 1) % 8; const uint bits1 = uint(data_a[a_offset + ib].qs[byte_idx1]); const float s1 = ((bits1 >> bit_idx1) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx2 = (iqs + 2) / 8; const uint bit_idx2 = (iqs + 2) % 8; const uint bits2 = uint(data_a[a_offset + ib].qs[byte_idx2]); const float s2 = ((bits2 >> bit_idx2) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx3 = (iqs + 3) / 8; const uint bit_idx3 = (iqs + 3) % 8; const uint bits3 = uint(data_a[a_offset + ib].qs[byte_idx3]); const float s3 = ((bits3 >> bit_idx3) & 1) == 1 ? 1.0f : -1.0f; return vec4(s0, s1, s2, s3);}#endif
453+
#if defined(DATA_A_Q1_0) || defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
452454
vec2 get_dm(uint ib, uint a_offset) {
453455
return vec2(float(data_a[a_offset + ib].d), 0);
454456
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#version 450
2+
3+
#include "dequant_head.glsl"
4+
5+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
6+
7+
layout (binding = 0) readonly buffer A {block_q1_0 data_a[];};
8+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9+
10+
void main() {
11+
// Each thread handles one 128-element block
12+
const uint ib = gl_WorkGroupID.x * 256 + gl_LocalInvocationID.x;
13+
14+
if (ib >= p.nel / 128) {
15+
return;
16+
}
17+
18+
const uint b_idx = ib * 128;
19+
const float d = float(data_a[ib].d);
20+
21+
// Each block has 16 bytes = 128 bits = 128 elements
22+
[[unroll]] for (uint byte_idx = 0; byte_idx < 16; ++byte_idx) {
23+
const uint bits = uint(data_a[ib].qs[byte_idx]);
24+
[[unroll]] for (uint bit_idx = 0; bit_idx < 8; ++bit_idx) {
25+
const float sign = ((bits >> bit_idx) & 1) == 1 ? 1.0f : -1.0f;
26+
data_b[b_idx + byte_idx * 8 + bit_idx] = D_TYPE(d * sign);
27+
}
28+
}
29+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#version 450
2+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
3+
#extension GL_EXT_shader_8bit_storage : require
4+
5+
#include "mul_mat_vec_base.glsl"
6+
7+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
8+
9+
// Fused 1-bit matrix-vector multiply for Q1_0.
10+
// 4 threads per block, each handles 32 elements (one uint32 of packed bits).
11+
// Uses simple ternary sign selection which compiles to v_cndmask on RDNA.
12+
13+
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
14+
15+
void calc_block(const uint a_offset, const uint b_offset, const uint itid, const uint i,
16+
const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
17+
18+
const uint y_idx_base = i * 128 + itid * 32;
19+
20+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
21+
const uint base_b = (j * p.batch_stride_b + b_offset + y_idx_base) / 4;
22+
const vec4 bv0 = vec4(data_b_v4[base_b]);
23+
const vec4 bv1 = vec4(data_b_v4[base_b + 1]);
24+
const vec4 bv2 = vec4(data_b_v4[base_b + 2]);
25+
const vec4 bv3 = vec4(data_b_v4[base_b + 3]);
26+
const vec4 bv4 = vec4(data_b_v4[base_b + 4]);
27+
const vec4 bv5 = vec4(data_b_v4[base_b + 5]);
28+
const vec4 bv6 = vec4(data_b_v4[base_b + 6]);
29+
const vec4 bv7 = vec4(data_b_v4[base_b + 7]);
30+
31+
uint ibi = a_offset + first_row * num_blocks_per_row + i;
32+
33+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
34+
const float d = float(data_a[ibi].d);
35+
36+
const uint byte_base = itid * 4;
37+
const uint bits = uint(data_a[ibi].qs[byte_base])
38+
| (uint(data_a[ibi].qs[byte_base + 1]) << 8)
39+
| (uint(data_a[ibi].qs[byte_base + 2]) << 16)
40+
| (uint(data_a[ibi].qs[byte_base + 3]) << 24);
41+
42+
FLOAT_TYPE partial = FLOAT_TYPE(0);
43+
44+
partial += FLOAT_TYPE(dot(vec4(
45+
(bits & 0x1u) != 0 ? 1.0 : -1.0, (bits & 0x2u) != 0 ? 1.0 : -1.0,
46+
(bits & 0x4u) != 0 ? 1.0 : -1.0, (bits & 0x8u) != 0 ? 1.0 : -1.0), bv0));
47+
partial += FLOAT_TYPE(dot(vec4(
48+
(bits & 0x10u) != 0 ? 1.0 : -1.0, (bits & 0x20u) != 0 ? 1.0 : -1.0,
49+
(bits & 0x40u) != 0 ? 1.0 : -1.0, (bits & 0x80u) != 0 ? 1.0 : -1.0), bv1));
50+
partial += FLOAT_TYPE(dot(vec4(
51+
(bits & 0x100u) != 0 ? 1.0 : -1.0, (bits & 0x200u) != 0 ? 1.0 : -1.0,
52+
(bits & 0x400u) != 0 ? 1.0 : -1.0, (bits & 0x800u) != 0 ? 1.0 : -1.0), bv2));
53+
partial += FLOAT_TYPE(dot(vec4(
54+
(bits & 0x1000u) != 0 ? 1.0 : -1.0, (bits & 0x2000u) != 0 ? 1.0 : -1.0,
55+
(bits & 0x4000u) != 0 ? 1.0 : -1.0, (bits & 0x8000u) != 0 ? 1.0 : -1.0), bv3));
56+
partial += FLOAT_TYPE(dot(vec4(
57+
(bits & 0x10000u) != 0 ? 1.0 : -1.0, (bits & 0x20000u) != 0 ? 1.0 : -1.0,
58+
(bits & 0x40000u) != 0 ? 1.0 : -1.0, (bits & 0x80000u) != 0 ? 1.0 : -1.0), bv4));
59+
partial += FLOAT_TYPE(dot(vec4(
60+
(bits & 0x100000u) != 0 ? 1.0 : -1.0, (bits & 0x200000u) != 0 ? 1.0 : -1.0,
61+
(bits & 0x400000u) != 0 ? 1.0 : -1.0, (bits & 0x800000u) != 0 ? 1.0 : -1.0), bv5));
62+
partial += FLOAT_TYPE(dot(vec4(
63+
(bits & 0x1000000u) != 0 ? 1.0 : -1.0, (bits & 0x2000000u) != 0 ? 1.0 : -1.0,
64+
(bits & 0x4000000u) != 0 ? 1.0 : -1.0, (bits & 0x8000000u) != 0 ? 1.0 : -1.0), bv6));
65+
partial += FLOAT_TYPE(dot(vec4(
66+
(bits & 0x10000000u) != 0 ? 1.0 : -1.0, (bits & 0x20000000u) != 0 ? 1.0 : -1.0,
67+
(bits & 0x40000000u) != 0 ? 1.0 : -1.0, (bits & 0x80000000u) != 0 ? 1.0 : -1.0), bv7));
68+
69+
temp[j][n] = fma(FLOAT_TYPE(d), partial, temp[j][n]);
70+
ibi += num_blocks_per_row;
71+
}
72+
}
73+
}
74+
75+
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
76+
uint a_offset, b_offset, d_offset;
77+
get_offsets(a_offset, b_offset, d_offset);
78+
79+
const uint num_blocks_per_row = p.ncols / 128;
80+
const uint blocks_per_wg = gl_WorkGroupSize.x / 4;
81+
const uint tid = gl_LocalInvocationID.x;
82+
const uint itid = tid % 4;
83+
const uint ix = tid / 4;
84+
85+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
86+
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
87+
temp[j][i] = FLOAT_TYPE(0);
88+
}
89+
}
90+
91+
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
92+
calc_block(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
93+
94+
reduce_result(temp, d_offset, first_row, num_rows, tid);
95+
}
96+
97+
void main() {
98+
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
99+
100+
if (first_row + NUM_ROWS <= p.stride_d) {
101+
compute_outputs(first_row, NUM_ROWS);
102+
} else {
103+
if (first_row >= p.stride_d) {
104+
return;
105+
}
106+
compute_outputs(first_row, p.stride_d - first_row);
107+
}
108+
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,36 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
128128
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
129129
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
130130

131+
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
132+
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
133+
#elif defined(DATA_A_Q1_0)
134+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
135+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
136+
137+
// LOAD_VEC_A = 4, so each load processes 4 elements.
138+
// 128 elements per block / 4 = 32 loads per block.
139+
const uint ib = idx / 32; // block index
140+
const uint iel = (idx % 32) * 4; // element offset within block (0,4,8,...124)
141+
142+
const float d = float(data_a[ib].d);
143+
const float d2 = d + d;
144+
const float neg_d = -d;
145+
146+
// Decode the containing 16-bit chunk, then select the 4-bit sub-group for this load.
147+
const uint chunk16 = iel / 16;
148+
const uint chunk_bit = iel % 16;
149+
const uint byte_offset = chunk16 * 2;
150+
const uint bits16 = uint(data_a[ib].qs[byte_offset])
151+
| (uint(data_a[ib].qs[byte_offset + 1]) << 8);
152+
const uint bits = (bits16 >> chunk_bit) & 0xFu;
153+
154+
// Branchless FMA: d*(2*bit-1) = fma(2d, bit_float, -d)
155+
const vec4 bit_floats = vec4(
156+
float(bits & 1u), float((bits >> 1) & 1u),
157+
float((bits >> 2) & 1u), float((bits >> 3) & 1u)
158+
);
159+
const vec4 v = fma(vec4(d2), bit_floats, vec4(neg_d));
160+
131161
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
132162
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
133163
#elif defined(DATA_A_Q2_K)

ggml/src/ggml-vulkan/vulkan-shaders/types.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
66
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
77
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
8+
#extension GL_EXT_shader_8bit_storage : require
89
#extension GL_EXT_shader_16bit_storage : require
910

1011
#if defined(DATA_A_F32)
@@ -46,6 +47,7 @@
4647
#endif
4748
#endif
4849

50+
#define QUANT_K_Q1_0 128#define QUANT_R_Q1_0 1struct block_q1_0{ float16_t d; uint8_t qs[16];};#if defined(DATA_A_Q1_0)#define QUANT_K QUANT_K_Q1_0#define QUANT_R QUANT_R_Q1_0#define QUANT_AUXF 1#define A_TYPE block_q1_0#define DATA_A_QUANT_LEGACY#endif
4951
#define QUANT_K_Q4_0 32
5052
#define QUANT_R_Q4_0 2
5153

0 commit comments

Comments
 (0)