Skip to content

Commit 0d0764d

Browse files
authored
[WebGPU] Implement async tensor api and event api (ggml-org#22099)
* Only run webgpu CI on my fork * Implement set_tensor_async * Implement synchronize api * Implement event creation and deletion API * Cleanup * Cleanup * Comment out jobs for local CI run * Add webgpu only workflow * Delete .github/workflows/build-webgpu.yml * Cleanup * Cleanup * Update API with function handlers * Run clang-format * Replace one-shot buffer with a direct queue.WriteBuffer using the buffer context
1 parent 6da7168 commit 0d0764d

1 file changed

Lines changed: 92 additions & 7 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2832,22 +2832,107 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
28322832
return GGML_STATUS_SUCCESS;
28332833
}
28342834

2835+
struct ggml_backend_webgpu_event_context {
2836+
webgpu_global_context global_ctx;
2837+
wgpu::Future future;
2838+
bool recorded = false;
2839+
};
2840+
2841+
static ggml_backend_event_t ggml_backend_webgpu_device_event_new(ggml_backend_dev_t device) {
2842+
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) device->context;
2843+
2844+
auto * event_ctx = new ggml_backend_webgpu_event_context();
2845+
event_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
2846+
2847+
auto * event = new ggml_backend_event;
2848+
event->device = device;
2849+
event->context = event_ctx;
2850+
return event;
2851+
}
2852+
2853+
static void ggml_backend_webgpu_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
2854+
GGML_UNUSED(dev);
2855+
delete static_cast<ggml_backend_webgpu_event_context *>(event->context);
2856+
delete event;
2857+
}
2858+
2859+
static void ggml_backend_webgpu_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
2860+
GGML_UNUSED(dev);
2861+
ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context;
2862+
if (!event_ctx->recorded) {
2863+
return;
2864+
}
2865+
wgpu::WaitStatus status =
2866+
event_ctx->global_ctx->instance.WaitAny(event_ctx->future, WEBGPU_RUNTIME_WAIT_TIMEOUT_NS);
2867+
if (status == wgpu::WaitStatus::TimedOut) {
2868+
GGML_ABORT("ggml_webgpu: event_synchronize timed out after %u ms\n", WEBGPU_RUNTIME_WAIT_TIMEOUT_MS);
2869+
}
2870+
event_ctx->recorded = false;
2871+
}
2872+
2873+
static void ggml_backend_webgpu_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
2874+
ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
2875+
ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context;
2876+
2877+
event_ctx->future = backend_ctx->webgpu_ctx->global_ctx->queue.OnSubmittedWorkDone(
2878+
wgpu::CallbackMode::AllowSpontaneous, [](wgpu::QueueWorkDoneStatus, wgpu::StringView) {});
2879+
event_ctx->recorded = true;
2880+
}
2881+
2882+
static void ggml_backend_webgpu_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
2883+
GGML_UNUSED(backend);
2884+
ggml_backend_webgpu_device_event_synchronize(nullptr, event);
2885+
}
2886+
2887+
static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend,
2888+
ggml_tensor * tensor,
2889+
const void * data,
2890+
size_t offset,
2891+
size_t size) {
2892+
GGML_UNUSED(backend);
2893+
auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
2894+
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
2895+
2896+
// Write aligned portion
2897+
buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
2898+
2899+
if (size % 4 != 0) {
2900+
// If size is not a multiple of 4, we need to memset the remaining bytes
2901+
size_t remaining_size = size % 4;
2902+
2903+
// pack the remaining bytes into a uint32_t
2904+
uint32_t val32 = 0;
2905+
2906+
for (size_t i = 0; i < remaining_size; i++) {
2907+
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
2908+
}
2909+
// memset the remaining bytes
2910+
ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
2911+
total_offset + (size - remaining_size), remaining_size);
2912+
}
2913+
}
2914+
2915+
static void ggml_backend_webgpu_synchronize(ggml_backend_t backend) {
2916+
ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
2917+
ggml_backend_webgpu_wait_queue(backend_ctx->webgpu_ctx->global_ctx);
2918+
}
2919+
28352920
static ggml_backend_i ggml_backend_webgpu_i = {
28362921
/* .get_name = */ ggml_backend_webgpu_name,
28372922
/* .free = */ ggml_backend_webgpu_free,
2838-
/* .set_tensor_async = */ NULL,
2923+
/* .set_tensor_async = */ ggml_backend_webgpu_set_tensor_async,
28392924
/* .get_tensor_async = */ NULL,
28402925
/* .get_tensor_2d_async = */ NULL,
28412926
/* .set_tensor_2d_async = */ NULL,
28422927
/* .cpy_tensor_async = */ NULL,
2843-
/* .synchronize = */ NULL,
2928+
/* .synchronize = */ ggml_backend_webgpu_synchronize,
28442929
/* .graph_plan_create = */ NULL,
28452930
/* .graph_plan_free = */ NULL,
28462931
/* .graph_plan_update = */ NULL,
28472932
/* .graph_plan_compute = */ NULL,
28482933
/* .graph_compute = */ ggml_backend_webgpu_graph_compute,
2849-
/* .event_record = */ NULL,
2850-
/* .event_wait = */ NULL,
2934+
/* .event_record = */ ggml_backend_webgpu_event_record,
2935+
/* .event_wait = */ ggml_backend_webgpu_event_wait,
28512936
/* .graph_optimize = */ NULL,
28522937
};
28532938

@@ -3810,9 +3895,9 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
38103895
/* .supports_op = */ ggml_backend_webgpu_device_supports_op,
38113896
/* .supports_buft = */ ggml_backend_webgpu_device_supports_buft,
38123897
/* .offload_op = */ NULL,
3813-
/* .event_new = */ NULL,
3814-
/* .event_free = */ NULL,
3815-
/* .event_synchronize = */ NULL,
3898+
/* .event_new = */ ggml_backend_webgpu_device_event_new,
3899+
/* .event_free = */ ggml_backend_webgpu_device_event_free,
3900+
/* .event_synchronize = */ ggml_backend_webgpu_device_event_synchronize,
38163901
};
38173902

38183903
/* End GGML Backend Device Interface */

0 commit comments

Comments
 (0)