Skip to content

Commit 5eb8fff

Browse files
neha-haNeha Abbasreeselevine
authored
ggml-webgpu: updated matrix-vector multiplication (ggml-org#21738)
* merged properly, but slow q3_k and q5_k with u32 indexing * Start on new mat-vec * New format float paths working * Working q4_0 * Work on remaining legacy q-types * port k-quants to new matvec * remove old shader * Remove old constants, format * remove accidental file --------- Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local> Co-authored-by: Reese Levine <reeselevine1@gmail.com>
1 parent f8ea065 commit 5eb8fff

4 files changed

Lines changed: 788 additions & 383 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,9 @@
4444
// Matrix-vector multiplication parameters
4545
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
4646

47-
// Must be multiple of 4 to work with vectorized paths, and must divide
48-
// mul_mat_vec wg size
49-
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
50-
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
51-
52-
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
53-
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
54-
55-
// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
56-
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
57-
// Requires at least two (and multiple of 2) k-quant blocks per tile
58-
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
47+
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4
48+
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
49+
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4
5950

6051
// default size for legacy matrix multiplication
6152
#define WEBGPU_MUL_MAT_WG_SIZE 256
@@ -78,6 +69,7 @@ struct ggml_webgpu_shader_lib_context {
7869
bool inplace = false;
7970
bool overlap = false;
8071
bool src_overlap = false;
72+
bool supports_subgroups = false;
8173
bool supports_subgroup_matrix = false;
8274
uint32_t sg_mat_m = 0;
8375
uint32_t sg_mat_n = 0;
@@ -575,7 +567,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
575567

576568
struct ggml_webgpu_mul_mat_vec_shader_decisions {
577569
uint32_t wg_size;
578-
uint32_t tile_k;
579570
uint32_t outputs_per_wg;
580571
uint32_t vec_size;
581572
};
@@ -1326,7 +1317,7 @@ class ggml_webgpu_shader_lib {
13261317
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
13271318
key.src0_type = context.src0->type;
13281319
key.src1_type = context.src1->type;
1329-
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
1320+
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
13301321
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
13311322
1 :
13321323
0;
@@ -1337,7 +1328,8 @@ class ggml_webgpu_shader_lib {
13371328
}
13381329

13391330
std::vector<std::string> defines;
1340-
std::string variant = "mul_mat_vec";
1331+
std::string variant = "mul_mat_vec";
1332+
const char * shader_src = wgsl_mul_mat_vec;
13411333

13421334
// src0 type (matrix row)
13431335
switch (context.src0->type) {
@@ -1386,25 +1378,25 @@ class ggml_webgpu_shader_lib {
13861378
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
13871379

13881380
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
1389-
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
13901381
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
13911382

13921383
if (key.src0_type >= GGML_TYPE_Q2_K) {
1393-
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
13941384
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
13951385
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
1396-
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
13971386
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
13981387
}
13991388

14001389
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
1401-
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
14021390
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
1391+
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
1392+
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
1393+
if (key.vectorized) {
1394+
variant += "_vectorized";
1395+
}
14031396

1404-
auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines);
1397+
auto processed = preprocessor.preprocess(shader_src, defines);
14051398
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
14061399
decisions->wg_size = wg_size;
1407-
decisions->tile_k = tile_k;
14081400
decisions->outputs_per_wg = outputs_per_wg;
14091401
decisions->vec_size = key.vectorized ? 4 : 1;
14101402

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ struct webgpu_dispatch_desc {
181181

182182
struct webgpu_capabilities {
183183
wgpu::Limits limits;
184+
bool supports_subgroups = false;
184185
bool supports_subgroup_matrix = false;
185186

186187
uint32_t sg_mat_m = 0;
@@ -1164,14 +1165,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
11641165
case GGML_TYPE_Q8_0:
11651166
case GGML_TYPE_Q8_1:
11661167
case GGML_TYPE_Q6_K:
1167-
use_fast = true;
1168-
break;
1169-
case GGML_TYPE_Q2_K:
1170-
case GGML_TYPE_Q3_K:
11711168
case GGML_TYPE_Q4_K:
11721169
case GGML_TYPE_Q5_K:
1173-
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
1174-
use_fast = !is_vec;
1170+
case GGML_TYPE_Q3_K:
1171+
case GGML_TYPE_Q2_K:
1172+
use_fast = true;
11751173
break;
11761174
default:
11771175
break;
@@ -1182,10 +1180,12 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
11821180
}
11831181

11841182
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
1185-
shader_lib_ctx.src0 = src0;
1186-
shader_lib_ctx.src1 = src1;
1187-
shader_lib_ctx.dst = dst;
1183+
1184+
shader_lib_ctx.src0 = src0;
1185+
shader_lib_ctx.src1 = src1;
1186+
shader_lib_ctx.dst = dst;
11881187
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
1188+
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
11891189
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
11901190
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
11911191
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
@@ -1287,7 +1287,8 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
12871287
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
12881288

12891289
// Get or create pipeline
1290-
webgpu_pipeline gather_pipeline, main_pipeline;
1290+
webgpu_pipeline gather_pipeline;
1291+
webgpu_pipeline main_pipeline;
12911292

12921293
std::vector<webgpu_dispatch_desc> dispatches;
12931294

@@ -3040,6 +3041,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
30403041
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
30413042
// we require f16 support
30423043
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
3044+
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
3045+
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
30433046

30443047
#ifndef __EMSCRIPTEN__
30453048
// Accept f16 subgroup matrix configurations (square or non-square).
@@ -3072,11 +3075,14 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
30723075
#ifndef __EMSCRIPTEN__
30733076
required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
30743077
if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
3075-
required_features.push_back(wgpu::FeatureName::Subgroups);
30763078
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
30773079
}
30783080
#endif
30793081

3082+
if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) {
3083+
required_features.push_back(wgpu::FeatureName::Subgroups);
3084+
}
3085+
30803086
#ifdef GGML_WEBGPU_GPU_PROFILE
30813087
required_features.push_back(wgpu::FeatureName::TimestampQuery);
30823088
#endif

ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ fn load_u16_at_src0(byte_offset: u32) -> u32 {
4545
return (word >> shift) & 0xFFFFu;
4646
}
4747

48+
// Always reads the 4-byte-aligned word containing byte_offset.
49+
// Caller extracts the 16-bit half it needs via & 0xFFFFu or >> 16u.
50+
// this is used in k-quants for better performance
51+
fn load_u32_at_src0_aligned(byte_offset: u32) -> u32 {
52+
return src0[(byte_offset & ~3u) / 4u];
53+
}
54+
4855
fn load_u32_at_src0(byte_offset: u32) -> u32 {
4956
let word_idx = byte_offset / 4u;
5057
let shift = (byte_offset & 0x3u) * 8u;

0 commit comments

Comments
 (0)