Skip to content

Commit 6da7168

Browse files
authored
ggml-webgpu: Add fused RMS_NORM + MUL (ggml-org#21983)
* fused rms_norm_mul + mul * Add GGML_WEBGPU_DISABLE_FUSION for being able to disable kernel fusion. * Decouple num_fused_ops from webgpu_context; misc cleanup * Fix eps handling and remove disable_fusion. * Fix not to use c++20 initializers.
1 parent 8bccdbb commit 6da7168

3 files changed

Lines changed: 349 additions & 18 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,26 @@ struct ggml_webgpu_row_norm_pipeline_key_hash {
194194
}
195195
};
196196

197+
/** RMS_NORM + MUL **/
198+
199+
struct ggml_webgpu_rms_norm_mul_pipeline_key {
200+
bool inplace;
201+
bool src_overlap;
202+
203+
bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const {
204+
return inplace == other.inplace && src_overlap == other.src_overlap;
205+
}
206+
};
207+
208+
struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
209+
size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const {
210+
size_t seed = 0;
211+
ggml_webgpu_hash_combine(seed, key.inplace);
212+
ggml_webgpu_hash_combine(seed, key.src_overlap);
213+
return seed;
214+
}
215+
};
216+
197217
/** Pad **/
198218
struct ggml_webgpu_pad_pipeline_key {
199219
bool circular;
@@ -517,7 +537,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
517537
const size_t q_tile = context.sg_mat_m;
518538
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
519539
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
520-
size_t bytes_per_kv = 0;
540+
size_t bytes_per_kv = 0;
521541
if (!key.kv_direct) {
522542
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
523543
}
@@ -755,16 +775,17 @@ class ggml_webgpu_shader_lib {
755775
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
756776
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
757777
row_norm_pipelines; // op/inplace
778+
758779
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
759-
get_rows_pipelines; // src_type, vectorized
780+
get_rows_pipelines; // src_type, vectorized
760781
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
761-
unary_pipelines; // type/op/inplace
782+
unary_pipelines; // type/op/inplace
762783
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
763-
scale_pipelines; // inplace
784+
scale_pipelines; // inplace
764785
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
765-
solve_tri_pipelines; // type
786+
solve_tri_pipelines; // type
766787
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
767-
ssm_conv_pipelines; // type/vectorized
788+
ssm_conv_pipelines; // type/vectorized
768789
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
769790
webgpu_pipeline,
770791
ggml_webgpu_gated_delta_net_pipeline_key_hash>
@@ -813,6 +834,11 @@ class ggml_webgpu_shader_lib {
813834
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
814835
conv2d_pipelines;
815836

837+
std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key,
838+
webgpu_pipeline,
839+
ggml_webgpu_rms_norm_mul_pipeline_key_hash>
840+
rms_norm_mul_pipelines;
841+
816842
public:
817843
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
818844

@@ -1828,6 +1854,39 @@ class ggml_webgpu_shader_lib {
18281854
return unary_pipelines[key];
18291855
}
18301856

1857+
webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
1858+
ggml_webgpu_rms_norm_mul_pipeline_key key = {};
1859+
key.inplace = context.inplace;
1860+
key.src_overlap = context.src_overlap;
1861+
1862+
auto it = rms_norm_mul_pipelines.find(key);
1863+
if (it != rms_norm_mul_pipelines.end()) {
1864+
return it->second;
1865+
}
1866+
1867+
std::vector<std::string> defines;
1868+
std::string op_name = "RMS_NORM_MUL";
1869+
std::string variant = op_name;
1870+
1871+
if (key.inplace) {
1872+
defines.push_back("INPLACE");
1873+
variant += "_inplace";
1874+
} else if (key.src_overlap) {
1875+
defines.push_back("SRC_OVERLAP");
1876+
variant += "_src_overlap";
1877+
}
1878+
1879+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1880+
1881+
auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
1882+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1883+
decisions->wg_size = context.max_wg_size;
1884+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1885+
pipeline.context = decisions;
1886+
rms_norm_mul_pipelines[key] = pipeline;
1887+
return rms_norm_mul_pipelines[key];
1888+
}
1889+
18311890
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
18321891
ggml_webgpu_binary_pipeline_key key = {};
18331892
key.type = context.dst->type;

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 145 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1972,6 +1972,94 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor *
19721972
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
19731973
}
19741974

1975+
static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context & ctx,
1976+
ggml_tensor * rn_src,
1977+
ggml_tensor * rn_dst,
1978+
ggml_tensor * mul_src0,
1979+
ggml_tensor * mul_src1,
1980+
ggml_tensor * dst) {
1981+
ggml_tensor * mul_src;
1982+
1983+
if (ggml_webgpu_tensor_equal(rn_dst, mul_src0)) {
1984+
mul_src = mul_src1;
1985+
} else if (ggml_webgpu_tensor_equal(rn_dst, mul_src1)) {
1986+
mul_src = mul_src0;
1987+
} else {
1988+
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
1989+
}
1990+
1991+
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
1992+
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
1993+
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
1994+
1995+
uint32_t offset_merged_rn_src = 0;
1996+
uint32_t offset_merged_mul_src = 0;
1997+
size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src);
1998+
size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src);
1999+
2000+
if (src_overlap) {
2001+
size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
2002+
offset_merged_rn_src =
2003+
(uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type));
2004+
offset_merged_mul_src =
2005+
(uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type));
2006+
}
2007+
2008+
std::vector<uint32_t> params = {
2009+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)),
2010+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)),
2011+
offset_merged_rn_src,
2012+
offset_merged_mul_src,
2013+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
2014+
(uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)),
2015+
(uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)),
2016+
(uint32_t) (rn_src->nb[3] / ggml_type_size(rn_src->type)),
2017+
(uint32_t) (mul_src->nb[1] / ggml_type_size(mul_src->type)),
2018+
(uint32_t) (mul_src->nb[2] / ggml_type_size(mul_src->type)),
2019+
(uint32_t) (mul_src->nb[3] / ggml_type_size(mul_src->type)),
2020+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
2021+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
2022+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
2023+
(uint32_t) mul_src->ne[0],
2024+
(uint32_t) mul_src->ne[1],
2025+
(uint32_t) mul_src->ne[2],
2026+
(uint32_t) mul_src->ne[3],
2027+
(uint32_t) dst->ne[0],
2028+
(uint32_t) dst->ne[1],
2029+
(uint32_t) dst->ne[2],
2030+
(uint32_t) dst->ne[3],
2031+
ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(rn_dst, 0)) // epsilon, treated as f32 in the shader
2032+
};
2033+
2034+
std::vector<wgpu::BindGroupEntry> entries;
2035+
2036+
if (inplace) {
2037+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
2038+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
2039+
} else if (src_overlap) {
2040+
size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
2041+
size_t merged_end =
2042+
std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src),
2043+
mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src));
2044+
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset,
2045+
merged_end - merged_offset));
2046+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
2047+
} else {
2048+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
2049+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
2050+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
2051+
}
2052+
2053+
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
2054+
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
2055+
shader_lib_ctx.inplace = inplace;
2056+
shader_lib_ctx.src_overlap = src_overlap;
2057+
2058+
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
2059+
2060+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst));
2061+
}
2062+
19752063
static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
19762064
bool inplace = ggml_webgpu_tensor_equal(src, dst);
19772065

@@ -2468,15 +2556,48 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor
24682556
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
24692557
}
24702558

2559+
static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, int node_idx) {
2560+
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2561+
return false;
2562+
}
2563+
2564+
// additional constraints specific to this fusion
2565+
const ggml_tensor * rms_norm = cgraph->nodes[node_idx];
2566+
const ggml_tensor * mul = cgraph->nodes[node_idx + 1];
2567+
2568+
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2569+
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
2570+
// rms_norm only supports f32
2571+
if (mul->src[0]->type != GGML_TYPE_F32 || mul->src[1]->type != GGML_TYPE_F32 || mul->type != GGML_TYPE_F32) {
2572+
return false;
2573+
}
2574+
// if rms_norm is the B operand, then we don't handle broadcast
2575+
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
2576+
return false;
2577+
}
2578+
// rms_norm shader assumes contiguous rows
2579+
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
2580+
return false;
2581+
}
2582+
2583+
return true;
2584+
}
2585+
24712586
// Returns the encoded command, or std::nullopt if the operation is a no-op
2472-
static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
2587+
static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
2588+
ggml_cgraph * cgraph,
2589+
int node_idx,
2590+
int & num_encoded_ops) {
2591+
ggml_tensor ** nodes = cgraph->nodes;
2592+
ggml_tensor * node = nodes[node_idx];
2593+
24732594
if (ggml_is_empty(node)) {
24742595
return std::nullopt;
24752596
}
24762597
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
24772598
return std::nullopt;
24782599
}
2479-
WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
2600+
WEBGPU_LOG_DEBUG("ggml_webgpu_encode(" << node << ", " << ggml_op_name(node->op) << ")");
24802601

24812602
ggml_tensor * src0 = node->src[0];
24822603
ggml_tensor * src1 = node->src[1];
@@ -2519,6 +2640,13 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context c
25192640
case GGML_OP_REPEAT:
25202641
return ggml_webgpu_repeat(ctx, src0, node);
25212642
case GGML_OP_RMS_NORM:
2643+
if (ggml_webgpu_can_fuse_rms_norm_mul(cgraph, node_idx)) {
2644+
num_encoded_ops = 2;
2645+
ggml_tensor * mul_node = nodes[node_idx + 1];
2646+
return ggml_webgpu_rms_norm_mul(ctx, src0, node, mul_node->src[0], mul_node->src[1], mul_node);
2647+
} else {
2648+
return ggml_webgpu_row_norm(ctx, src0, node);
2649+
}
25222650
case GGML_OP_L2_NORM:
25232651
return ggml_webgpu_row_norm(ctx, src0, node);
25242652
case GGML_OP_ROPE:
@@ -2629,6 +2757,8 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
26292757
uint32_t num_inflight_batches = 0;
26302758
bool contains_set_rows = false;
26312759
bool batch_compute_passes = true;
2760+
int num_encoded_ops = 1;
2761+
int node_idx = 0;
26322762

26332763
#ifdef GGML_WEBGPU_GPU_PROFILE
26342764
ctx->profile_timestamp_query_count = 0;
@@ -2641,11 +2771,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
26412771
ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass();
26422772
}
26432773

2644-
for (int i = 0; i < cgraph->n_nodes; i++) {
2645-
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
2774+
while (node_idx < cgraph->n_nodes) {
2775+
if (cgraph->nodes[node_idx]->op == GGML_OP_SET_ROWS) {
26462776
contains_set_rows = true;
26472777
}
2648-
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
2778+
if (auto cmd = ggml_webgpu_encode(ctx, cgraph, node_idx, num_encoded_ops)) {
26492779
commands.push_back(*cmd);
26502780
num_batched_kernels += cmd.value().num_kernels;
26512781
#ifdef GGML_WEBGPU_GPU_PROFILE
@@ -2670,6 +2800,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
26702800
ctx->param_arena.reset();
26712801
commands.clear();
26722802
}
2803+
2804+
node_idx += num_encoded_ops;
2805+
num_encoded_ops = 1;
26732806
}
26742807

26752808
if (ctx->active_compute_pass) {
@@ -3237,7 +3370,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
32373370
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
32383371
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
32393372
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
3240-
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
3373+
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
32413374
webgpu_ctx->param_arena.init(
32423375
webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES,
32433376
webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN,
@@ -3487,12 +3620,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
34873620
break;
34883621
}
34893622
// Head dimensions must fit in workgroup memory with minimum tile sizes
3490-
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
3491-
const bool has_mask = op->src[3] != nullptr;
3492-
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
3493-
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
3494-
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
3495-
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
3623+
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
3624+
const bool has_mask = op->src[3] != nullptr;
3625+
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
3626+
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
3627+
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
3628+
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
34963629
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
34973630
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
34983631
if (min_bytes > limit_bytes) {

0 commit comments

Comments
 (0)