Skip to content

Commit 151f3a9

Browse files
authored
ggml-webgpu: Check earlier for WebGPU required features (ggml-org#23879)
1 parent b22da25 commit 151f3a9

1 file changed

Lines changed: 12 additions & 9 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3724,7 +3724,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
37243724
ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
37253725
}
37263726

3727-
static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
3727+
static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
37283728
wgpu::RequestAdapterOptions options = {};
37293729

37303730
#ifndef __EMSCRIPTEN__
@@ -3762,10 +3762,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
37623762
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size();
37633763
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches();
37643764
ctx->webgpu_global_ctx->vendor = info.vendor;
3765-
wgpu::SupportedFeatures features;
3766-
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
3767-
// we require f16 support
3768-
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
37693765
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
37703766
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
37713767
// for dot4I8packed
@@ -3877,7 +3873,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
38773873
"device_desc: %s\n",
38783874
info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
38793875
std::string(info.device).c_str(), std::string(info.description).c_str());
3880-
return true;
38813876
}
38823877

38833878
static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
@@ -4507,16 +4502,24 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
45074502
UINT64_MAX);
45084503
}
45094504

4510-
if (adapter != nullptr) {
4505+
// WebGPU backend requires f16 support and, on native, implicit device synchronization.
4506+
if (adapter != nullptr && adapter.HasFeature(wgpu::FeatureName::ShaderF16)
4507+
#ifndef __EMSCRIPTEN__
4508+
&& adapter.HasFeature(wgpu::FeatureName::ImplicitDeviceSynchronization)
4509+
#endif
4510+
) {
45114511
ctx->device_count = 1;
45124512
}
45134513

45144514
return ®
45154515
}
45164516

45174517
ggml_backend_t ggml_backend_webgpu_init(void) {
4518-
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
4519-
4518+
ggml_backend_reg_t reg = ggml_backend_webgpu_reg();
4519+
if (ggml_backend_reg_dev_count(reg) == 0) {
4520+
return nullptr;
4521+
}
4522+
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(reg, 0);
45204523
return ggml_backend_webgpu_backend_init(dev, nullptr);
45214524
}
45224525

0 commit comments

Comments
 (0)