Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 13 additions & 21 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,9 @@
// Matrix-vector multiplication parameters
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256

// Must be multiple of 4 to work with vectorized paths, and must divide
// mul_mat_vec wg size
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256

#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256

// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
// Requires at least two (and multiple of 2) k-quant blocks per tile
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4

// default size for legacy matrix multiplication
#define WEBGPU_MUL_MAT_WG_SIZE 256
Expand All @@ -78,6 +69,7 @@ struct ggml_webgpu_shader_lib_context {
bool inplace = false;
bool overlap = false;
bool src_overlap = false;
bool supports_subgroups = false;
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m = 0;
uint32_t sg_mat_n = 0;
Expand Down Expand Up @@ -575,7 +567,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {

struct ggml_webgpu_mul_mat_vec_shader_decisions {
uint32_t wg_size;
uint32_t tile_k;
uint32_t outputs_per_wg;
uint32_t vec_size;
};
Expand Down Expand Up @@ -1326,7 +1317,7 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
Expand All @@ -1337,7 +1328,8 @@ class ggml_webgpu_shader_lib {
}

std::vector<std::string> defines;
std::string variant = "mul_mat_vec";
std::string variant = "mul_mat_vec";
const char * shader_src = wgsl_mul_mat_vec;

// src0 type (matrix row)
switch (context.src0->type) {
Expand Down Expand Up @@ -1386,25 +1378,25 @@ class ggml_webgpu_shader_lib {
defines.push_back(key.vectorized ? "VEC" : "SCALAR");

uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;

if (key.src0_type >= GGML_TYPE_Q2_K) {
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
}

defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
if (key.vectorized) {
variant += "_vectorized";
}

auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines);
auto processed = preprocessor.preprocess(shader_src, defines);
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
decisions->wg_size = wg_size;
decisions->tile_k = tile_k;
decisions->outputs_per_wg = outputs_per_wg;
decisions->vec_size = key.vectorized ? 4 : 1;

Expand Down
28 changes: 17 additions & 11 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ struct webgpu_dispatch_desc {

struct webgpu_capabilities {
wgpu::Limits limits;
bool supports_subgroups = false;
bool supports_subgroup_matrix = false;

uint32_t sg_mat_m = 0;
Expand Down Expand Up @@ -1164,14 +1165,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q6_K:
use_fast = true;
break;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
use_fast = !is_vec;
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q2_K:
use_fast = true;
break;
default:
break;
Expand All @@ -1182,10 +1180,12 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
}

ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;

shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
Expand Down Expand Up @@ -1287,7 +1287,8 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;

// Get or create pipeline
webgpu_pipeline gather_pipeline, main_pipeline;
webgpu_pipeline gather_pipeline;
webgpu_pipeline main_pipeline;

std::vector<webgpu_dispatch_desc> dispatches;

Expand Down Expand Up @@ -3040,6 +3041,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
// we require f16 support
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);

#ifndef __EMSCRIPTEN__
// Accept f16 subgroup matrix configurations (square or non-square).
Expand Down Expand Up @@ -3072,11 +3075,14 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
#ifndef __EMSCRIPTEN__
required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
required_features.push_back(wgpu::FeatureName::Subgroups);
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
}
#endif

if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) {
required_features.push_back(wgpu::FeatureName::Subgroups);
}

#ifdef GGML_WEBGPU_GPU_PROFILE
required_features.push_back(wgpu::FeatureName::TimestampQuery);
#endif
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ fn load_u16_at_src0(byte_offset: u32) -> u32 {
return (word >> shift) & 0xFFFFu;
}

// Always reads the 4-byte-aligned word containing byte_offset.
// Caller extracts the 16-bit half it needs via & 0xFFFFu or >> 16u.
// this is used in k-quants for better performance
fn load_u32_at_src0_aligned(byte_offset: u32) -> u32 {
return src0[(byte_offset & ~3u) / 4u];
}

fn load_u32_at_src0(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 0x3u) * 8u;
Expand Down
Loading
Loading