Skip to content

Commit 7c90850

Browse files
authored
ggml-webgpu: improve MTP inference by using mat-vec path for small batches (ggml-org#24811)
* ggml-webgpu: improve small batches decoding * Add barrier to the NUM_COLS loop in mul-mat-vec
1 parent 035cd8f commit 7c90850

8 files changed

Lines changed: 687 additions & 596 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key {
905905
ggml_type src0_type;
906906
ggml_type src1_type;
907907
int vectorized;
908+
uint32_t num_cols;
908909
bool use_mmvq;
909910

910911
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
911912
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
912-
use_mmvq == other.use_mmvq;
913+
num_cols == other.num_cols && use_mmvq == other.use_mmvq;
913914
}
914915
};
915916

@@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
919920
ggml_webgpu_hash_combine(seed, key.src0_type);
920921
ggml_webgpu_hash_combine(seed, key.src1_type);
921922
ggml_webgpu_hash_combine(seed, key.vectorized);
923+
ggml_webgpu_hash_combine(seed, key.num_cols);
922924
ggml_webgpu_hash_combine(seed, key.use_mmvq);
923925
return seed;
924926
}
@@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key {
993995
ggml_type src0_type;
994996
ggml_type src1_type;
995997
uint32_t n_experts;
998+
uint32_t num_cols;
996999
int vectorized;
9971000

9981001
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
9991002
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
1000-
vectorized == other.vectorized;
1003+
num_cols == other.num_cols && vectorized == other.vectorized;
10011004
}
10021005
};
10031006

@@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
10071010
ggml_webgpu_hash_combine(seed, key.src0_type);
10081011
ggml_webgpu_hash_combine(seed, key.src1_type);
10091012
ggml_webgpu_hash_combine(seed, key.n_experts);
1013+
ggml_webgpu_hash_combine(seed, key.num_cols);
10101014
ggml_webgpu_hash_combine(seed, key.vectorized);
10111015
return seed;
10121016
}
@@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
11071111
const ggml_tensor * src1,
11081112
bool supports_dot_product,
11091113
const std::string & vendor) {
1110-
if (src1->ne[1] == 1) {
1114+
if (src1->ne[1] <= 4) {
11111115
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
11121116
if (supports_dp4a && supports_dot_product) {
11131117
switch (src1->type) {
@@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib {
18891893
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
18901894
1 :
18911895
0;
1896+
key.num_cols = context.dst->ne[1];
18921897
key.use_mmvq =
18931898
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
18941899

@@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib {
20042009
if (key.vectorized) {
20052010
variant += "_vectorized";
20062011
}
2012+
defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols));
20072013

20082014
auto processed = preprocessor.preprocess(shader_src, defines);
20092015
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
@@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib {
24212427
if (key.vectorized) {
24222428
variant += "_vectorized";
24232429
}
2430+
defines.push_back(std::string("NUM_COLS=1"));
24242431

24252432
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
24262433

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
14181418
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
14191419
const size_t q8_src1_align_offset = ROUNDUP_POW2(
14201420
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
1421-
const size_t q8_src1_binding_size =
1422-
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
1423-
WEBGPU_STORAGE_BUF_BINDING_MULT);
1421+
const size_t q8_src1_binding_size = ROUNDUP_POW2(
1422+
src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
1423+
WEBGPU_STORAGE_BUF_BINDING_MULT);
14241424

14251425
std::vector<uint32_t> q8_params = {
14261426
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1427+
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
14271428
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
14281429
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
14291430
(uint32_t) src1->ne[0],
1431+
(uint32_t) src1->ne[1],
14301432
(uint32_t) src1->ne[2],
14311433
(uint32_t) src1->ne[3],
14321434
};
@@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
14421444
uint32_t q8_wg_x = 1;
14431445
uint32_t q8_wg_y = 1;
14441446
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
1445-
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
1447+
const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec;
14461448
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
14471449
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
14481450

@@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
14561458
ggml_tensor * src1,
14571459
ggml_tensor * dst) {
14581460
// Determine if this is a mat-vec operation
1459-
bool is_vec = (dst->ne[1] == 1);
1461+
bool use_mat_vec = (dst->ne[1] <= 4);
14601462

14611463
// use MMVQ path for mat-vec
14621464
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
@@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
14821484
webgpu_pipeline pipeline;
14831485
std::vector<webgpu_dispatch_desc> dispatches;
14841486

1485-
if (is_vec) {
1487+
if (use_mat_vec) {
14861488
if (use_mmvq) {
14871489
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
14881490
}
@@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
15291531
uint32_t wg_y = 1;
15301532
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
15311533

1532-
if (is_vec) {
1534+
if (use_mat_vec) {
15331535
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
15341536

15351537
uint32_t batches = dst->ne[2] * dst->ne[3];
@@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
36913693
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
36923694
ctx->webgpu_global_ctx->vendor);
36933695
if (use_mmvq) {
3694-
const size_t q8_src1_size =
3695-
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
3696+
const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] *
3697+
(36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
36963698
res = ROUNDUP_POW2(res + q8_src1_size +
36973699
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
36983700
WEBGPU_STORAGE_BUF_BINDING_MULT);

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ fn main(
103103

104104
#ifdef USE_SUBGROUP_REDUCTION
105105
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
106-
let subgroup_total = subgroupAdd(acc[row]);
106+
let subgroup_total = subgroupAdd(acc[0][row]);
107107
if (subgroup_invocation_id == 0u) {
108108
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
109109
}
@@ -126,7 +126,7 @@ fn main(
126126

127127
#ifdef USE_WORKGROUP_REDUCTION
128128
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
129-
partial_sums[partial_index(row, thread_id)] = acc[row];
129+
partial_sums[partial_index(row, thread_id)] = acc[0][row];
130130
}
131131

132132
workgroupBarrier();

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -91,61 +91,67 @@ fn main(
9191
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
9292

9393
#ifdef MMVQ
94-
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
94+
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u);
9595
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
9696
#else
9797
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
9898
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
9999
#endif
100100

101-
#ifdef USE_SUBGROUP_REDUCTION
102-
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
103-
let subgroup_total = subgroupAdd(acc[row]);
104-
if (subgroup_invocation_id == 0u) {
105-
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
106-
}
107-
}
101+
for (var col = 0u;col < NUM_COLS;col += 1) {
108102

109-
workgroupBarrier();
103+
#ifdef USE_SUBGROUP_REDUCTION
104+
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
105+
let subgroup_total = subgroupAdd(acc[col][row]);
106+
if (subgroup_invocation_id == 0u) {
107+
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
108+
}
109+
}
110110

111-
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
112-
let output_row = row_base + row;
113-
var row_acc = 0.0f;
114-
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
115-
row_acc += partial_sums[partial_index(row, k)];
116-
}
117-
let row_total = subgroupAdd(row_acc);
118-
if (subgroup_invocation_id == 0) {
119-
dst[dst_idx_base + row] = row_total;
120-
}
121-
}
111+
workgroupBarrier();
112+
113+
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
114+
let output_row = row_base + row;
115+
var row_acc = 0.0f;
116+
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
117+
row_acc += partial_sums[partial_index(row, k)];
118+
}
119+
let row_total = subgroupAdd(row_acc);
120+
if (subgroup_invocation_id == 0) {
121+
dst[dst_idx_base + col * params.m + row] = row_total;
122+
}
123+
}
122124
#endif
123125

124126
#ifdef USE_WORKGROUP_REDUCTION
125-
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
126-
partial_sums[partial_index(row, thread_id)] = acc[row];
127-
}
127+
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
128+
partial_sums[partial_index(row, thread_id)] = acc[col][row];
129+
}
128130

129-
workgroupBarrier();
131+
workgroupBarrier();
130132

131-
var stride = WG_SIZE / 2u;
133+
var stride = WG_SIZE / 2u;
132134

133-
while (stride > 0) {
134-
if (thread_id < stride) {
135-
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
136-
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
135+
while (stride > 0) {
136+
if (thread_id < stride) {
137+
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
138+
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
139+
}
140+
}
141+
142+
workgroupBarrier();
143+
stride = stride / 2;
137144
}
138-
}
139145

140-
workgroupBarrier();
141-
stride = stride / 2;
142-
}
146+
if (thread_id < OUTPUTS_PER_WG) {
147+
let output_row = row_base + thread_id;
148+
if (output_row < params.m) {
149+
dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)];
150+
}
151+
}
152+
#endif
153+
154+
workgroupBarrier();
143155

144-
if (thread_id < OUTPUTS_PER_WG) {
145-
let output_row = row_base + thread_id;
146-
if (output_row < params.m) {
147-
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
148-
}
149156
}
150-
#endif
151157
}

0 commit comments

Comments
 (0)