Skip to content

Commit f839c10

Browse files
committed
Work on remaining legacy q-types
1 parent 01bd912 commit f839c10

3 files changed

Lines changed: 197 additions & 10 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,7 +1336,9 @@ class ggml_webgpu_shader_lib {
13361336

13371337
webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
13381338
const bool use_row_tiled =
1339-
context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16 || context.src0->type == GGML_TYPE_Q4_0;
1339+
context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16 || context.src0->type == GGML_TYPE_Q4_0 ||
1340+
context.src0->type == GGML_TYPE_Q4_1 || context.src0->type == GGML_TYPE_Q5_0 || context.src0->type == GGML_TYPE_Q5_1 ||
1341+
context.src0->type == GGML_TYPE_Q8_0 || context.src0->type == GGML_TYPE_Q8_1;
13401342
ggml_webgpu_mul_mat_vec_pipeline_key key = {
13411343
.src0_type = context.src0->type,
13421344
.src1_type = context.src1->type,
@@ -1368,13 +1370,6 @@ class ggml_webgpu_shader_lib {
13681370
defines.push_back("MUL_ACC_FLOAT");
13691371
variant += "_f16";
13701372
break;
1371-
case GGML_TYPE_Q4_0:
1372-
defines.push_back("BYTE_HELPERS");
1373-
defines.push_back("MUL_ACC_Q4_0");
1374-
defines.push_back("U32_DEQUANT_HELPERS");
1375-
defines.push_back("SRC0_INNER_TYPE=u32");
1376-
variant += "_q4_0";
1377-
break;
13781373
default:
13791374
{
13801375
// Quantized types: use helpers but accumulate in f16

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,13 +1324,13 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
13241324
case GGML_TYPE_F32:
13251325
case GGML_TYPE_F16:
13261326
case GGML_TYPE_Q4_0:
1327-
use_fast = true;
1328-
break;
13291327
case GGML_TYPE_Q4_1:
13301328
case GGML_TYPE_Q5_0:
13311329
case GGML_TYPE_Q5_1:
13321330
case GGML_TYPE_Q8_0:
13331331
case GGML_TYPE_Q8_1:
1332+
use_fast = true;
1333+
break;
13341334
case GGML_TYPE_Q6_K:
13351335
use_fast = !is_vec || ctx->global_ctx->capabilities.supports_subgroups;
13361336
break;

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,198 @@ fn main(
155155
}
156156
#endif
157157

158+
#ifdef MUL_ACC_Q4_1
159+
#define BLOCK_SIZE 32
160+
#define BLOCK_SIZE_BYTES 20
161+
#define THREADS_PER_BLOCK 4
162+
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
163+
164+
let num_blocks = params.k / BLOCK_SIZE;
165+
let thread_within_block = thread_id % THREADS_PER_BLOCK;
166+
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
167+
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
168+
var x_block: array<f32, ELEMS_PER_THREAD>;
169+
for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
170+
x_block[i] = f32(src1[x_base + i]);
171+
x_block[i + 4] = f32(src1[x_base + i + 16]);
172+
}
173+
174+
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
175+
let output_row = row_base + row;
176+
if (output_row < params.m) {
177+
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
178+
let d = f32(load_src0_f16_at(block_byte_base));
179+
let m = f32(load_src0_f16_at(block_byte_base + 2u));
180+
var row_sum = 0.0;
181+
182+
let q_packed = load_src0_u32_at(block_byte_base + 4u + 4u * thread_within_block);
183+
for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
184+
let q_byte = get_byte(q_packed, byte_idx);
185+
let q_lo = f32(q_byte & 0xFu) * d + m;
186+
let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m;
187+
row_sum += q_lo * x_block[byte_idx];
188+
row_sum += q_hi * x_block[byte_idx + 4u];
189+
}
190+
acc[row] += row_sum;
191+
}
192+
}
193+
}
194+
#endif
195+
196+
#ifdef MUL_ACC_Q5_0
197+
#define BLOCK_SIZE 32
198+
#define BLOCK_SIZE_BYTES 22
199+
#define THREADS_PER_BLOCK 4
200+
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
201+
202+
let num_blocks = params.k / BLOCK_SIZE;
203+
let thread_within_block = thread_id % THREADS_PER_BLOCK;
204+
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
205+
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
206+
var x_block: array<f32, ELEMS_PER_THREAD>;
207+
for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
208+
x_block[i] = f32(src1[x_base + i]);
209+
x_block[i + 4] = f32(src1[x_base + i + 16]);
210+
}
211+
212+
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
213+
let output_row = row_base + row;
214+
if (output_row < params.m) {
215+
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
216+
let d = f32(load_src0_f16_at(block_byte_base));
217+
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
218+
let q_packed = load_src0_u32_at(block_byte_base + 6u + 4u * thread_within_block);
219+
let qh_shift = thread_within_block * 4u;
220+
var row_sum = 0.0;
221+
222+
for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
223+
let q_byte = get_byte(q_packed, byte_idx);
224+
let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u;
225+
let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u;
226+
let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d;
227+
let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d;
228+
row_sum += q_lo * x_block[byte_idx];
229+
row_sum += q_hi * x_block[byte_idx + 4u];
230+
}
231+
acc[row] += row_sum;
232+
}
233+
}
234+
}
235+
#endif
236+
237+
#ifdef MUL_ACC_Q5_1
238+
#define BLOCK_SIZE 32
239+
#define BLOCK_SIZE_BYTES 24
240+
#define THREADS_PER_BLOCK 4
241+
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
242+
243+
let num_blocks = params.k / BLOCK_SIZE;
244+
let thread_within_block = thread_id % THREADS_PER_BLOCK;
245+
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
246+
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
247+
var x_block: array<f32, ELEMS_PER_THREAD>;
248+
for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
249+
x_block[i] = f32(src1[x_base + i]);
250+
x_block[i + 4] = f32(src1[x_base + i + 16]);
251+
}
252+
253+
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
254+
let output_row = row_base + row;
255+
if (output_row < params.m) {
256+
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
257+
let d = f32(load_src0_f16_at(block_byte_base));
258+
let m = f32(load_src0_f16_at(block_byte_base + 2u));
259+
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
260+
let q_packed = load_src0_u32_at(block_byte_base + 8u + 4u * thread_within_block);
261+
let qh_shift = thread_within_block * 4u;
262+
var row_sum = 0.0;
263+
264+
for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
265+
let q_byte = get_byte(q_packed, byte_idx);
266+
let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u;
267+
let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u;
268+
let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m;
269+
let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m;
270+
row_sum += q_lo * x_block[byte_idx];
271+
row_sum += q_hi * x_block[byte_idx + 4u];
272+
}
273+
acc[row] += row_sum;
274+
}
275+
}
276+
}
277+
#endif
278+
279+
#ifdef MUL_ACC_Q8_0
280+
#define BLOCK_SIZE 32
281+
#define BLOCK_SIZE_BYTES 34
282+
#define THREADS_PER_BLOCK 4
283+
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
284+
285+
let num_blocks = params.k / BLOCK_SIZE;
286+
let thread_within_block = thread_id % THREADS_PER_BLOCK;
287+
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
288+
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
289+
var x_block: array<f32, ELEMS_PER_THREAD>;
290+
for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
291+
x_block[i] = f32(src1[x_base + i]);
292+
}
293+
294+
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
295+
let output_row = row_base + row;
296+
if (output_row < params.m) {
297+
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
298+
let d = f32(load_src0_f16_at(block_byte_base));
299+
var row_sum = 0.0;
300+
301+
for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) {
302+
let q_packed = load_src0_u32_at(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx));
303+
for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
304+
let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d;
305+
row_sum += q_val * x_block[packed_idx * 4u + byte_idx];
306+
}
307+
}
308+
acc[row] += row_sum;
309+
}
310+
}
311+
}
312+
#endif
313+
314+
#ifdef MUL_ACC_Q8_1
315+
#define BLOCK_SIZE 32
316+
#define BLOCK_SIZE_BYTES 36
317+
#define THREADS_PER_BLOCK 4
318+
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
319+
320+
let num_blocks = params.k / BLOCK_SIZE;
321+
let thread_within_block = thread_id % THREADS_PER_BLOCK;
322+
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
323+
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
324+
var x_block: array<f32, ELEMS_PER_THREAD>;
325+
for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
326+
x_block[i] = f32(src1[x_base + i]);
327+
}
328+
329+
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
330+
let output_row = row_base + row;
331+
if (output_row < params.m) {
332+
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
333+
let d = f32(load_src0_f16_at(block_byte_base));
334+
let m = f32(load_src0_f16_at(block_byte_base + 2u));
335+
var row_sum = 0.0;
336+
337+
for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) {
338+
let q_packed = load_src0_u32_at(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx));
339+
for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
340+
let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m;
341+
row_sum += q_val * x_block[packed_idx * 4u + byte_idx];
342+
}
343+
}
344+
acc[row] += row_sum;
345+
}
346+
}
347+
}
348+
#endif
349+
158350
#ifdef USE_SUBGROUP_REDUCTION
159351
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
160352
let subgroup_total = subgroupAdd(acc[row]);

0 commit comments

Comments
 (0)