Skip to content

Commit c36c422

Browse files
authored
[WebGPU EP] Fuse QMoE 1-token decode path to reduce GPU dispatches (#27998)
### Description: ### Summary Fuse the QMoE 1-token decode path to reduce GPU dispatches from 17 (1 + k×4) to 5 (gate + fc1 + swiglu + fc2 + mix), improving token generation throughput by ~21% on Meteor Lake for the gpt-oss-20b MoE model (19 → 23 tps). ### Motivation The QMoE operator processes Mixture-of-Experts layers by selecting top-k experts (k=4) per token. In the original 1-token decode path, each expert is processed serially with 4 dispatches (gather + fc1 + swiglu + fc2 + mix), totaling 17 GPU dispatches per QMoE call. Since each dispatch has M=1, the GPU is underutilized and CPU dispatch overhead dominates. ### Approach For the 1-token path (num_rows == 1): **Gate1Token** — Select top-k experts and output an [indirect_experts](vscode-file://vscode-app/c:/Users/jiajiaqin/AppData/Local/Programs/Microsoft%20VS%20Code/ce099c1ed2/resources/app/out/vs/code/electron-browser/workbench/workbench.html) buffer mapping row index → expert index **Batched fc1 MatMulNBits** — Run a single M=k matmul with [per_row_weight_indirect](vscode-file://vscode-app/c:/Users/jiajiaqin/AppData/Local/Programs/Microsoft%20VS%20Code/ce099c1ed2/resources/app/out/vs/code/electron-browser/workbench/workbench.html) mode, where each row selects a different expert's weights via the indirect buffer **SwiGLU** — Apply activation on all k rows at once **Batched fc2 MatMulNBits** — Same per-row expert selection for the down projection **FusedFinalMix** — Accumulate all k weighted expert results into the output ### Follow-ups Fuse Batched fc1 MatMulNBits + SwiGLU Fuse Batched fc2 MatMulNBits + FusedFinalMix Finally, we only need three shaders: Gate1Token, fused Batched fc1 MatMulNBits, fused batched fc2 MatMulNBits.
1 parent 87b0643 commit c36c422

10 files changed

Lines changed: 297 additions & 137 deletions

onnxruntime/contrib_ops/webgpu/moe/final_mix_1token.wgsl.template

Lines changed: 0 additions & 19 deletions
This file was deleted.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
// Fused FinalMix for 1-token MoE: processes all k experts in one dispatch.
5+
// The kernel is dispatched over hidden_size; each thread computes one output
6+
// element and accumulates the weighted contributions from all k experts.
7+
// in: fc2_outputs [k, hidden_size] — concatenated fc2 results for all k experts
8+
// in: router_values [1, num_experts] — softmax weights per expert
9+
// in: indirect_experts [k] — which expert index each row corresponds to
10+
// out: output [1, hidden_size] — accumulated weighted output
11+
// uniform: hidden_size, k
12+
13+
$MAIN {
14+
let out_idx = workgroup_idx * workgroup_size_x + local_idx;
15+
if (out_idx >= uniforms.hidden_size) {
16+
return;
17+
}
18+
var acc = output_element_t(0);
19+
for (var i = 0u; i < uniforms.k; i++) {
20+
let expert_idx = indirect_experts[i];
21+
let router_value = router_values[expert_idx];
22+
acc += router_value * fc2_outputs[i * uniforms.hidden_size + out_idx];
23+
}
24+
output[out_idx] = acc;
25+
} // MAIN

onnxruntime/contrib_ops/webgpu/moe/qmoe.cc

Lines changed: 81 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -124,44 +124,41 @@ class SwigLuProgram final : public Program<SwigLuProgram> {
124124
private:
125125
};
126126

127-
class QMoEFinalMixProgram final : public Program<QMoEFinalMixProgram> {
127+
class FusedFinalMix1TokenProgram final : public Program<FusedFinalMix1TokenProgram> {
128128
public:
129-
QMoEFinalMixProgram() : Program<QMoEFinalMixProgram>{"QMoEFinalMix"} {}
129+
FusedFinalMix1TokenProgram() : Program<FusedFinalMix1TokenProgram>{"QmoeFusedFinalMix1Token"} {}
130130

131131
Status GenerateShaderCode(ShaderHelper& shader) const override {
132132
shader.AddInput("fc2_outputs", ShaderUsage::UseElementTypeAlias);
133133
shader.AddInput("router_values", ShaderUsage::UseElementTypeAlias);
134-
shader.AddInput("expert_tokens", ShaderUsage::UseElementTypeAlias);
134+
shader.AddInput("indirect_experts", ShaderUsage::UseElementTypeAlias);
135135
shader.AddOutput("output", ShaderUsage::UseElementTypeAlias);
136-
137-
return WGSL_TEMPLATE_APPLY(shader, "moe/final_mix.wgsl.template");
136+
return WGSL_TEMPLATE_APPLY(shader, "moe/fused_final_mix_1token.wgsl.template");
138137
}
139138

140139
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
141140
{"hidden_size", ProgramUniformVariableDataType::Uint32},
142-
{"num_experts", ProgramUniformVariableDataType::Uint32},
143-
{"expert_idx", ProgramUniformVariableDataType::Uint32},
144-
{"token_offset", ProgramUniformVariableDataType::Uint32});
145-
146-
private:
141+
{"k", ProgramUniformVariableDataType::Uint32});
147142
};
148143

149-
class QMoEFinalMix1TokenProgram final : public Program<QMoEFinalMix1TokenProgram> {
144+
class QMoEFinalMixProgram final : public Program<QMoEFinalMixProgram> {
150145
public:
151-
QMoEFinalMix1TokenProgram() : Program<QMoEFinalMix1TokenProgram>{"QMoEFinalMix1TokenProgram"} {}
146+
QMoEFinalMixProgram() : Program<QMoEFinalMixProgram>{"QMoEFinalMix"} {}
152147

153148
Status GenerateShaderCode(ShaderHelper& shader) const override {
154149
shader.AddInput("fc2_outputs", ShaderUsage::UseElementTypeAlias);
155150
shader.AddInput("router_values", ShaderUsage::UseElementTypeAlias);
156-
shader.AddInput("indirect_experts", ShaderUsage::UseElementTypeAlias);
151+
shader.AddInput("expert_tokens", ShaderUsage::UseElementTypeAlias);
157152
shader.AddOutput("output", ShaderUsage::UseElementTypeAlias);
158153

159-
return WGSL_TEMPLATE_APPLY(shader, "moe/final_mix_1token.wgsl.template");
154+
return WGSL_TEMPLATE_APPLY(shader, "moe/final_mix.wgsl.template");
160155
}
161156

162157
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
163158
{"hidden_size", ProgramUniformVariableDataType::Uint32},
164-
{"expert_idx", ProgramUniformVariableDataType::Uint32});
159+
{"num_experts", ProgramUniformVariableDataType::Uint32},
160+
{"expert_idx", ProgramUniformVariableDataType::Uint32},
161+
{"token_offset", ProgramUniformVariableDataType::Uint32});
165162

166163
private:
167164
};
@@ -235,87 +232,97 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
235232
Status status;
236233

237234
Tensor* output_tensor = context.Output(0, input_shape);
238-
const int total_output_size = (static_cast<int>(input_shape.Size()) + 3) / 4;
239-
240-
// we are accumulating expert results into output_tensor, need to initialize to zero
241-
ZeroTensorProgram zero;
242-
zero
243-
.AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, ProgramOutput::Flatten, 4})
244-
.SetDispatchGroupSize((total_output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
245-
.AddUniformVariables({static_cast<uint32_t>(total_output_size)});
246-
ORT_RETURN_IF_ERROR(context.RunProgram(zero));
247235

248236
if (moe_params.num_rows == 1) {
249-
// Optimized code path for 1 token to avoid gpu -> cpu copy
237+
// Fused MoE path for 1 token: instead of looping k times with separate dispatches,
238+
// run a single batched MatMulNBits with M=k where each row uses a different expert's
239+
// weights via weight_index_indirect. A's single row is broadcast to all k rows.
240+
// This reduces dispatches from 1 + k*4 = 17 to 5 (gate + fc1 + swiglu + fc2 + mix).
250241

251-
const int num_tokens = 1;
242+
const uint32_t k = static_cast<uint32_t>(k_);
243+
const uint32_t num_tokens = 1;
252244
TensorShape gate_value_shape({num_tokens, num_experts});
253-
TensorShape indirect_experts_shape({k_});
245+
TensorShape indirect_experts_shape({k});
254246

255247
Tensor router_values = context.CreateGPUTensor(dtype, gate_value_shape);
256248
Tensor indirect_experts = context.CreateGPUTensor(dtype_uint32, indirect_experts_shape);
257249

250+
// Step 1: Gate — select top-k experts
258251
Gate1TokenProgram gate{k_, is_fp16};
259252
gate
260253
.AddInputs({{router_logits, ProgramTensorMetadataDependency::Type}})
261254
.AddOutput({&router_values, ProgramTensorMetadataDependency::None})
262255
.AddOutput({&indirect_experts, ProgramTensorMetadataDependency::None})
263256
.SetWorkgroupSize(num_experts)
264-
.SetDispatchGroupSize(static_cast<uint32_t>(num_tokens))
265-
.AddUniformVariables({static_cast<uint32_t>(num_tokens), num_experts})
257+
.SetDispatchGroupSize(num_tokens)
258+
.AddUniformVariables({num_tokens, num_experts})
266259
.CacheHint(k_, is_fp16 ? "fp16" : "fp32");
267-
268260
ORT_RETURN_IF_ERROR(context.RunProgram(gate));
269261

270-
for (uint32_t expert_idx = 0; expert_idx < static_cast<uint32_t>(k_); expert_idx++) {
271-
TensorShape fc1_output_shape({num_tokens, fc1_output_size});
272-
Tensor fc1_outputs = context.CreateGPUTensor(dtype, fc1_output_shape);
273-
TensorShape fc1_activated_shape({num_tokens, moe_params.inter_size});
274-
Tensor fc1_activated = context.CreateGPUTensor(dtype, fc1_activated_shape);
275-
TensorShape fc2_output_shape({num_tokens, N_fc2});
276-
Tensor fc2_outputs = context.CreateGPUTensor(dtype, fc2_output_shape);
277-
278-
status = ApplyMatMulNBits(hidden_state, fc1_experts_weights, fc1_scales, nullptr, fc1_experts_bias_optional,
279-
K_fc1, N_fc1, block_size_fc1, accuracy_level, expert_weight_bits_, context,
280-
&fc1_outputs, expert_idx, &indirect_experts);
281-
ORT_RETURN_IF_ERROR(status);
282-
283-
if (is_swiglu) {
284-
SwigLuProgram swiglu;
285-
swiglu
286-
.AddInputs({{&fc1_outputs, ProgramTensorMetadataDependency::Type, 2}})
287-
.AddOutput({&fc1_activated, ProgramTensorMetadataDependency::None})
288-
.SetWorkgroupSize(128)
289-
.SetDispatchGroupSize(((num_tokens * static_cast<uint32_t>(moe_params.inter_size)) + 127) / 128)
290-
.AddUniformVariables({static_cast<uint32_t>(num_tokens),
291-
static_cast<uint32_t>(moe_params.inter_size),
292-
activation_alpha_,
293-
activation_beta_,
294-
swiglu_limit_});
295-
ORT_RETURN_IF_ERROR(context.RunProgram(swiglu));
296-
} else {
297-
ORT_THROW("only swiglu is supported for WebGPU.");
298-
}
299-
300-
status = ApplyMatMulNBits(&fc1_activated, fc2_experts_weights, fc2_scales, nullptr, fc2_experts_bias_optional,
301-
K_fc2, N_fc2, block_size_fc2, accuracy_level, expert_weight_bits_, context,
302-
&fc2_outputs, expert_idx, &indirect_experts);
303-
ORT_RETURN_IF_ERROR(status);
262+
// Step 2: Batched fc1 MatMulNBits with M=k, per-row expert selection.
263+
// A is (1, hidden_size) but dispatched with override_M=k; shader broadcasts A row 0.
264+
TensorShape fc1_output_shape({static_cast<int64_t>(k), fc1_output_size});
265+
Tensor fc1_outputs = context.CreateGPUTensor(dtype, fc1_output_shape);
266+
status = ApplyMatMulNBits(hidden_state, fc1_experts_weights, fc1_scales, nullptr, fc1_experts_bias_optional,
267+
K_fc1, N_fc1, block_size_fc1, accuracy_level, expert_weight_bits_, context,
268+
&fc1_outputs, 0, &indirect_experts, /*override_M=*/k);
269+
ORT_RETURN_IF_ERROR(status);
270+
271+
// Step 3: SwiGLU on all k rows at once
272+
TensorShape fc1_activated_shape({static_cast<int64_t>(k), moe_params.inter_size});
273+
Tensor fc1_activated = context.CreateGPUTensor(dtype, fc1_activated_shape);
274+
if (is_swiglu) {
275+
SwigLuProgram swiglu;
276+
swiglu
277+
.AddInputs({{&fc1_outputs, ProgramTensorMetadataDependency::Type, 2}})
278+
.AddOutput({&fc1_activated, ProgramTensorMetadataDependency::None})
279+
.SetWorkgroupSize(128)
280+
.SetDispatchGroupSize(((k * static_cast<uint32_t>(moe_params.inter_size)) + 127) / 128)
281+
.AddUniformVariables({k,
282+
static_cast<uint32_t>(moe_params.inter_size),
283+
activation_alpha_,
284+
activation_beta_,
285+
swiglu_limit_});
286+
ORT_RETURN_IF_ERROR(context.RunProgram(swiglu));
287+
} else {
288+
ORT_THROW("only swiglu is supported for WebGPU.");
289+
}
304290

305-
QMoEFinalMix1TokenProgram final_mix;
306-
final_mix
307-
.AddInputs({{&fc2_outputs, ProgramTensorMetadataDependency::Type}})
308-
.AddInputs({{&router_values, ProgramTensorMetadataDependency::Type}})
309-
.AddInputs({{&indirect_experts, ProgramTensorMetadataDependency::Type}})
310-
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None})
311-
.SetDispatchGroupSize(1)
312-
.AddUniformVariables({hidden_size, expert_idx});
291+
// Step 4: Batched fc2 MatMulNBits with M=k, per-row expert selection
292+
// fc1_activated already has k rows (one per expert), no override_M needed.
293+
TensorShape fc2_output_shape({static_cast<int64_t>(k), N_fc2});
294+
Tensor fc2_outputs = context.CreateGPUTensor(dtype, fc2_output_shape);
295+
status = ApplyMatMulNBits(&fc1_activated, fc2_experts_weights, fc2_scales, nullptr, fc2_experts_bias_optional,
296+
K_fc2, N_fc2, block_size_fc2, accuracy_level, expert_weight_bits_, context,
297+
&fc2_outputs, 0, &indirect_experts, /*override_M=*/0);
298+
ORT_RETURN_IF_ERROR(status);
299+
300+
// Step 5: Fused FinalMix — accumulate all k expert results weighted by router_values
301+
// Dispatch across hidden_size (not k) to avoid race: each thread accumulates all k experts.
302+
const uint32_t mix_wg_size = 256;
303+
FusedFinalMix1TokenProgram final_mix;
304+
final_mix
305+
.AddInputs({{&fc2_outputs, ProgramTensorMetadataDependency::Type}})
306+
.AddInputs({{&router_values, ProgramTensorMetadataDependency::Type}})
307+
.AddInputs({{&indirect_experts, ProgramTensorMetadataDependency::Type}})
308+
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None})
309+
.SetWorkgroupSize(mix_wg_size)
310+
.SetDispatchGroupSize((hidden_size + mix_wg_size - 1) / mix_wg_size)
311+
.AddUniformVariables({hidden_size, k});
312+
ORT_RETURN_IF_ERROR(context.RunProgram(final_mix));
313313

314-
ORT_RETURN_IF_ERROR(context.RunProgram(final_mix));
315-
}
316314
return Status::OK();
317315
}
318316

317+
// Multi-token path: accumulates into output_tensor, need to initialize to zero.
318+
const int total_output_size = (static_cast<int>(input_shape.Size()) + 3) / 4;
319+
ZeroTensorProgram zero;
320+
zero
321+
.AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, ProgramOutput::Flatten, 4})
322+
.SetDispatchGroupSize((total_output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
323+
.AddUniformVariables({static_cast<uint32_t>(total_output_size)});
324+
ORT_RETURN_IF_ERROR(context.RunProgram(zero));
325+
319326
// path for num_tokens > 1
320327
// process tokens in chunks of max_tokens to put some cap on memory usage
321328
for (int token_offset = 0; token_offset < moe_params.num_rows; token_offset += max_tokens) {

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
7272
ORT_ENFORCE(tile_size_ % sub_tile_count == 0, "tile_size_ must be divisible by sub_tile_count");
7373

7474
return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul_small_m.wgsl.template",
75+
WGSL_TEMPLATE_PARAMETER(broadcast_a_row, broadcast_a_row_),
7576
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
7677
WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_),
7778
WGSL_TEMPLATE_PARAMETER(has_weight_idx_indirect, has_weight_idx_indirect_),
@@ -93,6 +94,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
9394
const Tensor* zero_points, const Tensor* bias,
9495
uint32_t batch_count,
9596
uint32_t M,
97+
uint32_t dispatch_M,
9698
uint32_t N,
9799
uint32_t K,
98100
uint32_t block_size,
@@ -124,25 +126,26 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
124126

125127
const bool has_zero_points = zero_points != nullptr;
126128
const bool has_bias = bias != nullptr;
127-
const bool has_weight_idx = weight_index != 0;
128129
const bool has_weight_idx_indirect = weight_index_indirect != nullptr;
130+
const bool has_weight_idx = weight_index != 0 || has_weight_idx_indirect;
129131
const bool single_scale_weights = (block_size == K * N);
130-
if (M < min_M_for_tile_optimization) {
132+
if (has_weight_idx_indirect || M < min_M_for_tile_optimization) {
131133
uint32_t tile_size_k_vec = 32;
132134
uint32_t tile_size_n = 4;
133135

134136
const uint32_t b_components = (nbits == 2 ? kVec2Components : kVec4Components);
135-
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights};
137+
const bool broadcast_a = dispatch_M > M;
138+
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights, broadcast_a};
136139
uint32_t num_N_tile = (N + tile_size_n - 1) / tile_size_n;
137140
mul_program.SetWorkgroupSize(128);
138-
mul_program.SetDispatchGroupSize(batch_count * M * num_N_tile);
141+
mul_program.SetDispatchGroupSize(batch_count * dispatch_M * num_N_tile);
139142
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
140143
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1},
141144
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(b_components * kU32Components)},
142145
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
143-
.AddUniformVariables({batch_count, M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col, weight_index})
146+
.AddUniformVariables({batch_count, M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col, weight_index, dispatch_M})
144147
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1})
145-
.CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect);
148+
.CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect, broadcast_a);
146149
if (has_zero_points) {
147150
mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
148151
}

0 commit comments

Comments
 (0)