@@ -1972,6 +1972,94 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor *
19721972 return ggml_backend_webgpu_build (ctx, pipeline, params, entries, wg_x);
19731973}
19741974
1975+ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul (webgpu_context & ctx,
1976+ ggml_tensor * rn_src,
1977+ ggml_tensor * rn_dst,
1978+ ggml_tensor * mul_src0,
1979+ ggml_tensor * mul_src1,
1980+ ggml_tensor * dst) {
1981+ ggml_tensor * mul_src;
1982+
1983+ if (ggml_webgpu_tensor_equal (rn_dst, mul_src0)) {
1984+ mul_src = mul_src1;
1985+ } else if (ggml_webgpu_tensor_equal (rn_dst, mul_src1)) {
1986+ mul_src = mul_src0;
1987+ } else {
1988+ GGML_ABORT (" rms_norm must be equal to the one of mul_src0 and mul_src1" );
1989+ }
1990+
1991+ bool inplace = (ggml_webgpu_tensor_equal (rn_dst, mul_src0) && ggml_webgpu_tensor_equal (mul_src1, dst)) ||
1992+ (ggml_webgpu_tensor_equal (rn_dst, mul_src1) && ggml_webgpu_tensor_equal (mul_src0, dst));
1993+ bool src_overlap = ggml_webgpu_tensor_overlap (rn_src, mul_src);
1994+
1995+ uint32_t offset_merged_rn_src = 0 ;
1996+ uint32_t offset_merged_mul_src = 0 ;
1997+ size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset (ctx, rn_src);
1998+ size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset (ctx, mul_src);
1999+
2000+ if (src_overlap) {
2001+ size_t min_offset = std::min (rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
2002+ offset_merged_rn_src =
2003+ (uint32_t ) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size (rn_src->type ));
2004+ offset_merged_mul_src =
2005+ (uint32_t ) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size (mul_src->type ));
2006+ }
2007+
2008+ std::vector<uint32_t > params = {
2009+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, rn_src) / ggml_type_size (rn_src->type )),
2010+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, mul_src) / ggml_type_size (mul_src->type )),
2011+ offset_merged_rn_src,
2012+ offset_merged_mul_src,
2013+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, dst) / ggml_type_size (dst->type )),
2014+ (uint32_t ) (rn_src->nb [1 ] / ggml_type_size (rn_src->type )),
2015+ (uint32_t ) (rn_src->nb [2 ] / ggml_type_size (rn_src->type )),
2016+ (uint32_t ) (rn_src->nb [3 ] / ggml_type_size (rn_src->type )),
2017+ (uint32_t ) (mul_src->nb [1 ] / ggml_type_size (mul_src->type )),
2018+ (uint32_t ) (mul_src->nb [2 ] / ggml_type_size (mul_src->type )),
2019+ (uint32_t ) (mul_src->nb [3 ] / ggml_type_size (mul_src->type )),
2020+ (uint32_t ) (dst->nb [1 ] / ggml_type_size (dst->type )),
2021+ (uint32_t ) (dst->nb [2 ] / ggml_type_size (dst->type )),
2022+ (uint32_t ) (dst->nb [3 ] / ggml_type_size (dst->type )),
2023+ (uint32_t ) mul_src->ne [0 ],
2024+ (uint32_t ) mul_src->ne [1 ],
2025+ (uint32_t ) mul_src->ne [2 ],
2026+ (uint32_t ) mul_src->ne [3 ],
2027+ (uint32_t ) dst->ne [0 ],
2028+ (uint32_t ) dst->ne [1 ],
2029+ (uint32_t ) dst->ne [2 ],
2030+ (uint32_t ) dst->ne [3 ],
2031+ ggml_webgpu_u32_from_f32 (ggml_get_op_params_f32 (rn_dst, 0 )) // epsilon, treated as f32 in the shader
2032+ };
2033+
2034+ std::vector<wgpu::BindGroupEntry> entries;
2035+
2036+ if (inplace) {
2037+ entries.push_back (ggml_webgpu_make_tensor_bind_group_entry (ctx, 0 , rn_src));
2038+ entries.push_back (ggml_webgpu_make_tensor_bind_group_entry (ctx, 1 , mul_src));
2039+ } else if (src_overlap) {
2040+ size_t merged_offset = std::min (rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
2041+ size_t merged_end =
2042+ std::max (rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size (ctx, rn_src),
2043+ mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size (ctx, mul_src));
2044+ entries.push_back (ggml_webgpu_make_bind_group_entry (0 , ggml_webgpu_tensor_buf (rn_src), merged_offset,
2045+ merged_end - merged_offset));
2046+ entries.push_back (ggml_webgpu_make_tensor_bind_group_entry (ctx, 1 , dst));
2047+ } else {
2048+ entries.push_back (ggml_webgpu_make_tensor_bind_group_entry (ctx, 0 , rn_src));
2049+ entries.push_back (ggml_webgpu_make_tensor_bind_group_entry (ctx, 1 , mul_src));
2050+ entries.push_back (ggml_webgpu_make_tensor_bind_group_entry (ctx, 2 , dst));
2051+ }
2052+
2053+ ggml_webgpu_shader_lib_context shader_lib_ctx = {};
2054+ shader_lib_ctx.max_wg_size = ctx->global_ctx ->capabilities .limits .maxComputeInvocationsPerWorkgroup ;
2055+ shader_lib_ctx.inplace = inplace;
2056+ shader_lib_ctx.src_overlap = src_overlap;
2057+
2058+ webgpu_pipeline pipeline = ctx->shader_lib ->get_rms_norm_mul_pipeline (shader_lib_ctx);
2059+
2060+ return ggml_backend_webgpu_build (ctx, pipeline, params, entries, ggml_nrows (dst));
2061+ }
2062+
19752063static webgpu_encoded_op ggml_webgpu_row_norm (webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
19762064 bool inplace = ggml_webgpu_tensor_equal (src, dst);
19772065
@@ -2468,15 +2556,48 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor
24682556 return ggml_backend_webgpu_build (ctx, pipeline, params, entries, wg_x);
24692557}
24702558
2559+ static bool ggml_webgpu_can_fuse_rms_norm_mul (const struct ggml_cgraph * cgraph, int node_idx) {
2560+ if (!ggml_can_fuse (cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2561+ return false ;
2562+ }
2563+
2564+ // additional constraints specific to this fusion
2565+ const ggml_tensor * rms_norm = cgraph->nodes [node_idx];
2566+ const ggml_tensor * mul = cgraph->nodes [node_idx + 1 ];
2567+
2568+ GGML_ASSERT (rms_norm->src [0 ]->type == GGML_TYPE_F32);
2569+ GGML_ASSERT (rms_norm->type == GGML_TYPE_F32);
2570+ // rms_norm only supports f32
2571+ if (mul->src [0 ]->type != GGML_TYPE_F32 || mul->src [1 ]->type != GGML_TYPE_F32 || mul->type != GGML_TYPE_F32) {
2572+ return false ;
2573+ }
2574+ // if rms_norm is the B operand, then we don't handle broadcast
2575+ if (rms_norm == mul->src [1 ] && !ggml_are_same_shape (mul->src [0 ], rms_norm)) {
2576+ return false ;
2577+ }
2578+ // rms_norm shader assumes contiguous rows
2579+ if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
2580+ return false ;
2581+ }
2582+
2583+ return true ;
2584+ }
2585+
24712586// Returns the encoded command, or std::nullopt if the operation is a no-op
2472- static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node (webgpu_context ctx, ggml_tensor * node) {
2587+ static std::optional<webgpu_encoded_op> ggml_webgpu_encode (webgpu_context ctx,
2588+ ggml_cgraph * cgraph,
2589+ int node_idx,
2590+ int & num_encoded_ops) {
2591+ ggml_tensor ** nodes = cgraph->nodes ;
2592+ ggml_tensor * node = nodes[node_idx];
2593+
24732594 if (ggml_is_empty (node)) {
24742595 return std::nullopt ;
24752596 }
24762597 if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0 ) {
24772598 return std::nullopt ;
24782599 }
2479- WEBGPU_LOG_DEBUG (" ggml_webgpu_encode_node (" << node << " , " << ggml_op_name (node->op ) << " )" );
2600+ WEBGPU_LOG_DEBUG (" ggml_webgpu_encode (" << node << " , " << ggml_op_name (node->op ) << " )" );
24802601
24812602 ggml_tensor * src0 = node->src [0 ];
24822603 ggml_tensor * src1 = node->src [1 ];
@@ -2519,6 +2640,13 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context c
25192640 case GGML_OP_REPEAT:
25202641 return ggml_webgpu_repeat (ctx, src0, node);
25212642 case GGML_OP_RMS_NORM:
2643+ if (ggml_webgpu_can_fuse_rms_norm_mul (cgraph, node_idx)) {
2644+ num_encoded_ops = 2 ;
2645+ ggml_tensor * mul_node = nodes[node_idx + 1 ];
2646+ return ggml_webgpu_rms_norm_mul (ctx, src0, node, mul_node->src [0 ], mul_node->src [1 ], mul_node);
2647+ } else {
2648+ return ggml_webgpu_row_norm (ctx, src0, node);
2649+ }
25222650 case GGML_OP_L2_NORM:
25232651 return ggml_webgpu_row_norm (ctx, src0, node);
25242652 case GGML_OP_ROPE:
@@ -2629,6 +2757,8 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
26292757 uint32_t num_inflight_batches = 0 ;
26302758 bool contains_set_rows = false ;
26312759 bool batch_compute_passes = true ;
2760+ int num_encoded_ops = 1 ;
2761+ int node_idx = 0 ;
26322762
26332763#ifdef GGML_WEBGPU_GPU_PROFILE
26342764 ctx->profile_timestamp_query_count = 0 ;
@@ -2641,11 +2771,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
26412771 ctx->active_compute_pass = ctx->active_command_encoder .BeginComputePass ();
26422772 }
26432773
2644- for ( int i = 0 ; i < cgraph->n_nodes ; i++ ) {
2645- if (cgraph->nodes [i ]->op == GGML_OP_SET_ROWS) {
2774+ while (node_idx < cgraph->n_nodes ) {
2775+ if (cgraph->nodes [node_idx ]->op == GGML_OP_SET_ROWS) {
26462776 contains_set_rows = true ;
26472777 }
2648- if (auto cmd = ggml_webgpu_encode_node (ctx, cgraph-> nodes [i] )) {
2778+ if (auto cmd = ggml_webgpu_encode (ctx, cgraph, node_idx, num_encoded_ops )) {
26492779 commands.push_back (*cmd);
26502780 num_batched_kernels += cmd.value ().num_kernels ;
26512781#ifdef GGML_WEBGPU_GPU_PROFILE
@@ -2670,6 +2800,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
26702800 ctx->param_arena .reset ();
26712801 commands.clear ();
26722802 }
2803+
2804+ node_idx += num_encoded_ops;
2805+ num_encoded_ops = 1 ;
26732806 }
26742807
26752808 if (ctx->active_compute_pass ) {
@@ -3237,7 +3370,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
32373370 ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context ;
32383371 webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
32393372 webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx ;
3240- webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx ->device );
3373+ webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx ->device );
32413374 webgpu_ctx->param_arena .init (
32423375 webgpu_ctx->global_ctx ->device , WEBGPU_PARAMS_BUF_SIZE_BYTES,
32433376 webgpu_ctx->global_ctx ->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN,
@@ -3487,12 +3620,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
34873620 break ;
34883621 }
34893622 // Head dimensions must fit in workgroup memory with minimum tile sizes
3490- size_t limit_bytes = ctx->webgpu_global_ctx ->capabilities .limits .maxComputeWorkgroupStorageSize ;
3491- const bool has_mask = op->src [3 ] != nullptr ;
3492- const bool kv_direct = src1->type == GGML_TYPE_F16 &&
3493- (src0->ne [0 ] % ctx->webgpu_global_ctx ->capabilities .sg_mat_k ) == 0 &&
3494- (src1->ne [1 ] % GGML_WEBGPU_KV_SEQ_PAD) == 0 ;
3495- const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes (
3623+ size_t limit_bytes = ctx->webgpu_global_ctx ->capabilities .limits .maxComputeWorkgroupStorageSize ;
3624+ const bool has_mask = op->src [3 ] != nullptr ;
3625+ const bool kv_direct = src1->type == GGML_TYPE_F16 &&
3626+ (src0->ne [0 ] % ctx->webgpu_global_ctx ->capabilities .sg_mat_k ) == 0 &&
3627+ (src1->ne [1 ] % GGML_WEBGPU_KV_SEQ_PAD) == 0 ;
3628+ const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes (
34963629 ctx->webgpu_global_ctx ->capabilities .sg_mat_m , ctx->webgpu_global_ctx ->capabilities .sg_mat_n ,
34973630 (uint32_t ) src0->ne [0 ], (uint32_t ) src2->ne [0 ], has_mask, kv_direct);
34983631 if (min_bytes > limit_bytes) {
0 commit comments