Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 17 additions & 34 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim

/* Constants */

#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 32u
#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 64u
#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6)
Expand All @@ -97,14 +97,6 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim

/* End Constants */

static inline wgpu::CallbackMode ggml_webgpu_callback_mode() {
#ifdef __EMSCRIPTEN__
return wgpu::CallbackMode::AllowProcessEvents;
#else
return wgpu::CallbackMode::AllowSpontaneous;
#endif
}

// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
// their locations.
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
Expand Down Expand Up @@ -445,34 +437,25 @@ static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status,
}

#ifdef __EMSCRIPTEN__
// iOS browsers seem to have very strict limits on the number of in-flight GPU commands, so we need to throttle to avoid failures.
EM_JS(int, ggml_webgpu_is_ios_browser, (), {
const ua = navigator.userAgent;
return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0;
});
#endif

static uint32_t ggml_backend_webgpu_get_max_inflight_batches(const wgpu::AdapterInfo & info) {
// TODO: these next two functions may want tuning across different platforms and workloads,
static uint32_t ggml_backend_webgpu_get_max_inflight_batches() {
#ifdef __EMSCRIPTEN__
// iOS has very strict limits on the number of in-flight GPU commands,
// so we need to throttle to avoid failures.
if (ggml_webgpu_is_ios_browser()) {
return 1;
}
#else
GGML_UNUSED(info);
#endif

return UINT32_MAX;
}

static uint32_t ggml_backend_webgpu_get_command_submit_batch_size(const wgpu::AdapterInfo & info) {
#ifdef __EMSCRIPTEN__
if (ggml_webgpu_is_ios_browser()) {
return 16;
}
#else
GGML_UNUSED(info);
#endif

static uint32_t ggml_backend_webgpu_get_command_submit_batch_size() {
return WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE;
}

Expand All @@ -482,7 +465,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) {

const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
ctx->queue.OnSubmittedWorkDone(
ggml_webgpu_callback_mode(),
wgpu::CallbackMode::AllowSpontaneous,
[&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
callback_status = status;
callback_message = std::string(message);
Expand All @@ -502,7 +485,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
std::string callback_message;

const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(),
buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
[&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) {
callback_status = status;
callback_message = std::string(message);
Expand Down Expand Up @@ -546,15 +529,15 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
#endif

#ifdef GGML_WEBGPU_GPU_PROFILE
static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx,
const std::vector<webgpu_command> & commands,
std::vector<wgpu::FutureWaitInfo> & futures) {
static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx,
const std::vector<webgpu_encoded_op> & commands,
std::vector<wgpu::FutureWaitInfo> & futures) {
for (const auto & command : commands) {
auto label = command.pipeline_name;
auto ts_bufs = command.timestamp_query_bufs;

wgpu::Future f = ts_bufs.host_buf.MapAsync(
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(),
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
[ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
Expand Down Expand Up @@ -3432,7 +3415,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {

ctx->webgpu_global_ctx->instance.WaitAny(
ctx->webgpu_global_ctx->instance.RequestAdapter(
&options, ggml_webgpu_callback_mode(),
&options, wgpu::CallbackMode::AllowSpontaneous,
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
Expand All @@ -3453,8 +3436,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
}
#endif
ctx->webgpu_global_ctx->adapter.GetInfo(&info);
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(info);
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(info);
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size();
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches();
wgpu::SupportedFeatures features;
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
// we require f16 support
Expand Down Expand Up @@ -3505,7 +3488,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
dev_desc.requiredFeatures = required_features.data();
dev_desc.requiredFeatureCount = required_features.size();
dev_desc.SetDeviceLostCallback(
ggml_webgpu_callback_mode(),
wgpu::CallbackMode::AllowSpontaneous,
[ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
if (reason == wgpu::DeviceLostReason::Destroyed) {
return;
Expand Down Expand Up @@ -3539,7 +3522,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {

ctx->webgpu_global_ctx->instance.WaitAny(
ctx->webgpu_global_ctx->adapter.RequestDevice(
&dev_desc, ggml_webgpu_callback_mode(),
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
if (status != wgpu::RequestDeviceStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
Expand Down
35 changes: 14 additions & 21 deletions ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,6 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let d = load_f16_at(&src0, block_byte_base);
let dmin = load_f16_at(&src0, block_byte_base + 2u);

// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
}

// Map k_in_block to loop structure:
// Outer loop over 64-element groups (alternating q_b_idx)
// Inner loop over 2 shifts per group
Expand All @@ -523,15 +517,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
var sc: u32;
var mn: u32;

let scale_base = block_byte_base + 4u;

if (is < 4u) {
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u);
let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);

sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
Expand Down Expand Up @@ -578,11 +574,6 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let d = load_f16_at(&src0, block_byte_base);
let dmin = load_f16_at(&src0, block_byte_base + 2u);

// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
}

// The original loop processes elements in groups of 64
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
Expand All @@ -603,15 +594,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
var sc: u32;
var mn: u32;

let scale_base = block_byte_base + 4u;

if (is < 4u) {
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u);
let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);

sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
Expand Down
12 changes: 6 additions & 6 deletions ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ enable f16;
#include "mul_mat_decls.tmpl"

#ifdef VEC
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
fn store_val(acc: array<array<f32, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
return vec4<f32>(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]);
}
#endif

#ifdef SCALAR
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
return f32(acc[tm][tn]);
fn store_val(acc: array<array<f32, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
return acc[tm][tn];
}
#endif

Expand Down Expand Up @@ -98,7 +98,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;

var acc: array<array<f16, TILE_N>, TILE_M>;
var acc: array<array<f32, TILE_N>, TILE_M>;

for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {

Expand All @@ -122,7 +122,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let src1_idx = src1_n * TILE_K + k_inner;
let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
for (var tm = 0u; tm < TILE_M; tm++) {
acc[tm][tn] += src0_tile[tm] * src1_val;
acc[tm][tn] += f32(src0_tile[tm]) * f32(src1_val);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ enable chromium_experimental_subgroup_matrix;
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"

// TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs.
// See https://github.com/ggml-org/llama.cpp/issues/21602

#ifdef VEC
fn store_dst(shmem_idx: u32, dst_idx: u32) {
dst[dst_idx] = vec4<f32>(
Expand Down
Loading