Skip to content

Commit 068eed2

Browse files
jeffbolznvSascha
authored andcommitted
vulkan: Support Q1_0 (ggml-org#21539)
* vulkan: Support Q1_0 * use get_dm
1 parent ec74db9 commit 068eed2

File tree

9 files changed

+161
-4
lines changed

9 files changed

+161
-4
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 33 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,31 @@ void quantize(uint dst_idx, uint src_idx)
184184
}
185185
#endif
186186

187+
#if defined(DATA_A_Q1_0)
188+
void quantize(uint dst_idx, uint src_idx)
189+
{
190+
float sum_abs = 0.0;
191+
192+
[[unroll]] for (int j = 0; j < QUANT_K_Q1_0; j++) {
193+
sum_abs += abs(data_s[src_idx + j]);
194+
}
195+
196+
const float d = sum_abs / QUANT_K_Q1_0;
197+
198+
data_q[dst_idx].d = float16_t(d);
199+
200+
[[unroll]] for (int j = 0; j < QUANT_K_Q1_0 / 8; ++j) {
201+
data_q[dst_idx].qs[j] = uint8_t(0);
202+
}
203+
204+
[[unroll]] for (int j = 0; j < QUANT_K_Q1_0; ++j) {
205+
if (data_s[src_idx + j] >= 0.0) {
206+
data_q[dst_idx].qs[j / 8] |= uint8_t(1 << (j % 8));
207+
}
208+
}
209+
}
210+
#endif
211+
187212
#if defined(DATA_A_IQ4_NL)
188213
uint best_index(float x) {
189214
if (x <= kvalues_iq4nl[0]) return 0;

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,23 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
8787
}
8888
#endif
8989

90+
#if defined(DATA_A_Q1_0)
91+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
92+
const uint bits = uint(data_a[a_offset + ib].qs[iqs / 8u]) >> (iqs % 8u);
93+
return vec2(
94+
(bits & 1u) != 0u ? 1.0f : -1.0f,
95+
(bits & 2u) != 0u ? 1.0f : -1.0f);
96+
}
97+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
98+
const uint bits = uint(data_a[a_offset + ib].qs[iqs / 8u]) >> (iqs % 8u);
99+
return vec4(
100+
(bits & 1u) != 0u ? 1.0f : -1.0f,
101+
(bits & 2u) != 0u ? 1.0f : -1.0f,
102+
(bits & 4u) != 0u ? 1.0f : -1.0f,
103+
(bits & 8u) != 0u ? 1.0f : -1.0f);
104+
}
105+
#endif
106+
90107
#if defined(DATA_A_IQ1_S)
91108
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
92109
const uint ib32 = iqs / 32;
@@ -454,6 +471,13 @@ vec2 get_dm(uint ib, uint a_offset) {
454471
}
455472
#endif
456473

474+
#if defined(DATA_A_Q1_0)
475+
vec2 get_dm(uint ib, uint a_offset) {
476+
const float d = float(data_a[a_offset + ib].d);
477+
return vec2(d, 0);
478+
}
479+
#endif
480+
457481
#if defined(DATA_A_MXFP4)
458482
vec2 get_dm(uint ib, uint a_offset) {
459483
return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,18 @@ float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2],
1313
return vf16[idx];
1414
}
1515

16+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ1_0 {
17+
block_q1_0 block;
18+
};
19+
20+
float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
21+
{
22+
const float16_t d = bl.block.d;
23+
const uint idx = coordInBlock[1];
24+
const uint bit = (uint(bl.block.qs[(idx & 0x78) >> 3]) >> (idx & 0x7)) & 1u;
25+
return bit != 0u ? d : -d;
26+
}
27+
1628
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
1729
block_q4_0_packed16 block;
1830
};
@@ -685,7 +697,9 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
685697
}
686698
#endif
687699

688-
#if defined(DATA_A_Q4_0)
700+
#if defined(DATA_A_Q1_0)
701+
#define dequantFuncA dequantFuncQ1_0
702+
#elif defined(DATA_A_Q4_0)
689703
#define dequantFuncA dequantFuncQ4_0
690704
#elif defined(DATA_A_Q4_1)
691705
#define dequantFuncA dequantFuncQ4_1
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#version 450
2+
3+
#include "dequant_head.glsl"
4+
5+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
6+
7+
layout (binding = 0) readonly buffer A {block_q1_0 data_a[];};
8+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9+
10+
void main() {
11+
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
12+
13+
const uint tid = gl_LocalInvocationID.x % 64;
14+
const uint il = tid / 4;
15+
const uint ir = tid % 4;
16+
const uint ib = 4*i + ir;
17+
if (ib >= p.nel / 128) {
18+
return;
19+
}
20+
21+
const uint b_idx = 512*i + 128*ir + 8*il;
22+
23+
const float d = float(data_a[ib].d);
24+
const uint bits = uint(data_a[ib].qs[il]);
25+
26+
[[unroll]] for (uint l = 0; l < 8; ++l) {
27+
data_b[b_idx + l] = D_TYPE((bits & (1u << l)) != 0u ? d : -d);
28+
}
29+
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
130130

131131
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy);
132132
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw);
133+
#elif defined(DATA_A_Q1_0)
134+
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
135+
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
136+
137+
const uint ib = idx / 16;
138+
const uint iqs = idx & 0xfu;
139+
140+
const float d = float(data_a[ib].d);
141+
const uint bits = uint(data_a[ib].qs[iqs]);
142+
143+
buf_a[buf_idx ] = FLOAT_TYPEV2((bits & 0x01u) != 0u ? d : -d, (bits & 0x02u) != 0u ? d : -d);
144+
buf_a[buf_idx + 1] = FLOAT_TYPEV2((bits & 0x04u) != 0u ? d : -d, (bits & 0x08u) != 0u ? d : -d);
145+
buf_a[buf_idx + 2] = FLOAT_TYPEV2((bits & 0x10u) != 0u ? d : -d, (bits & 0x20u) != 0u ? d : -d);
146+
buf_a[buf_idx + 3] = FLOAT_TYPEV2((bits & 0x40u) != 0u ? d : -d, (bits & 0x80u) != 0u ? d : -d);
133147
#elif defined(DATA_A_Q2_K)
134148
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
135149
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;

ggml/src/ggml-vulkan/vulkan-shaders/types.glsl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,22 @@ struct block_q8_0_packed16
188188
#define DATA_A_QUANT_LEGACY
189189
#endif
190190

191+
#define QUANT_K_Q1_0 128
192+
#define QUANT_R_Q1_0 1
193+
194+
struct block_q1_0
195+
{
196+
float16_t d;
197+
uint8_t qs[QUANT_K_Q1_0 / 8];
198+
};
199+
200+
#if defined(DATA_A_Q1_0)
201+
#define QUANT_K QUANT_K_Q1_0
202+
#define QUANT_R QUANT_R_Q1_0
203+
#define QUANT_AUXF 1
204+
#define A_TYPE block_q1_0
205+
#endif
206+
191207
#define QUANT_K_Q8_1 32
192208
#define QUANT_R_Q8_1 1
193209

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ std::string target_cpp = "";
4545
const std::vector<std::string> type_names = {
4646
"f32",
4747
"f16",
48+
"q1_0",
4849
"q4_0",
4950
"q4_1",
5051
"q5_0",
@@ -553,7 +554,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
553554

554555
for (const auto& tname : type_names) {
555556
std::string load_vec_quant = "2";
556-
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
557+
if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
557558
load_vec_quant = "8";
558559
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4"))
559560
load_vec_quant = "4";
@@ -758,13 +759,13 @@ void process_shaders() {
758759
string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
759760
string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}});
760761

761-
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
762+
for (std::string t : {"q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
762763
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
763764
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
764765
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
765766
}
766767

767-
for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
768+
for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
768769
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
769770
string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
770771
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7265,6 +7265,7 @@ static const ggml_type all_types[] = {
72657265
static const ggml_type base_types[] = {
72667266
GGML_TYPE_F32, GGML_TYPE_F16,
72677267
GGML_TYPE_Q8_0, // for I8MM tests
7268+
GGML_TYPE_Q1_0,
72687269
GGML_TYPE_Q4_0,
72697270
GGML_TYPE_Q4_1, // for I8MM tests
72707271
GGML_TYPE_Q4_K,

0 commit comments

Comments
 (0)