@@ -3724,7 +3724,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
37243724 ctx->memset_pipeline = ggml_webgpu_create_pipeline (ctx->device , wgsl_memset, " memset" , constants);
37253725}
37263726
3727- static bool create_webgpu_device (ggml_backend_webgpu_reg_context * ctx) {
3727+ static void create_webgpu_device (ggml_backend_webgpu_reg_context * ctx) {
37283728 wgpu::RequestAdapterOptions options = {};
37293729
37303730#ifndef __EMSCRIPTEN__
@@ -3762,10 +3762,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
37623762 ctx->webgpu_global_ctx ->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size ();
37633763 ctx->webgpu_global_ctx ->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches ();
37643764 ctx->webgpu_global_ctx ->vendor = info.vendor ;
3765- wgpu::SupportedFeatures features;
3766- ctx->webgpu_global_ctx ->adapter .GetFeatures (&features);
3767- // we require f16 support
3768- GGML_ASSERT (ctx->webgpu_global_ctx ->adapter .HasFeature (wgpu::FeatureName::ShaderF16));
37693765 ctx->webgpu_global_ctx ->capabilities .supports_subgroups =
37703766 ctx->webgpu_global_ctx ->adapter .HasFeature (wgpu::FeatureName::Subgroups);
37713767 // for dot4I8packed
@@ -3877,7 +3873,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
38773873 " device_desc: %s\n " ,
38783874 info.vendorID , std::string (info.vendor ).c_str (), std::string (info.architecture ).c_str (), info.deviceID ,
38793875 std::string (info.device ).c_str (), std::string (info.description ).c_str ());
3880- return true ;
38813876}
38823877
38833878static webgpu_context initialize_webgpu_context (ggml_backend_dev_t dev) {
@@ -4507,16 +4502,24 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
45074502 UINT64_MAX);
45084503 }
45094504
4510- if (adapter != nullptr ) {
4505+ // WebGPU backend requires f16 support and, on native, implicit device synchronization.
4506+ if (adapter != nullptr && adapter.HasFeature (wgpu::FeatureName::ShaderF16)
4507+ #ifndef __EMSCRIPTEN__
4508+ && adapter.HasFeature (wgpu::FeatureName::ImplicitDeviceSynchronization)
4509+ #endif
4510+ ) {
45114511 ctx->device_count = 1 ;
45124512 }
45134513
45144514 return ®
45154515}
45164516
45174517ggml_backend_t ggml_backend_webgpu_init (void ) {
4518- ggml_backend_dev_t dev = ggml_backend_reg_dev_get (ggml_backend_webgpu_reg (), 0 );
4519-
4518+ ggml_backend_reg_t reg = ggml_backend_webgpu_reg ();
4519+ if (ggml_backend_reg_dev_count (reg) == 0 ) {
4520+ return nullptr ;
4521+ }
4522+ ggml_backend_dev_t dev = ggml_backend_reg_dev_get (reg, 0 );
45204523 return ggml_backend_webgpu_backend_init (dev, nullptr );
45214524}
45224525
0 commit comments