Skip to content

Commit 20d3bc2

Browse files
authored
ggml-webgpu: Fix dequantization helpers to not pass in pointers (ggml-org#21872)
* Fix dequantization helpers to not pass in pointers * Increase XIELU precision
1 parent a620695 commit 20d3bc2

File tree

9 files changed

+223
-192
lines changed

9 files changed

+223
-192
lines changed

ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,42 +9,65 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
99
#endif
1010

1111
#ifdef U32_DEQUANT_HELPERS
12-
fn load_u16_at(
13-
buf: ptr<storage, array<u32>, read_write>,
14-
byte_offset: u32) -> u32 {
15-
let word = buf[byte_offset / 4];
16-
let shift = (byte_offset & 0x2) * 8;
17-
return (word >> shift) & 0xFFFF;
12+
#ifdef DECLARE_BYTE_LOADERS_SRC
13+
fn load_u16_at_src(byte_offset: u32) -> u32 {
14+
let word = src[byte_offset / 4u];
15+
let shift = (byte_offset & 0x2u) * 8u;
16+
return (word >> shift) & 0xFFFFu;
1817
}
1918

20-
fn load_u32_at(
21-
buf: ptr<storage, array<u32>, read_write>,
22-
byte_offset: u32) -> u32 {
23-
let word_idx = byte_offset / 4;
24-
let shift = (byte_offset & 0x3) * 8;
25-
let lo = buf[word_idx];
26-
let hi = buf[word_idx + 1];
27-
let shifted = (lo >> shift) | (hi << (32 - shift));
28-
return select(shifted, lo, shift == 0);
19+
fn load_u32_at_src(byte_offset: u32) -> u32 {
20+
let word_idx = byte_offset / 4u;
21+
let shift = (byte_offset & 0x3u) * 8u;
22+
let lo = src[word_idx];
23+
let hi = src[word_idx + 1u];
24+
let shifted = (lo >> shift) | (hi << (32u - shift));
25+
return select(shifted, lo, shift == 0u);
2926
}
3027

31-
fn load_f16_at(
32-
buf: ptr<storage, array<u32>, read_write>,
33-
byte_offset: u32) -> f16 {
34-
let packed = unpack2x16float(load_u16_at(buf, byte_offset));
28+
fn load_f16_at_src(byte_offset: u32) -> f16 {
29+
let packed = unpack2x16float(load_u16_at_src(byte_offset));
3530
return f16(packed[0]);
3631
}
3732

38-
fn load_f16_as_f32_at(
39-
buf: ptr<storage, array<u32>, read_write>,
40-
byte_offset: u32) -> f32 {
41-
let word = buf[byte_offset / 4];
42-
let shift = (byte_offset & 0x2) * 8;
43-
let d_bits = (word >> shift) & 0xFFFF;
33+
fn load_f16_as_f32_at_src(byte_offset: u32) -> f32 {
34+
let word = src[byte_offset / 4u];
35+
let shift = (byte_offset & 0x2u) * 8u;
36+
let d_bits = (word >> shift) & 0xFFFFu;
4437
return unpack2x16float(d_bits)[0];
4538
}
4639
#endif
4740

41+
#ifdef DECLARE_BYTE_LOADERS_SRC0
42+
fn load_u16_at_src0(byte_offset: u32) -> u32 {
43+
let word = src0[byte_offset / 4u];
44+
let shift = (byte_offset & 0x2u) * 8u;
45+
return (word >> shift) & 0xFFFFu;
46+
}
47+
48+
fn load_u32_at_src0(byte_offset: u32) -> u32 {
49+
let word_idx = byte_offset / 4u;
50+
let shift = (byte_offset & 0x3u) * 8u;
51+
let lo = src0[word_idx];
52+
let hi = src0[word_idx + 1u];
53+
let shifted = (lo >> shift) | (hi << (32u - shift));
54+
return select(shifted, lo, shift == 0u);
55+
}
56+
57+
fn load_f16_at_src0(byte_offset: u32) -> f16 {
58+
let packed = unpack2x16float(load_u16_at_src0(byte_offset));
59+
return f16(packed[0]);
60+
}
61+
62+
fn load_f16_as_f32_at_src0(byte_offset: u32) -> f32 {
63+
let word = src0[byte_offset / 4u];
64+
let shift = (byte_offset & 0x2u) * 8u;
65+
let d_bits = (word >> shift) & 0xFFFFu;
66+
return unpack2x16float(d_bits)[0];
67+
}
68+
#endif
69+
#endif
70+
4871

4972

5073
#ifdef Q4_1_T

ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
enable f16;
2+
#define DECLARE_BYTE_LOADERS_SRC
23
#include "common_decls.tmpl"
34

5+
46
#ifdef F32_VEC
57
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
68
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
@@ -28,10 +30,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
2830
#ifdef Q4_0
2931
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
3032
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
31-
let d = load_f16_as_f32_at(&src, block_byte_base);
33+
let d = load_f16_as_f32_at_src(block_byte_base);
3234
for (var j: u32 = 0u; j < 4; j++) {
3335
let q_byte_offset = block_byte_base + 2 + j * 4;
34-
let q_packed = load_u32_at(&src, q_byte_offset);
36+
let q_packed = load_u32_at_src(q_byte_offset);
3537
for (var k: u32 = 0; k < 4; k++) {
3638
let q_byte = get_byte(q_packed, k);
3739
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
@@ -66,11 +68,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
6668
#ifdef Q5_0
6769
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
6870
let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes
69-
let d = load_f16_as_f32_at(&src, block_byte_base);
70-
let qh_packed = load_u32_at(&src, block_byte_base + 2);
71+
let d = load_f16_as_f32_at_src(block_byte_base);
72+
let qh_packed = load_u32_at_src(block_byte_base + 2);
7173
for (var j: u32 = 0; j < 4; j++) {
7274
let q_byte_offset = block_byte_base + 6 + j * 4;
73-
let q_packed = load_u32_at(&src, q_byte_offset);
75+
let q_packed = load_u32_at_src(q_byte_offset);
7476

7577
for (var k: u32 = 0; k < 4; k++) {
7678
let q_byte = get_byte(q_packed, k);
@@ -113,10 +115,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
113115
#ifdef Q8_0
114116
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
115117
let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes
116-
let d = load_f16_as_f32_at(&src, block_byte_base);
118+
let d = load_f16_as_f32_at_src(block_byte_base);
117119
for (var j: u32 = 0u; j < 8u; j++) {
118120
let q_byte_offset = block_byte_base + 2u + j * 4u;
119-
let q_packed = load_u32_at(&src, q_byte_offset);
121+
let q_packed = load_u32_at_src(q_byte_offset);
120122
for (var k: u32 = 0u; k < 4u; k++) {
121123
let q_byte = get_byte_i32(q_packed, k);
122124
let q_val = f32(q_byte) * d;
@@ -162,16 +164,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
162164
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
163165

164166
// Bytes 108-109: f16 scale 'd'
165-
let d = load_f16_as_f32_at(&src, block_byte_base + 108);
167+
let d = load_f16_as_f32_at_src(block_byte_base + 108);
166168

167169
// Bytes 96-107: 12 bytes of scales (3 u32s)
168170
let kmask1: u32 = 0x03030303;
169171
let kmask2: u32 = 0x0f0f0f0f;
170172

171173
var scale_vals: array<u32, 4>;
172-
scale_vals[0] = load_u32_at(&src, block_byte_base + 96);
173-
scale_vals[1] = load_u32_at(&src, block_byte_base + 100);
174-
scale_vals[2] = load_u32_at(&src, block_byte_base + 104);
174+
scale_vals[0] = load_u32_at_src(block_byte_base + 96);
175+
scale_vals[1] = load_u32_at_src(block_byte_base + 100);
176+
scale_vals[2] = load_u32_at_src(block_byte_base + 104);
175177

176178
var tmp: u32 = scale_vals[2];
177179
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
@@ -182,13 +184,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
182184
// Bytes 0-31: 32 bytes of hmask (8 u32s)
183185
var hmask_vals: array<u32, 8>;
184186
for (var i: u32 = 0; i < 8; i++) {
185-
hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
187+
hmask_vals[i] = load_u32_at_src(block_byte_base + i * 4);
186188
}
187189

188190
// Bytes 32-95: 64 bytes of qs (16 u32s)
189191
var qs_vals: array<u32, 16>;
190192
for (var i: u32 = 0u; i < 16; i++) {
191-
qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4);
193+
qs_vals[i] = load_u32_at_src(block_byte_base + 32 + i * 4);
192194
}
193195

194196
var dst_i = dst_base + offset * 256;
@@ -286,24 +288,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
286288
let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes
287289

288290
// Bytes 208-209: f16 scale 'd'
289-
let d = load_f16_as_f32_at(&src, block_byte_base + 208);
291+
let d = load_f16_as_f32_at_src(block_byte_base + 208);
290292

291293
// Bytes 0-127: 128 bytes of ql (32 u32s)
292294
var ql_vals: array<u32, 32>;
293295
for (var i: u32 = 0; i < 32; i++) {
294-
ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
296+
ql_vals[i] = load_u32_at_src(block_byte_base + i * 4);
295297
}
296298

297299
// Bytes 128-191: 64 bytes of qh (16 u32s)
298300
var qh_vals: array<u32, 16>;
299301
for (var i: u32 = 0; i < 16u; i++) {
300-
qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u);
302+
qh_vals[i] = load_u32_at_src(block_byte_base + 128 + i * 4u);
301303
}
302304

303305
// Bytes 192-207: 16 bytes of scales (4 u32s)
304306
var scale_vals: array<u32, 4>;
305307
for (var i: u32 = 0; i < 4; i++) {
306-
scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4);
308+
scale_vals[i] = load_u32_at_src(block_byte_base + 192 + i * 4);
307309
}
308310

309311
var dst_i = dst_base + offset * 256;
@@ -345,13 +347,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
345347
#ifdef IQ2_XXS
346348
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
347349
let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes
348-
let d = load_f16_as_f32_at(&src, block_byte_base);
350+
let d = load_f16_as_f32_at_src(block_byte_base);
349351
var dst_i = dst_base + offset * 256;
350352
for (var ib: u32 = 0; ib < 32; ib += 4) {
351353
let aux0_offset = block_byte_base + 2 + ib * 2;
352354
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
353-
let aux0 = load_u32_at(&src, aux0_offset);
354-
let aux1 = load_u32_at(&src, aux1_offset);
355+
let aux0 = load_u32_at_src(aux0_offset);
356+
let aux1 = load_u32_at_src(aux1_offset);
355357
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
356358
for (var l: u32 = 0; l < 4; l++) {
357359
let ig = get_byte(aux0, l) * 8;
@@ -373,12 +375,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
373375
#ifdef IQ2_XS
374376
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
375377
let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes
376-
let d = load_f16_as_f32_at(&src, block_byte_base);
378+
let d = load_f16_as_f32_at_src(block_byte_base);
377379
var dst_i = dst_base + offset * 256;
378380

379381
var scale_vals = array<u32, 2>(
380-
load_u32_at(&src, block_byte_base + 66),
381-
load_u32_at(&src, block_byte_base + 70)
382+
load_u32_at_src(block_byte_base + 66),
383+
load_u32_at_src(block_byte_base + 70)
382384
);
383385

384386
for (var ib: u32 = 0; ib < 32; ib += 4) {
@@ -389,7 +391,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
389391
);
390392
for (var l: u32 = 0; l < 4; l++) {
391393
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
392-
let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF;
394+
let qs_val = load_u32_at_src(qs_offset) & 0xFFFF;
393395
let ig = (qs_val & 511) * 8;
394396
let is = qs_val >> 9;
395397
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
@@ -408,21 +410,21 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
408410
#ifdef IQ2_S
409411
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
410412
let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes
411-
let d = load_f16_as_f32_at(&src, block_byte_base);
413+
let d = load_f16_as_f32_at_src(block_byte_base);
412414
var dst_i = dst_base + offset * 256;
413415

414416
var qs_vals : array<u32, 16>;
415417
for (var i: u32 = 0; i < 16; i++) {
416-
qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
418+
qs_vals[i] = load_u32_at_src(block_byte_base + 2 + i * 4);
417419
}
418420

419421
var qh_vals: array<u32, 2>;
420-
qh_vals[0] = load_u32_at(&src, block_byte_base + 66);
421-
qh_vals[1] = load_u32_at(&src, block_byte_base + 70);
422+
qh_vals[0] = load_u32_at_src(block_byte_base + 66);
423+
qh_vals[1] = load_u32_at_src(block_byte_base + 70);
422424

423425
var scale_vals: array<u32, 2>;
424-
scale_vals[0] = load_u32_at(&src, block_byte_base + 74);
425-
scale_vals[1] = load_u32_at(&src, block_byte_base + 78);
426+
scale_vals[0] = load_u32_at_src(block_byte_base + 74);
427+
scale_vals[1] = load_u32_at_src(block_byte_base + 78);
426428

427429
for (var ib: u32 = 0; ib < 8; ib ++) {
428430
let s = get_byte(scale_vals[ib / 4], ib % 4);
@@ -450,16 +452,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
450452
#ifdef IQ3_XXS
451453
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
452454
let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes
453-
let d = load_f16_as_f32_at(&src, block_byte_base);
455+
let d = load_f16_as_f32_at_src(block_byte_base);
454456
var dst_i = dst_base + offset * 256;
455457
for (var ib: u32 = 0; ib < 16; ib += 2) {
456458
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
457-
let sc_sign = load_u32_at(&src, sc_sign_offset);
459+
let sc_sign = load_u32_at_src(sc_sign_offset);
458460
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
459461
for (var l: u32 = 0; l < 4; l++) {
460462
let is = (sc_sign >> (7 * l)) & 127;
461463
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
462-
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
464+
let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
463465
let ig1 = get_byte(ig_val, 0);
464466
let ig2 = get_byte(ig_val, 1);
465467
for (var j: u32 = 0; j < 4; j++) {
@@ -480,20 +482,20 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
480482
#ifdef IQ3_S
481483
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
482484
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
483-
let d = load_f16_as_f32_at(&src, block_byte_base);
485+
let d = load_f16_as_f32_at_src(block_byte_base);
484486
var dst_i = dst_base + offset * 256;
485487

486488
var qh_vals = array<u32, 2>(
487-
load_u32_at(&src, block_byte_base + 66),
488-
load_u32_at(&src, block_byte_base + 70)
489+
load_u32_at_src(block_byte_base + 66),
490+
load_u32_at_src(block_byte_base + 70)
489491
);
490492

491493
var sign_vals: array<u32, 8>;
492494
for (var i: u32 = 0; i < 8; i++) {
493-
sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4);
495+
sign_vals[i] = load_u32_at_src(block_byte_base + 74 + i * 4);
494496
}
495497

496-
var scale_vals = load_u32_at(&src, block_byte_base + 106);
498+
var scale_vals = load_u32_at_src(block_byte_base + 106);
497499

498500
for (var ib: u32 = 0; ib < 4; ib++) {
499501
let s = get_byte(scale_vals, ib);
@@ -507,7 +509,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
507509
let sign_w = sign_vals[ib * 2 + k];
508510
for (var l: u32 = 0; l < 4; l++) {
509511
let signs = get_byte(sign_w, l);
510-
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
512+
let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
511513
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
512514
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
513515
for (var j: u32 = 0; j < 4; j++) {
@@ -529,13 +531,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
529531
#ifdef IQ1_S
530532
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
531533
let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes
532-
let d = load_f16_as_f32_at(&src, block_byte_base);
534+
let d = load_f16_as_f32_at_src(block_byte_base);
533535
var dst_i = dst_base + offset * 256;
534536
for (var ib: u32 = 0; ib < 8; ib++) {
535-
let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF;
537+
let qh = load_u32_at_src(block_byte_base + 34 + ib * 2) & 0xFFFF;
536538
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
537539
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
538-
let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4);
540+
let qs_w = load_u32_at_src(block_byte_base + 2 + ib * 4);
539541
for (var l: u32 = 0; l < 4; l++) {
540542
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
541543
for (var j: u32 = 0; j < 8; j++) {
@@ -596,11 +598,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
596598
#ifdef IQ4_NL
597599
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
598600
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
599-
let d = load_f16_as_f32_at(&src, block_byte_base);
601+
let d = load_f16_as_f32_at_src(block_byte_base);
600602
var dst_i = dst_base + offset * 32;
601603
var qs: array<u32, 4>;
602604
for (var i: u32 = 0; i < 4; i++) {
603-
qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
605+
qs[i] = load_u32_at_src(block_byte_base + 2 + i * 4);
604606
}
605607
for (var j: u32 = 0; j < 16; j++) {
606608
let qsb = get_byte(qs[j / 4], j % 4);

0 commit comments

Comments
 (0)