From 3c36b5562e3b875762644758f597538df1f3e7a6 Mon Sep 17 00:00:00 2001 From: Neha Abbas Date: Wed, 8 Apr 2026 17:27:23 -0700 Subject: [PATCH 1/9] merged properly, but slow q3_k and q5_k with u32 indexing --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 20 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 27 +- .../wgsl-shaders/common_decls.tmpl | 7 + .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 793 ++++++++++++++++-- src/ggml-webgpu.cpp | 0 5 files changed, 758 insertions(+), 89 deletions(-) create mode 100644 src/ggml-webgpu.cpp diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c10157766d9..6b894374bc5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -46,16 +46,16 @@ // Must be multiple of 4 to work with vectorized paths, and must divide // mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256 +#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 8 +#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 1024 -#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256 +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 8 +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 1024 // Requires 32 threads per output (wg_size/outputs_per_wg == 32) #define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8 // Requires at least two (and multiple of 2) k-quant blocks per tile -#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512 +#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 2048 // default size for legacy matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 @@ -1734,11 +1734,11 @@ class ggml_webgpu_shader_lib { const bool is_unary = context.dst->op == GGML_OP_UNARY; const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; ggml_webgpu_unary_pipeline_key key = { - .type = context.dst->type, - .op = op, - .is_unary = is_unary, - .inplace = context.inplace, - .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0), + .type = context.dst->type, + .op = op, + .is_unary = is_unary, + .inplace = context.inplace, + .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0), }; auto it = unary_pipelines.find(key); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b8df0f4dd05..beab945e804 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1701,8 +1701,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); + const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); const bool use_blk = use_vec && has_mask; @@ -3262,7 +3262,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const size_t q_tile = sg_mat_m; const size_t base_q_bytes = (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; + size_t bytes_per_kv = 0; if (!kv_direct) { bytes_per_kv += std::max(Q->ne[0], V->ne[0]); } @@ -3471,9 +3471,16 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { std::vector required_features = { wgpu::FeatureName::ShaderF16 }; #ifndef __EMSCRIPTEN__ + // required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); + // if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + // required_features.push_back(wgpu::FeatureName::Subgroups); + // required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + // } + required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); + required_features.push_back(wgpu::FeatureName::Subgroups); + if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { - required_features.push_back(wgpu::FeatureName::Subgroups); required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); } #endif @@ -3790,12 +3797,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + const bool kv_direct = src1->type == GGML_TYPE_F16 && + (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && + (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); if (min_bytes > limit_bytes) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index feb0bca3f84..aa9b7e110d4 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -26,6 +26,13 @@ fn load_src0_u32_at(byte_offset: u32) -> u32 { return (lo >> shift) | (hi << (32u - shift)); } +// Always reads the 4-byte-aligned word containing byte_offset. +// Caller extracts the 16-bit half it needs via & 0xFFFFu or >> 16u. +// this is used in k-quants for better performance +fn load_src0_u32_at_aligned(byte_offset: u32) -> u32 { + return src0[(byte_offset & ~3u) / 4u]; +} + fn load_src0_f16_at(byte_offset: u32) -> f16 { let packed = unpack2x16float(load_src0_u16_at(byte_offset)); return f16(packed[0]); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 6525f23bdfc..e87221146c7 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,7 +1,20 @@ +enable subgroups; enable f16; #include "common_decls.tmpl" +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif + #ifdef VEC #define VEC_SIZE 4 @@ -53,8 +66,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { const BLOCK_SIZE = 32; const BLOCK_SIZE_BYTES = 18u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const NQ = 16u; +const WEIGHTS_PER_F16 = 4u; const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { @@ -63,12 +76,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; let d = f32(load_src0_f16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_src0_u32_at(block_byte_base + 2u + 2u * (block_offset + j)); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -86,8 +97,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { const BLOCK_SIZE = 32; const BLOCK_SIZE_BYTES = 20u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const NQ = 16u; +const WEIGHTS_PER_F16 = 4u; const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { @@ -96,13 +107,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; let d = f32(load_src0_f16_at(block_byte_base)); let m = f32(load_src0_f16_at(block_byte_base + 2u)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_src0_u32_at(block_byte_base + 4u + 2u * (block_offset + j)); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = f32((q_byte >> 4) & 0xF) * d + m; @@ -120,8 +129,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { const BLOCK_SIZE = 32; const BLOCK_SIZE_BYTES = 22u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const NQ = 16u; +const WEIGHTS_PER_F16 = 4u; const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { @@ -130,42 +139,34 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; let d = f32(load_src0_f16_at(block_byte_base)); let qh_packed = load_src0_u32_at(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { - let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); - + let q_packed = load_src0_u32_at(block_byte_base + 6u + 2u * (block_offset + j * 2u)); let j_adjusted = j + (block_offset / 2u); - for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; } - } } return local_sum; } #endif - #ifdef MUL_ACC_Q5_1 const BLOCK_SIZE = 32; const BLOCK_SIZE_BYTES = 24u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const NQ = 16u; +const WEIGHTS_PER_F16 = 4u; const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { @@ -174,42 +175,34 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; let d = f32(load_src0_f16_at(block_byte_base)); let m = load_src0_f16_at(block_byte_base + 2u); let qh_packed = load_src0_u32_at(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { - let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); - + let q_packed = load_src0_u32_at(block_byte_base + 8u + 2u * (block_offset + j * 2u)); let j_adjusted = j + (block_offset / 2u); - for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m); let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m); - local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; } - } } return local_sum; } #endif - #ifdef MUL_ACC_Q8_0 const BLOCK_SIZE = 32; const BLOCK_SIZE_BYTES = 34u; -const NQ = 16u; // number of weights per thread +const NQ = 16u; const WEIGHTS_PER_F16 = 2u; const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -219,13 +212,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; let d = f32(load_src0_f16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_src0_u32_at(block_byte_base + 2u + 2u * (block_offset + j)); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -237,12 +227,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { } #endif - #ifdef MUL_ACC_Q8_1 const BLOCK_SIZE = 32; const BLOCK_SIZE_BYTES = 36u; -const NQ = 16u; // number of weights per thread +const NQ = 16u; const WEIGHTS_PER_F16 = 2u; const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -252,14 +241,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; let d = f32(load_src0_f16_at(block_byte_base)); let m = load_src0_f16_at(block_byte_base + 2u); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_src0_u32_at(block_byte_base + 4u + 2u * (block_offset + j)); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d + f32(m); @@ -271,19 +257,674 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { } #endif -#ifdef MUL_ACC_Q6_K +#ifdef MUL_ACC_Q2_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 210u; +const Q2K_BLOCK_SIZE = 256u; +const Q2K_BLOCK_SIZE_BYTES = 84u; // 42 f16s * 2 -fn byte_of(v: u32, b: u32) -> u32 { - return (v >> (b * 8u)) & 0xFFu; +fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + let ix = tig / 8u; + let it = tig % 8u; + let iq = it / 4u; + let ir = it % 4u; + let is = (8u * ir) / 16u; + + let nb = tile_size / Q2K_BLOCK_SIZE; + let k_block_start = k_outer / Q2K_BLOCK_SIZE; + let y4_offset = 128u * iq + 8u * ir; + + let sc0_byte = 8u * iq + is; + let sc2_byte = 8u * iq + is + 2u; + let sc4_byte = 8u * iq + is + 4u; + let sc6_byte = 8u * iq + is + 6u; + let qs_byte = 16u + (16u * iq + 4u * ir) * 2u; + + var sumf = 0.0; + + for (var ib = ix; ib < nb; ib += 4u) { + let bbase = (idx_base + k_block_start + ib) * Q2K_BLOCK_SIZE_BYTES; + + let dall = f32(load_src0_f16_at(bbase + 80u)); + let dmin = f32(load_src0_f16_at(bbase + 82u)) * (1.0 / 16.0); + + let sc0 = byte_of(load_src0_u32_at_aligned(bbase + sc0_byte), sc0_byte & 3u); + let sc2 = byte_of(load_src0_u32_at_aligned(bbase + sc2_byte), sc2_byte & 3u); + let sc4 = byte_of(load_src0_u32_at_aligned(bbase + sc4_byte), sc4_byte & 3u); + let sc6 = byte_of(load_src0_u32_at_aligned(bbase + sc6_byte), sc6_byte & 3u); + + let qs_u32_0 = load_src0_u32_at_aligned(bbase + qs_byte); + let qs_u32_1 = load_src0_u32_at_aligned(bbase + qs_byte + 4u); + let qs0 = qs_u32_0 & 0xFFFFu; + let qs1 = qs_u32_0 >> 16u; + let qs2 = qs_u32_1 & 0xFFFFu; + let qs3 = qs_u32_1 >> 16u; + + let y_base = ib * Q2K_BLOCK_SIZE + y4_offset; + + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + + // i=0: j=0,1 + { + let y00 = f32(shared_vector[y_base ]); sumy[0] += y00; + let y01 = f32(shared_vector[y_base + 1u]); sumy[0] += y01; + let y10 = f32(shared_vector[y_base + 32u]); sumy[1] += y10; + let y11 = f32(shared_vector[y_base + 33u]); sumy[1] += y11; + let y20 = f32(shared_vector[y_base + 64u]); sumy[2] += y20; + let y21 = f32(shared_vector[y_base + 65u]); sumy[2] += y21; + let y30 = f32(shared_vector[y_base + 96u]); sumy[3] += y30; + let y31 = f32(shared_vector[y_base + 97u]); sumy[3] += y31; + acc1[0] += y00 * f32(qs0 & 0x0003u); + acc2[0] += y01 * f32(qs0 & 0x0300u); + acc1[1] += y10 * f32(qs0 & 0x000Cu); + acc2[1] += y11 * f32(qs0 & 0x0C00u); + acc1[2] += y20 * f32(qs0 & 0x0030u); + acc2[2] += y21 * f32(qs0 & 0x3000u); + acc1[3] += y30 * f32(qs0 & 0x00C0u); + acc2[3] += y31 * f32(qs0 & 0xC000u); + } + // i=2: j=2,3 + { + let y00 = f32(shared_vector[y_base + 2u]); sumy[0] += y00; + let y01 = f32(shared_vector[y_base + 3u]); sumy[0] += y01; + let y10 = f32(shared_vector[y_base + 34u]); sumy[1] += y10; + let y11 = f32(shared_vector[y_base + 35u]); sumy[1] += y11; + let y20 = f32(shared_vector[y_base + 66u]); sumy[2] += y20; + let y21 = f32(shared_vector[y_base + 67u]); sumy[2] += y21; + let y30 = f32(shared_vector[y_base + 98u]); sumy[3] += y30; + let y31 = f32(shared_vector[y_base + 99u]); sumy[3] += y31; + acc1[0] += y00 * f32(qs1 & 0x0003u); + acc2[0] += y01 * f32(qs1 & 0x0300u); + acc1[1] += y10 * f32(qs1 & 0x000Cu); + acc2[1] += y11 * f32(qs1 & 0x0C00u); + acc1[2] += y20 * f32(qs1 & 0x0030u); + acc2[2] += y21 * f32(qs1 & 0x3000u); + acc1[3] += y30 * f32(qs1 & 0x00C0u); + acc2[3] += y31 * f32(qs1 & 0xC000u); + } + // i=4: j=4,5 + { + let y00 = f32(shared_vector[y_base + 4u]); sumy[0] += y00; + let y01 = f32(shared_vector[y_base + 5u]); sumy[0] += y01; + let y10 = f32(shared_vector[y_base + 36u]); sumy[1] += y10; + let y11 = f32(shared_vector[y_base + 37u]); sumy[1] += y11; + let y20 = f32(shared_vector[y_base + 68u]); sumy[2] += y20; + let y21 = f32(shared_vector[y_base + 69u]); sumy[2] += y21; + let y30 = f32(shared_vector[y_base + 100u]); sumy[3] += y30; + let y31 = f32(shared_vector[y_base + 101u]); sumy[3] += y31; + acc1[0] += y00 * f32(qs2 & 0x0003u); + acc2[0] += y01 * f32(qs2 & 0x0300u); + acc1[1] += y10 * f32(qs2 & 0x000Cu); + acc2[1] += y11 * f32(qs2 & 0x0C00u); + acc1[2] += y20 * f32(qs2 & 0x0030u); + acc2[2] += y21 * f32(qs2 & 0x3000u); + acc1[3] += y30 * f32(qs2 & 0x00C0u); + acc2[3] += y31 * f32(qs2 & 0xC000u); + } + // i=6: j=6,7 + { + let y00 = f32(shared_vector[y_base + 6u]); sumy[0] += y00; + let y01 = f32(shared_vector[y_base + 7u]); sumy[0] += y01; + let y10 = f32(shared_vector[y_base + 38u]); sumy[1] += y10; + let y11 = f32(shared_vector[y_base + 39u]); sumy[1] += y11; + let y20 = f32(shared_vector[y_base + 70u]); sumy[2] += y20; + let y21 = f32(shared_vector[y_base + 71u]); sumy[2] += y21; + let y30 = f32(shared_vector[y_base + 102u]); sumy[3] += y30; + let y31 = f32(shared_vector[y_base + 103u]); sumy[3] += y31; + acc1[0] += y00 * f32(qs3 & 0x0003u); + acc2[0] += y01 * f32(qs3 & 0x0300u); + acc1[1] += y10 * f32(qs3 & 0x000Cu); + acc2[1] += y11 * f32(qs3 & 0x0C00u); + acc1[2] += y20 * f32(qs3 & 0x0030u); + acc2[2] += y21 * f32(qs3 & 0x3000u); + acc1[3] += y30 * f32(qs3 & 0x00C0u); + acc2[3] += y31 * f32(qs3 & 0xC000u); + } + + sumf += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } + + return sumf; } +#endif -fn sbyte_of(v: u32, b: u32) -> i32 { - let raw = i32((v >> (b * 8u)) & 0xFFu); - return select(raw, raw - 256, raw >= 128); +#ifdef MUL_ACC_Q3_K + +const Q3K_BLOCK_SIZE = 256u; +const Q3K_BLOCK_SIZE_BYTES = 110u; // 55 f16s * 2 + +fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + let tid = tig / 4u; + let ix = tig % 4u; + let ip = tid / 4u; + let il = 2u * ((tid % 4u) / 2u); + let ir = tid % 2u; + let l0 = 8u * ir; + + let nb = tile_size / Q3K_BLOCK_SIZE; + let k_block_start = k_outer / Q3K_BLOCK_SIZE; + + let q_byte = 32u + 32u * ip + l0; + let h_byte = l0; + let y_offset = 128u * ip + 32u * il + l0; + + let s_shift1 = 4u * ip; + let s_shift2 = s_shift1 + il; + + let v1 = select(64.0, 4.0, il == 0u); + let v2 = 4.0 * v1; + let shift = 2u * il; + + var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; + if (il == 0u) { + qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; + } else { + qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; + } + + let mm_idx = 2u * ip + il / 2u; + var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; + switch (mm_idx) { + case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } + case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } + case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } + default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } + } + + var sumf1 = 0.0; + var sumf2 = 0.0; + + for (var i = ix; i < nb; i += 4u) { + let bbase = (idx_base + k_block_start + i) * Q3K_BLOCK_SIZE_BYTES; + + let d_all = f32(load_src0_f16_at(bbase + 108u)); + + // Scale unpacking + let a_base = 96u; + let a_il0_u32 = load_src0_u32_at_aligned(bbase + a_base + il * 2u); + let a_il0 = select(a_il0_u32 & 0xFFFFu, a_il0_u32 >> 16u, (il & 1u) != 0u); + let a_il1_u32 = load_src0_u32_at_aligned(bbase + a_base + (il + 1u) * 2u); + let a_il1 = select(a_il1_u32 & 0xFFFFu, a_il1_u32 >> 16u, ((il + 1u) & 1u) != 0u); + let a_45_u32 = load_src0_u32_at_aligned(bbase + a_base + 8u); + let a_4 = a_45_u32 & 0xFFFFu; + let a_5 = a_45_u32 >> 16u; + + var scales32 = a_4 | (a_5 << 16u); + let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; + scales32 = a_il0 | (a_il1 << 16u); + scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; + + let sc0 = i32(byte_of(scales32, 0u)) - 32; + let sc1 = i32(byte_of(scales32, 1u)) - 32; + let sc2 = i32(byte_of(scales32, 2u)) - 32; + let sc3 = i32(byte_of(scales32, 3u)) - 32; + + let y_base = i * Q3K_BLOCK_SIZE + y_offset; + var yl: array; + for (var l = 0u; l < 8u; l++) { + yl[l + 0] = f32(shared_vector[y_base + l ]); + yl[l + 8] = f32(shared_vector[y_base + l + 16u]); + yl[l + 16] = f32(shared_vector[y_base + l + 32u]); + yl[l + 24] = f32(shared_vector[y_base + l + 48u]); + } + + // First qs/h loop: q[0..3], h[0..3] + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = load_src0_u32_at_aligned(bbase + q_byte + (l & ~2u)); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = load_src0_u32_at_aligned(bbase + h_byte + (l & ~2u)); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + + s1 += yl[l + 0u] * f32(qs & qm0); + s2 += yl[l + 1u] * f32(qs & qm1); + s3 += select(0.0, yl[l + 0u], (hv & hm0) == 0u) + + select(0.0, yl[l + 1u], (hv & hm1) == 0u); + s4 += yl[l + 16u] * f32(qs & qm2); + s5 += yl[l + 17u] * f32(qs & qm3); + s6 += select(0.0, yl[l + 16u], (hv & hm2) == 0u) + + select(0.0, yl[l + 17u], (hv & hm3) == 0u); + } + let d1 = d_all * (s1 + (1.0/256.0)*s2 - s3*v1); + let d2 = d_all * (s4 + (1.0/256.0)*s5 - s6*v2); + sumf1 += d1 * f32(sc0); + sumf2 += d2 * f32(sc2); + + // Second qs/h loop: q[8..11], h[8..11] (16 bytes further) + s1 = 0.0; s2 = 0.0; s3 = 0.0; + s4 = 0.0; s5 = 0.0; s6 = 0.0; + + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = load_src0_u32_at_aligned(bbase + q_byte + 16u + (l & ~2u)); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = load_src0_u32_at_aligned(bbase + h_byte + 16u + (l & ~2u)); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + + s1 += yl[l + 8u] * f32(qs & qm0); + s2 += yl[l + 9u] * f32(qs & qm1); + s3 += select(0.0, yl[l + 8u], (hv & hm0) == 0u) + + select(0.0, yl[l + 9u], (hv & hm1) == 0u); + s4 += yl[l + 24u] * f32(qs & qm2); + s5 += yl[l + 25u] * f32(qs & qm3); + s6 += select(0.0, yl[l + 24u], (hv & hm2) == 0u) + + select(0.0, yl[l + 25u], (hv & hm3) == 0u); + } + let d3 = d_all * (s1 + (1.0/256.0)*s2 - s3*v1); + let d4 = d_all * (s4 + (1.0/256.0)*s5 - s6*v2); + sumf1 += d3 * f32(sc1); + sumf2 += d4 * f32(sc3); + } + + return (sumf1 + 0.25 * sumf2) / f32(1u << shift); +} +#endif + +#ifdef MUL_ACC_Q4_K + +const Q4K_BLOCK_SIZE = 256u; +const Q4K_BLOCK_SIZE_BYTES = 144u; // 72 f16s * 2 + +fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + let ix = tig / 8u; + let it = tig % 8u; + let iq = it / 4u; + let ir = it % 4u; + + let nb = tile_size / Q4K_BLOCK_SIZE; + let k_block_start = k_outer / Q4K_BLOCK_SIZE; + + let y_offset = 64u * iq + 8u * ir; + + let sc0_byte = 4u + iq * 2u; + let sc2_byte = 4u + (iq + 2u) * 2u; + let sc4_byte = 4u + (iq + 4u) * 2u; + let q1_byte = 16u + (16u * iq + 4u * ir) * 2u; + let q2_byte = q1_byte + 64u; + + var sumf = 0.0; + + for (var ib = ix; ib < nb; ib += 4u) { + let bbase = (idx_base + k_block_start + ib) * Q4K_BLOCK_SIZE_BYTES; + + let d = f32(load_src0_f16_at(bbase + 0u)); + let dmin = f32(load_src0_f16_at(bbase + 2u)); + + let sc0_u32 = load_src0_u32_at_aligned(bbase + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_src0_u32_at_aligned(bbase + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_src0_u32_at_aligned(bbase + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = ((sc4 ) & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u ) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let sc8_0 = sc16_0 & 0xFFu; + let sc8_1 = (sc16_0 >> 8u) & 0xFFu; + let sc8_2 = sc16_1 & 0xFFu; + let sc8_3 = (sc16_1 >> 8u) & 0xFFu; + let sc8_4 = sc16_2 & 0xFFu; + let sc8_5 = (sc16_2 >> 8u) & 0xFFu; + let sc8_6 = sc16_3 & 0xFFu; + let sc8_7 = (sc16_3 >> 8u) & 0xFFu; + + let q1_u32_0 = load_src0_u32_at_aligned(bbase + q1_byte); + let q1_u32_1 = load_src0_u32_at_aligned(bbase + q1_byte + 4u); + let q2_u32_0 = load_src0_u32_at_aligned(bbase + q2_byte); + let q2_u32_1 = load_src0_u32_at_aligned(bbase + q2_byte + 4u); + + let q1_0 = q1_u32_0 & 0xFFFFu; + let q1_1 = q1_u32_0 >> 16u; + let q1_2 = q1_u32_1 & 0xFFFFu; + let q1_3 = q1_u32_1 >> 16u; + let q2_0 = q2_u32_0 & 0xFFFFu; + let q2_1 = q2_u32_0 >> 16u; + let q2_2 = q2_u32_1 & 0xFFFFu; + let q2_3 = q2_u32_1 >> 16u; + + let y_base = ib * Q4K_BLOCK_SIZE + y_offset; + + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + + // i=0: yl[0,1,8,9], yh[0,1,8,9] + { + let yl0 = f32(shared_vector[y_base + 0u]); sumy[0] += yl0; + let yl1 = f32(shared_vector[y_base + 1u]); sumy[0] += yl1; + let yl8 = f32(shared_vector[y_base + 32u]); sumy[1] += yl8; + let yl9 = f32(shared_vector[y_base + 33u]); sumy[1] += yl9; + let yh0 = f32(shared_vector[y_base + 128u]); sumy[2] += yh0; + let yh1 = f32(shared_vector[y_base + 129u]); sumy[2] += yh1; + let yh8 = f32(shared_vector[y_base + 160u]); sumy[3] += yh8; + let yh9 = f32(shared_vector[y_base + 161u]); sumy[3] += yh9; + acc1[0] += yl0 * f32(q1_0 & 0x000Fu); + acc1[1] += yl1 * f32(q1_0 & 0x0F00u); + acc1[2] += yl8 * f32(q1_0 & 0x00F0u); + acc1[3] += yl9 * f32(q1_0 & 0xF000u); + acc2[0] += yh0 * f32(q2_0 & 0x000Fu); + acc2[1] += yh1 * f32(q2_0 & 0x0F00u); + acc2[2] += yh8 * f32(q2_0 & 0x00F0u); + acc2[3] += yh9 * f32(q2_0 & 0xF000u); + } + // i=1: yl[2,3,10,11], yh[2,3,10,11] + { + let yl0 = f32(shared_vector[y_base + 2u]); sumy[0] += yl0; + let yl1 = f32(shared_vector[y_base + 3u]); sumy[0] += yl1; + let yl8 = f32(shared_vector[y_base + 34u]); sumy[1] += yl8; + let yl9 = f32(shared_vector[y_base + 35u]); sumy[1] += yl9; + let yh0 = f32(shared_vector[y_base + 130u]); sumy[2] += yh0; + let yh1 = f32(shared_vector[y_base + 131u]); sumy[2] += yh1; + let yh8 = f32(shared_vector[y_base + 162u]); sumy[3] += yh8; + let yh9 = f32(shared_vector[y_base + 163u]); sumy[3] += yh9; + acc1[0] += yl0 * f32(q1_1 & 0x000Fu); + acc1[1] += yl1 * f32(q1_1 & 0x0F00u); + acc1[2] += yl8 * f32(q1_1 & 0x00F0u); + acc1[3] += yl9 * f32(q1_1 & 0xF000u); + acc2[0] += yh0 * f32(q2_1 & 0x000Fu); + acc2[1] += yh1 * f32(q2_1 & 0x0F00u); + acc2[2] += yh8 * f32(q2_1 & 0x00F0u); + acc2[3] += yh9 * f32(q2_1 & 0xF000u); + } + // i=2: yl[4,5,12,13], yh[4,5,12,13] + { + let yl0 = f32(shared_vector[y_base + 4u]); sumy[0] += yl0; + let yl1 = f32(shared_vector[y_base + 5u]); sumy[0] += yl1; + let yl8 = f32(shared_vector[y_base + 36u]); sumy[1] += yl8; + let yl9 = f32(shared_vector[y_base + 37u]); sumy[1] += yl9; + let yh0 = f32(shared_vector[y_base + 132u]); sumy[2] += yh0; + let yh1 = f32(shared_vector[y_base + 133u]); sumy[2] += yh1; + let yh8 = f32(shared_vector[y_base + 164u]); sumy[3] += yh8; + let yh9 = f32(shared_vector[y_base + 165u]); sumy[3] += yh9; + acc1[0] += yl0 * f32(q1_2 & 0x000Fu); + acc1[1] += yl1 * f32(q1_2 & 0x0F00u); + acc1[2] += yl8 * f32(q1_2 & 0x00F0u); + acc1[3] += yl9 * f32(q1_2 & 0xF000u); + acc2[0] += yh0 * f32(q2_2 & 0x000Fu); + acc2[1] += yh1 * f32(q2_2 & 0x0F00u); + acc2[2] += yh8 * f32(q2_2 & 0x00F0u); + acc2[3] += yh9 * f32(q2_2 & 0xF000u); + } + // i=3: yl[6,7,14,15], yh[6,7,14,15] + { + let yl0 = f32(shared_vector[y_base + 6u]); sumy[0] += yl0; + let yl1 = f32(shared_vector[y_base + 7u]); sumy[0] += yl1; + let yl8 = f32(shared_vector[y_base + 38u]); sumy[1] += yl8; + let yl9 = f32(shared_vector[y_base + 39u]); sumy[1] += yl9; + let yh0 = f32(shared_vector[y_base + 134u]); sumy[2] += yh0; + let yh1 = f32(shared_vector[y_base + 135u]); sumy[2] += yh1; + let yh8 = f32(shared_vector[y_base + 166u]); sumy[3] += yh8; + let yh9 = f32(shared_vector[y_base + 167u]); sumy[3] += yh9; + acc1[0] += yl0 * f32(q1_3 & 0x000Fu); + acc1[1] += yl1 * f32(q1_3 & 0x0F00u); + acc1[2] += yl8 * f32(q1_3 & 0x00F0u); + acc1[3] += yl9 * f32(q1_3 & 0xF000u); + acc2[0] += yh0 * f32(q2_3 & 0x000Fu); + acc2[1] += yh1 * f32(q2_3 & 0x0F00u); + acc2[2] += yh8 * f32(q2_3 & 0x00F0u); + acc2[3] += yh9 * f32(q2_3 & 0xF000u); + } + + sumf += d * ((acc1[0] + (1.0/256.0)*acc1[1]) * f32(sc8_0) + + (acc1[2] + (1.0/256.0)*acc1[3]) * f32(sc8_1) * (1.0/16.0) + + (acc2[0] + (1.0/256.0)*acc2[1]) * f32(sc8_4) + + (acc2[2] + (1.0/256.0)*acc2[3]) * f32(sc8_5) * (1.0/16.0)) + - dmin * (sumy[0] * f32(sc8_2) + sumy[1] * f32(sc8_3) + + sumy[2] * f32(sc8_6) + sumy[3] * f32(sc8_7)); + } + + return sumf; +} +#endif + +#ifdef MUL_ACC_Q5_K + +const Q5K_BLOCK_SIZE = 256u; +const Q5K_BLOCK_SIZE_BYTES = 176u; // 88 f16s * 2 + +fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + let tid = tig / 4u; + let ix = tig % 4u; + let iq = tid / 4u; + let ir = tid % 4u; + let l0 = 8u * ir; + + let nb = tile_size / Q5K_BLOCK_SIZE; + let k_block_start = k_outer / Q5K_BLOCK_SIZE; + + let y_offset = 64u * iq + l0; + let q1_byte = 48u + 32u * iq + l0; + let q2_byte = q1_byte + 64u; + let qh_byte = 16u + l0; + let sc0_byte = 4u + iq * 2u; + let sc2_byte = 4u + (iq + 2u) * 2u; + let sc4_byte = 4u + (iq + 4u) * 2u; + + let hm1 = 1u << (2u * iq); + let hm2 = hm1 << 1u; + let hm3 = hm1 << 4u; + let hm4 = hm2 << 4u; + + var sumf = 0.0; + + for (var ib = ix; ib < nb; ib += 4u) { + let bbase = (idx_base + k_block_start + ib) * Q5K_BLOCK_SIZE_BYTES; + + let d = f32(load_src0_f16_at(bbase + 0u)); + let dmin = f32(load_src0_f16_at(bbase + 2u)); + + let sc0_u32 = load_src0_u32_at_aligned(bbase + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_src0_u32_at_aligned(bbase + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_src0_u32_at_aligned(bbase + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = ((sc4 ) & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u ) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let sc8_0 = sc16_0 & 0xFFu; + let sc8_1 = (sc16_0 >> 8u) & 0xFFu; + let sc8_2 = sc16_1 & 0xFFu; + let sc8_3 = (sc16_1 >> 8u) & 0xFFu; + let sc8_4 = sc16_2 & 0xFFu; + let sc8_5 = (sc16_2 >> 8u) & 0xFFu; + let sc8_6 = sc16_3 & 0xFFu; + let sc8_7 = (sc16_3 >> 8u) & 0xFFu; + + let f0 = f32(sc8_0); + let f1_lo = f32(sc8_1) * (1.0/16.0); + let f1_hi = f32(sc8_1) * 16.0; + let f4 = f32(sc8_4); + let f5_lo = f32(sc8_5) * (1.0/16.0); + let f5_hi = f32(sc8_5) * 16.0; + + let q1_u32_0 = load_src0_u32_at_aligned(bbase + q1_byte); + let q1_u32_1 = load_src0_u32_at_aligned(bbase + q1_byte + 4u); + let q2_u32_0 = load_src0_u32_at_aligned(bbase + q2_byte); + let q2_u32_1 = load_src0_u32_at_aligned(bbase + q2_byte + 4u); + let qh_u32_0 = load_src0_u32_at_aligned(bbase + qh_byte); + let qh_u32_1 = load_src0_u32_at_aligned(bbase + qh_byte + 4u); + + let y_base = ib * Q5K_BLOCK_SIZE + y_offset; + + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + + // l=0 + { + let q1b = byte_of(q1_u32_0, 0u); + let q2b = byte_of(q2_u32_0, 0u); + let qhb = byte_of(qh_u32_0, 0u); + let yl0 = f32(shared_vector[y_base + 0u]); sumy[0] += yl0; + let yl8 = f32(shared_vector[y_base + 32u]); sumy[1] += yl8; + let yh0 = f32(shared_vector[y_base +128u]); sumy[2] += yh0; + let yh8 = f32(shared_vector[y_base +160u]); sumy[3] += yh8; + acc1[0] += yl0 * f32(q1b & 0x0Fu); + acc1[1] += yl8 * f32(q1b & 0xF0u); + acc1[2] += yh0 * f32(q2b & 0x0Fu); + acc1[3] += yh8 * f32(q2b & 0xF0u); + acc2[0] += yl0 * f32((qhb & hm1) != 0u); + acc2[1] += yl8 * f32((qhb & hm2) != 0u); + acc2[2] += yh0 * f32((qhb & hm3) != 0u); + acc2[3] += yh8 * f32((qhb & hm4) != 0u); + } + // l=1 + { + let q1b = byte_of(q1_u32_0, 1u); + let q2b = byte_of(q2_u32_0, 1u); + let qhb = byte_of(qh_u32_0, 1u); + let yl0 = f32(shared_vector[y_base + 1u]); sumy[0] += yl0; + let yl8 = f32(shared_vector[y_base + 33u]); sumy[1] += yl8; + let yh0 = f32(shared_vector[y_base +129u]); sumy[2] += yh0; + let yh8 = f32(shared_vector[y_base +161u]); sumy[3] += yh8; + acc1[0] += yl0 * f32(q1b & 0x0Fu); + acc1[1] += yl8 * f32(q1b & 0xF0u); + acc1[2] += yh0 * f32(q2b & 0x0Fu); + acc1[3] += yh8 * f32(q2b & 0xF0u); + acc2[0] += yl0 * f32((qhb & hm1) != 0u); + acc2[1] += yl8 * f32((qhb & hm2) != 0u); + acc2[2] += yh0 * f32((qhb & hm3) != 0u); + acc2[3] += yh8 * f32((qhb & hm4) != 0u); + } + // l=2 + { + let q1b = byte_of(q1_u32_0, 2u); + let q2b = byte_of(q2_u32_0, 2u); + let qhb = byte_of(qh_u32_0, 2u); + let yl0 = f32(shared_vector[y_base + 2u]); sumy[0] += yl0; + let yl8 = f32(shared_vector[y_base + 34u]); sumy[1] += yl8; + let yh0 = f32(shared_vector[y_base +130u]); sumy[2] += yh0; + let yh8 = f32(shared_vector[y_base +162u]); sumy[3] += yh8; + acc1[0] += yl0 * f32(q1b & 0x0Fu); + acc1[1] += yl8 * f32(q1b & 0xF0u); + acc1[2] += yh0 * f32(q2b & 0x0Fu); + acc1[3] += yh8 * f32(q2b & 0xF0u); + acc2[0] += yl0 * f32((qhb & hm1) != 0u); + acc2[1] += yl8 * f32((qhb & hm2) != 0u); + acc2[2] += yh0 * f32((qhb & hm3) != 0u); + acc2[3] += yh8 * f32((qhb & hm4) != 0u); + } + // l=3 + { + let q1b = byte_of(q1_u32_0, 3u); + let q2b = byte_of(q2_u32_0, 3u); + let qhb = byte_of(qh_u32_0, 3u); + let yl0 = f32(shared_vector[y_base + 3u]); sumy[0] += yl0; + let yl8 = f32(shared_vector[y_base + 35u]); sumy[1] += yl8; + let yh0 = f32(shared_vector[y_base +131u]); sumy[2] += yh0; + let yh8 = f32(shared_vector[y_base +163u]); sumy[3] += yh8; + acc1[0] += yl0 * f32(q1b & 0x0Fu); + acc1[1] += yl8 * f32(q1b & 0xF0u); + acc1[2] += yh0 * f32(q2b & 0x0Fu); + acc1[3] += yh8 * f32(q2b & 0xF0u); + acc2[0] += yl0 * f32((qhb & hm1) != 0u); + acc2[1] += yl8 * f32((qhb & hm2) != 0u); + acc2[2] += yh0 * f32((qhb & hm3) != 0u); + acc2[3] += yh8 * f32((qhb & hm4) != 0u); + } + // l=4 + { + let q1b = byte_of(q1_u32_1, 0u); + let q2b = byte_of(q2_u32_1, 0u); + let qhb = byte_of(qh_u32_1, 0u); + let yl0 = f32(shared_vector[y_base + 4u]); sumy[0] += yl0; + let yl8 = f32(shared_vector[y_base + 36u]); sumy[1] += yl8; + let yh0 = f32(shared_vector[y_base +132u]); sumy[2] += yh0; + let yh8 = f32(shared_vector[y_base +164u]); sumy[3] += yh8; + acc1[0] += yl0 * f32(q1b & 0x0Fu); + acc1[1] += yl8 * f32(q1b & 0xF0u); + acc1[2] += yh0 * f32(q2b & 0x0Fu); + acc1[3] += yh8 * f32(q2b & 0xF0u); + acc2[0] += yl0 * f32((qhb & hm1) != 0u); + acc2[1] += yl8 * f32((qhb & hm2) != 0u); + acc2[2] += yh0 * f32((qhb & hm3) != 0u); + acc2[3] += yh8 * f32((qhb & hm4) != 0u); + } + // l=5 + { + let q1b = byte_of(q1_u32_1, 1u); + let q2b = byte_of(q2_u32_1, 1u); + let qhb = byte_of(qh_u32_1, 1u); + let yl0 = f32(shared_vector[y_base + 5u]); sumy[0] += yl0; + let yl8 = f32(shared_vector[y_base + 37u]); sumy[1] += yl8; + let yh0 = f32(shared_vector[y_base +133u]); sumy[2] += yh0; + let yh8 = f32(shared_vector[y_base +165u]); sumy[3] += yh8; + acc1[0] += yl0 * f32(q1b & 0x0Fu); + acc1[1] += yl8 * f32(q1b & 0xF0u); + acc1[2] += yh0 * f32(q2b & 0x0Fu); + acc1[3] += yh8 * f32(q2b & 0xF0u); + acc2[0] += yl0 * f32((qhb & hm1) != 0u); + acc2[1] += yl8 * f32((qhb & hm2) != 0u); + acc2[2] += yh0 * f32((qhb & hm3) != 0u); + acc2[3] += yh8 * f32((qhb & hm4) != 0u); + } + // l=6 + { + let q1b = byte_of(q1_u32_1, 2u); + let q2b = byte_of(q2_u32_1, 2u); + let qhb = byte_of(qh_u32_1, 2u); + let yl0 = f32(shared_vector[y_base + 6u]); sumy[0] += yl0; + let yl8 = f32(shared_vector[y_base + 38u]); sumy[1] += yl8; + let yh0 = f32(shared_vector[y_base +134u]); sumy[2] += yh0; + let yh8 = f32(shared_vector[y_base +166u]); sumy[3] += yh8; + acc1[0] += yl0 * f32(q1b & 0x0Fu); + acc1[1] += yl8 * f32(q1b & 0xF0u); + acc1[2] += yh0 * f32(q2b & 0x0Fu); + acc1[3] += yh8 * f32(q2b & 0xF0u); + acc2[0] += yl0 * f32((qhb & hm1) != 0u); + acc2[1] += yl8 * f32((qhb & hm2) != 0u); + acc2[2] += yh0 * f32((qhb & hm3) != 0u); + acc2[3] += yh8 * f32((qhb & hm4) != 0u); + } + // l=7 + { + let q1b = byte_of(q1_u32_1, 3u); + let q2b = byte_of(q2_u32_1, 3u); + let qhb = byte_of(qh_u32_1, 3u); + let yl0 = f32(shared_vector[y_base + 7u]); sumy[0] += yl0; + let yl8 = f32(shared_vector[y_base + 39u]); sumy[1] += yl8; + let yh0 = f32(shared_vector[y_base +135u]); sumy[2] += yh0; + let yh8 = f32(shared_vector[y_base +167u]); sumy[3] += yh8; + acc1[0] += yl0 * f32(q1b & 0x0Fu); + acc1[1] += yl8 * f32(q1b & 0xF0u); + acc1[2] += yh0 * f32(q2b & 0x0Fu); + acc1[3] += yh8 * f32(q2b & 0xF0u); + acc2[0] += yl0 * f32((qhb & hm1) != 0u); + acc2[1] += yl8 * f32((qhb & hm2) != 0u); + acc2[2] += yh0 * f32((qhb & hm3) != 0u); + acc2[3] += yh8 * f32((qhb & hm4) != 0u); + } + + sumf += d * (f0 * acc1[0] + f0 * 16.0 * acc2[0] + + f1_lo * acc1[1] + f1_hi * acc2[1] + + f4 * acc1[2] + f4 * 16.0 * acc2[2] + + f5_lo * acc1[3] + f5_hi * acc2[3]) + - dmin * (sumy[0]*f32(sc8_2) + sumy[1]*f32(sc8_3) + + sumy[2]*f32(sc8_6) + sumy[3]*f32(sc8_7)); + } + + return sumf; } +#endif + +#ifdef MUL_ACC_Q6_K + +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 210u; // 105 f16s * 2 fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let tid = tig / 2u; @@ -373,24 +1014,27 @@ struct MulMatParams { broadcast3: u32 }; -// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included +// SRC0_TYPE and SRC1_TYPE are defined above (VEC/SCALAR for float, U32_DEQUANT_HELPERS for quantized) @group(0) @binding(0) var src0: array; // M rows, K columns @group(0) @binding(1) var src1: array; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) @group(0) @binding(3) var params: MulMatParams; const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG; +// const THREADS_PER_OUTPUT = 32u; // Shared memory for collaborative loading and reduction -var shared_vector: array; // Cache vector tile +// padded by + 1 to serialize reads (perf improvement for legacy quants?) +var shared_vector: array; // Cache vector tile var partial_sums: array; // For reduction @compute @workgroup_size(WG_SIZE) fn main( @builtin(local_invocation_id) local_id: vec3, @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3) { + @builtin(num_workgroups) num_wg: vec3, + @builtin(subgroup_size) subgroup_size: u32) { let thread_id = local_id.x; // Handle batch dimensions @@ -442,22 +1086,33 @@ fn main( workgroupBarrier(); } - // Store partial sums and reduce within each partition - partial_sums[thread_id] = local_sum; + let subgroup_local_id = thread_in_group % subgroup_size; + let subgroup_in_group = thread_in_group / subgroup_size; + let subgroup_total = subgroupAdd(local_sum); + + if (subgroup_local_id == 0u) { + partial_sums[thread_group * THREADS_PER_OUTPUT + subgroup_in_group] = subgroup_total; + } workgroupBarrier(); - let group_base = thread_group * THREADS_PER_OUTPUT; - let thread_base = group_base + thread_in_group; - var offset: u32 = THREADS_PER_OUTPUT / 2; - while (offset > 0) { - if (thread_in_group < offset) { - partial_sums[thread_base] += partial_sums[thread_base + offset]; + + if (thread_in_group == 0u && output_row < params.m) { + var row_total = 0.0; + let num_subgroups = THREADS_PER_OUTPUT / subgroup_size; + for (var s = 0u; s < num_subgroups; s++) { + row_total += partial_sums[thread_group * THREADS_PER_OUTPUT + s]; } - offset = offset / 2; - workgroupBarrier(); + #ifdef SCALAR + dst[dst_idx] = row_total; + #endif + #ifdef VEC + partial_sums[thread_group * THREADS_PER_OUTPUT] = row_total; + #endif } - // Store back to global memory - if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) { - dst[dst_idx / VEC_SIZE] = store_val(group_base); + #ifdef VEC + workgroupBarrier(); + if (output_row < params.m && thread_group % VEC_SIZE == 0u && thread_in_group == 0u) { + dst[dst_idx / VEC_SIZE] = store_val(thread_group * THREADS_PER_OUTPUT); } -} + #endif +} \ No newline at end of file diff --git a/src/ggml-webgpu.cpp b/src/ggml-webgpu.cpp new file mode 100644 index 00000000000..e69de29bb2d From 3c9e474cc66a5ad6ac6e36bd7ff8e1de21a2bfb5 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 14 Apr 2026 08:25:50 -0700 Subject: [PATCH 2/9] Start on new mat-vec --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 42 +++-- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 15 +- .../wgsl-shaders/mul_mat_vec_row_tiled.wgsl | 146 ++++++++++++++++++ 3 files changed, 191 insertions(+), 12 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 6b894374bc5..ebd62703157 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -46,7 +46,7 @@ // Must be multiple of 4 to work with vectorized paths, and must divide // mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 8 +#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4 #define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 1024 #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 8 @@ -78,6 +78,7 @@ struct ggml_webgpu_shader_lib_context { bool inplace = false; bool overlap = false; bool src_overlap = false; + bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; @@ -592,9 +593,13 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; + int use_subgroup_reduction; bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { - return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized; + return src0_type == other.src0_type && + src1_type == other.src1_type && + vectorized == other.vectorized && + use_subgroup_reduction == other.use_subgroup_reduction; } }; @@ -604,6 +609,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.use_subgroup_reduction); return seed; } }; @@ -613,6 +619,7 @@ struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t tile_k; uint32_t outputs_per_wg; uint32_t vec_size; + bool use_subgroup_reduction; }; struct ggml_webgpu_mul_mat_pipeline_key { @@ -1332,14 +1339,18 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool use_row_tiled_float = + context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16; ggml_webgpu_mul_mat_vec_pipeline_key key = { .src0_type = context.src0->type, .src1_type = context.src1->type, // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + .vectorized = (!use_row_tiled_float && + context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0, + .use_subgroup_reduction = use_row_tiled_float && context.supports_subgroups, }; auto it = mul_mat_vec_pipelines.find(key); @@ -1348,19 +1359,24 @@ class ggml_webgpu_shader_lib { } std::vector defines; - std::string variant = "mul_mat_vec"; + std::string variant = use_row_tiled_float ? "mul_mat_vec_row_tiled" : "mul_mat_vec"; + const char * shader_src = use_row_tiled_float ? wgsl_mul_mat_vec_row_tiled : wgsl_mul_mat_vec; // src0 type (matrix row) switch (context.src0->type) { case GGML_TYPE_F32: defines.push_back("SRC0_INNER_TYPE=f32"); - defines.push_back("MUL_ACC_FLOAT"); variant += "_f32"; + if (!use_row_tiled_float) { + defines.push_back("MUL_ACC_FLOAT"); + } break; case GGML_TYPE_F16: defines.push_back("SRC0_INNER_TYPE=f16"); - defines.push_back("MUL_ACC_FLOAT"); variant += "_f16"; + if (!use_row_tiled_float) { + defines.push_back("MUL_ACC_FLOAT"); + } break; default: { @@ -1394,7 +1410,9 @@ class ggml_webgpu_shader_lib { } // VEC/SCALAR controls - defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + if (!use_row_tiled_float) { + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + } uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K; @@ -1409,15 +1427,21 @@ class ggml_webgpu_shader_lib { } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + if (use_row_tiled_float) { + defines.push_back(key.use_subgroup_reduction ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += key.use_subgroup_reduction ? "_subgroup_reduce" : "_workgroup_reduce"; + } else { + defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); + } - auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines); + auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); decisions->wg_size = wg_size; decisions->tile_k = tile_k; decisions->outputs_per_wg = outputs_per_wg; decisions->vec_size = key.vectorized ? 4 : 1; + decisions->use_subgroup_reduction = key.use_subgroup_reduction; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index beab945e804..c4bc01b0db4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -233,6 +233,7 @@ struct webgpu_encoded_op { struct webgpu_capabilities { wgpu::Limits limits; + bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; @@ -1322,6 +1323,8 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + use_fast = true; + break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1329,7 +1332,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q6_K: - use_fast = true; + use_fast = !is_vec || ctx->global_ctx->capabilities.supports_subgroups; break; case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: @@ -1351,6 +1354,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, .src1 = src1, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups, .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix, .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, @@ -3443,6 +3447,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + ctx->webgpu_global_ctx->capabilities.supports_subgroups = + ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); #ifndef __EMSCRIPTEN__ // Only support square f16 matrices of size 8 or 16 for now @@ -3466,7 +3472,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize; + ctx->webgpu_global_ctx->capabilities.max_subgroup_size = + ctx->webgpu_global_ctx->capabilities.supports_subgroups ? info.subgroupMaxSize : 1u; // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16 }; @@ -3478,7 +3485,9 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { // } required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); - required_features.push_back(wgpu::FeatureName::Subgroups); + if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) { + required_features.push_back(wgpu::FeatureName::Subgroups); + } if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl new file mode 100644 index 00000000000..8f207411a69 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl @@ -0,0 +1,146 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif +enable f16; + +#include "common_decls.tmpl" + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array; +@group(0) @binding(1) var src1: array; +@group(0) @binding(2) var dst: array; + +@group(0) @binding(3) var params: MulMatParams; + +// Flattened as [row][thread] to keep each row's reduction contiguous in memory. +var partial_sums: array; + +fn partial_index(row: u32, thread: u32) -> u32 { + return row * WG_SIZE + thread; +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32 +#endif +) { + let thread_id = local_id.x; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let batch_idx = wg_linear / output_groups; + if (batch_idx >= total_batches) { + return; + } + + let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; + + let dst2_stride = params.m * params.n; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; + + var acc: array; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + acc[row] = 0.0; + } + + // Each thread walks K with unit-stride loads from the vector and updates + // a small block of output rows held in registers. + for (var k = thread_id; k < params.k; k += WG_SIZE) { + let x = f32(src1[src1_idx_base + k]); + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let src0_idx = src0_batch_offset + output_row * params.stride_01 + k; + acc[row] += f32(src0[src0_idx]) * x; + } + } + } + +#ifdef USE_SUBGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; + } + } + + workgroupBarrier(); + + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + var row_total = 0.0; + for (var s = 0u; s < num_subgroups; s++) { + row_total += partial_sums[partial_index(thread_id, s)]; + } + dst[dst_idx_base + thread_id] = row_total; + } + } +#endif + +#ifdef USE_WORKGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[row]; + } + + workgroupBarrier(); + + var stride = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } + } + + workgroupBarrier(); + stride = stride / 2u; + } + + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0u)]; + } + } +#endif +} From 0bcf75c13ad2c02da82a7f510a726cb30ed93cbb Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 14 Apr 2026 09:27:43 -0700 Subject: [PATCH 3/9] New format float paths working --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 26 +++---- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 17 ++--- .../wgsl-shaders/mul_mat_vec_row_tiled.wgsl | 67 ++++++++++++------- 3 files changed, 57 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index ebd62703157..53b3b460799 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -593,13 +593,11 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; - int use_subgroup_reduction; bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && - vectorized == other.vectorized && - use_subgroup_reduction == other.use_subgroup_reduction; + vectorized == other.vectorized; } }; @@ -609,7 +607,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); - ggml_webgpu_hash_combine(seed, key.use_subgroup_reduction); return seed; } }; @@ -619,7 +616,6 @@ struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t tile_k; uint32_t outputs_per_wg; uint32_t vec_size; - bool use_subgroup_reduction; }; struct ggml_webgpu_mul_mat_pipeline_key { @@ -1344,13 +1340,11 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_vec_pipeline_key key = { .src0_type = context.src0->type, .src1_type = context.src1->type, - // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float - .vectorized = (!use_row_tiled_float && - context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + .vectorized = (context.src0->ne[0] % 4 == 0 && + (use_row_tiled_float || context.dst->ne[0] % 4 == 0) && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : - 0, - .use_subgroup_reduction = use_row_tiled_float && context.supports_subgroups, + 0 }; auto it = mul_mat_vec_pipelines.find(key); @@ -1410,9 +1404,7 @@ class ggml_webgpu_shader_lib { } // VEC/SCALAR controls - if (!use_row_tiled_float) { - defines.push_back(key.vectorized ? "VEC" : "SCALAR"); - } + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K; @@ -1429,11 +1421,14 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); if (use_row_tiled_float) { - defines.push_back(key.use_subgroup_reduction ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); - variant += key.use_subgroup_reduction ? "_subgroup_reduce" : "_workgroup_reduce"; + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; } else { defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); } + if (key.vectorized) { + variant += "_vectorized"; + } auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); @@ -1441,7 +1436,6 @@ class ggml_webgpu_shader_lib { decisions->tile_k = tile_k; decisions->outputs_per_wg = outputs_per_wg; decisions->vec_size = key.vectorized ? 4 : 1; - decisions->use_subgroup_reduction = key.use_subgroup_reduction; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index c4bc01b0db4..d1516434981 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3472,28 +3472,21 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->webgpu_global_ctx->capabilities.max_subgroup_size = - ctx->webgpu_global_ctx->capabilities.supports_subgroups ? info.subgroupMaxSize : 1u; + ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize; // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16 }; #ifndef __EMSCRIPTEN__ - // required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); - // if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { - // required_features.push_back(wgpu::FeatureName::Subgroups); - // required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); - // } - required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); - if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) { - required_features.push_back(wgpu::FeatureName::Subgroups); - } - if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); } #endif + if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) { + required_features.push_back(wgpu::FeatureName::Subgroups); + } + #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl index 8f207411a69..8a532c37bf3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl @@ -5,6 +5,26 @@ enable f16; #include "common_decls.tmpl" +#ifdef VEC +#define VEC_SIZE 4u +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(dot(SRC1_TYPE(src0_val), src1_val)); +} +#endif + +#ifdef SCALAR +#define VEC_SIZE 1u +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(src0_val) * f32(src1_val); +} +#endif + struct MulMatParams { offset_src0: u32, offset_src1: u32, @@ -24,8 +44,8 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array; -@group(0) @binding(1) var src1: array; +@group(0) @binding(0) var src0: array; +@group(0) @binding(1) var src1: array; @group(0) @binding(2) var dst: array; @group(0) @binding(3) var params: MulMatParams; @@ -45,7 +65,8 @@ fn main( #ifdef USE_SUBGROUP_REDUCTION , @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, - @builtin(num_subgroups) num_subgroups: u32 + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32 #endif ) { let thread_id = local_id.x; @@ -74,19 +95,19 @@ fn main( let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; var acc: array; - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - acc[row] = 0.0; - } - // Each thread walks K with unit-stride loads from the vector and updates + let k_vec = params.k / VEC_SIZE; + let src1_idx_base_vec = src1_idx_base / VEC_SIZE; + + // Each thread walks K, loads from the vector, and updates // a small block of output rows held in registers. - for (var k = thread_id; k < params.k; k += WG_SIZE) { - let x = f32(src1[src1_idx_base + k]); + for (var k = thread_id; k < k_vec; k += WG_SIZE) { + let x = src1[src1_idx_base_vec + k]; for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { - let src0_idx = src0_batch_offset + output_row * params.stride_01 + k; - acc[row] += f32(src0[src0_idx]) * x; + let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; + acc[row] += inner_dot(src0[src0_idx], x); } } } @@ -101,15 +122,14 @@ fn main( workgroupBarrier(); - if (thread_id < OUTPUTS_PER_WG) { - let output_row = row_base + thread_id; - if (output_row < params.m) { - var row_total = 0.0; - for (var s = 0u; s < num_subgroups; s++) { - row_total += partial_sums[partial_index(thread_id, s)]; - } - dst[dst_idx_base + thread_id] = row_total; + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; } + let row_total = subgroupAdd(row_acc); + dst[dst_idx_base + row] = row_total; } #endif @@ -121,11 +141,8 @@ fn main( workgroupBarrier(); var stride = WG_SIZE / 2u; - loop { - if (stride == 0u) { - break; - } + while (stride > 0) { if (thread_id < stride) { for (var row = 0u; row < OUTPUTS_PER_WG; row++) { partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; @@ -133,13 +150,13 @@ fn main( } workgroupBarrier(); - stride = stride / 2u; + stride = stride / 2; } if (thread_id < OUTPUTS_PER_WG) { let output_row = row_base + thread_id; if (output_row < params.m) { - dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0u)]; + dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; } } #endif From 01bd9127564cc401d73c3907cfff5787a98f61d2 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 14 Apr 2026 10:17:16 -0700 Subject: [PATCH 4/9] Working q4_0 --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 27 ++++++----- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- .../wgsl-shaders/mul_mat_vec_row_tiled.wgsl | 47 ++++++++++++++++++- 3 files changed, 62 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 53b3b460799..1854f351e5e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1335,13 +1335,13 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { - const bool use_row_tiled_float = - context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16; + const bool use_row_tiled = + context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16 || context.src0->type == GGML_TYPE_Q4_0; ggml_webgpu_mul_mat_vec_pipeline_key key = { .src0_type = context.src0->type, .src1_type = context.src1->type, .vectorized = (context.src0->ne[0] % 4 == 0 && - (use_row_tiled_float || context.dst->ne[0] % 4 == 0) && + (use_row_tiled || context.dst->ne[0] % 4 == 0) && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0 @@ -1353,24 +1353,27 @@ class ggml_webgpu_shader_lib { } std::vector defines; - std::string variant = use_row_tiled_float ? "mul_mat_vec_row_tiled" : "mul_mat_vec"; - const char * shader_src = use_row_tiled_float ? wgsl_mul_mat_vec_row_tiled : wgsl_mul_mat_vec; + std::string variant = use_row_tiled ? "mul_mat_vec_row_tiled" : "mul_mat_vec"; + const char * shader_src = use_row_tiled ? wgsl_mul_mat_vec_row_tiled : wgsl_mul_mat_vec; // src0 type (matrix row) switch (context.src0->type) { case GGML_TYPE_F32: defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("MUL_ACC_FLOAT"); variant += "_f32"; - if (!use_row_tiled_float) { - defines.push_back("MUL_ACC_FLOAT"); - } break; case GGML_TYPE_F16: defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("MUL_ACC_FLOAT"); variant += "_f16"; - if (!use_row_tiled_float) { - defines.push_back("MUL_ACC_FLOAT"); - } + break; + case GGML_TYPE_Q4_0: + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_Q4_0"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + variant += "_q4_0"; break; default: { @@ -1420,7 +1423,7 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); - if (use_row_tiled_float) { + if (use_row_tiled) { defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; } else { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index d1516434981..a6c01be23d4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1323,9 +1323,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: use_fast = true; break; - case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl index 8a532c37bf3..278d46b9cfa 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl @@ -5,6 +5,10 @@ enable f16; #include "common_decls.tmpl" +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 +#endif + #ifdef VEC #define VEC_SIZE 4u #define SRC0_TYPE vec4 @@ -96,6 +100,7 @@ fn main( var acc: array; +#ifdef MUL_ACC_FLOAT let k_vec = params.k / VEC_SIZE; let src1_idx_base_vec = src1_idx_base / VEC_SIZE; @@ -111,6 +116,44 @@ fn main( } } } +#endif + +#ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % 4; + for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_src0_f16_at(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_src0_u32_at(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } +#endif #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { @@ -129,7 +172,9 @@ fn main( row_acc += partial_sums[partial_index(row, k)]; } let row_total = subgroupAdd(row_acc); - dst[dst_idx_base + row] = row_total; + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + row] = row_total; + } } #endif From f839c1032149dc2a40c9845defe84d8d1c93b8f0 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 14 Apr 2026 10:27:06 -0700 Subject: [PATCH 5/9] Work on remaining legacy q-types --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 11 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 +- .../wgsl-shaders/mul_mat_vec_row_tiled.wgsl | 192 ++++++++++++++++++ 3 files changed, 197 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 1854f351e5e..41f1f28a231 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1336,7 +1336,9 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool use_row_tiled = - context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16 || context.src0->type == GGML_TYPE_Q4_0; + context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16 || context.src0->type == GGML_TYPE_Q4_0 || + context.src0->type == GGML_TYPE_Q4_1 || context.src0->type == GGML_TYPE_Q5_0 || context.src0->type == GGML_TYPE_Q5_1 || + context.src0->type == GGML_TYPE_Q8_0 || context.src0->type == GGML_TYPE_Q8_1; ggml_webgpu_mul_mat_vec_pipeline_key key = { .src0_type = context.src0->type, .src1_type = context.src1->type, @@ -1368,13 +1370,6 @@ class ggml_webgpu_shader_lib { defines.push_back("MUL_ACC_FLOAT"); variant += "_f16"; break; - case GGML_TYPE_Q4_0: - defines.push_back("BYTE_HELPERS"); - defines.push_back("MUL_ACC_Q4_0"); - defines.push_back("U32_DEQUANT_HELPERS"); - defines.push_back("SRC0_INNER_TYPE=u32"); - variant += "_q4_0"; - break; default: { // Quantized types: use helpers but accumulate in f16 diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a6c01be23d4..4f77906bab3 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1324,13 +1324,13 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: - use_fast = true; - break; case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + use_fast = true; + break; case GGML_TYPE_Q6_K: use_fast = !is_vec || ctx->global_ctx->capabilities.supports_subgroups; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl index 278d46b9cfa..09498bf98d9 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl @@ -155,6 +155,198 @@ fn main( } #endif +#ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 20 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_src0_f16_at(block_byte_base)); + let m = f32(load_src0_f16_at(block_byte_base + 2u)); + var row_sum = 0.0; + + let q_packed = load_src0_u32_at(block_byte_base + 4u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_Q5_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 22 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_src0_f16_at(block_byte_base)); + let qh_packed = load_src0_u32_at(block_byte_base + 2u); + let q_packed = load_src0_u32_at(block_byte_base + 6u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_Q5_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 24 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_src0_f16_at(block_byte_base)); + let m = f32(load_src0_f16_at(block_byte_base + 2u)); + let qh_packed = load_src0_u32_at(block_byte_base + 4u); + let q_packed = load_src0_u32_at(block_byte_base + 8u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 34 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_src0_f16_at(block_byte_base)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_src0_u32_at(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_Q8_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 36 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_src0_f16_at(block_byte_base)); + let m = f32(load_src0_f16_at(block_byte_base + 2u)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_src0_u32_at(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; + } + } + } +#endif + #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let subgroup_total = subgroupAdd(acc[row]); From ba96122580a95a7866463df003aadd2f06f7d272 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 14 Apr 2026 17:26:21 -0700 Subject: [PATCH 6/9] port k-quants to new matvec --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 4 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 11 +- .../wgsl-shaders/mul_mat_vec_row_tiled.wgsl | 464 ++++++++++++++++++ 3 files changed, 470 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 41f1f28a231..b7f46b9bc29 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1338,7 +1338,9 @@ class ggml_webgpu_shader_lib { const bool use_row_tiled = context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16 || context.src0->type == GGML_TYPE_Q4_0 || context.src0->type == GGML_TYPE_Q4_1 || context.src0->type == GGML_TYPE_Q5_0 || context.src0->type == GGML_TYPE_Q5_1 || - context.src0->type == GGML_TYPE_Q8_0 || context.src0->type == GGML_TYPE_Q8_1; + context.src0->type == GGML_TYPE_Q8_0 || context.src0->type == GGML_TYPE_Q8_1 || context.src0->type == GGML_TYPE_Q6_K || + context.src0->type == GGML_TYPE_Q4_K || context.src0->type == GGML_TYPE_Q5_K || context.src0->type == GGML_TYPE_Q3_K || + context.src0->type == GGML_TYPE_Q2_K; ggml_webgpu_mul_mat_vec_pipeline_key key = { .src0_type = context.src0->type, .src1_type = context.src1->type, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4f77906bab3..b351108f631 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1329,17 +1329,12 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: - use_fast = true; - break; case GGML_TYPE_Q6_K: - use_fast = !is_vec || ctx->global_ctx->capabilities.supports_subgroups; - break; - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: - // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat - use_fast = !is_vec; + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q2_K: + use_fast = true; break; default: break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl index 09498bf98d9..dd03435d357 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl @@ -7,6 +7,15 @@ enable f16; #ifdef U32_DEQUANT_HELPERS #define SRC0_TYPE u32 + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} #endif #ifdef VEC @@ -347,6 +356,461 @@ fn main( } #endif +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 84 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let iq = lane / 4u; + let ir = lane % 4u; + let is = ir / 2u; + + let y_offset = 128u * iq + 8u * ir + 4u * phase; + let sc0_byte = 8u * iq + is; + let sc2_byte = 8u * iq + is + 2u; + let sc4_byte = 8u * iq + is + 4u; + let sc6_byte = 8u * iq + is + 6u; + let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 64u + i]); + x_block[i + 12u] = f32(src1[x_base + 96u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let dall = f32(load_src0_f16_at(block_byte_base + 80u)); + let dmin = f32(load_src0_f16_at(block_byte_base + 82u)) * (1.0 / 16.0); + + let sc0 = byte_of(load_src0_u32_at_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); + let sc2 = byte_of(load_src0_u32_at_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); + let sc4 = byte_of(load_src0_u32_at_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); + let sc6 = byte_of(load_src0_u32_at_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); + + let q_u32 = load_src0_u32_at_aligned(block_byte_base + qs_byte); + let qs0 = q_u32 & 0xFFFFu; + let qs1 = q_u32 >> 16u; + + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + + sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; + sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; + sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; + sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + + acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + + acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } + } + } +#endif + + +#ifdef MUL_ACC_Q3_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let ip = lane / 4u; + let il = 2u * ((lane % 4u) / 2u); + let ir = lane % 2u; + let l0 = 8u * ir; + + let q_byte = 32u + 32u * ip + l0 + 16u * phase; + let h_byte = l0 + 16u * phase; + let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; + + let s_shift1 = 4u * ip; + let s_shift2 = s_shift1 + il; + + let v1 = select(64.0, 4.0, il == 0u); + let v2 = 4.0 * v1; + let shift = 2u * il; + + var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; + if (il == 0u) { + qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; + } else { + qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; + } + + let mm_idx = 2u * ip + il / 2u; + var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; + switch (mm_idx) { + case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } + case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } + case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } + default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } + } + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 8u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 8u] = f32(src1[x_base + 32u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_src0_f16_at(block_byte_base + 108u)); + let a_base = 96u; + let a_il0 = load_src0_u16_at(block_byte_base + a_base + il * 2u); + let a_il1 = load_src0_u16_at(block_byte_base + a_base + (il + 1u) * 2u); + let a_4 = load_src0_u16_at(block_byte_base + a_base + 8u); + let a_5 = load_src0_u16_at(block_byte_base + a_base + 10u); + + var scales32 = a_4 | (a_5 << 16u); + let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; + scales32 = a_il0 | (a_il1 << 16u); + scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; + + let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); + let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); + + let q_u32_0 = load_src0_u32_at(block_byte_base + q_byte + 0u); + let q_u32_1 = load_src0_u32_at(block_byte_base + q_byte + 4u); + let h_u32_0 = load_src0_u32_at(block_byte_base + h_byte + 0u); + let h_u32_1 = load_src0_u32_at(block_byte_base + h_byte + 4u); + + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + + s1 += x_block[l + 0u] * f32(qs & qm0); + s2 += x_block[l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[l + 1u], (hv & hm1) == 0u); + s4 += x_block[l + 8u] * f32(qs & qm2); + s5 += x_block[l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); + } + } + } +#endif + +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 144 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 32u * im + l0; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_src0_f16_at(block_byte_base + 0u)); + let dmin = f32(load_src0_f16_at(block_byte_base + 2u)); + + let sc0_u32 = load_src0_u32_at_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_src0_u32_at_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_src0_u32_at_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let scale0 = f32(sc16_0 & 0xFFu); + let scale1 = f32((sc16_0 >> 8u) & 0xFFu); + let min0 = f32(sc16_1 & 0xFFu); + let min1 = f32((sc16_1 >> 8u) & 0xFFu); + let scale2 = f32(sc16_2 & 0xFFu); + let scale3 = f32((sc16_2 >> 8u) & 0xFFu); + let min2 = f32(sc16_3 & 0xFFu); + let min3 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_src0_u32_at_aligned(block_byte_base + 16u + q_offset); + let q2_u32 = load_src0_u32_at_aligned(block_byte_base + 80u + q_offset); + + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[i] * f32(q1b & 0x0Fu); + dot[1] += x_block[i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[i]; + sumx[1] += x_block[i + 4u]; + sumx[2] += x_block[i + 8u]; + sumx[3] += x_block[i + 12u]; + } + + acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } + } + } +#endif + +#ifdef MUL_ACC_Q5_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 176 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 48u + 32u * im + l0; + let qh_offset = 16u + 8u * ir + 4u * in; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let hm1 = 1u << (2u * im); + let hm2 = hm1 << 1u; + let hm3 = hm1 << 4u; + let hm4 = hm2 << 4u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_src0_f16_at(block_byte_base + 0u)); + let dmin = f32(load_src0_f16_at(block_byte_base + 2u)); + + let sc0_u32 = load_src0_u32_at_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_src0_u32_at_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_src0_u32_at_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let f0 = f32(sc16_0 & 0xFFu); + let f1 = f32((sc16_0 >> 8u) & 0xFFu); + let m0 = f32(sc16_1 & 0xFFu); + let m1 = f32((sc16_1 >> 8u) & 0xFFu); + let f4 = f32(sc16_2 & 0xFFu); + let f5 = f32((sc16_2 >> 8u) & 0xFFu); + let m4 = f32(sc16_3 & 0xFFu); + let m5 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_src0_u32_at_aligned(block_byte_base + q_offset); + let q2_u32 = load_src0_u32_at_aligned(block_byte_base + q_offset + 64u); + let qh_u32 = load_src0_u32_at_aligned(block_byte_base + qh_offset); + + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); + + let yl0 = x_block[i]; + let yl8 = x_block[i + 4u]; + let yh0 = x_block[i + 8u]; + let yh8 = x_block[i + 12u]; + + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; + + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); + } + } + } +#endif + +#ifdef MUL_ACC_Q6_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 210 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let ip = tid / 8u; + let il = tid % 8u; + let l0 = 4u * il; + let is = 8u * ip + l0 / 16u; + + let y_offset = 128u * ip + l0; + let q_offset_l = 64u * ip + l0; + let q_offset_h = 32u * ip + l0; + + let num_blocks = params.k / BLOCK_SIZE; + let sc_base_byte = 192u + (is & ~3u); + let sc_byte_pos = is & 3u; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var l = 0u; l < 4u; l++) { + x_block[l] = f32(src1[x_base + l]); + x_block[l + 4u] = f32(src1[x_base + 32u + l]); + x_block[l + 8u] = f32(src1[x_base + 64u + l]); + x_block[l + 12u] = f32(src1[x_base + 96u + l]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_src0_f16_at(block_byte_base + 208u)); + let ql1_u32 = load_src0_u32_at(block_byte_base + q_offset_l); + let ql2_u32 = load_src0_u32_at(block_byte_base + q_offset_l + 32u); + let qh_u32 = load_src0_u32_at(block_byte_base + 128u + q_offset_h); + let sc_u32_0 = load_src0_u32_at(block_byte_base + sc_base_byte); + let sc_u32_1 = load_src0_u32_at(block_byte_base + sc_base_byte + 4u); + + let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); + let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); + let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); + let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + + var sums = vec4(0.0, 0.0, 0.0, 0.0); + + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); + + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + + sums[0] += x_block[l] * dq0; + sums[1] += x_block[l + 4u] * dq1; + sums[2] += x_block[l + 8u] * dq2; + sums[3] += x_block[l + 12u] * dq3; + } + + acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } + } + } +#endif + #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let subgroup_total = subgroupAdd(acc[row]); From b4b6ffc46a5c0fe9575b9d1c6688065cfcd44db7 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 14 Apr 2026 17:37:31 -0700 Subject: [PATCH 7/9] remove old shader --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 28 +- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 1636 +++++++---------- .../wgsl-shaders/mul_mat_vec_row_tiled.wgsl | 864 --------- 3 files changed, 697 insertions(+), 1831 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index b7f46b9bc29..2a25c199fc6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -49,11 +49,11 @@ #define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4 #define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 1024 -#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 8 +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4 #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 1024 // Requires 32 threads per output (wg_size/outputs_per_wg == 32) -#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8 +#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4 // Requires at least two (and multiple of 2) k-quant blocks per tile #define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 2048 @@ -613,7 +613,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t wg_size; - uint32_t tile_k; uint32_t outputs_per_wg; uint32_t vec_size; }; @@ -1335,17 +1334,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { - const bool use_row_tiled = - context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16 || context.src0->type == GGML_TYPE_Q4_0 || - context.src0->type == GGML_TYPE_Q4_1 || context.src0->type == GGML_TYPE_Q5_0 || context.src0->type == GGML_TYPE_Q5_1 || - context.src0->type == GGML_TYPE_Q8_0 || context.src0->type == GGML_TYPE_Q8_1 || context.src0->type == GGML_TYPE_Q6_K || - context.src0->type == GGML_TYPE_Q4_K || context.src0->type == GGML_TYPE_Q5_K || context.src0->type == GGML_TYPE_Q3_K || - context.src0->type == GGML_TYPE_Q2_K; ggml_webgpu_mul_mat_vec_pipeline_key key = { .src0_type = context.src0->type, .src1_type = context.src1->type, .vectorized = (context.src0->ne[0] % 4 == 0 && - (use_row_tiled || context.dst->ne[0] % 4 == 0) && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0 @@ -1357,8 +1349,8 @@ class ggml_webgpu_shader_lib { } std::vector defines; - std::string variant = use_row_tiled ? "mul_mat_vec_row_tiled" : "mul_mat_vec"; - const char * shader_src = use_row_tiled ? wgsl_mul_mat_vec_row_tiled : wgsl_mul_mat_vec; + std::string variant = "mul_mat_vec"; + const char * shader_src = wgsl_mul_mat_vec; // src0 type (matrix row) switch (context.src0->type) { @@ -1407,25 +1399,18 @@ class ggml_webgpu_shader_lib { defines.push_back(key.vectorized ? "VEC" : "SCALAR"); uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; - uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K; uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; if (key.src0_type >= GGML_TYPE_Q2_K) { - tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K; outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; } else if (key.src0_type >= GGML_TYPE_Q4_0) { - tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K; outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); - if (use_row_tiled) { - defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); - variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; - } else { - defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); - } + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; if (key.vectorized) { variant += "_vectorized"; } @@ -1433,7 +1418,6 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); decisions->wg_size = wg_size; - decisions->tile_k = tile_k; decisions->outputs_per_wg = outputs_per_wg; decisions->vec_size = key.vectorized ? 4 : 1; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index e87221146c7..dd03435d357 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,10 +1,13 @@ +#ifdef USE_SUBGROUP_REDUCTION enable subgroups; +#endif enable f16; #include "common_decls.tmpl" #ifdef U32_DEQUANT_HELPERS #define SRC0_TYPE u32 + fn byte_of(v: u32, b: u32) -> u32 { return (v >> (b * 8u)) & 0xFFu; } @@ -16,403 +19,443 @@ fn sbyte_of(v: u32, b: u32) -> i32 { #endif #ifdef VEC - -#define VEC_SIZE 4 -#define DST_TYPE vec4 +#define VEC_SIZE 4u #define SRC0_TYPE vec4 #define SRC1_TYPE vec4 fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(dot(SRC1_TYPE(src0_val), src1_val)); } - -fn store_val(group_base: u32) -> vec4 { - return vec4(partial_sums[group_base], - partial_sums[group_base + THREADS_PER_OUTPUT], - partial_sums[group_base + THREADS_PER_OUTPUT * 2], - partial_sums[group_base + THREADS_PER_OUTPUT * 3]); -} #endif #ifdef SCALAR - -#define VEC_SIZE 1 -#define DST_TYPE f32 +#define VEC_SIZE 1u #define SRC0_TYPE SRC0_INNER_TYPE #define SRC1_TYPE SRC1_INNER_TYPE fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(src0_val) * f32(src1_val); } +#endif -fn store_val(group_base: u32) -> f32 { - return partial_sums[group_base]; +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array; +@group(0) @binding(1) var src1: array; +@group(0) @binding(2) var dst: array; + +@group(0) @binding(3) var params: MulMatParams; + +// Flattened as [row][thread] to keep each row's reduction contiguous in memory. +var partial_sums: array; + +fn partial_index(row: u32, thread: u32) -> u32 { + return row * WG_SIZE + thread; } + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32 #endif +) { + let thread_id = local_id.x; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let batch_idx = wg_linear / output_groups; + if (batch_idx >= total_batches) { + return; + } + + let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; + + let dst2_stride = params.m * params.n; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; + + var acc: array; #ifdef MUL_ACC_FLOAT -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) { - let a = src0[(idx_base + k_outer + i) / VEC_SIZE]; - let b = shared_vector[i / VEC_SIZE]; - local_sum += inner_dot(a, b); + let k_vec = params.k / VEC_SIZE; + let src1_idx_base_vec = src1_idx_base / VEC_SIZE; + + // Each thread walks K, loads from the vector, and updates + // a small block of output rows held in registers. + for (var k = thread_id; k < k_vec; k += WG_SIZE) { + let x = src1[src1_idx_base_vec + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; + acc[row] += inner_dot(src0[src0_idx], x); + } + } } - return local_sum; -} #endif #ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % 4; + for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 18u; -const NQ = 16u; -const WEIGHTS_PER_F16 = 4u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_packed = load_src0_u32_at(block_byte_base + 2u + 2u * (block_offset + j)); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0) * d; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_src0_f16_at(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_src0_u32_at(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } } } - return local_sum; -} #endif #ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 20 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 20u; -const NQ = 16u; -const WEIGHTS_PER_F16 = 4u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = f32(load_src0_f16_at(block_byte_base + 2u)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_packed = load_src0_u32_at(block_byte_base + 4u + 2u * (block_offset + j)); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32((q_byte >> 4) & 0xF) * d + m; - let q_lo = f32(q_byte & 0xF) * d + m; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_src0_f16_at(block_byte_base)); + let m = f32(load_src0_f16_at(block_byte_base + 2u)); + var row_sum = 0.0; + + let q_packed = load_src0_u32_at(block_byte_base + 4u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } } } - return local_sum; -} #endif #ifdef MUL_ACC_Q5_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 22 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 22u; -const NQ = 16u; -const WEIGHTS_PER_F16 = 4u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let qh_packed = load_src0_u32_at(block_byte_base + 2u); - - for (var j = 0u; j < 2; j++) { - let q_packed = load_src0_u32_at(block_byte_base + 6u + 2u * (block_offset + j * 2u)); - let j_adjusted = j + (block_offset / 2u); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; - let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; - let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_src0_f16_at(block_byte_base)); + let qh_packed = load_src0_u32_at(block_byte_base + 2u); + let q_packed = load_src0_u32_at(block_byte_base + 6u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } } } - return local_sum; -} #endif #ifdef MUL_ACC_Q5_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 24 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 24u; -const NQ = 16u; -const WEIGHTS_PER_F16 = 4u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = load_src0_f16_at(block_byte_base + 2u); - let qh_packed = load_src0_u32_at(block_byte_base + 4u); - - for (var j = 0u; j < 2; j++) { - let q_packed = load_src0_u32_at(block_byte_base + 8u + 2u * (block_offset + j * 2u)); - let j_adjusted = j + (block_offset / 2u); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; - let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m); - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; - let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m); - local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_src0_f16_at(block_byte_base)); + let m = f32(load_src0_f16_at(block_byte_base + 2u)); + let qh_packed = load_src0_u32_at(block_byte_base + 4u); + let q_packed = load_src0_u32_at(block_byte_base + 8u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } } } - return local_sum; -} #endif #ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 34 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 34u; -const NQ = 16u; -const WEIGHTS_PER_F16 = 2u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_packed = load_src0_u32_at(block_byte_base + 2u + 2u * (block_offset + j)); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d; - local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_src0_f16_at(block_byte_base)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_src0_u32_at(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; } } } - return local_sum; -} #endif #ifdef MUL_ACC_Q8_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 36 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 36u; -const NQ = 16u; -const WEIGHTS_PER_F16 = 2u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = load_src0_f16_at(block_byte_base + 2u); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_packed = load_src0_u32_at(block_byte_base + 4u + 2u * (block_offset + j)); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d + f32(m); - local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_src0_f16_at(block_byte_base)); + let m = f32(load_src0_f16_at(block_byte_base + 2u)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_src0_u32_at(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; } } } - return local_sum; -} #endif #ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 84 +#define THREADS_PER_BLOCK 16 -const Q2K_BLOCK_SIZE = 256u; -const Q2K_BLOCK_SIZE_BYTES = 84u; // 42 f16s * 2 - -fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - let ix = tig / 8u; - let it = tig % 8u; - let iq = it / 4u; - let ir = it % 4u; - let is = (8u * ir) / 16u; + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - let nb = tile_size / Q2K_BLOCK_SIZE; - let k_block_start = k_outer / Q2K_BLOCK_SIZE; - let y4_offset = 128u * iq + 8u * ir; + let lane = tid / 2u; + let phase = tid % 2u; + let iq = lane / 4u; + let ir = lane % 4u; + let is = ir / 2u; + let y_offset = 128u * iq + 8u * ir + 4u * phase; let sc0_byte = 8u * iq + is; let sc2_byte = 8u * iq + is + 2u; let sc4_byte = 8u * iq + is + 4u; let sc6_byte = 8u * iq + is + 6u; - let qs_byte = 16u + (16u * iq + 4u * ir) * 2u; - - var sumf = 0.0; - - for (var ib = ix; ib < nb; ib += 4u) { - let bbase = (idx_base + k_block_start + ib) * Q2K_BLOCK_SIZE_BYTES; - - let dall = f32(load_src0_f16_at(bbase + 80u)); - let dmin = f32(load_src0_f16_at(bbase + 82u)) * (1.0 / 16.0); - - let sc0 = byte_of(load_src0_u32_at_aligned(bbase + sc0_byte), sc0_byte & 3u); - let sc2 = byte_of(load_src0_u32_at_aligned(bbase + sc2_byte), sc2_byte & 3u); - let sc4 = byte_of(load_src0_u32_at_aligned(bbase + sc4_byte), sc4_byte & 3u); - let sc6 = byte_of(load_src0_u32_at_aligned(bbase + sc6_byte), sc6_byte & 3u); - - let qs_u32_0 = load_src0_u32_at_aligned(bbase + qs_byte); - let qs_u32_1 = load_src0_u32_at_aligned(bbase + qs_byte + 4u); - let qs0 = qs_u32_0 & 0xFFFFu; - let qs1 = qs_u32_0 >> 16u; - let qs2 = qs_u32_1 & 0xFFFFu; - let qs3 = qs_u32_1 >> 16u; - - let y_base = ib * Q2K_BLOCK_SIZE + y4_offset; - - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - var acc1 = vec4(0.0, 0.0, 0.0, 0.0); - var acc2 = vec4(0.0, 0.0, 0.0, 0.0); - - // i=0: j=0,1 - { - let y00 = f32(shared_vector[y_base ]); sumy[0] += y00; - let y01 = f32(shared_vector[y_base + 1u]); sumy[0] += y01; - let y10 = f32(shared_vector[y_base + 32u]); sumy[1] += y10; - let y11 = f32(shared_vector[y_base + 33u]); sumy[1] += y11; - let y20 = f32(shared_vector[y_base + 64u]); sumy[2] += y20; - let y21 = f32(shared_vector[y_base + 65u]); sumy[2] += y21; - let y30 = f32(shared_vector[y_base + 96u]); sumy[3] += y30; - let y31 = f32(shared_vector[y_base + 97u]); sumy[3] += y31; - acc1[0] += y00 * f32(qs0 & 0x0003u); - acc2[0] += y01 * f32(qs0 & 0x0300u); - acc1[1] += y10 * f32(qs0 & 0x000Cu); - acc2[1] += y11 * f32(qs0 & 0x0C00u); - acc1[2] += y20 * f32(qs0 & 0x0030u); - acc2[2] += y21 * f32(qs0 & 0x3000u); - acc1[3] += y30 * f32(qs0 & 0x00C0u); - acc2[3] += y31 * f32(qs0 & 0xC000u); - } - // i=2: j=2,3 - { - let y00 = f32(shared_vector[y_base + 2u]); sumy[0] += y00; - let y01 = f32(shared_vector[y_base + 3u]); sumy[0] += y01; - let y10 = f32(shared_vector[y_base + 34u]); sumy[1] += y10; - let y11 = f32(shared_vector[y_base + 35u]); sumy[1] += y11; - let y20 = f32(shared_vector[y_base + 66u]); sumy[2] += y20; - let y21 = f32(shared_vector[y_base + 67u]); sumy[2] += y21; - let y30 = f32(shared_vector[y_base + 98u]); sumy[3] += y30; - let y31 = f32(shared_vector[y_base + 99u]); sumy[3] += y31; - acc1[0] += y00 * f32(qs1 & 0x0003u); - acc2[0] += y01 * f32(qs1 & 0x0300u); - acc1[1] += y10 * f32(qs1 & 0x000Cu); - acc2[1] += y11 * f32(qs1 & 0x0C00u); - acc1[2] += y20 * f32(qs1 & 0x0030u); - acc2[2] += y21 * f32(qs1 & 0x3000u); - acc1[3] += y30 * f32(qs1 & 0x00C0u); - acc2[3] += y31 * f32(qs1 & 0xC000u); - } - // i=4: j=4,5 - { - let y00 = f32(shared_vector[y_base + 4u]); sumy[0] += y00; - let y01 = f32(shared_vector[y_base + 5u]); sumy[0] += y01; - let y10 = f32(shared_vector[y_base + 36u]); sumy[1] += y10; - let y11 = f32(shared_vector[y_base + 37u]); sumy[1] += y11; - let y20 = f32(shared_vector[y_base + 68u]); sumy[2] += y20; - let y21 = f32(shared_vector[y_base + 69u]); sumy[2] += y21; - let y30 = f32(shared_vector[y_base + 100u]); sumy[3] += y30; - let y31 = f32(shared_vector[y_base + 101u]); sumy[3] += y31; - acc1[0] += y00 * f32(qs2 & 0x0003u); - acc2[0] += y01 * f32(qs2 & 0x0300u); - acc1[1] += y10 * f32(qs2 & 0x000Cu); - acc2[1] += y11 * f32(qs2 & 0x0C00u); - acc1[2] += y20 * f32(qs2 & 0x0030u); - acc2[2] += y21 * f32(qs2 & 0x3000u); - acc1[3] += y30 * f32(qs2 & 0x00C0u); - acc2[3] += y31 * f32(qs2 & 0xC000u); - } - // i=6: j=6,7 - { - let y00 = f32(shared_vector[y_base + 6u]); sumy[0] += y00; - let y01 = f32(shared_vector[y_base + 7u]); sumy[0] += y01; - let y10 = f32(shared_vector[y_base + 38u]); sumy[1] += y10; - let y11 = f32(shared_vector[y_base + 39u]); sumy[1] += y11; - let y20 = f32(shared_vector[y_base + 70u]); sumy[2] += y20; - let y21 = f32(shared_vector[y_base + 71u]); sumy[2] += y21; - let y30 = f32(shared_vector[y_base + 102u]); sumy[3] += y30; - let y31 = f32(shared_vector[y_base + 103u]); sumy[3] += y31; - acc1[0] += y00 * f32(qs3 & 0x0003u); - acc2[0] += y01 * f32(qs3 & 0x0300u); - acc1[1] += y10 * f32(qs3 & 0x000Cu); - acc2[1] += y11 * f32(qs3 & 0x0C00u); - acc1[2] += y20 * f32(qs3 & 0x0030u); - acc2[2] += y21 * f32(qs3 & 0x3000u); - acc1[3] += y30 * f32(qs3 & 0x00C0u); - acc2[3] += y31 * f32(qs3 & 0xC000u); + let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 64u + i]); + x_block[i + 12u] = f32(src1[x_base + 96u + i]); } - sumf += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + - (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + - (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + - (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) - - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + - sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let dall = f32(load_src0_f16_at(block_byte_base + 80u)); + let dmin = f32(load_src0_f16_at(block_byte_base + 82u)) * (1.0 / 16.0); + + let sc0 = byte_of(load_src0_u32_at_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); + let sc2 = byte_of(load_src0_u32_at_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); + let sc4 = byte_of(load_src0_u32_at_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); + let sc6 = byte_of(load_src0_u32_at_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); + + let q_u32 = load_src0_u32_at_aligned(block_byte_base + qs_byte); + let qs0 = q_u32 & 0xFFFFu; + let qs1 = q_u32 >> 16u; + + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + + sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; + sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; + sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; + sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + + acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + + acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } + } } - - return sumf; -} #endif -#ifdef MUL_ACC_Q3_K -const Q3K_BLOCK_SIZE = 256u; -const Q3K_BLOCK_SIZE_BYTES = 110u; // 55 f16s * 2 +#ifdef MUL_ACC_Q3_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 -fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - let tid = tig / 4u; - let ix = tig % 4u; - let ip = tid / 4u; - let il = 2u * ((tid % 4u) / 2u); - let ir = tid % 2u; - let l0 = 8u * ir; + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - let nb = tile_size / Q3K_BLOCK_SIZE; - let k_block_start = k_outer / Q3K_BLOCK_SIZE; + let lane = tid / 2u; + let phase = tid % 2u; + let ip = lane / 4u; + let il = 2u * ((lane % 4u) / 2u); + let ir = lane % 2u; + let l0 = 8u * ir; - let q_byte = 32u + 32u * ip + l0; - let h_byte = l0; - let y_offset = 128u * ip + 32u * il + l0; + let q_byte = 32u + 32u * ip + l0 + 16u * phase; + let h_byte = l0 + 16u * phase; + let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; let s_shift1 = 4u * ip; let s_shift2 = s_shift1 + il; @@ -437,682 +480,385 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } } - var sumf1 = 0.0; - var sumf2 = 0.0; - - for (var i = ix; i < nb; i += 4u) { - let bbase = (idx_base + k_block_start + i) * Q3K_BLOCK_SIZE_BYTES; - - let d_all = f32(load_src0_f16_at(bbase + 108u)); - - // Scale unpacking - let a_base = 96u; - let a_il0_u32 = load_src0_u32_at_aligned(bbase + a_base + il * 2u); - let a_il0 = select(a_il0_u32 & 0xFFFFu, a_il0_u32 >> 16u, (il & 1u) != 0u); - let a_il1_u32 = load_src0_u32_at_aligned(bbase + a_base + (il + 1u) * 2u); - let a_il1 = select(a_il1_u32 & 0xFFFFu, a_il1_u32 >> 16u, ((il + 1u) & 1u) != 0u); - let a_45_u32 = load_src0_u32_at_aligned(bbase + a_base + 8u); - let a_4 = a_45_u32 & 0xFFFFu; - let a_5 = a_45_u32 >> 16u; - - var scales32 = a_4 | (a_5 << 16u); - let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; - scales32 = a_il0 | (a_il1 << 16u); - scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; - - let sc0 = i32(byte_of(scales32, 0u)) - 32; - let sc1 = i32(byte_of(scales32, 1u)) - 32; - let sc2 = i32(byte_of(scales32, 2u)) - 32; - let sc3 = i32(byte_of(scales32, 3u)) - 32; - - let y_base = i * Q3K_BLOCK_SIZE + y_offset; - var yl: array; - for (var l = 0u; l < 8u; l++) { - yl[l + 0] = f32(shared_vector[y_base + l ]); - yl[l + 8] = f32(shared_vector[y_base + l + 16u]); - yl[l + 16] = f32(shared_vector[y_base + l + 32u]); - yl[l + 24] = f32(shared_vector[y_base + l + 48u]); - } + let num_blocks = params.k / BLOCK_SIZE; - // First qs/h loop: q[0..3], h[0..3] - var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; - var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; - - for (var l = 0u; l < 8u; l += 2u) { - let q_u32 = load_src0_u32_at_aligned(bbase + q_byte + (l & ~2u)); - let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); - let h_u32 = load_src0_u32_at_aligned(bbase + h_byte + (l & ~2u)); - let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); - - s1 += yl[l + 0u] * f32(qs & qm0); - s2 += yl[l + 1u] * f32(qs & qm1); - s3 += select(0.0, yl[l + 0u], (hv & hm0) == 0u) + - select(0.0, yl[l + 1u], (hv & hm1) == 0u); - s4 += yl[l + 16u] * f32(qs & qm2); - s5 += yl[l + 17u] * f32(qs & qm3); - s6 += select(0.0, yl[l + 16u], (hv & hm2) == 0u) + - select(0.0, yl[l + 17u], (hv & hm3) == 0u); + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 8u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 8u] = f32(src1[x_base + 32u + i]); } - let d1 = d_all * (s1 + (1.0/256.0)*s2 - s3*v1); - let d2 = d_all * (s4 + (1.0/256.0)*s5 - s6*v2); - sumf1 += d1 * f32(sc0); - sumf2 += d2 * f32(sc2); - - // Second qs/h loop: q[8..11], h[8..11] (16 bytes further) - s1 = 0.0; s2 = 0.0; s3 = 0.0; - s4 = 0.0; s5 = 0.0; s6 = 0.0; - - for (var l = 0u; l < 8u; l += 2u) { - let q_u32 = load_src0_u32_at_aligned(bbase + q_byte + 16u + (l & ~2u)); - let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); - let h_u32 = load_src0_u32_at_aligned(bbase + h_byte + 16u + (l & ~2u)); - let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); - - s1 += yl[l + 8u] * f32(qs & qm0); - s2 += yl[l + 9u] * f32(qs & qm1); - s3 += select(0.0, yl[l + 8u], (hv & hm0) == 0u) + - select(0.0, yl[l + 9u], (hv & hm1) == 0u); - s4 += yl[l + 24u] * f32(qs & qm2); - s5 += yl[l + 25u] * f32(qs & qm3); - s6 += select(0.0, yl[l + 24u], (hv & hm2) == 0u) + - select(0.0, yl[l + 25u], (hv & hm3) == 0u); + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_src0_f16_at(block_byte_base + 108u)); + let a_base = 96u; + let a_il0 = load_src0_u16_at(block_byte_base + a_base + il * 2u); + let a_il1 = load_src0_u16_at(block_byte_base + a_base + (il + 1u) * 2u); + let a_4 = load_src0_u16_at(block_byte_base + a_base + 8u); + let a_5 = load_src0_u16_at(block_byte_base + a_base + 10u); + + var scales32 = a_4 | (a_5 << 16u); + let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; + scales32 = a_il0 | (a_il1 << 16u); + scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; + + let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); + let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); + + let q_u32_0 = load_src0_u32_at(block_byte_base + q_byte + 0u); + let q_u32_1 = load_src0_u32_at(block_byte_base + q_byte + 4u); + let h_u32_0 = load_src0_u32_at(block_byte_base + h_byte + 0u); + let h_u32_1 = load_src0_u32_at(block_byte_base + h_byte + 4u); + + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + + s1 += x_block[l + 0u] * f32(qs & qm0); + s2 += x_block[l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[l + 1u], (hv & hm1) == 0u); + s4 += x_block[l + 8u] * f32(qs & qm2); + s5 += x_block[l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); + } } - let d3 = d_all * (s1 + (1.0/256.0)*s2 - s3*v1); - let d4 = d_all * (s4 + (1.0/256.0)*s5 - s6*v2); - sumf1 += d3 * f32(sc1); - sumf2 += d4 * f32(sc3); } - - return (sumf1 + 0.25 * sumf2) / f32(1u << shift); -} #endif #ifdef MUL_ACC_Q4_K - -const Q4K_BLOCK_SIZE = 256u; -const Q4K_BLOCK_SIZE_BYTES = 144u; // 72 f16s * 2 - -fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - let ix = tig / 8u; - let it = tig % 8u; - let iq = it / 4u; - let ir = it % 4u; - - let nb = tile_size / Q4K_BLOCK_SIZE; - let k_block_start = k_outer / Q4K_BLOCK_SIZE; - - let y_offset = 64u * iq + 8u * ir; - - let sc0_byte = 4u + iq * 2u; - let sc2_byte = 4u + (iq + 2u) * 2u; - let sc4_byte = 4u + (iq + 4u) * 2u; - let q1_byte = 16u + (16u * iq + 4u * ir) * 2u; - let q2_byte = q1_byte + 64u; - - var sumf = 0.0; - - for (var ib = ix; ib < nb; ib += 4u) { - let bbase = (idx_base + k_block_start + ib) * Q4K_BLOCK_SIZE_BYTES; - - let d = f32(load_src0_f16_at(bbase + 0u)); - let dmin = f32(load_src0_f16_at(bbase + 2u)); - - let sc0_u32 = load_src0_u32_at_aligned(bbase + sc0_byte); - let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); - let sc2_u32 = load_src0_u32_at_aligned(bbase + sc2_byte); - let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); - let sc4_u32 = load_src0_u32_at_aligned(bbase + sc4_byte); - let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); - - let sc16_0 = sc0 & 0x3F3Fu; - let sc16_1 = sc2 & 0x3F3Fu; - let sc16_2 = ((sc4 ) & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); - let sc16_3 = ((sc4 >> 4u ) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); - - let sc8_0 = sc16_0 & 0xFFu; - let sc8_1 = (sc16_0 >> 8u) & 0xFFu; - let sc8_2 = sc16_1 & 0xFFu; - let sc8_3 = (sc16_1 >> 8u) & 0xFFu; - let sc8_4 = sc16_2 & 0xFFu; - let sc8_5 = (sc16_2 >> 8u) & 0xFFu; - let sc8_6 = sc16_3 & 0xFFu; - let sc8_7 = (sc16_3 >> 8u) & 0xFFu; - - let q1_u32_0 = load_src0_u32_at_aligned(bbase + q1_byte); - let q1_u32_1 = load_src0_u32_at_aligned(bbase + q1_byte + 4u); - let q2_u32_0 = load_src0_u32_at_aligned(bbase + q2_byte); - let q2_u32_1 = load_src0_u32_at_aligned(bbase + q2_byte + 4u); - - let q1_0 = q1_u32_0 & 0xFFFFu; - let q1_1 = q1_u32_0 >> 16u; - let q1_2 = q1_u32_1 & 0xFFFFu; - let q1_3 = q1_u32_1 >> 16u; - let q2_0 = q2_u32_0 & 0xFFFFu; - let q2_1 = q2_u32_0 >> 16u; - let q2_2 = q2_u32_1 & 0xFFFFu; - let q2_3 = q2_u32_1 >> 16u; - - let y_base = ib * Q4K_BLOCK_SIZE + y_offset; - - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - var acc1 = vec4(0.0, 0.0, 0.0, 0.0); - var acc2 = vec4(0.0, 0.0, 0.0, 0.0); - - // i=0: yl[0,1,8,9], yh[0,1,8,9] - { - let yl0 = f32(shared_vector[y_base + 0u]); sumy[0] += yl0; - let yl1 = f32(shared_vector[y_base + 1u]); sumy[0] += yl1; - let yl8 = f32(shared_vector[y_base + 32u]); sumy[1] += yl8; - let yl9 = f32(shared_vector[y_base + 33u]); sumy[1] += yl9; - let yh0 = f32(shared_vector[y_base + 128u]); sumy[2] += yh0; - let yh1 = f32(shared_vector[y_base + 129u]); sumy[2] += yh1; - let yh8 = f32(shared_vector[y_base + 160u]); sumy[3] += yh8; - let yh9 = f32(shared_vector[y_base + 161u]); sumy[3] += yh9; - acc1[0] += yl0 * f32(q1_0 & 0x000Fu); - acc1[1] += yl1 * f32(q1_0 & 0x0F00u); - acc1[2] += yl8 * f32(q1_0 & 0x00F0u); - acc1[3] += yl9 * f32(q1_0 & 0xF000u); - acc2[0] += yh0 * f32(q2_0 & 0x000Fu); - acc2[1] += yh1 * f32(q2_0 & 0x0F00u); - acc2[2] += yh8 * f32(q2_0 & 0x00F0u); - acc2[3] += yh9 * f32(q2_0 & 0xF000u); - } - // i=1: yl[2,3,10,11], yh[2,3,10,11] - { - let yl0 = f32(shared_vector[y_base + 2u]); sumy[0] += yl0; - let yl1 = f32(shared_vector[y_base + 3u]); sumy[0] += yl1; - let yl8 = f32(shared_vector[y_base + 34u]); sumy[1] += yl8; - let yl9 = f32(shared_vector[y_base + 35u]); sumy[1] += yl9; - let yh0 = f32(shared_vector[y_base + 130u]); sumy[2] += yh0; - let yh1 = f32(shared_vector[y_base + 131u]); sumy[2] += yh1; - let yh8 = f32(shared_vector[y_base + 162u]); sumy[3] += yh8; - let yh9 = f32(shared_vector[y_base + 163u]); sumy[3] += yh9; - acc1[0] += yl0 * f32(q1_1 & 0x000Fu); - acc1[1] += yl1 * f32(q1_1 & 0x0F00u); - acc1[2] += yl8 * f32(q1_1 & 0x00F0u); - acc1[3] += yl9 * f32(q1_1 & 0xF000u); - acc2[0] += yh0 * f32(q2_1 & 0x000Fu); - acc2[1] += yh1 * f32(q2_1 & 0x0F00u); - acc2[2] += yh8 * f32(q2_1 & 0x00F0u); - acc2[3] += yh9 * f32(q2_1 & 0xF000u); - } - // i=2: yl[4,5,12,13], yh[4,5,12,13] - { - let yl0 = f32(shared_vector[y_base + 4u]); sumy[0] += yl0; - let yl1 = f32(shared_vector[y_base + 5u]); sumy[0] += yl1; - let yl8 = f32(shared_vector[y_base + 36u]); sumy[1] += yl8; - let yl9 = f32(shared_vector[y_base + 37u]); sumy[1] += yl9; - let yh0 = f32(shared_vector[y_base + 132u]); sumy[2] += yh0; - let yh1 = f32(shared_vector[y_base + 133u]); sumy[2] += yh1; - let yh8 = f32(shared_vector[y_base + 164u]); sumy[3] += yh8; - let yh9 = f32(shared_vector[y_base + 165u]); sumy[3] += yh9; - acc1[0] += yl0 * f32(q1_2 & 0x000Fu); - acc1[1] += yl1 * f32(q1_2 & 0x0F00u); - acc1[2] += yl8 * f32(q1_2 & 0x00F0u); - acc1[3] += yl9 * f32(q1_2 & 0xF000u); - acc2[0] += yh0 * f32(q2_2 & 0x000Fu); - acc2[1] += yh1 * f32(q2_2 & 0x0F00u); - acc2[2] += yh8 * f32(q2_2 & 0x00F0u); - acc2[3] += yh9 * f32(q2_2 & 0xF000u); - } - // i=3: yl[6,7,14,15], yh[6,7,14,15] - { - let yl0 = f32(shared_vector[y_base + 6u]); sumy[0] += yl0; - let yl1 = f32(shared_vector[y_base + 7u]); sumy[0] += yl1; - let yl8 = f32(shared_vector[y_base + 38u]); sumy[1] += yl8; - let yl9 = f32(shared_vector[y_base + 39u]); sumy[1] += yl9; - let yh0 = f32(shared_vector[y_base + 134u]); sumy[2] += yh0; - let yh1 = f32(shared_vector[y_base + 135u]); sumy[2] += yh1; - let yh8 = f32(shared_vector[y_base + 166u]); sumy[3] += yh8; - let yh9 = f32(shared_vector[y_base + 167u]); sumy[3] += yh9; - acc1[0] += yl0 * f32(q1_3 & 0x000Fu); - acc1[1] += yl1 * f32(q1_3 & 0x0F00u); - acc1[2] += yl8 * f32(q1_3 & 0x00F0u); - acc1[3] += yl9 * f32(q1_3 & 0xF000u); - acc2[0] += yh0 * f32(q2_3 & 0x000Fu); - acc2[1] += yh1 * f32(q2_3 & 0x0F00u); - acc2[2] += yh8 * f32(q2_3 & 0x00F0u); - acc2[3] += yh9 * f32(q2_3 & 0xF000u); +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 144 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 32u * im + l0; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); } - sumf += d * ((acc1[0] + (1.0/256.0)*acc1[1]) * f32(sc8_0) + - (acc1[2] + (1.0/256.0)*acc1[3]) * f32(sc8_1) * (1.0/16.0) + - (acc2[0] + (1.0/256.0)*acc2[1]) * f32(sc8_4) + - (acc2[2] + (1.0/256.0)*acc2[3]) * f32(sc8_5) * (1.0/16.0)) - - dmin * (sumy[0] * f32(sc8_2) + sumy[1] * f32(sc8_3) + - sumy[2] * f32(sc8_6) + sumy[3] * f32(sc8_7)); + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_src0_f16_at(block_byte_base + 0u)); + let dmin = f32(load_src0_f16_at(block_byte_base + 2u)); + + let sc0_u32 = load_src0_u32_at_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_src0_u32_at_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_src0_u32_at_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let scale0 = f32(sc16_0 & 0xFFu); + let scale1 = f32((sc16_0 >> 8u) & 0xFFu); + let min0 = f32(sc16_1 & 0xFFu); + let min1 = f32((sc16_1 >> 8u) & 0xFFu); + let scale2 = f32(sc16_2 & 0xFFu); + let scale3 = f32((sc16_2 >> 8u) & 0xFFu); + let min2 = f32(sc16_3 & 0xFFu); + let min3 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_src0_u32_at_aligned(block_byte_base + 16u + q_offset); + let q2_u32 = load_src0_u32_at_aligned(block_byte_base + 80u + q_offset); + + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[i] * f32(q1b & 0x0Fu); + dot[1] += x_block[i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[i]; + sumx[1] += x_block[i + 4u]; + sumx[2] += x_block[i + 8u]; + sumx[3] += x_block[i + 12u]; + } + + acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } + } } - - return sumf; -} #endif #ifdef MUL_ACC_Q5_K - -const Q5K_BLOCK_SIZE = 256u; -const Q5K_BLOCK_SIZE_BYTES = 176u; // 88 f16s * 2 - -fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - let tid = tig / 4u; - let ix = tig % 4u; - let iq = tid / 4u; - let ir = tid % 4u; - let l0 = 8u * ir; - - let nb = tile_size / Q5K_BLOCK_SIZE; - let k_block_start = k_outer / Q5K_BLOCK_SIZE; - - let y_offset = 64u * iq + l0; - let q1_byte = 48u + 32u * iq + l0; - let q2_byte = q1_byte + 64u; - let qh_byte = 16u + l0; - let sc0_byte = 4u + iq * 2u; - let sc2_byte = 4u + (iq + 2u) * 2u; - let sc4_byte = 4u + (iq + 4u) * 2u; - - let hm1 = 1u << (2u * iq); +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 176 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 48u + 32u * im + l0; + let qh_offset = 16u + 8u * ir + 4u * in; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let hm1 = 1u << (2u * im); let hm2 = hm1 << 1u; let hm3 = hm1 << 4u; let hm4 = hm2 << 4u; - var sumf = 0.0; - - for (var ib = ix; ib < nb; ib += 4u) { - let bbase = (idx_base + k_block_start + ib) * Q5K_BLOCK_SIZE_BYTES; - - let d = f32(load_src0_f16_at(bbase + 0u)); - let dmin = f32(load_src0_f16_at(bbase + 2u)); - - let sc0_u32 = load_src0_u32_at_aligned(bbase + sc0_byte); - let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); - let sc2_u32 = load_src0_u32_at_aligned(bbase + sc2_byte); - let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); - let sc4_u32 = load_src0_u32_at_aligned(bbase + sc4_byte); - let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); - - let sc16_0 = sc0 & 0x3F3Fu; - let sc16_1 = sc2 & 0x3F3Fu; - let sc16_2 = ((sc4 ) & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); - let sc16_3 = ((sc4 >> 4u ) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); - - let sc8_0 = sc16_0 & 0xFFu; - let sc8_1 = (sc16_0 >> 8u) & 0xFFu; - let sc8_2 = sc16_1 & 0xFFu; - let sc8_3 = (sc16_1 >> 8u) & 0xFFu; - let sc8_4 = sc16_2 & 0xFFu; - let sc8_5 = (sc16_2 >> 8u) & 0xFFu; - let sc8_6 = sc16_3 & 0xFFu; - let sc8_7 = (sc16_3 >> 8u) & 0xFFu; - - let f0 = f32(sc8_0); - let f1_lo = f32(sc8_1) * (1.0/16.0); - let f1_hi = f32(sc8_1) * 16.0; - let f4 = f32(sc8_4); - let f5_lo = f32(sc8_5) * (1.0/16.0); - let f5_hi = f32(sc8_5) * 16.0; - - let q1_u32_0 = load_src0_u32_at_aligned(bbase + q1_byte); - let q1_u32_1 = load_src0_u32_at_aligned(bbase + q1_byte + 4u); - let q2_u32_0 = load_src0_u32_at_aligned(bbase + q2_byte); - let q2_u32_1 = load_src0_u32_at_aligned(bbase + q2_byte + 4u); - let qh_u32_0 = load_src0_u32_at_aligned(bbase + qh_byte); - let qh_u32_1 = load_src0_u32_at_aligned(bbase + qh_byte + 4u); - - let y_base = ib * Q5K_BLOCK_SIZE + y_offset; - - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - var acc1 = vec4(0.0, 0.0, 0.0, 0.0); - var acc2 = vec4(0.0, 0.0, 0.0, 0.0); - - // l=0 - { - let q1b = byte_of(q1_u32_0, 0u); - let q2b = byte_of(q2_u32_0, 0u); - let qhb = byte_of(qh_u32_0, 0u); - let yl0 = f32(shared_vector[y_base + 0u]); sumy[0] += yl0; - let yl8 = f32(shared_vector[y_base + 32u]); sumy[1] += yl8; - let yh0 = f32(shared_vector[y_base +128u]); sumy[2] += yh0; - let yh8 = f32(shared_vector[y_base +160u]); sumy[3] += yh8; - acc1[0] += yl0 * f32(q1b & 0x0Fu); - acc1[1] += yl8 * f32(q1b & 0xF0u); - acc1[2] += yh0 * f32(q2b & 0x0Fu); - acc1[3] += yh8 * f32(q2b & 0xF0u); - acc2[0] += yl0 * f32((qhb & hm1) != 0u); - acc2[1] += yl8 * f32((qhb & hm2) != 0u); - acc2[2] += yh0 * f32((qhb & hm3) != 0u); - acc2[3] += yh8 * f32((qhb & hm4) != 0u); - } - // l=1 - { - let q1b = byte_of(q1_u32_0, 1u); - let q2b = byte_of(q2_u32_0, 1u); - let qhb = byte_of(qh_u32_0, 1u); - let yl0 = f32(shared_vector[y_base + 1u]); sumy[0] += yl0; - let yl8 = f32(shared_vector[y_base + 33u]); sumy[1] += yl8; - let yh0 = f32(shared_vector[y_base +129u]); sumy[2] += yh0; - let yh8 = f32(shared_vector[y_base +161u]); sumy[3] += yh8; - acc1[0] += yl0 * f32(q1b & 0x0Fu); - acc1[1] += yl8 * f32(q1b & 0xF0u); - acc1[2] += yh0 * f32(q2b & 0x0Fu); - acc1[3] += yh8 * f32(q2b & 0xF0u); - acc2[0] += yl0 * f32((qhb & hm1) != 0u); - acc2[1] += yl8 * f32((qhb & hm2) != 0u); - acc2[2] += yh0 * f32((qhb & hm3) != 0u); - acc2[3] += yh8 * f32((qhb & hm4) != 0u); - } - // l=2 - { - let q1b = byte_of(q1_u32_0, 2u); - let q2b = byte_of(q2_u32_0, 2u); - let qhb = byte_of(qh_u32_0, 2u); - let yl0 = f32(shared_vector[y_base + 2u]); sumy[0] += yl0; - let yl8 = f32(shared_vector[y_base + 34u]); sumy[1] += yl8; - let yh0 = f32(shared_vector[y_base +130u]); sumy[2] += yh0; - let yh8 = f32(shared_vector[y_base +162u]); sumy[3] += yh8; - acc1[0] += yl0 * f32(q1b & 0x0Fu); - acc1[1] += yl8 * f32(q1b & 0xF0u); - acc1[2] += yh0 * f32(q2b & 0x0Fu); - acc1[3] += yh8 * f32(q2b & 0xF0u); - acc2[0] += yl0 * f32((qhb & hm1) != 0u); - acc2[1] += yl8 * f32((qhb & hm2) != 0u); - acc2[2] += yh0 * f32((qhb & hm3) != 0u); - acc2[3] += yh8 * f32((qhb & hm4) != 0u); - } - // l=3 - { - let q1b = byte_of(q1_u32_0, 3u); - let q2b = byte_of(q2_u32_0, 3u); - let qhb = byte_of(qh_u32_0, 3u); - let yl0 = f32(shared_vector[y_base + 3u]); sumy[0] += yl0; - let yl8 = f32(shared_vector[y_base + 35u]); sumy[1] += yl8; - let yh0 = f32(shared_vector[y_base +131u]); sumy[2] += yh0; - let yh8 = f32(shared_vector[y_base +163u]); sumy[3] += yh8; - acc1[0] += yl0 * f32(q1b & 0x0Fu); - acc1[1] += yl8 * f32(q1b & 0xF0u); - acc1[2] += yh0 * f32(q2b & 0x0Fu); - acc1[3] += yh8 * f32(q2b & 0xF0u); - acc2[0] += yl0 * f32((qhb & hm1) != 0u); - acc2[1] += yl8 * f32((qhb & hm2) != 0u); - acc2[2] += yh0 * f32((qhb & hm3) != 0u); - acc2[3] += yh8 * f32((qhb & hm4) != 0u); - } - // l=4 - { - let q1b = byte_of(q1_u32_1, 0u); - let q2b = byte_of(q2_u32_1, 0u); - let qhb = byte_of(qh_u32_1, 0u); - let yl0 = f32(shared_vector[y_base + 4u]); sumy[0] += yl0; - let yl8 = f32(shared_vector[y_base + 36u]); sumy[1] += yl8; - let yh0 = f32(shared_vector[y_base +132u]); sumy[2] += yh0; - let yh8 = f32(shared_vector[y_base +164u]); sumy[3] += yh8; - acc1[0] += yl0 * f32(q1b & 0x0Fu); - acc1[1] += yl8 * f32(q1b & 0xF0u); - acc1[2] += yh0 * f32(q2b & 0x0Fu); - acc1[3] += yh8 * f32(q2b & 0xF0u); - acc2[0] += yl0 * f32((qhb & hm1) != 0u); - acc2[1] += yl8 * f32((qhb & hm2) != 0u); - acc2[2] += yh0 * f32((qhb & hm3) != 0u); - acc2[3] += yh8 * f32((qhb & hm4) != 0u); - } - // l=5 - { - let q1b = byte_of(q1_u32_1, 1u); - let q2b = byte_of(q2_u32_1, 1u); - let qhb = byte_of(qh_u32_1, 1u); - let yl0 = f32(shared_vector[y_base + 5u]); sumy[0] += yl0; - let yl8 = f32(shared_vector[y_base + 37u]); sumy[1] += yl8; - let yh0 = f32(shared_vector[y_base +133u]); sumy[2] += yh0; - let yh8 = f32(shared_vector[y_base +165u]); sumy[3] += yh8; - acc1[0] += yl0 * f32(q1b & 0x0Fu); - acc1[1] += yl8 * f32(q1b & 0xF0u); - acc1[2] += yh0 * f32(q2b & 0x0Fu); - acc1[3] += yh8 * f32(q2b & 0xF0u); - acc2[0] += yl0 * f32((qhb & hm1) != 0u); - acc2[1] += yl8 * f32((qhb & hm2) != 0u); - acc2[2] += yh0 * f32((qhb & hm3) != 0u); - acc2[3] += yh8 * f32((qhb & hm4) != 0u); - } - // l=6 - { - let q1b = byte_of(q1_u32_1, 2u); - let q2b = byte_of(q2_u32_1, 2u); - let qhb = byte_of(qh_u32_1, 2u); - let yl0 = f32(shared_vector[y_base + 6u]); sumy[0] += yl0; - let yl8 = f32(shared_vector[y_base + 38u]); sumy[1] += yl8; - let yh0 = f32(shared_vector[y_base +134u]); sumy[2] += yh0; - let yh8 = f32(shared_vector[y_base +166u]); sumy[3] += yh8; - acc1[0] += yl0 * f32(q1b & 0x0Fu); - acc1[1] += yl8 * f32(q1b & 0xF0u); - acc1[2] += yh0 * f32(q2b & 0x0Fu); - acc1[3] += yh8 * f32(q2b & 0xF0u); - acc2[0] += yl0 * f32((qhb & hm1) != 0u); - acc2[1] += yl8 * f32((qhb & hm2) != 0u); - acc2[2] += yh0 * f32((qhb & hm3) != 0u); - acc2[3] += yh8 * f32((qhb & hm4) != 0u); - } - // l=7 - { - let q1b = byte_of(q1_u32_1, 3u); - let q2b = byte_of(q2_u32_1, 3u); - let qhb = byte_of(qh_u32_1, 3u); - let yl0 = f32(shared_vector[y_base + 7u]); sumy[0] += yl0; - let yl8 = f32(shared_vector[y_base + 39u]); sumy[1] += yl8; - let yh0 = f32(shared_vector[y_base +135u]); sumy[2] += yh0; - let yh8 = f32(shared_vector[y_base +167u]); sumy[3] += yh8; - acc1[0] += yl0 * f32(q1b & 0x0Fu); - acc1[1] += yl8 * f32(q1b & 0xF0u); - acc1[2] += yh0 * f32(q2b & 0x0Fu); - acc1[3] += yh8 * f32(q2b & 0xF0u); - acc2[0] += yl0 * f32((qhb & hm1) != 0u); - acc2[1] += yl8 * f32((qhb & hm2) != 0u); - acc2[2] += yh0 * f32((qhb & hm3) != 0u); - acc2[3] += yh8 * f32((qhb & hm4) != 0u); + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); } - sumf += d * (f0 * acc1[0] + f0 * 16.0 * acc2[0] + - f1_lo * acc1[1] + f1_hi * acc2[1] + - f4 * acc1[2] + f4 * 16.0 * acc2[2] + - f5_lo * acc1[3] + f5_hi * acc2[3]) - - dmin * (sumy[0]*f32(sc8_2) + sumy[1]*f32(sc8_3) + - sumy[2]*f32(sc8_6) + sumy[3]*f32(sc8_7)); + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_src0_f16_at(block_byte_base + 0u)); + let dmin = f32(load_src0_f16_at(block_byte_base + 2u)); + + let sc0_u32 = load_src0_u32_at_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_src0_u32_at_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_src0_u32_at_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let f0 = f32(sc16_0 & 0xFFu); + let f1 = f32((sc16_0 >> 8u) & 0xFFu); + let m0 = f32(sc16_1 & 0xFFu); + let m1 = f32((sc16_1 >> 8u) & 0xFFu); + let f4 = f32(sc16_2 & 0xFFu); + let f5 = f32((sc16_2 >> 8u) & 0xFFu); + let m4 = f32(sc16_3 & 0xFFu); + let m5 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_src0_u32_at_aligned(block_byte_base + q_offset); + let q2_u32 = load_src0_u32_at_aligned(block_byte_base + q_offset + 64u); + let qh_u32 = load_src0_u32_at_aligned(block_byte_base + qh_offset); + + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); + + let yl0 = x_block[i]; + let yl8 = x_block[i + 4u]; + let yh0 = x_block[i + 8u]; + let yh8 = x_block[i + 12u]; + + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; + + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); + } + } } - - return sumf; -} #endif #ifdef MUL_ACC_Q6_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 210 +#define THREADS_PER_BLOCK 16 -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 210u; // 105 f16s * 2 + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; -fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - let tid = tig / 2u; - let ix = tig % 2u; - let ip = tid / 8u; - let il = tid % 8u; - let l0 = 4u * il; - let is = 8u * ip + l0 / 16u; + let ip = tid / 8u; + let il = tid % 8u; + let l0 = 4u * il; + let is = 8u * ip + l0 / 16u; - let y_offset = 128u * ip + l0; - let q_offset_l = 64u * ip + l0; - let q_offset_h = 32u * ip + l0; + let y_offset = 128u * ip + l0; + let q_offset_l = 64u * ip + l0; + let q_offset_h = 32u * ip + l0; - let nb = tile_size / BLOCK_SIZE; - let k_block_start = k_outer / BLOCK_SIZE; - - // Aligned scale byte position (is can be odd) + let num_blocks = params.k / BLOCK_SIZE; let sc_base_byte = 192u + (is & ~3u); - let sc_byte_pos = is & 3u; - - var local_sum = 0.0; - - for (var i = ix; i < nb; i += 2u) { - let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; - - let d = f32(load_src0_f16_at(bbase + 208u)); - - let ql1_u32 = load_src0_u32_at(bbase + q_offset_l); - let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u); - let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h); - let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte); - let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u); - - let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); - let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); - let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); - let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); - - var sums = vec4(0.0, 0.0, 0.0, 0.0); + let sc_byte_pos = is & 3u; + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; for (var l = 0u; l < 4u; l++) { - let y_base = i * BLOCK_SIZE + y_offset + l; - let yl0 = f32(shared_vector[y_base]); - let yl1 = f32(shared_vector[y_base + 32u]); - let yl2 = f32(shared_vector[y_base + 64u]); - let yl3 = f32(shared_vector[y_base + 96u]); - - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); - - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - - sums[0] += yl0 * dq0; - sums[1] += yl1 * dq1; - sums[2] += yl2 * dq2; - sums[3] += yl3 * dq3; + x_block[l] = f32(src1[x_base + l]); + x_block[l + 4u] = f32(src1[x_base + 32u + l]); + x_block[l + 8u] = f32(src1[x_base + 64u + l]); + x_block[l + 12u] = f32(src1[x_base + 96u + l]); } - local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_src0_f16_at(block_byte_base + 208u)); + let ql1_u32 = load_src0_u32_at(block_byte_base + q_offset_l); + let ql2_u32 = load_src0_u32_at(block_byte_base + q_offset_l + 32u); + let qh_u32 = load_src0_u32_at(block_byte_base + 128u + q_offset_h); + let sc_u32_0 = load_src0_u32_at(block_byte_base + sc_base_byte); + let sc_u32_1 = load_src0_u32_at(block_byte_base + sc_base_byte + 4u); + + let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); + let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); + let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); + let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + + var sums = vec4(0.0, 0.0, 0.0, 0.0); + + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); + + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + + sums[0] += x_block[l] * dq0; + sums[1] += x_block[l + 4u] * dq1; + sums[2] += x_block[l + 8u] * dq2; + sums[3] += x_block[l + 12u] * dq3; + } + + acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } + } } - - return local_sum; -} #endif -struct MulMatParams { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - m: u32, - n: u32, - k: u32, - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -// SRC0_TYPE and SRC1_TYPE are defined above (VEC/SCALAR for float, U32_DEQUANT_HELPERS for quantized) -@group(0) @binding(0) var src0: array; // M rows, K columns -@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) - -@group(0) @binding(3) var params: MulMatParams; - -const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG; -// const THREADS_PER_OUTPUT = 32u; - -// Shared memory for collaborative loading and reduction -// padded by + 1 to serialize reads (perf improvement for legacy quants?) -var shared_vector: array; // Cache vector tile -var partial_sums: array; // For reduction - -@compute @workgroup_size(WG_SIZE) -fn main( - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3, - @builtin(subgroup_size) subgroup_size: u32) { - let thread_id = local_id.x; - - // Handle batch dimensions - let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; - let batch_idx = wg_linear / output_groups; - if (batch_idx >= total_batches) { - return; +#ifdef USE_SUBGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; + } } - // Which of the outputs does this thread belong to? - let thread_group = thread_id / THREADS_PER_OUTPUT; - let thread_in_group = thread_id % THREADS_PER_OUTPUT; - - // Each workgroup computes OUTPUTS_PER_WG consecutive outputs - let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; - - let dst2_stride = params.m * params.n; - let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); - let src03_idx = dst3_idx / params.broadcast3; - let src13_idx = dst3_idx; - let src02_idx = dst2_idx / params.broadcast2; - let src12_idx = dst2_idx; - - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; - - var local_sum = 0.0; - - // Each thread processes multiple K elements and accumulates - for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { - let tile_size = min(TILE_K, params.k - k_tile); + workgroupBarrier(); - // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) { - shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE]; + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; } - - workgroupBarrier(); - - if (output_row < params.m) { - local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + row] = row_total; } - - workgroupBarrier(); } +#endif - let subgroup_local_id = thread_in_group % subgroup_size; - let subgroup_in_group = thread_in_group / subgroup_size; - let subgroup_total = subgroupAdd(local_sum); - - if (subgroup_local_id == 0u) { - partial_sums[thread_group * THREADS_PER_OUTPUT + subgroup_in_group] = subgroup_total; +#ifdef USE_WORKGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[row]; } + workgroupBarrier(); - if (thread_in_group == 0u && output_row < params.m) { - var row_total = 0.0; - let num_subgroups = THREADS_PER_OUTPUT / subgroup_size; - for (var s = 0u; s < num_subgroups; s++) { - row_total += partial_sums[thread_group * THREADS_PER_OUTPUT + s]; + var stride = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } } - #ifdef SCALAR - dst[dst_idx] = row_total; - #endif - #ifdef VEC - partial_sums[thread_group * THREADS_PER_OUTPUT] = row_total; - #endif + + workgroupBarrier(); + stride = stride / 2; } - #ifdef VEC - workgroupBarrier(); - if (output_row < params.m && thread_group % VEC_SIZE == 0u && thread_in_group == 0u) { - dst[dst_idx / VEC_SIZE] = store_val(thread_group * THREADS_PER_OUTPUT); + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } } - #endif -} \ No newline at end of file +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl deleted file mode 100644 index dd03435d357..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_row_tiled.wgsl +++ /dev/null @@ -1,864 +0,0 @@ -#ifdef USE_SUBGROUP_REDUCTION -enable subgroups; -#endif -enable f16; - -#include "common_decls.tmpl" - -#ifdef U32_DEQUANT_HELPERS -#define SRC0_TYPE u32 - -fn byte_of(v: u32, b: u32) -> u32 { - return (v >> (b * 8u)) & 0xFFu; -} - -fn sbyte_of(v: u32, b: u32) -> i32 { - let raw = i32((v >> (b * 8u)) & 0xFFu); - return select(raw, raw - 256, raw >= 128); -} -#endif - -#ifdef VEC -#define VEC_SIZE 4u -#define SRC0_TYPE vec4 -#define SRC1_TYPE vec4 - -fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { - return f32(dot(SRC1_TYPE(src0_val), src1_val)); -} -#endif - -#ifdef SCALAR -#define VEC_SIZE 1u -#define SRC0_TYPE SRC0_INNER_TYPE -#define SRC1_TYPE SRC1_INNER_TYPE - -fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { - return f32(src0_val) * f32(src1_val); -} -#endif - -struct MulMatParams { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - m: u32, - n: u32, - k: u32, - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var src0: array; -@group(0) @binding(1) var src1: array; -@group(0) @binding(2) var dst: array; - -@group(0) @binding(3) var params: MulMatParams; - -// Flattened as [row][thread] to keep each row's reduction contiguous in memory. -var partial_sums: array; - -fn partial_index(row: u32, thread: u32) -> u32 { - return row * WG_SIZE + thread; -} - -@compute @workgroup_size(WG_SIZE) -fn main( - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3 -#ifdef USE_SUBGROUP_REDUCTION - , @builtin(subgroup_id) subgroup_id: u32, - @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, - @builtin(num_subgroups) num_subgroups: u32, - @builtin(subgroup_size) subgroup_size: u32 -#endif -) { - let thread_id = local_id.x; - - let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; - let batch_idx = wg_linear / output_groups; - if (batch_idx >= total_batches) { - return; - } - - let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; - - let dst2_stride = params.m * params.n; - let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); - let src03_idx = dst3_idx / params.broadcast3; - let src13_idx = dst3_idx; - let src02_idx = dst2_idx / params.broadcast2; - let src12_idx = dst2_idx; - - let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; - - var acc: array; - -#ifdef MUL_ACC_FLOAT - let k_vec = params.k / VEC_SIZE; - let src1_idx_base_vec = src1_idx_base / VEC_SIZE; - - // Each thread walks K, loads from the vector, and updates - // a small block of output rows held in registers. - for (var k = thread_id; k < k_vec; k += WG_SIZE) { - let x = src1[src1_idx_base_vec + k]; - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; - acc[row] += inner_dot(src0[src0_idx], x); - } - } - } -#endif - -#ifdef MUL_ACC_Q4_0 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 18 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % 4; - for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_src0_f16_at(block_byte_base)); - var row_sum = 0.0; - - let q_packed = load_src0_u32_at(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; - let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q4_1 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 20 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = f32(load_src0_f16_at(block_byte_base + 2u)); - var row_sum = 0.0; - - let q_packed = load_src0_u32_at(block_byte_base + 4u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(q_byte & 0xFu) * d + m; - let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q5_0 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 22 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_src0_f16_at(block_byte_base)); - let qh_packed = load_src0_u32_at(block_byte_base + 2u); - let q_packed = load_src0_u32_at(block_byte_base + 6u + 4u * thread_within_block); - let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; - let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q5_1 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 24 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = f32(load_src0_f16_at(block_byte_base + 2u)); - let qh_packed = load_src0_u32_at(block_byte_base + 4u); - let q_packed = load_src0_u32_at(block_byte_base + 8u + 4u * thread_within_block); - let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; - let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q8_0 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 34 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_src0_f16_at(block_byte_base)); - var row_sum = 0.0; - - for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_src0_u32_at(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q8_1 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 36 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = f32(load_src0_f16_at(block_byte_base + 2u)); - var row_sum = 0.0; - - for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_src0_u32_at(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q2_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 84 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let lane = tid / 2u; - let phase = tid % 2u; - let iq = lane / 4u; - let ir = lane % 4u; - let is = ir / 2u; - - let y_offset = 128u * iq + 8u * ir + 4u * phase; - let sc0_byte = 8u * iq + is; - let sc2_byte = 8u * iq + is + 2u; - let sc4_byte = 8u * iq + is + 4u; - let sc6_byte = 8u * iq + is + 6u; - let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 64u + i]); - x_block[i + 12u] = f32(src1[x_base + 96u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let dall = f32(load_src0_f16_at(block_byte_base + 80u)); - let dmin = f32(load_src0_f16_at(block_byte_base + 82u)) * (1.0 / 16.0); - - let sc0 = byte_of(load_src0_u32_at_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); - let sc2 = byte_of(load_src0_u32_at_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); - let sc4 = byte_of(load_src0_u32_at_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); - let sc6 = byte_of(load_src0_u32_at_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); - - let q_u32 = load_src0_u32_at_aligned(block_byte_base + qs_byte); - let qs0 = q_u32 & 0xFFFFu; - let qs1 = q_u32 >> 16u; - - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - var acc1 = vec4(0.0, 0.0, 0.0, 0.0); - var acc2 = vec4(0.0, 0.0, 0.0, 0.0); - - sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; - sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; - sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; - sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; - - acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); - acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); - acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); - acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); - acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); - acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); - acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); - acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); - - acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + - (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + - (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + - (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) - - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + - sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); - } - } - } -#endif - - -#ifdef MUL_ACC_Q3_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 110 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let lane = tid / 2u; - let phase = tid % 2u; - let ip = lane / 4u; - let il = 2u * ((lane % 4u) / 2u); - let ir = lane % 2u; - let l0 = 8u * ir; - - let q_byte = 32u + 32u * ip + l0 + 16u * phase; - let h_byte = l0 + 16u * phase; - let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; - - let s_shift1 = 4u * ip; - let s_shift2 = s_shift1 + il; - - let v1 = select(64.0, 4.0, il == 0u); - let v2 = 4.0 * v1; - let shift = 2u * il; - - var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; - if (il == 0u) { - qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; - } else { - qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; - } - - let mm_idx = 2u * ip + il / 2u; - var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; - switch (mm_idx) { - case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } - case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } - case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } - default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } - } - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 8u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 8u] = f32(src1[x_base + 32u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_src0_f16_at(block_byte_base + 108u)); - let a_base = 96u; - let a_il0 = load_src0_u16_at(block_byte_base + a_base + il * 2u); - let a_il1 = load_src0_u16_at(block_byte_base + a_base + (il + 1u) * 2u); - let a_4 = load_src0_u16_at(block_byte_base + a_base + 8u); - let a_5 = load_src0_u16_at(block_byte_base + a_base + 10u); - - var scales32 = a_4 | (a_5 << 16u); - let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; - scales32 = a_il0 | (a_il1 << 16u); - scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; - - let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); - let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); - - let q_u32_0 = load_src0_u32_at(block_byte_base + q_byte + 0u); - let q_u32_1 = load_src0_u32_at(block_byte_base + q_byte + 4u); - let h_u32_0 = load_src0_u32_at(block_byte_base + h_byte + 0u); - let h_u32_1 = load_src0_u32_at(block_byte_base + h_byte + 4u); - - var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; - var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; - - for (var l = 0u; l < 8u; l += 2u) { - let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); - let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); - let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); - let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); - - s1 += x_block[l + 0u] * f32(qs & qm0); - s2 += x_block[l + 1u] * f32(qs & qm1); - s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + - select(0.0, x_block[l + 1u], (hv & hm1) == 0u); - s4 += x_block[l + 8u] * f32(qs & qm2); - s5 += x_block[l + 9u] * f32(qs & qm3); - s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + - select(0.0, x_block[l + 9u], (hv & hm3) == 0u); - } - - let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); - let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); - acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); - } - } - } -#endif - -#ifdef MUL_ACC_Q4_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 144 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let il = tid / 4u; - let ir = tid % 4u; - let im = il / 2u; - let in = il % 2u; - let l0 = 4u * (2u * ir + in); - - let y_offset = 64u * im + l0; - let q_offset = 32u * im + l0; - let sc0_byte = 4u + im * 2u; - let sc2_byte = 4u + (im + 2u) * 2u; - let sc4_byte = 4u + (im + 4u) * 2u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_src0_f16_at(block_byte_base + 0u)); - let dmin = f32(load_src0_f16_at(block_byte_base + 2u)); - - let sc0_u32 = load_src0_u32_at_aligned(block_byte_base + sc0_byte); - let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); - let sc2_u32 = load_src0_u32_at_aligned(block_byte_base + sc2_byte); - let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); - let sc4_u32 = load_src0_u32_at_aligned(block_byte_base + sc4_byte); - let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); - - let sc16_0 = sc0 & 0x3F3Fu; - let sc16_1 = sc2 & 0x3F3Fu; - let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); - let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); - - let scale0 = f32(sc16_0 & 0xFFu); - let scale1 = f32((sc16_0 >> 8u) & 0xFFu); - let min0 = f32(sc16_1 & 0xFFu); - let min1 = f32((sc16_1 >> 8u) & 0xFFu); - let scale2 = f32(sc16_2 & 0xFFu); - let scale3 = f32((sc16_2 >> 8u) & 0xFFu); - let min2 = f32(sc16_3 & 0xFFu); - let min3 = f32((sc16_3 >> 8u) & 0xFFu); - - let q1_u32 = load_src0_u32_at_aligned(block_byte_base + 16u + q_offset); - let q2_u32 = load_src0_u32_at_aligned(block_byte_base + 80u + q_offset); - - var dot = vec4(0.0, 0.0, 0.0, 0.0); - var sumx = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - dot[0] += x_block[i] * f32(q1b & 0x0Fu); - dot[1] += x_block[i + 4u] * f32(q1b >> 4u); - dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); - dot[3] += x_block[i + 12u] * f32(q2b >> 4u); - sumx[0] += x_block[i]; - sumx[1] += x_block[i + 4u]; - sumx[2] += x_block[i + 8u]; - sumx[3] += x_block[i + 12u]; - } - - acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) - - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); - } - } - } -#endif - -#ifdef MUL_ACC_Q5_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 176 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let il = tid / 4u; - let ir = tid % 4u; - let im = il / 2u; - let in = il % 2u; - let l0 = 4u * (2u * ir + in); - - let y_offset = 64u * im + l0; - let q_offset = 48u + 32u * im + l0; - let qh_offset = 16u + 8u * ir + 4u * in; - let sc0_byte = 4u + im * 2u; - let sc2_byte = 4u + (im + 2u) * 2u; - let sc4_byte = 4u + (im + 4u) * 2u; - - let hm1 = 1u << (2u * im); - let hm2 = hm1 << 1u; - let hm3 = hm1 << 4u; - let hm4 = hm2 << 4u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_src0_f16_at(block_byte_base + 0u)); - let dmin = f32(load_src0_f16_at(block_byte_base + 2u)); - - let sc0_u32 = load_src0_u32_at_aligned(block_byte_base + sc0_byte); - let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); - let sc2_u32 = load_src0_u32_at_aligned(block_byte_base + sc2_byte); - let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); - let sc4_u32 = load_src0_u32_at_aligned(block_byte_base + sc4_byte); - let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); - - let sc16_0 = sc0 & 0x3F3Fu; - let sc16_1 = sc2 & 0x3F3Fu; - let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); - let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); - - let f0 = f32(sc16_0 & 0xFFu); - let f1 = f32((sc16_0 >> 8u) & 0xFFu); - let m0 = f32(sc16_1 & 0xFFu); - let m1 = f32((sc16_1 >> 8u) & 0xFFu); - let f4 = f32(sc16_2 & 0xFFu); - let f5 = f32((sc16_2 >> 8u) & 0xFFu); - let m4 = f32(sc16_3 & 0xFFu); - let m5 = f32((sc16_3 >> 8u) & 0xFFu); - - let q1_u32 = load_src0_u32_at_aligned(block_byte_base + q_offset); - let q2_u32 = load_src0_u32_at_aligned(block_byte_base + q_offset + 64u); - let qh_u32 = load_src0_u32_at_aligned(block_byte_base + qh_offset); - - var vals = vec4(0.0, 0.0, 0.0, 0.0); - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - let qhb = byte_of(qh_u32, i); - - let yl0 = x_block[i]; - let yl8 = x_block[i + 4u]; - let yh0 = x_block[i + 8u]; - let yh8 = x_block[i + 12u]; - - sumy[0] += yl0; - sumy[1] += yl8; - sumy[2] += yh0; - sumy[3] += yh8; - - let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); - let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); - let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); - let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); - - vals[0] += yl0 * q0; - vals[1] += yl8 * q1; - vals[2] += yh0 * q2; - vals[3] += yh8 * q3; - } - - acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) - - dmin * (sumy[0] * m0 + sumy[1] * m1 + - sumy[2] * m4 + sumy[3] * m5); - } - } - } -#endif - -#ifdef MUL_ACC_Q6_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 210 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let ip = tid / 8u; - let il = tid % 8u; - let l0 = 4u * il; - let is = 8u * ip + l0 / 16u; - - let y_offset = 128u * ip + l0; - let q_offset_l = 64u * ip + l0; - let q_offset_h = 32u * ip + l0; - - let num_blocks = params.k / BLOCK_SIZE; - let sc_base_byte = 192u + (is & ~3u); - let sc_byte_pos = is & 3u; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var l = 0u; l < 4u; l++) { - x_block[l] = f32(src1[x_base + l]); - x_block[l + 4u] = f32(src1[x_base + 32u + l]); - x_block[l + 8u] = f32(src1[x_base + 64u + l]); - x_block[l + 12u] = f32(src1[x_base + 96u + l]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_src0_f16_at(block_byte_base + 208u)); - let ql1_u32 = load_src0_u32_at(block_byte_base + q_offset_l); - let ql2_u32 = load_src0_u32_at(block_byte_base + q_offset_l + 32u); - let qh_u32 = load_src0_u32_at(block_byte_base + 128u + q_offset_h); - let sc_u32_0 = load_src0_u32_at(block_byte_base + sc_base_byte); - let sc_u32_1 = load_src0_u32_at(block_byte_base + sc_base_byte + 4u); - - let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); - let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); - let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); - let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); - - var sums = vec4(0.0, 0.0, 0.0, 0.0); - - for (var l = 0u; l < 4u; l++) { - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); - - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - - sums[0] += x_block[l] * dq0; - sums[1] += x_block[l + 4u] * dq1; - sums[2] += x_block[l + 8u] * dq2; - sums[3] += x_block[l + 12u] * dq3; - } - - acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); - } - } - } -#endif - -#ifdef USE_SUBGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let subgroup_total = subgroupAdd(acc[row]); - if (subgroup_invocation_id == 0u) { - partial_sums[partial_index(row, subgroup_id)] = subgroup_total; - } - } - - workgroupBarrier(); - - for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { - let output_row = row_base + row; - var row_acc = 0.0f; - for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { - row_acc += partial_sums[partial_index(row, k)]; - } - let row_total = subgroupAdd(row_acc); - if (subgroup_invocation_id == 0) { - dst[dst_idx_base + row] = row_total; - } - } -#endif - -#ifdef USE_WORKGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] = acc[row]; - } - - workgroupBarrier(); - - var stride = WG_SIZE / 2u; - - while (stride > 0) { - if (thread_id < stride) { - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; - } - } - - workgroupBarrier(); - stride = stride / 2; - } - - if (thread_id < OUTPUTS_PER_WG) { - let output_row = row_base + thread_id; - if (output_row < params.m) { - dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; - } - } -#endif -} From ca49e73ad67ac5138fd70ceb106d93299a6a7ead Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 17 Apr 2026 10:43:43 -0700 Subject: [PATCH 8/9] Remove old constants, format --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index e00273b54c8..9d88f98050e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -44,18 +44,9 @@ // Matrix-vector multiplication parameters #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 -// Must be multiple of 4 to work with vectorized paths, and must divide -// mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4 -#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 1024 - +#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4 #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4 -#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 1024 - -// Requires 32 threads per output (wg_size/outputs_per_wg == 32) -#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4 -// Requires at least two (and multiple of 2) k-quant blocks per tile -#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 2048 +#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4 // default size for legacy matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 From b92011ef9dc8bc441fc817a6d49601bac1f0fc80 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sun, 19 Apr 2026 15:41:42 -0700 Subject: [PATCH 9/9] remove accidental file --- src/ggml-webgpu.cpp | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/ggml-webgpu.cpp diff --git a/src/ggml-webgpu.cpp b/src/ggml-webgpu.cpp deleted file mode 100644 index e69de29bb2d..00000000000