Skip to content

Commit 434b2a1

Browse files
authored
ggml-webgpu: add Q1_0 support (ggml-org#22374)
* add fast matmul matvec q1_0 kernel * ggml-webgpu: drop redundant zero-fills in Q1_0 shmem init
1 parent 983ca89 commit 434b2a1

5 files changed

Lines changed: 94 additions & 2 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,7 @@ class ggml_webgpu_shader_lib {
12871287
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
12881288

12891289
switch (key.src_type) {
1290+
case GGML_TYPE_Q1_0:
12901291
case GGML_TYPE_Q4_0:
12911292
case GGML_TYPE_Q5_0:
12921293
case GGML_TYPE_Q8_0:
@@ -1323,7 +1324,9 @@ class ggml_webgpu_shader_lib {
13231324

13241325
defines.push_back("DST_TYPE=f32");
13251326

1326-
if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
1327+
if (key.src_type == GGML_TYPE_Q1_0) {
1328+
defines.push_back("BLOCK_SIZE=128u");
1329+
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
13271330
key.src_type == GGML_TYPE_IQ4_NL) {
13281331
defines.push_back("BLOCK_SIZE=32u");
13291332
} else if (key.src_type >= GGML_TYPE_Q2_K) {
@@ -1657,7 +1660,9 @@ class ggml_webgpu_shader_lib {
16571660
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
16581661
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
16591662

1660-
if (key.src0_type >= GGML_TYPE_Q2_K) {
1663+
if (key.src0_type == GGML_TYPE_Q1_0) {
1664+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
1665+
} else if (key.src0_type >= GGML_TYPE_Q2_K) {
16611666
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
16621667
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
16631668
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
13891389
case GGML_TYPE_Q5_K:
13901390
case GGML_TYPE_Q3_K:
13911391
case GGML_TYPE_Q2_K:
1392+
case GGML_TYPE_Q1_0:
13921393
use_fast = true;
13931394
break;
13941395
case GGML_TYPE_IQ1_S:
@@ -3736,6 +3737,7 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm
37363737

37373738
static bool ggml_webgpu_supported_qtype(ggml_type type) {
37383739
switch (type) {
3740+
case GGML_TYPE_Q1_0:
37393741
case GGML_TYPE_Q4_0:
37403742
case GGML_TYPE_Q4_1:
37413743
case GGML_TYPE_Q5_0:
@@ -3830,6 +3832,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
38303832
switch (src0->type) {
38313833
case GGML_TYPE_F32:
38323834
case GGML_TYPE_F16:
3835+
case GGML_TYPE_Q1_0:
38333836
case GGML_TYPE_Q4_0:
38343837
case GGML_TYPE_Q4_1:
38353838
case GGML_TYPE_Q5_0:
@@ -3868,6 +3871,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
38683871
switch (src0->type) {
38693872
case GGML_TYPE_F32:
38703873
case GGML_TYPE_F16:
3874+
case GGML_TYPE_Q1_0:
38713875
case GGML_TYPE_Q4_0:
38723876
case GGML_TYPE_Q4_1:
38733877
case GGML_TYPE_Q5_0:

ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
2727
}
2828
#endif
2929

30+
#ifdef Q1_0
31+
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
32+
let block_byte_base = (src_base + offset) * 18;
33+
let d = load_f16_as_f32_at_src(block_byte_base);
34+
for (var j: u32 = 0u; j < 4u; j++) {
35+
let q_packed = load_u32_at_src(block_byte_base + 2u + j * 4u);
36+
let dst_base128 = dst_base + offset * 128u + j * 32u;
37+
for (var k: u32 = 0; k < 4u; k++) {
38+
let q_byte = get_byte(q_packed, k);
39+
for (var bit: u32 = 0; bit < 8u; bit++) {
40+
let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
41+
dst[dst_base128 + k * 8u + bit] = w;
42+
}
43+
}
44+
}
45+
}
46+
#endif
47+
3048
#ifdef Q4_0
3149
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
3250
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,39 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
6161
#endif // INIT_SRC1_SHMEM_FLOAT
6262
#endif
6363

64+
#ifdef INIT_SRC0_SHMEM_Q1_0
65+
const BLOCK_SIZE = 128u;
66+
const BLOCK_SIZE_BYTES = 18u;
67+
const NQ = 8u; // 8 weights (1 byte of qs) per thread per iteration
68+
69+
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
70+
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
71+
let tile_m = i / TILE_K;
72+
let tile_k_start = i % TILE_K;
73+
let global_m = offset_m + tile_m;
74+
let global_k_start = k_outer + tile_k_start;
75+
76+
if (global_m >= params.m) {
77+
break;
78+
}
79+
80+
let block_k = global_k_start / BLOCK_SIZE;
81+
let byte_in_block = (global_k_start % BLOCK_SIZE) / 8u;
82+
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
83+
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
84+
let d = load_f16_at_src0(block_byte_base);
85+
let q_byte = load_u32_at_src0(block_byte_base + 2u + byte_in_block) & 0xFFu;
86+
87+
for (var bit = 0u; bit < NQ; bit++) {
88+
let global_k = global_k_start + bit;
89+
if (global_k < params.k) {
90+
shmem[i + bit] = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
91+
}
92+
}
93+
}
94+
}
95+
#endif // INIT_SRC0_SHMEM_Q1_0
96+
6497
#ifdef INIT_SRC0_SHMEM_Q4_0
6598
const BLOCK_SIZE = 32u;
6699
const BLOCK_SIZE_BYTES = 18u;

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,38 @@ fn main(
128128
}
129129
#endif
130130

131+
#ifdef MUL_ACC_Q1_0
132+
#define BLOCK_SIZE 128
133+
#define BLOCK_SIZE_BYTES 18
134+
#define THREADS_PER_BLOCK 16
135+
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
136+
137+
let num_blocks = params.k / BLOCK_SIZE;
138+
let thread_within_block = thread_id % THREADS_PER_BLOCK;
139+
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
140+
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
141+
var x_block: array<f32, ELEMS_PER_THREAD>;
142+
for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
143+
x_block[i] = f32(src1[x_base + i]);
144+
}
145+
146+
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
147+
let output_row = row_base + row;
148+
if (output_row < params.m) {
149+
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
150+
let d = f32(load_f16_at_src0(block_byte_base));
151+
let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu;
152+
var row_sum = 0.0;
153+
for (var bit = 0u; bit < 8u; bit++) {
154+
let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
155+
row_sum += w * x_block[bit];
156+
}
157+
acc[row] += row_sum;
158+
}
159+
}
160+
}
161+
#endif
162+
131163
#ifdef MUL_ACC_Q4_0
132164
#define BLOCK_SIZE 32
133165
#define BLOCK_SIZE_BYTES 18

0 commit comments

Comments
 (0)