Skip to content

Commit 117435f

Browse files
Constannnnntjjhartmann
authored andcommitted
ggml-webgpu: address quantization precision and backend lifecycle managment (ggml-org#21521)
* ggml(webgpu): fix the busy-polls in Emscripten in the waitAny after ggml-org#20618, and remove the busy webgpu log * Merge with upstream * Fix GET_ROWS packed integer NaN when using f16 as memory buffer in shader quants * Update Unary wgsl EXP and EXPM1 for f16 stability * Fix GET_ROWS IQ4_XS strcut for NaN f16 canonicalization * Fix numerical percision for unary sqrt when working with f16 * Fix NaN canonicalization for packed integers using f16 * Update err threshold for binary div ops when using f16 * backend: Keep one Dawn/WebGPU instance alive for the lifetime of the static backend * clean: uncomment existing code logs * clean: clean the unncessary debug info * Refactor and generalize dequant helpers * Remove deprecated quant structs * Refactor shader defines to reduce repetition * Remove error override for F16 type * fix: fix the accidential removal of the proper initialization of ctx * clean: clean legacy and format code * fix: did not modify tests ops --------- Co-authored-by: Jeremy J. Hartmann <jeremy@mtion.tv>
1 parent 0c2ca72 commit 117435f

8 files changed

Lines changed: 383 additions & 330 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,32 @@ class ggml_webgpu_shader_lib {
11151115
std::string type_upper = type_str;
11161116
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
11171117

1118+
switch (key.src_type)
1119+
{
1120+
case GGML_TYPE_Q4_0:
1121+
case GGML_TYPE_Q5_0:
1122+
case GGML_TYPE_Q8_0:
1123+
case GGML_TYPE_Q3_K:
1124+
case GGML_TYPE_Q6_K:
1125+
case GGML_TYPE_IQ2_XXS:
1126+
case GGML_TYPE_IQ2_XS:
1127+
case GGML_TYPE_IQ2_S:
1128+
case GGML_TYPE_IQ3_XXS:
1129+
case GGML_TYPE_IQ3_S:
1130+
case GGML_TYPE_IQ1_S:
1131+
case GGML_TYPE_IQ4_NL:
1132+
{
1133+
// Quantized types using u32 buffers for portability.
1134+
defines.push_back("SRC_TYPE=u32");
1135+
defines.push_back("U32_DEQUANT_HELPERS");
1136+
break;
1137+
}
1138+
default:
1139+
{
1140+
defines.push_back(std::string("SRC_TYPE=") + type_str);
1141+
}
1142+
}
1143+
11181144
defines.push_back("BYTE_HELPERS");
11191145
defines.push_back(type_upper + "_T");
11201146
defines.push_back(type_upper);
@@ -1125,7 +1151,6 @@ class ggml_webgpu_shader_lib {
11251151
variant += "_";
11261152
variant += type_str;
11271153

1128-
defines.push_back(std::string("SRC_TYPE=") + type_str);
11291154
defines.push_back("DST_TYPE=f32");
11301155

11311156
if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
@@ -1593,11 +1618,35 @@ class ggml_webgpu_shader_lib {
15931618
break;
15941619
default:
15951620
{
1596-
// quantized types
15971621
std::string type_upper = src0_name;
15981622
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
15991623

1600-
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
1624+
switch (context.src0->type)
1625+
{
1626+
case GGML_TYPE_Q4_0:
1627+
case GGML_TYPE_Q5_0:
1628+
case GGML_TYPE_Q8_0:
1629+
case GGML_TYPE_Q3_K:
1630+
case GGML_TYPE_Q6_K:
1631+
case GGML_TYPE_IQ2_XXS:
1632+
case GGML_TYPE_IQ2_XS:
1633+
case GGML_TYPE_IQ2_S:
1634+
case GGML_TYPE_IQ3_XXS:
1635+
case GGML_TYPE_IQ3_S:
1636+
case GGML_TYPE_IQ1_S:
1637+
case GGML_TYPE_IQ4_NL:
1638+
{
1639+
// Quantized types using u32 buffers for portability.
1640+
defines.push_back("SRC0_TYPE=u32");
1641+
defines.push_back("U32_DEQUANT_HELPERS");
1642+
break;
1643+
}
1644+
default:
1645+
{
1646+
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
1647+
}
1648+
}
1649+
16011650
defines.push_back("BYTE_HELPERS");
16021651
defines.push_back(type_upper + "_T");
16031652
defines.push_back(type_upper);

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
9797

9898
/* End Constants */
9999

100+
static inline wgpu::CallbackMode ggml_webgpu_callback_mode() {
101+
#ifdef __EMSCRIPTEN__
102+
return wgpu::CallbackMode::AllowProcessEvents;
103+
#else
104+
return wgpu::CallbackMode::AllowSpontaneous;
105+
#endif
106+
}
107+
100108
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
101109
// their locations.
102110
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
@@ -474,7 +482,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) {
474482

475483
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
476484
ctx->queue.OnSubmittedWorkDone(
477-
wgpu::CallbackMode::AllowSpontaneous,
485+
ggml_webgpu_callback_mode(),
478486
[&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
479487
callback_status = status;
480488
callback_message = std::string(message);
@@ -494,7 +502,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
494502
std::string callback_message;
495503

496504
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
497-
buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
505+
buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(),
498506
[&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) {
499507
callback_status = status;
500508
callback_message = std::string(message);
@@ -526,7 +534,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
526534
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
527535
wgpu::CommandBuffer commands = encoder.Finish();
528536
ctx->queue.Submit(1, &commands);
529-
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
537+
if (!ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0,
538+
ctx->debug_host_buf.GetSize())) {
539+
GGML_LOG_ERROR("ggml_webgpu: Debug buffer map failed\n");
540+
return;
541+
}
530542
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
531543
std::cout << "debug[0]: " << debug_data[0] << "\n";
532544
ctx->debug_host_buf.Unmap();
@@ -542,7 +554,7 @@ static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context &
542554
auto ts_bufs = command.timestamp_query_bufs;
543555

544556
wgpu::Future f = ts_bufs.host_buf.MapAsync(
545-
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
557+
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(),
546558
[ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
547559
if (status != wgpu::MapAsyncStatus::Success) {
548560
GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
@@ -3420,7 +3432,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
34203432

34213433
ctx->webgpu_global_ctx->instance.WaitAny(
34223434
ctx->webgpu_global_ctx->instance.RequestAdapter(
3423-
&options, wgpu::CallbackMode::AllowSpontaneous,
3435+
&options, ggml_webgpu_callback_mode(),
34243436
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
34253437
if (status != wgpu::RequestAdapterStatus::Success) {
34263438
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
@@ -3491,8 +3503,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
34913503
dev_desc.requiredFeatures = required_features.data();
34923504
dev_desc.requiredFeatureCount = required_features.size();
34933505
dev_desc.SetDeviceLostCallback(
3494-
wgpu::CallbackMode::AllowSpontaneous,
3495-
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
3506+
ggml_webgpu_callback_mode(),
3507+
[ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
34963508
if (reason == wgpu::DeviceLostReason::Destroyed) {
34973509
return;
34983510
}
@@ -3525,7 +3537,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
35253537

35263538
ctx->webgpu_global_ctx->instance.WaitAny(
35273539
ctx->webgpu_global_ctx->adapter.RequestDevice(
3528-
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
3540+
&dev_desc, ggml_webgpu_callback_mode(),
35293541
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
35303542
if (status != wgpu::RequestDeviceStatus::Success) {
35313543
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
@@ -4046,6 +4058,13 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
40464058
ctx.name = GGML_WEBGPU_NAME;
40474059
ctx.device_count = 0;
40484060

4061+
// Keep one Dawn/WebGPU instance alive for the lifetime of the static backend
4062+
// registry. Recreating it on repeated registry lookups can invalidate
4063+
// adapter/device references that are still held by the backend/device layer.
4064+
if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) {
4065+
return &reg;
4066+
}
4067+
40494068
wgpu::InstanceDescriptor instance_descriptor{};
40504069
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
40514070
instance_descriptor.requiredFeatures = instance_features.data();
@@ -4063,11 +4082,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
40634082
ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
40644083
ctx.webgpu_global_ctx->instance = std::move(inst);
40654084

4085+
// Probe for adapter support
40664086
wgpu::Adapter adapter;
40674087
if (ctx.webgpu_global_ctx->instance != nullptr) {
40684088
wgpu::RequestAdapterOptions options = {};
40694089

4070-
// probe for adapter support
40714090
ctx.webgpu_global_ctx->instance.WaitAny(
40724091
ctx.webgpu_global_ctx->instance.RequestAdapter(
40734092
&options, wgpu::CallbackMode::AllowSpontaneous,

ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl

Lines changed: 30 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,44 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
99
#endif
1010

1111
#ifdef U32_DEQUANT_HELPERS
12-
fn load_src0_u16_at(byte_offset: u32) -> u32 {
13-
let word = src0[byte_offset / 4u];
14-
let shift = (byte_offset & 2u) * 8u;
15-
return (word >> shift) & 0xFFFFu;
12+
fn load_u16_at(
13+
buf: ptr<storage, array<u32>, read_write>,
14+
byte_offset: u32) -> u32 {
15+
let word = buf[byte_offset / 4];
16+
let shift = (byte_offset & 0x2) * 8;
17+
return (word >> shift) & 0xFFFF;
1618
}
1719

18-
fn load_src0_u32_at(byte_offset: u32) -> u32 {
19-
let word_idx = byte_offset / 4u;
20-
let shift = (byte_offset & 3u) * 8u;
21-
let lo = src0[word_idx];
22-
if (shift == 0u) {
23-
return lo;
24-
}
25-
let hi = src0[word_idx + 1u];
26-
return (lo >> shift) | (hi << (32u - shift));
20+
fn load_u32_at(
21+
buf: ptr<storage, array<u32>, read_write>,
22+
byte_offset: u32) -> u32 {
23+
let word_idx = byte_offset / 4;
24+
let shift = (byte_offset & 0x3) * 8;
25+
let lo = buf[word_idx];
26+
let hi = buf[word_idx + 1];
27+
let shifted = (lo >> shift) | (hi << (32 - shift));
28+
return select(shifted, lo, shift == 0);
2729
}
2830

29-
fn load_src0_f16_at(byte_offset: u32) -> f16 {
30-
let packed = unpack2x16float(load_src0_u16_at(byte_offset));
31+
fn load_f16_at(
32+
buf: ptr<storage, array<u32>, read_write>,
33+
byte_offset: u32) -> f16 {
34+
let packed = unpack2x16float(load_u16_at(buf, byte_offset));
3135
return f16(packed[0]);
3236
}
33-
#endif
3437

35-
#ifdef Q4_0_T
36-
struct q4_0 {
37-
d: f16,
38-
qs: array<f16, 8>
39-
};
38+
fn load_f16_as_f32_at(
39+
buf: ptr<storage, array<u32>, read_write>,
40+
byte_offset: u32) -> f32 {
41+
let word = buf[byte_offset / 4];
42+
let shift = (byte_offset & 0x2) * 8;
43+
let d_bits = (word >> shift) & 0xFFFF;
44+
return unpack2x16float(d_bits)[0];
45+
}
4046
#endif
4147

48+
49+
4250
#ifdef Q4_1_T
4351
struct q4_1 {
4452
d: f16,
@@ -47,13 +55,6 @@ struct q4_1 {
4755
};
4856
#endif
4957

50-
#ifdef Q5_0_T
51-
struct q5_0 {
52-
d: f16,
53-
qh: array<f16, 2>,
54-
qs: array<f16, 8>
55-
};
56-
#endif
5758

5859
#ifdef Q5_1_T
5960
struct q5_1 {
@@ -64,12 +65,6 @@ struct q5_1 {
6465
};
6566
#endif
6667

67-
#ifdef Q8_0_T
68-
struct q8_0 {
69-
d: f16,
70-
qs: array<f16, 16>
71-
};
72-
#endif
7368

7469
#ifdef Q8_1_T
7570
struct q8_1 {
@@ -88,14 +83,6 @@ struct q2_K {
8883
};
8984
#endif
9085

91-
#ifdef Q3_K_T
92-
struct q3_K {
93-
hmask: array<f16, 16>,
94-
qs: array<f16, 32>,
95-
scales: array<f16, 6>,
96-
d: f16
97-
};
98-
#endif
9986

10087
#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
10188
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
@@ -132,64 +119,6 @@ struct q5_K {
132119
};
133120
#endif
134121

135-
#ifdef Q6_K_T
136-
struct q6_K {
137-
ql: array<f16, 64>,
138-
qh: array<f16, 32>,
139-
scales: array<f16, 8>,
140-
d: f16
141-
};
142-
#endif
143-
144-
#ifdef IQ2_XXS_T
145-
struct iq2_xxs {
146-
d: f16,
147-
qs: array<f16, 32>
148-
};
149-
#endif
150-
151-
#ifdef IQ2_XS_T
152-
struct iq2_xs {
153-
d: f16,
154-
qs: array<f16, 32>,
155-
scales: array<f16, 4>
156-
};
157-
#endif
158-
159-
#ifdef IQ2_S_T
160-
struct iq2_s {
161-
d: f16,
162-
qs: array<f16, 32>,
163-
qh: array<f16, 4>,
164-
scales: array<f16, 4>
165-
};
166-
#endif
167-
168-
#ifdef IQ3_XXS_T
169-
struct iq3_xxs {
170-
d: f16,
171-
qs: array<f16, 48>
172-
};
173-
#endif
174-
175-
#ifdef IQ3_S_T
176-
struct iq3_s {
177-
d: f16,
178-
qs: array<f16, 32>,
179-
qh: array<f16, 4>,
180-
signs: array<f16, 16>,
181-
scales: array<f16, 2>
182-
};
183-
#endif
184-
185-
#ifdef IQ1_S_T
186-
struct iq1_s {
187-
d: f16,
188-
qs: array<f16, 16>,
189-
qh: array<f16, 8>
190-
};
191-
#endif
192-
193122
#ifdef IQ1_M_T
194123
struct iq1_m {
195124
qs: array<u32, 8>,
@@ -198,17 +127,9 @@ struct iq1_m {
198127
};
199128
#endif
200129

201-
#ifdef IQ4_NL_T
202-
struct iq4_nl {
203-
d: f16,
204-
qs: array<f16, 8>,
205-
};
206-
#endif
207-
208130
#ifdef IQ4_XS_T
209131
struct iq4_xs {
210-
d: f16,
211-
scales_h: f16,
132+
d_scales_h: u32,
212133
scales_l: u32,
213134
qs: array<u32, 32>
214135
};

0 commit comments

Comments
 (0)