diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index aa3fe06d5a9..01637e2ddab 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -73,8 +73,8 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #endif // GGML_WEBGPU_CPU_PROFILE #ifdef GGML_WEBGPU_GPU_PROFILE -# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32 -# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps +# define WEBGPU_MAX_PROFILE_QUERY_COUNT 4096u +# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES (WEBGPU_MAX_PROFILE_QUERY_COUNT * sizeof(uint64_t)) #endif /* Constants */ @@ -159,78 +159,20 @@ struct webgpu_param_arena { ~webgpu_param_arena() { this->cleanup(); } }; -#ifdef GGML_WEBGPU_GPU_PROFILE -struct webgpu_gpu_profile_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - wgpu::QuerySet query_set; -}; - -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_gpu_profile_buf_pool { - std::vector free; - - std::mutex mutex; - - std::condition_variable cv; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf"); - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf"); - // Create a query set for 2 timestamps - wgpu::QuerySetDescriptor ts_query_set_desc = {}; - - ts_query_set_desc.type = wgpu::QueryType::Timestamp; - ts_query_set_desc.count = 2; - wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc); - - free.push_back({ host_buf, dev_buf, ts_query_set }); - } - } - - webgpu_gpu_profile_bufs alloc_bufs() { - std::unique_lock lock(mutex); - cv.wait(lock, [this] { return !free.empty(); }); - webgpu_gpu_profile_bufs bufs = free.back(); - free.pop_back(); - return bufs; - } - - void free_bufs(std::vector bufs) { - std::lock_guard lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } - - void cleanup() { - std::lock_guard lock(mutex); - for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); - bufs.query_set.Destroy(); - } - free.clear(); - } - - ~webgpu_gpu_profile_buf_pool() { this->cleanup(); } -}; -#endif - struct webgpu_encoded_op { uint32_t num_kernels = 0; #ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs timestamp_query_bufs; - std::string pipeline_name; + std::vector pipeline_names; #endif }; +struct webgpu_dispatch_desc { + webgpu_pipeline pipeline; + std::vector params; + std::vector bind_group_entries; + std::pair workgroups = { 1, 1 }; +}; + struct webgpu_capabilities { wgpu::Limits limits; bool supports_subgroup_matrix = false; @@ -256,7 +198,7 @@ struct webgpu_global_context_struct { webgpu_capabilities capabilities; // Shared buffer to move data from device to host wgpu::Buffer get_tensor_staging_buf; - // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches. + // Global mutex for get_tensor std::recursive_mutex mutex; wgpu::Buffer memset_params_buf; @@ -272,8 +214,6 @@ struct webgpu_global_context_struct { #ifdef GGML_WEBGPU_GPU_PROFILE // Profiling: per-shader GPU time in ms std::unordered_map shader_gpu_time_ms; - // Profiling: pool of timestamp query buffers (one per operation) - webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; #endif #ifdef GGML_WEBGPU_DEBUG @@ -312,11 +252,45 @@ struct webgpu_context_struct { std::unique_ptr shader_lib; - webgpu_param_arena param_arena; - wgpu::Buffer set_rows_dev_error_buf; - wgpu::Buffer set_rows_host_error_buf; + webgpu_param_arena param_arena; + wgpu::Buffer set_rows_dev_error_buf; + wgpu::Buffer set_rows_host_error_buf; + wgpu::CommandEncoder active_command_encoder; + wgpu::ComputePassEncoder active_compute_pass; size_t memset_bytes_per_thread; + +#ifdef GGML_WEBGPU_GPU_PROFILE + wgpu::Buffer profile_timestamp_dev_buf; + wgpu::Buffer profile_timestamp_host_buf; + wgpu::QuerySet profile_timestamp_query_set; + uint32_t profile_timestamp_query_count = 0; +#endif + + ~webgpu_context_struct() { +#ifdef GGML_WEBGPU_GPU_PROFILE + if (this->profile_timestamp_host_buf) { + this->profile_timestamp_host_buf.Destroy(); + this->profile_timestamp_host_buf = nullptr; + } + if (this->profile_timestamp_dev_buf) { + this->profile_timestamp_dev_buf.Destroy(); + this->profile_timestamp_dev_buf = nullptr; + } + if (this->profile_timestamp_query_set) { + this->profile_timestamp_query_set.Destroy(); + this->profile_timestamp_query_set = nullptr; + } +#endif + if (this->set_rows_host_error_buf) { + this->set_rows_host_error_buf.Destroy(); + this->set_rows_host_error_buf = nullptr; + } + if (this->set_rows_dev_error_buf) { + this->set_rows_dev_error_buf.Destroy(); + this->set_rows_dev_error_buf = nullptr; + } + } }; typedef std::shared_ptr webgpu_context; @@ -399,24 +373,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ -#ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx, - std::vector & futures) { - if (futures.empty()) { - return; - } - - constexpr size_t max_futures_per_wait = 64; - - while (!futures.empty()) { - ctx->instance.WaitAny(std::min(max_futures_per_wait, futures.size()), futures.data(), UINT64_MAX); - futures.erase(std::remove_if(futures.begin(), futures.end(), - [](const wgpu::FutureWaitInfo & info) { return info.completed; }), - futures.end()); - } -} -#endif - template static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, T callback_status, @@ -436,22 +392,8 @@ static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, } } -#ifdef __EMSCRIPTEN__ -EM_JS(int, ggml_webgpu_is_ios_browser, (), { - const ua = navigator.userAgent; - return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0; -}); -#endif - // TODO: these next two functions may want tuning across different platforms and workloads, static uint32_t ggml_backend_webgpu_get_max_inflight_batches() { -#ifdef __EMSCRIPTEN__ - // iOS has very strict limits on the number of in-flight GPU commands, - // so we need to throttle to avoid failures. - if (ggml_webgpu_is_ios_browser()) { - return 1; - } -#endif return UINT32_MAX; } @@ -524,118 +466,77 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -#ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx, - const std::vector & commands, - std::vector & futures) { - for (const auto & command : commands) { - auto label = command.pipeline_name; - auto ts_bufs = command.timestamp_query_bufs; - - wgpu::Future f = ts_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); - } else { - const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange(); - // WebGPU timestamps are in ns; convert to ms - double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6; - ctx->shader_gpu_time_ms[label] += elapsed_ms; - } - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); - }); - futures.push_back({ f }); - } -} -#endif - -static webgpu_encoded_op ggml_backend_webgpu_build_multi( - webgpu_global_context & ctx, - webgpu_param_arena & param_arena, - wgpu::CommandEncoder & encoder, - const std::vector & pipelines, - const std::vector> & params_list, - const std::vector> & bind_group_entries_list, - const std::vector> & workgroups_list) { - GGML_ASSERT(pipelines.size() == params_list.size()); - GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); - GGML_ASSERT(pipelines.size() == workgroups_list.size()); - +static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & ctx, + const std::vector & dispatches) { webgpu_encoded_op result = {}; std::vector bind_groups; std::vector param_offsets; - result.num_kernels = pipelines.size(); + result.num_kernels = dispatches.size(); - for (size_t i = 0; i < pipelines.size(); i++) { - const size_t param_size = params_list[i].size() * sizeof(uint32_t); - const size_t param_offset = param_arena.alloc_slot(param_size); + for (size_t i = 0; i < dispatches.size(); i++) { + const webgpu_dispatch_desc & dispatch = dispatches[i]; + const size_t param_size = dispatch.params.size() * sizeof(uint32_t); + const size_t param_offset = ctx->param_arena.alloc_slot(param_size); - std::vector entries = bind_group_entries_list[i]; + std::vector entries = dispatch.bind_group_entries; uint32_t params_binding_num = entries.size(); entries.push_back({ .binding = params_binding_num, - .buffer = param_arena.buffer, + .buffer = ctx->param_arena.buffer, .offset = param_offset, - .size = param_arena.slot_size }); + .size = ctx->param_arena.slot_size }); wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); + bind_group_desc.layout = dispatch.pipeline.pipeline.GetBindGroupLayout(0); bind_group_desc.entryCount = entries.size(); bind_group_desc.entries = entries.data(); - bind_group_desc.label = pipelines[i].name.c_str(); - bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc)); + bind_group_desc.label = dispatch.pipeline.name.c_str(); + bind_groups.push_back(ctx->global_ctx->device.CreateBindGroup(&bind_group_desc)); param_offsets.push_back(param_offset); } for (size_t i = 0; i < param_offsets.size(); i++) { - ctx->queue.WriteBuffer(param_arena.buffer, param_offsets[i], params_list[i].data(), - params_list[i].size() * sizeof(uint32_t)); + ctx->global_ctx->queue.WriteBuffer(ctx->param_arena.buffer, param_offsets[i], dispatches[i].params.data(), + dispatches[i].params.size() * sizeof(uint32_t)); } + #ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); - if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - ts_bufs.host_buf.Unmap(); + for (size_t i = 0; i < dispatches.size(); i++) { + GGML_ASSERT(ctx->profile_timestamp_query_count + 2 <= WEBGPU_MAX_PROFILE_QUERY_COUNT); + const uint32_t query_begin = ctx->profile_timestamp_query_count++; + const uint32_t query_end = ctx->profile_timestamp_query_count++; + wgpu::PassTimestampWrites ts_writes = { .querySet = ctx->profile_timestamp_query_set, + .beginningOfPassWriteIndex = query_begin, + .endOfPassWriteIndex = query_end }; + wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); + + pass.SetPipeline(dispatches[i].pipeline.pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); + pass.End(); + result.pipeline_names.push_back(dispatches[i].pipeline.name); } - - wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, - .beginningOfPassWriteIndex = 0, - .endOfPassWriteIndex = 1 }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); #else - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); -#endif - for (size_t i = 0; i < pipelines.size(); i++) { - pass.SetPipeline(pipelines[i].pipeline); - pass.SetBindGroup(0, bind_groups[i]); - pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); + for (size_t i = 0; i < dispatches.size(); i++) { + ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline); + ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]); + ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); } - pass.End(); - -#ifdef GGML_WEBGPU_GPU_PROFILE - encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); - encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); - result.timestamp_query_bufs = ts_bufs; - result.pipeline_name = pipelines.front().name; #endif + return result; } -static webgpu_encoded_op ggml_backend_webgpu_build(webgpu_global_context & ctx, - webgpu_param_arena & param_arena, - wgpu::CommandEncoder & encoder, +static webgpu_encoded_op ggml_backend_webgpu_build(webgpu_context & ctx, webgpu_pipeline & pipeline, std::vector params, std::vector bind_group_entries, uint32_t wg_x, uint32_t wg_y = 1) { - return ggml_backend_webgpu_build_multi(ctx, param_arena, encoder, - { - pipeline - }, - { std::move(params) }, { std::move(bind_group_entries) }, - { { wg_x, wg_y } }); + return ggml_backend_webgpu_build_multi( + ctx, { + { pipeline, std::move(params), std::move(bind_group_entries), { wg_x, wg_y } }, + }); } static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, @@ -784,10 +685,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 return flags; } -static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, @@ -825,14 +723,13 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { const bool inplace = ggml_webgpu_tensor_equal(src0, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -891,13 +788,10 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; @@ -949,14 +843,13 @@ static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1011,14 +904,13 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1068,18 +960,17 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2]; - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * src3, - ggml_tensor * src4, - ggml_tensor * src5, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1154,14 +1045,13 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, h, n_seqs); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, h, n_seqs); } -static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { // For set rows specifically, we need to check if src and idx are empty // tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { @@ -1224,7 +1114,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, 1); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1); } // Workgroup size is a common constant @@ -1235,11 +1125,10 @@ static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_si return constants; } -static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -1291,14 +1180,13 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows; uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { // Determine if this is a mat-vec operation bool is_vec = (dst->ne[1] == 1); @@ -1437,15 +1325,14 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1457,10 +1344,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // Get or create pipeline webgpu_pipeline gather_pipeline, main_pipeline; - std::vector pipelines; - std::vector> params_list; - std::vector> entries_list; - std::vector> workgroups_list; + std::vector dispatches; gather_pipeline = ctx->shader_lib->get_mul_mat_id_gather_pipeline(shader_lib_ctx); main_pipeline = ctx->shader_lib->get_mul_mat_id_pipeline(shader_lib_ctx); @@ -1520,10 +1404,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim); const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x); - pipelines.push_back(gather_pipeline); - params_list.push_back(std::move(gather_params)); - entries_list.push_back(std::move(gather_entries)); - workgroups_list.push_back({ gather_wg_x, gather_wg_y }); + dispatches.push_back({ + gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, gather_wg_y } + }); // params for mul_mat_id.wgsl std::vector main_params = { @@ -1588,24 +1471,21 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); - pipelines.push_back(main_pipeline); - params_list.push_back(std::move(main_params)); - entries_list.push_back(std::move(main_entries)); - workgroups_list.push_back({ wg_x, wg_y }); + dispatches.push_back({ + main_pipeline, std::move(main_params), std::move(main_entries), { wg_x, wg_y } + }); - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } #ifndef __EMSCRIPTEN__ -static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { float scale = *(float *) dst->op_params; float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); @@ -1897,40 +1777,33 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, const uint64_t split_wg_total = (uint64_t) wg_x * nwg; GGML_ASSERT(split_wg_total <= UINT32_MAX); - std::vector pipelines; - std::vector> params_list; - std::vector> entries_list; - std::vector> workgroups_list; + std::vector dispatches; if (use_blk) { - pipelines.push_back(blk_pipeline); - params_list.push_back(std::move(blk_params)); - entries_list.push_back(std::move(blk_entries)); - workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count }); + dispatches.push_back({ + blk_pipeline, + std::move(blk_params), + std::move(blk_entries), + { blk_nblk0, blk_nblk1 * blk_batch_count } + }); } - pipelines.push_back(pipeline); - params_list.push_back(std::move(split_params)); - entries_list.push_back(std::move(split_entries)); - workgroups_list.push_back({ (uint32_t) split_wg_total, 1u }); + dispatches.push_back({ + pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + }); if (use_vec_reduce) { - pipelines.push_back(reduce_pipeline); - params_list.push_back(std::move(reduce_params)); - entries_list.push_back(std::move(reduce_entries)); - workgroups_list.push_back({ (uint32_t) nrows, 1u }); + dispatches.push_back({ + reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } + }); } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } #endif // __EMSCRIPTEN__ -static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); @@ -2005,14 +1878,13 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2108,14 +1980,13 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t dim = (uint32_t) dst->op_params[0]; @@ -2165,13 +2036,10 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { ne, @@ -2210,13 +2078,10 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { @@ -2256,16 +2121,14 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, }; webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, - ggml_nrows(src)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); } -static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2362,14 +2225,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2428,13 +2290,10 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2482,15 +2341,14 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2566,14 +2424,10 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, - ggml_nrows(dst)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); } -static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; @@ -2595,13 +2449,10 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nelements(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_top_k = dst->op == GGML_OP_TOP_K; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2659,10 +2510,7 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1]; const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2]; - std::vector pipelines; - std::vector> params_list; - std::vector> entries_list; - std::vector> workgroups_list; + std::vector dispatches; const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst; const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst); @@ -2686,14 +2534,12 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size } }; - pipelines.push_back(argsort_pipeline); - params_list.push_back(std::move(init_params)); - entries_list.push_back(std::move(init_entries)); - workgroups_list.push_back({ wg_x_init, wg_y_init }); + dispatches.push_back({ + argsort_pipeline, std::move(init_params), std::move(init_entries), { wg_x_init, wg_y_init } + }); if (merge_passes == 0) { - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } bool in_is_tmp = start_in_tmp; @@ -2745,23 +2591,18 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, const uint32_t total_wg_merge = nm * nrows; const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg); const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge); - workgroups_list.push_back({ wg_x_merge, wg_y_merge }); - pipelines.push_back(argsort_merge_pipeline); - params_list.push_back(std::move(merge_params)); - entries_list.push_back(std::move(merge_entries)); + dispatches.push_back({ + argsort_merge_pipeline, std::move(merge_params), std::move(merge_entries), { wg_x_merge, wg_y_merge } + }); len <<= 1; in_is_tmp = !in_is_tmp; } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } -static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; @@ -2786,13 +2627,10 @@ static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool total_sum = dst->op == GGML_OP_SUM; std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -2821,13 +2659,11 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } // Returns the encoded command, or std::nullopt if the operation is a no-op -static std::optional ggml_webgpu_encode_node(webgpu_context ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * node) { +static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { return std::nullopt; } @@ -2850,20 +2686,20 @@ static std::optional ggml_webgpu_encode_node(webgpu_context return std::nullopt; case GGML_OP_CPY: case GGML_OP_CONT: - return ggml_webgpu_cpy(ctx, encoder, src0, node); + return ggml_webgpu_cpy(ctx, src0, node); case GGML_OP_SET: - return ggml_webgpu_set(ctx, encoder, src0, src1, node); + return ggml_webgpu_set(ctx, src0, src1, node); case GGML_OP_SET_ROWS: - return ggml_webgpu_set_rows(ctx, encoder, src0, src1, node); + return ggml_webgpu_set_rows(ctx, src0, src1, node); case GGML_OP_GET_ROWS: - return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node); + return ggml_webgpu_get_rows(ctx, src0, src1, node); case GGML_OP_MUL_MAT: - return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node); + return ggml_webgpu_mul_mat(ctx, src0, src1, node); case GGML_OP_MUL_MAT_ID: - return ggml_webgpu_mul_mat_id(ctx, encoder, src0, src1, src2, node); + return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node); case GGML_OP_FLASH_ATTN_EXT: #ifndef __EMSCRIPTEN__ - return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node); + return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); #else return std::nullopt; #endif @@ -2871,22 +2707,22 @@ static std::optional ggml_webgpu_encode_node(webgpu_context case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - return ggml_webgpu_binary_op(ctx, encoder, src0, src1, node); + return ggml_webgpu_binary_op(ctx, src0, src1, node); case GGML_OP_CONCAT: - return ggml_webgpu_concat(ctx, encoder, src0, src1, node); + return ggml_webgpu_concat(ctx, src0, src1, node); case GGML_OP_REPEAT: - return ggml_webgpu_repeat(ctx, encoder, src0, node); + return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: - return ggml_webgpu_row_norm(ctx, encoder, src0, node); + return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: - return ggml_webgpu_rope(ctx, encoder, src0, src1, src2, node); + return ggml_webgpu_rope(ctx, src0, src1, src2, node); case GGML_OP_GLU: - return ggml_webgpu_glu(ctx, encoder, src0, src1, node); + return ggml_webgpu_glu(ctx, src0, src1, node); case GGML_OP_SCALE: - return ggml_webgpu_scale(ctx, encoder, src0, node); + return ggml_webgpu_scale(ctx, src0, node); case GGML_OP_SOFT_MAX: - return ggml_webgpu_soft_max(ctx, encoder, src0, src1, src2, node); + return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); case GGML_OP_UNARY: case GGML_OP_CLAMP: case GGML_OP_FILL: @@ -2897,32 +2733,80 @@ static std::optional ggml_webgpu_encode_node(webgpu_context case GGML_OP_COS: case GGML_OP_DIAG: case GGML_OP_TRI: - return ggml_webgpu_unary_op(ctx, encoder, src0, node); + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SOLVE_TRI: - return ggml_webgpu_solve_tri(ctx, encoder, src0, src1, node); + return ggml_webgpu_solve_tri(ctx, src0, src1, node); case GGML_OP_SSM_CONV: - return ggml_webgpu_ssm_conv(ctx, encoder, src0, src1, node); + return ggml_webgpu_ssm_conv(ctx, src0, src1, node); case GGML_OP_GATED_DELTA_NET: - return ggml_webgpu_gated_delta_net(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node->src[5], - node); + return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); case GGML_OP_PAD: - return ggml_webgpu_pad(ctx, encoder, src0, node); + return ggml_webgpu_pad(ctx, src0, node); case GGML_OP_ARGMAX: - return ggml_webgpu_argmax(ctx, encoder, src0, node); + return ggml_webgpu_argmax(ctx, src0, node); case GGML_OP_ARGSORT: case GGML_OP_TOP_K: // we reuse the same argsort implementation for top_k - return ggml_webgpu_argsort(ctx, encoder, src0, node); + return ggml_webgpu_argsort(ctx, src0, node); case GGML_OP_CUMSUM: - return ggml_webgpu_cumsum(ctx, encoder, src0, node); + return ggml_webgpu_cumsum(ctx, src0, node); case GGML_OP_SUM: case GGML_OP_SUM_ROWS: - return ggml_webgpu_sum_rows(ctx, encoder, src0, node); + return ggml_webgpu_sum_rows(ctx, src0, node); default: return std::nullopt; } } +#ifdef GGML_WEBGPU_GPU_PROFILE +static void ggml_backend_webgpu_collect_profile_results(webgpu_context & ctx, + const std::vector & pipeline_names, + uint32_t & num_inflight_batches) { + if (pipeline_names.empty()) { + return; + } + + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.ResolveQuerySet(ctx->profile_timestamp_query_set, 0, ctx->profile_timestamp_query_count, + ctx->profile_timestamp_dev_buf, 0); + encoder.CopyBufferToBuffer(ctx->profile_timestamp_dev_buf, 0, ctx->profile_timestamp_host_buf, 0, + ctx->profile_timestamp_query_count * sizeof(uint64_t)); + + wgpu::CommandBuffer profile_commands = encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, profile_commands, num_inflight_batches); + + const size_t mapped_size = ctx->profile_timestamp_query_count * sizeof(uint64_t); + GGML_ASSERT(ctx->profile_timestamp_query_count == 2 * pipeline_names.size()); + + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->profile_timestamp_host_buf, wgpu::MapMode::Read, 0, + mapped_size); + const uint64_t * ts_data = (const uint64_t *) ctx->profile_timestamp_host_buf.GetConstMappedRange(0, mapped_size); + + for (size_t i = 0; i < pipeline_names.size(); ++i) { + // WebGPU timestamps are in ns; convert to ms. + const double elapsed_ms = double(ts_data[2 * i + 1] - ts_data[2 * i]) * 1e-6; + ctx->global_ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; + } + + ctx->profile_timestamp_host_buf.Unmap(); +} +#endif + +static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & num_inflight_batches) { + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, + ctx->set_rows_host_error_buf.GetSize()); + wgpu::CommandBuffer commands = encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, commands, num_inflight_batches); + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, + ctx->set_rows_host_error_buf.GetSize()); + const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + ctx->set_rows_host_error_buf.Unmap(); +} + static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); @@ -2932,69 +2816,77 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); std::vector commands; + + uint32_t num_batched_kernels = 0; + uint32_t num_inflight_batches = 0; + bool contains_set_rows = false; + bool batch_compute_passes = true; + #ifdef GGML_WEBGPU_GPU_PROFILE - std::vector profile_futures; + ctx->profile_timestamp_query_count = 0; + batch_compute_passes = false; + std::vector profile_pipeline_names; #endif - uint32_t num_batched_kernels = 0; - uint32_t num_inflight_batches = 0; - bool contains_set_rows = false; - wgpu::CommandEncoder batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + + ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + if (batch_compute_passes) { + ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); + } for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } - if (auto cmd = ggml_webgpu_encode_node(ctx, batch_encoder, cgraph->nodes[i])) { + if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; +#ifdef GGML_WEBGPU_GPU_PROFILE + profile_pipeline_names.insert(profile_pipeline_names.end(), cmd->pipeline_names.begin(), + cmd->pipeline_names.end()); +#endif } if (num_batched_kernels >= ctx->global_ctx->command_submit_batch_size) { + if (ctx->active_compute_pass) { + ctx->active_compute_pass.End(); + } num_batched_kernels = 0; - wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); + wgpu::CommandBuffer batch_commands = ctx->active_command_encoder.Finish(); ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); -#endif + + // reset state for next batch + ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + if (batch_compute_passes) { + ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); + } ctx->param_arena.reset(); commands.clear(); - batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); } } - if (!commands.empty()) { - wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); + + if (ctx->active_compute_pass) { + ctx->active_compute_pass.End(); + ctx->active_compute_pass = nullptr; + } + + if (num_batched_kernels > 0) { + wgpu::CommandBuffer batch_commands = ctx->active_command_encoder.Finish(); ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); -#endif ctx->param_arena.reset(); commands.clear(); } + ctx->active_command_encoder = nullptr; + +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_collect_profile_results(ctx, profile_pipeline_names, num_inflight_batches); +#endif - // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking. if (contains_set_rows) { - wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, - ctx->set_rows_host_error_buf.GetSize()); - wgpu::CommandBuffer set_rows_commands = encoder.Finish(); - ggml_backend_webgpu_submit_commands(ctx, set_rows_commands, num_inflight_batches); + ggml_backend_webgpu_check_set_rows(ctx, num_inflight_batches); } ggml_backend_webgpu_wait_queue(ctx->global_ctx); - if (contains_set_rows) { - ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, - ctx->set_rows_host_error_buf.GetSize()); - const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); - if (*error_data) { - GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); - } - ctx->set_rows_host_error_buf.Unmap(); - } - -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx->global_ctx, profile_futures); -#endif WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -3535,14 +3427,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { "memset_params_buf"); ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue(); -#ifdef GGML_WEBGPU_GPU_PROFILE - // Initialize buffer pool for timestamp queries, used for profiling - ctx->webgpu_global_ctx->timestamp_query_buf_pool.init( - ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, - wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, - wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); -#endif - GGML_LOG_INFO( "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " "device_desc: %s\n", @@ -3567,6 +3451,19 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_webgpu_create_buffer( + webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_dev_buf, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, "profile_timestamp_dev_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_host_buf, + WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "profile_timestamp_host_buf"); + wgpu::QuerySetDescriptor query_set_desc = {}; + query_set_desc.type = wgpu::QueryType::Timestamp; + query_set_desc.count = WEBGPU_MAX_PROFILE_QUERY_COUNT; + webgpu_ctx->profile_timestamp_query_set = webgpu_ctx->global_ctx->device.CreateQuerySet(&query_set_desc); +#endif + #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,