Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 48 additions & 25 deletions ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,65 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
#endif

#ifdef U32_DEQUANT_HELPERS
fn load_u16_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> u32 {
let word = buf[byte_offset / 4];
let shift = (byte_offset & 0x2) * 8;
return (word >> shift) & 0xFFFF;
#ifdef DECLARE_BYTE_LOADERS_SRC
fn load_u16_at_src(byte_offset: u32) -> u32 {
let word = src[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
return (word >> shift) & 0xFFFFu;
}

fn load_u32_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4;
let shift = (byte_offset & 0x3) * 8;
let lo = buf[word_idx];
let hi = buf[word_idx + 1];
let shifted = (lo >> shift) | (hi << (32 - shift));
return select(shifted, lo, shift == 0);
fn load_u32_at_src(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 0x3u) * 8u;
let lo = src[word_idx];
let hi = src[word_idx + 1u];
let shifted = (lo >> shift) | (hi << (32u - shift));
return select(shifted, lo, shift == 0u);
}

fn load_f16_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at(buf, byte_offset));
fn load_f16_at_src(byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at_src(byte_offset));
return f16(packed[0]);
}

fn load_f16_as_f32_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> f32 {
let word = buf[byte_offset / 4];
let shift = (byte_offset & 0x2) * 8;
let d_bits = (word >> shift) & 0xFFFF;
fn load_f16_as_f32_at_src(byte_offset: u32) -> f32 {
let word = src[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
let d_bits = (word >> shift) & 0xFFFFu;
return unpack2x16float(d_bits)[0];
}
#endif

#ifdef DECLARE_BYTE_LOADERS_SRC0
fn load_u16_at_src0(byte_offset: u32) -> u32 {
let word = src0[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
return (word >> shift) & 0xFFFFu;
}

fn load_u32_at_src0(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 0x3u) * 8u;
let lo = src0[word_idx];
let hi = src0[word_idx + 1u];
let shifted = (lo >> shift) | (hi << (32u - shift));
return select(shifted, lo, shift == 0u);
}

fn load_f16_at_src0(byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at_src0(byte_offset));
return f16(packed[0]);
}

fn load_f16_as_f32_at_src0(byte_offset: u32) -> f32 {
let word = src0[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
let d_bits = (word >> shift) & 0xFFFFu;
return unpack2x16float(d_bits)[0];
}
#endif
#endif



#ifdef Q4_1_T
Expand Down
90 changes: 46 additions & 44 deletions ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC
#include "common_decls.tmpl"


#ifdef F32_VEC
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
Expand Down Expand Up @@ -28,10 +30,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q4_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
for (var j: u32 = 0u; j < 4; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src, q_byte_offset);
let q_packed = load_u32_at_src(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
Expand Down Expand Up @@ -66,11 +68,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q5_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let qh_packed = load_u32_at(&src, block_byte_base + 2);
let d = load_f16_as_f32_at_src(block_byte_base);
let qh_packed = load_u32_at_src(block_byte_base + 2);
for (var j: u32 = 0; j < 4; j++) {
let q_byte_offset = block_byte_base + 6 + j * 4;
let q_packed = load_u32_at(&src, q_byte_offset);
let q_packed = load_u32_at_src(q_byte_offset);

for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
Expand Down Expand Up @@ -113,10 +115,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q8_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
for (var j: u32 = 0u; j < 8u; j++) {
let q_byte_offset = block_byte_base + 2u + j * 4u;
let q_packed = load_u32_at(&src, q_byte_offset);
let q_packed = load_u32_at_src(q_byte_offset);
for (var k: u32 = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
Expand Down Expand Up @@ -162,16 +164,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes

// Bytes 108-109: f16 scale 'd'
let d = load_f16_as_f32_at(&src, block_byte_base + 108);
let d = load_f16_as_f32_at_src(block_byte_base + 108);

// Bytes 96-107: 12 bytes of scales (3 u32s)
let kmask1: u32 = 0x03030303;
let kmask2: u32 = 0x0f0f0f0f;

var scale_vals: array<u32, 4>;
scale_vals[0] = load_u32_at(&src, block_byte_base + 96);
scale_vals[1] = load_u32_at(&src, block_byte_base + 100);
scale_vals[2] = load_u32_at(&src, block_byte_base + 104);
scale_vals[0] = load_u32_at_src(block_byte_base + 96);
scale_vals[1] = load_u32_at_src(block_byte_base + 100);
scale_vals[2] = load_u32_at_src(block_byte_base + 104);

var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
Expand All @@ -182,13 +184,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
// Bytes 0-31: 32 bytes of hmask (8 u32s)
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
hmask_vals[i] = load_u32_at_src(block_byte_base + i * 4);
}

// Bytes 32-95: 64 bytes of qs (16 u32s)
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16; i++) {
qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4);
qs_vals[i] = load_u32_at_src(block_byte_base + 32 + i * 4);
}

var dst_i = dst_base + offset * 256;
Expand Down Expand Up @@ -286,24 +288,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes

// Bytes 208-209: f16 scale 'd'
let d = load_f16_as_f32_at(&src, block_byte_base + 208);
let d = load_f16_as_f32_at_src(block_byte_base + 208);

// Bytes 0-127: 128 bytes of ql (32 u32s)
var ql_vals: array<u32, 32>;
for (var i: u32 = 0; i < 32; i++) {
ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
ql_vals[i] = load_u32_at_src(block_byte_base + i * 4);
}

// Bytes 128-191: 64 bytes of qh (16 u32s)
var qh_vals: array<u32, 16>;
for (var i: u32 = 0; i < 16u; i++) {
qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u);
qh_vals[i] = load_u32_at_src(block_byte_base + 128 + i * 4u);
}

// Bytes 192-207: 16 bytes of scales (4 u32s)
var scale_vals: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4);
scale_vals[i] = load_u32_at_src(block_byte_base + 192 + i * 4);
}

var dst_i = dst_base + offset * 256;
Expand Down Expand Up @@ -345,13 +347,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 32; ib += 4) {
let aux0_offset = block_byte_base + 2 + ib * 2;
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
let aux0 = load_u32_at(&src, aux0_offset);
let aux1 = load_u32_at(&src, aux1_offset);
let aux0 = load_u32_at_src(aux0_offset);
let aux1 = load_u32_at_src(aux1_offset);
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
for (var l: u32 = 0; l < 4; l++) {
let ig = get_byte(aux0, l) * 8;
Expand All @@ -373,12 +375,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;

var scale_vals = array<u32, 2>(
load_u32_at(&src, block_byte_base + 66),
load_u32_at(&src, block_byte_base + 70)
load_u32_at_src(block_byte_base + 66),
load_u32_at_src(block_byte_base + 70)
);

for (var ib: u32 = 0; ib < 32; ib += 4) {
Expand All @@ -389,7 +391,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
);
for (var l: u32 = 0; l < 4; l++) {
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF;
let qs_val = load_u32_at_src(qs_offset) & 0xFFFF;
let ig = (qs_val & 511) * 8;
let is = qs_val >> 9;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
Expand All @@ -408,21 +410,21 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;

var qs_vals : array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
qs_vals[i] = load_u32_at_src(block_byte_base + 2 + i * 4);
}

var qh_vals: array<u32, 2>;
qh_vals[0] = load_u32_at(&src, block_byte_base + 66);
qh_vals[1] = load_u32_at(&src, block_byte_base + 70);
qh_vals[0] = load_u32_at_src(block_byte_base + 66);
qh_vals[1] = load_u32_at_src(block_byte_base + 70);

var scale_vals: array<u32, 2>;
scale_vals[0] = load_u32_at(&src, block_byte_base + 74);
scale_vals[1] = load_u32_at(&src, block_byte_base + 78);
scale_vals[0] = load_u32_at_src(block_byte_base + 74);
scale_vals[1] = load_u32_at_src(block_byte_base + 78);

for (var ib: u32 = 0; ib < 8; ib ++) {
let s = get_byte(scale_vals[ib / 4], ib % 4);
Expand Down Expand Up @@ -450,16 +452,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ3_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 16; ib += 2) {
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
let sc_sign = load_u32_at(&src, sc_sign_offset);
let sc_sign = load_u32_at_src(sc_sign_offset);
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
for (var l: u32 = 0; l < 4; l++) {
let is = (sc_sign >> (7 * l)) & 127;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0);
let ig2 = get_byte(ig_val, 1);
for (var j: u32 = 0; j < 4; j++) {
Expand All @@ -480,20 +482,20 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ3_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;

var qh_vals = array<u32, 2>(
load_u32_at(&src, block_byte_base + 66),
load_u32_at(&src, block_byte_base + 70)
load_u32_at_src(block_byte_base + 66),
load_u32_at_src(block_byte_base + 70)
);

var sign_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4);
sign_vals[i] = load_u32_at_src(block_byte_base + 74 + i * 4);
}

var scale_vals = load_u32_at(&src, block_byte_base + 106);
var scale_vals = load_u32_at_src(block_byte_base + 106);

for (var ib: u32 = 0; ib < 4; ib++) {
let s = get_byte(scale_vals, ib);
Expand All @@ -507,7 +509,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let sign_w = sign_vals[ib * 2 + k];
for (var l: u32 = 0; l < 4; l++) {
let signs = get_byte(sign_w, l);
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
for (var j: u32 = 0; j < 4; j++) {
Expand All @@ -529,13 +531,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ1_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 8; ib++) {
let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF;
let qh = load_u32_at_src(block_byte_base + 34 + ib * 2) & 0xFFFF;
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4);
let qs_w = load_u32_at_src(block_byte_base + 2 + ib * 4);
for (var l: u32 = 0; l < 4; l++) {
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
for (var j: u32 = 0; j < 8; j++) {
Expand Down Expand Up @@ -596,11 +598,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ4_NL
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 32;
var qs: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
qs[i] = load_u32_at_src(block_byte_base + 2 + i * 4);
}
for (var j: u32 = 0; j < 16; j++) {
let qsb = get_byte(qs[j / 4], j % 4);
Expand Down
Loading