diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 0d3501c34a2..62fe72ee3b1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -9,42 +9,65 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { #endif #ifdef U32_DEQUANT_HELPERS -fn load_u16_at( - buf: ptr, read_write>, - byte_offset: u32) -> u32 { - let word = buf[byte_offset / 4]; - let shift = (byte_offset & 0x2) * 8; - return (word >> shift) & 0xFFFF; +#ifdef DECLARE_BYTE_LOADERS_SRC +fn load_u16_at_src(byte_offset: u32) -> u32 { + let word = src[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + return (word >> shift) & 0xFFFFu; } -fn load_u32_at( - buf: ptr, read_write>, - byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4; - let shift = (byte_offset & 0x3) * 8; - let lo = buf[word_idx]; - let hi = buf[word_idx + 1]; - let shifted = (lo >> shift) | (hi << (32 - shift)); - return select(shifted, lo, shift == 0); +fn load_u32_at_src(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 0x3u) * 8u; + let lo = src[word_idx]; + let hi = src[word_idx + 1u]; + let shifted = (lo >> shift) | (hi << (32u - shift)); + return select(shifted, lo, shift == 0u); } -fn load_f16_at( - buf: ptr, read_write>, - byte_offset: u32) -> f16 { - let packed = unpack2x16float(load_u16_at(buf, byte_offset)); +fn load_f16_at_src(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at_src(byte_offset)); return f16(packed[0]); } -fn load_f16_as_f32_at( - buf: ptr, read_write>, - byte_offset: u32) -> f32 { - let word = buf[byte_offset / 4]; - let shift = (byte_offset & 0x2) * 8; - let d_bits = (word >> shift) & 0xFFFF; +fn load_f16_as_f32_at_src(byte_offset: u32) -> f32 { + let word = src[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + let d_bits = (word >> shift) & 0xFFFFu; return unpack2x16float(d_bits)[0]; } #endif +#ifdef DECLARE_BYTE_LOADERS_SRC0 +fn load_u16_at_src0(byte_offset: u32) -> u32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_u32_at_src0(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 0x3u) * 8u; + let lo = src0[word_idx]; + let hi = src0[word_idx + 1u]; + let shifted = (lo >> shift) | (hi << (32u - shift)); + return select(shifted, lo, shift == 0u); +} + +fn load_f16_at_src0(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at_src0(byte_offset)); + return f16(packed[0]); +} + +fn load_f16_as_f32_at_src0(byte_offset: u32) -> f32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + let d_bits = (word >> shift) & 0xFFFFu; + return unpack2x16float(d_bits)[0]; +} +#endif +#endif + #ifdef Q4_1_T diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index 3c8b84c9ac3..1415798fa6b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC #include "common_decls.tmpl" + #ifdef F32_VEC fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; @@ -28,10 +30,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); for (var j: u32 = 0u; j < 4; j++) { let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at(&src, q_byte_offset); + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -66,11 +68,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q5_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); - let qh_packed = load_u32_at(&src, block_byte_base + 2); + let d = load_f16_as_f32_at_src(block_byte_base); + let qh_packed = load_u32_at_src(block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { let q_byte_offset = block_byte_base + 6 + j * 4; - let q_packed = load_u32_at(&src, q_byte_offset); + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); @@ -113,10 +115,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q8_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); for (var j: u32 = 0u; j < 8u; j++) { let q_byte_offset = block_byte_base + 2u + j * 4u; - let q_packed = load_u32_at(&src, q_byte_offset); + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -162,16 +164,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes // Bytes 108-109: f16 scale 'd' - let d = load_f16_as_f32_at(&src, block_byte_base + 108); + let d = load_f16_as_f32_at_src(block_byte_base + 108); // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; var scale_vals: array; - scale_vals[0] = load_u32_at(&src, block_byte_base + 96); - scale_vals[1] = load_u32_at(&src, block_byte_base + 100); - scale_vals[2] = load_u32_at(&src, block_byte_base + 104); + scale_vals[0] = load_u32_at_src(block_byte_base + 96); + scale_vals[1] = load_u32_at_src(block_byte_base + 100); + scale_vals[2] = load_u32_at_src(block_byte_base + 104); var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); @@ -182,13 +184,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4); + hmask_vals[i] = load_u32_at_src(block_byte_base + i * 4); } // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; for (var i: u32 = 0u; i < 16; i++) { - qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4); + qs_vals[i] = load_u32_at_src(block_byte_base + 32 + i * 4); } var dst_i = dst_base + offset * 256; @@ -286,24 +288,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes // Bytes 208-209: f16 scale 'd' - let d = load_f16_as_f32_at(&src, block_byte_base + 208); + let d = load_f16_as_f32_at_src(block_byte_base + 208); // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4); + ql_vals[i] = load_u32_at_src(block_byte_base + i * 4); } // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; for (var i: u32 = 0; i < 16u; i++) { - qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u); + qh_vals[i] = load_u32_at_src(block_byte_base + 128 + i * 4u); } // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4); + scale_vals[i] = load_u32_at_src(block_byte_base + 192 + i * 4); } var dst_i = dst_base + offset * 256; @@ -345,13 +347,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 32; ib += 4) { let aux0_offset = block_byte_base + 2 + ib * 2; let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; - let aux0 = load_u32_at(&src, aux0_offset); - let aux1 = load_u32_at(&src, aux1_offset); + let aux0 = load_u32_at_src(aux0_offset); + let aux1 = load_u32_at_src(aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -373,12 +375,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; var scale_vals = array( - load_u32_at(&src, block_byte_base + 66), - load_u32_at(&src, block_byte_base + 70) + load_u32_at_src(block_byte_base + 66), + load_u32_at_src(block_byte_base + 70) ); for (var ib: u32 = 0; ib < 32; ib += 4) { @@ -389,7 +391,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { ); for (var l: u32 = 0; l < 4; l++) { let qs_offset = block_byte_base + 2 + (ib + l) * 2; - let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF; + let qs_val = load_u32_at_src(qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -408,21 +410,21 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); + qs_vals[i] = load_u32_at_src(block_byte_base + 2 + i * 4); } var qh_vals: array; - qh_vals[0] = load_u32_at(&src, block_byte_base + 66); - qh_vals[1] = load_u32_at(&src, block_byte_base + 70); + qh_vals[0] = load_u32_at_src(block_byte_base + 66); + qh_vals[1] = load_u32_at_src(block_byte_base + 70); var scale_vals: array; - scale_vals[0] = load_u32_at(&src, block_byte_base + 74); - scale_vals[1] = load_u32_at(&src, block_byte_base + 78); + scale_vals[0] = load_u32_at_src(block_byte_base + 74); + scale_vals[1] = load_u32_at_src(block_byte_base + 78); for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); @@ -450,16 +452,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 16; ib += 2) { let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; - let sc_sign = load_u32_at(&src, sc_sign_offset); + let sc_sign = load_u32_at_src(sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -480,20 +482,20 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; var qh_vals = array( - load_u32_at(&src, block_byte_base + 66), - load_u32_at(&src, block_byte_base + 70) + load_u32_at_src(block_byte_base + 66), + load_u32_at_src(block_byte_base + 70) ); var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4); + sign_vals[i] = load_u32_at_src(block_byte_base + 74 + i * 4); } - var scale_vals = load_u32_at(&src, block_byte_base + 106); + var scale_vals = load_u32_at_src(block_byte_base + 106); for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); @@ -507,7 +509,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -529,13 +531,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ1_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF; + let qh = load_u32_at_src(block_byte_base + 34 + ib * 2) & 0xFFFF; let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4); + let qs_w = load_u32_at_src(block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -596,11 +598,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ4_NL fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 32; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); + qs[i] = load_u32_at_src(block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index fdabaf09b2e..fcbefdeb802 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -1,7 +1,9 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #ifdef FLOAT const BLOCK_SIZE = 1u; @@ -21,11 +23,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q4_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 4; j++) { let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; @@ -63,12 +65,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q5_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; - let qh_packed = load_u32_at(&src0, block_byte_base + 2); + let qh_packed = load_u32_at_src0(block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { let q_byte_offset = block_byte_base + 6 + j * 4; - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; @@ -110,11 +112,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q8_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 8; j++) { let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -184,7 +186,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes // Bytes 108-109: f16 scale 'd' - let d = load_f16_as_f32_at(&src0, block_byte_base + 108); + let d = load_f16_as_f32_at_src0(block_byte_base + 108); // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, // and 2-bits from the last 4 bytes @@ -192,9 +194,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; var scale_vals: array; - scale_vals[0] = load_u32_at(&src0, block_byte_base + 96); - scale_vals[1] = load_u32_at(&src0, block_byte_base + 100); - scale_vals[2] = load_u32_at(&src0, block_byte_base + 104); + scale_vals[0] = load_u32_at_src0(block_byte_base + 96); + scale_vals[1] = load_u32_at_src0(block_byte_base + 100); + scale_vals[2] = load_u32_at_src0(block_byte_base + 104); var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); @@ -205,13 +207,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); + hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4); } // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; for (var i: u32 = 0u; i < 16; i++) { - qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4); + qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4); } var sum = 0.0; @@ -313,24 +315,24 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes // Bytes 208-209: f16 scale 'd' - let d = load_f16_as_f32_at(&src0, block_byte_base + 208); + let d = load_f16_as_f32_at_src0(block_byte_base + 208); // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); + ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4); } // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4); + qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4); } // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4); + scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4); } var sum = 0.0; @@ -374,14 +376,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 32; ib += 4) { let aux0_offset = block_byte_base + 2 + ib * 2; let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; - let aux0 = load_u32_at(&src0, aux0_offset); - let aux1 = load_u32_at(&src0, aux1_offset); + let aux0 = load_u32_at_src0(aux0_offset); + let aux1 = load_u32_at_src0(aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -402,12 +404,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var scale_vals = array( - load_u32_at(&src0, block_byte_base + 66), - load_u32_at(&src0, block_byte_base + 70) + load_u32_at_src0(block_byte_base + 66), + load_u32_at_src0(block_byte_base + 70) ); var sum = 0.0; @@ -419,7 +421,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { ); for (var l: u32 = 0; l < 4; l++) { let qs_offset = block_byte_base + 2 + (ib + l) * 2; - let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF; + let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -439,21 +441,21 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); + qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); } var qh_vals: array; - qh_vals[0] = load_u32_at(&src0, block_byte_base + 66); - qh_vals[1] = load_u32_at(&src0, block_byte_base + 70); + qh_vals[0] = load_u32_at_src0(block_byte_base + 66); + qh_vals[1] = load_u32_at_src0(block_byte_base + 70); var scale_vals: array; - scale_vals[0] = load_u32_at(&src0, block_byte_base + 74); - scale_vals[1] = load_u32_at(&src0, block_byte_base + 78); + scale_vals[0] = load_u32_at_src0(block_byte_base + 74); + scale_vals[1] = load_u32_at_src0(block_byte_base + 78); var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib ++) { @@ -483,17 +485,17 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 16; ib += 2) { let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; - let sc_sign = load_u32_at(&src0, sc_sign_offset); + let sc_sign = load_u32_at_src0(sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -515,20 +517,20 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var qh_vals = array( - load_u32_at(&src0, block_byte_base + 66), - load_u32_at(&src0, block_byte_base + 70) + load_u32_at_src0(block_byte_base + 66), + load_u32_at_src0(block_byte_base + 70) ); var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4); + sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4); } - var scale_vals = load_u32_at(&src0, block_byte_base + 106); + var scale_vals = load_u32_at_src0(block_byte_base + 106); var sum = 0.0; for (var ib: u32 = 0; ib < 4; ib++) { @@ -543,7 +545,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -566,14 +568,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ1_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF; + let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF; let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4); + let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -638,12 +640,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ4_NL fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 32; var sum = 0.0; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); + qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 56a76a6e6c4..5a323818260 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); + let d = load_f16_at_src0(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let m = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; @@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let qh_packed = load_u32_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let m = load_f16_at(&src0, block_byte_base + 2u); - let qh_packed = load_u32_at(&src0, block_byte_base + 4u); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); + let d = load_f16_at_src0(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let m = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base + 80u); - let dmin = load_f16_at(&src0, block_byte_base + 82u); + let d = load_f16_at_src0(block_byte_base + 80u); + let dmin = load_f16_at_src0(block_byte_base + 82u); // Decode the element at position k_in_block let block_of_32 = k_in_block / 32u; @@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let is = k_in_block / 16u; - let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u)); + let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u)); let sc = get_byte(sc_packed, is % 4u); let dl = d * f16(sc & 0xFu); let ml = dmin * f16(sc >> 4u); let q_idx = q_b_idx + k + l; - let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 3u; @@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base + 108u); + let d = load_f16_at_src0(block_byte_base + 108u); // Load and unpack scales let kmask1: u32 = 0x03030303u; @@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var scale_vals: array; for (var i: u32 = 0u; i < 4u; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i); + scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i); } var tmp: u32 = scale_vals[2]; @@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load hmask and qs arrays var hmask_vals: array; for (var i: u32 = 0u; i < 8u; i++) { - hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i); + hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i); } var qs_vals: array; for (var i: u32 = 0u; i < 16u; i++) { - qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i); + qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i); } let half = k_in_block / 128u; // 0 or 1 @@ -499,8 +499,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let dmin = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let dmin = load_f16_at_src0(block_byte_base + 2u); // Map k_in_block to loop structure: // Outer loop over 64-element groups (alternating q_b_idx) @@ -520,14 +520,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let scale_base = block_byte_base + 4u; if (is < 4u) { - let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); - let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); + let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -537,7 +537,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 0xFu; @@ -571,8 +571,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let dmin = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let dmin = load_f16_at_src0(block_byte_base + 2u); // The original loop processes elements in groups of 64 @@ -597,14 +597,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let scale_base = block_byte_base + 4u; if (is < 4u) { - let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); - let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); + let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -614,11 +614,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); - let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u)); + let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u)); let qh_byte = get_byte(qh_packed, l % 4u); @@ -666,17 +666,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only ql13 word needed let ql13_flat = ql_b_idx + l; - let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat); + let ql13 = load_u32_at_src0(block_byte_base + ql13_flat); let ql13_b = get_byte(ql13, 0u); // Load only ql24 word needed let ql24_flat = ql_b_idx + l + 32u; - let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat); + let ql24 = load_u32_at_src0(block_byte_base + ql24_flat); let ql24_b = get_byte(ql24, 0u); // Load only qh word needed let qh_flat = qh_b_idx + l; - let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat); + let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat); let qh_b = get_byte(qh, 0u); let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); @@ -687,10 +687,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only the scale word needed let is = l / 16u; let sc_idx = sc_b_idx + is + quarter * 2u; - let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx); + let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx); let sc_val = get_byte_i32(sc, 0u); - let d = load_f16_at(&src0, block_byte_base + 208u); + let d = load_f16_at_src0(block_byte_base + 208u); var q_val: f16; if (quarter == 0u) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl index 5f763a6400a..91039ff2546 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" #ifdef VEC diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index ee37e6d249c..98bbdeb83ba 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" #ifdef VEC diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 4151ce430b0..d86a72ce6e0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -3,7 +3,9 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" // TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs. @@ -196,4 +198,3 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } } - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 6f6bcaf7940..9f7b3e32eca 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,7 +1,9 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #ifdef VEC #define VEC_SIZE 4 @@ -65,10 +67,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); + let d = f32(load_f16_at_src0(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -98,11 +100,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let m = f32(load_f16_at(&src0, block_byte_base + 2u)); + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = f32((q_byte >> 4) & 0xF) * d + m; @@ -132,12 +134,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let qh_packed = load_u32_at(&src0, block_byte_base + 2u); + let d = f32(load_f16_at_src0(block_byte_base)); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -176,13 +178,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let m = load_f16_at(&src0, block_byte_base + 2u); - let qh_packed = load_u32_at(&src0, block_byte_base + 4u); + let d = f32(load_f16_at_src0(block_byte_base)); + let m = load_f16_at_src0(block_byte_base + 2u); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -221,11 +223,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); + let d = f32(load_f16_at_src0(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -254,12 +256,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let m = load_f16_at(&src0, block_byte_base + 2u); + let d = f32(load_f16_at_src0(block_byte_base)); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d + f32(m); @@ -309,13 +311,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = ix; i < nb; i += 2u) { let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at(&src0, bbase + 208u)); + let d = f32(load_f16_at_src0(bbase + 208u)); - let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l); - let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u); - let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h); - let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte); - let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u); + let ql1_u32 = load_u32_at_src0(bbase + q_offset_l); + let ql2_u32 = load_u32_at_src0(bbase + q_offset_l + 32u); + let qh_u32 = load_u32_at_src0(bbase + 128u + q_offset_h); + let sc_u32_0 = load_u32_at_src0(bbase + sc_base_byte); + let sc_u32_1 = load_u32_at_src0(bbase + sc_base_byte + 4u); let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 8c334817ccd..b8f1bca1284 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -147,15 +147,12 @@ fn main(@builtin(global_invocation_id) gid: vec3) { -9.010913, 9.010913))); #endif #ifdef XIELU + let val = f32(src[params.offset_src + src_idx]); let res = - select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) - - src[params.offset_src + src_idx]) * - TYPE(params.alpha_n) + - TYPE(params.beta) * src[params.offset_src + src_idx], - TYPE(params.alpha_p) * src[params.offset_src + src_idx] * - src[params.offset_src + src_idx] + - TYPE(params.beta) * src[params.offset_src + src_idx], - src[params.offset_src + src_idx] > 0.0); + TYPE(select( + ((exp(min(val, params.eps)) - 1.0) - val) * params.alpha_n + params.beta * val, + params.alpha_p * val * val + params.beta * val, + val > 0.0)); #endif #ifdef SOFTPLUS let src_f32 = f32(src[params.offset_src + src_idx]);