@@ -124,44 +124,41 @@ class SwigLuProgram final : public Program<SwigLuProgram> {
124124 private:
125125};
126126
127- class QMoEFinalMixProgram final : public Program<QMoEFinalMixProgram > {
127+ class FusedFinalMix1TokenProgram final : public Program<FusedFinalMix1TokenProgram > {
128128 public:
129- QMoEFinalMixProgram () : Program<QMoEFinalMixProgram >{" QMoEFinalMix " } {}
129+ FusedFinalMix1TokenProgram () : Program<FusedFinalMix1TokenProgram >{" QmoeFusedFinalMix1Token " } {}
130130
131131 Status GenerateShaderCode (ShaderHelper& shader) const override {
132132 shader.AddInput (" fc2_outputs" , ShaderUsage::UseElementTypeAlias);
133133 shader.AddInput (" router_values" , ShaderUsage::UseElementTypeAlias);
134- shader.AddInput (" expert_tokens " , ShaderUsage::UseElementTypeAlias);
134+ shader.AddInput (" indirect_experts " , ShaderUsage::UseElementTypeAlias);
135135 shader.AddOutput (" output" , ShaderUsage::UseElementTypeAlias);
136-
137- return WGSL_TEMPLATE_APPLY (shader, " moe/final_mix.wgsl.template" );
136+ return WGSL_TEMPLATE_APPLY (shader, " moe/fused_final_mix_1token.wgsl.template" );
138137 }
139138
140139 WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES (
141140 {" hidden_size" , ProgramUniformVariableDataType::Uint32},
142- {" num_experts" , ProgramUniformVariableDataType::Uint32},
143- {" expert_idx" , ProgramUniformVariableDataType::Uint32},
144- {" token_offset" , ProgramUniformVariableDataType::Uint32});
145-
146- private:
141+ {" k" , ProgramUniformVariableDataType::Uint32});
147142};
148143
149- class QMoEFinalMix1TokenProgram final : public Program<QMoEFinalMix1TokenProgram > {
144+ class QMoEFinalMixProgram final : public Program<QMoEFinalMixProgram > {
150145 public:
151- QMoEFinalMix1TokenProgram () : Program<QMoEFinalMix1TokenProgram >{" QMoEFinalMix1TokenProgram " } {}
146+ QMoEFinalMixProgram () : Program<QMoEFinalMixProgram >{" QMoEFinalMix " } {}
152147
153148 Status GenerateShaderCode (ShaderHelper& shader) const override {
154149 shader.AddInput (" fc2_outputs" , ShaderUsage::UseElementTypeAlias);
155150 shader.AddInput (" router_values" , ShaderUsage::UseElementTypeAlias);
156- shader.AddInput (" indirect_experts " , ShaderUsage::UseElementTypeAlias);
151+ shader.AddInput (" expert_tokens " , ShaderUsage::UseElementTypeAlias);
157152 shader.AddOutput (" output" , ShaderUsage::UseElementTypeAlias);
158153
159- return WGSL_TEMPLATE_APPLY (shader, " moe/final_mix_1token .wgsl.template" );
154+ return WGSL_TEMPLATE_APPLY (shader, " moe/final_mix .wgsl.template" );
160155 }
161156
162157 WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES (
163158 {" hidden_size" , ProgramUniformVariableDataType::Uint32},
164- {" expert_idx" , ProgramUniformVariableDataType::Uint32});
159+ {" num_experts" , ProgramUniformVariableDataType::Uint32},
160+ {" expert_idx" , ProgramUniformVariableDataType::Uint32},
161+ {" token_offset" , ProgramUniformVariableDataType::Uint32});
165162
166163 private:
167164};
@@ -235,87 +232,97 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
235232 Status status;
236233
237234 Tensor* output_tensor = context.Output (0 , input_shape);
238- const int total_output_size = (static_cast <int >(input_shape.Size ()) + 3 ) / 4 ;
239-
240- // we are accumulating expert results into output_tensor, need to initialize to zero
241- ZeroTensorProgram zero;
242- zero
243- .AddOutput ({output_tensor, ProgramTensorMetadataDependency::Type, ProgramOutput::Flatten, 4 })
244- .SetDispatchGroupSize ((total_output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE )
245- .AddUniformVariables ({static_cast <uint32_t >(total_output_size)});
246- ORT_RETURN_IF_ERROR (context.RunProgram (zero));
247235
248236 if (moe_params.num_rows == 1 ) {
249- // Optimized code path for 1 token to avoid gpu -> cpu copy
237+ // Fused MoE path for 1 token: instead of looping k times with separate dispatches,
238+ // run a single batched MatMulNBits with M=k where each row uses a different expert's
239+ // weights via weight_index_indirect. A's single row is broadcast to all k rows.
240+ // This reduces dispatches from 1 + k*4 = 17 to 5 (gate + fc1 + swiglu + fc2 + mix).
250241
251- const int num_tokens = 1 ;
242+ const uint32_t k = static_cast <uint32_t >(k_);
243+ const uint32_t num_tokens = 1 ;
252244 TensorShape gate_value_shape ({num_tokens, num_experts});
253- TensorShape indirect_experts_shape ({k_ });
245+ TensorShape indirect_experts_shape ({k });
254246
255247 Tensor router_values = context.CreateGPUTensor (dtype, gate_value_shape);
256248 Tensor indirect_experts = context.CreateGPUTensor (dtype_uint32, indirect_experts_shape);
257249
250+ // Step 1: Gate — select top-k experts
258251 Gate1TokenProgram gate{k_, is_fp16};
259252 gate
260253 .AddInputs ({{router_logits, ProgramTensorMetadataDependency::Type}})
261254 .AddOutput ({&router_values, ProgramTensorMetadataDependency::None})
262255 .AddOutput ({&indirect_experts, ProgramTensorMetadataDependency::None})
263256 .SetWorkgroupSize (num_experts)
264- .SetDispatchGroupSize (static_cast < uint32_t >( num_tokens) )
265- .AddUniformVariables ({static_cast < uint32_t >( num_tokens) , num_experts})
257+ .SetDispatchGroupSize (num_tokens)
258+ .AddUniformVariables ({num_tokens, num_experts})
266259 .CacheHint (k_, is_fp16 ? " fp16" : " fp32" );
267-
268260 ORT_RETURN_IF_ERROR (context.RunProgram (gate));
269261
270- for (uint32_t expert_idx = 0 ; expert_idx < static_cast <uint32_t >(k_); expert_idx++) {
271- TensorShape fc1_output_shape ({num_tokens, fc1_output_size});
272- Tensor fc1_outputs = context.CreateGPUTensor (dtype, fc1_output_shape);
273- TensorShape fc1_activated_shape ({num_tokens, moe_params.inter_size });
274- Tensor fc1_activated = context.CreateGPUTensor (dtype, fc1_activated_shape);
275- TensorShape fc2_output_shape ({num_tokens, N_fc2});
276- Tensor fc2_outputs = context.CreateGPUTensor (dtype, fc2_output_shape);
277-
278- status = ApplyMatMulNBits (hidden_state, fc1_experts_weights, fc1_scales, nullptr , fc1_experts_bias_optional,
279- K_fc1, N_fc1, block_size_fc1, accuracy_level, expert_weight_bits_, context,
280- &fc1_outputs, expert_idx, &indirect_experts);
281- ORT_RETURN_IF_ERROR (status);
282-
283- if (is_swiglu) {
284- SwigLuProgram swiglu;
285- swiglu
286- .AddInputs ({{&fc1_outputs, ProgramTensorMetadataDependency::Type, 2 }})
287- .AddOutput ({&fc1_activated, ProgramTensorMetadataDependency::None})
288- .SetWorkgroupSize (128 )
289- .SetDispatchGroupSize (((num_tokens * static_cast <uint32_t >(moe_params.inter_size )) + 127 ) / 128 )
290- .AddUniformVariables ({static_cast <uint32_t >(num_tokens),
291- static_cast <uint32_t >(moe_params.inter_size ),
292- activation_alpha_,
293- activation_beta_,
294- swiglu_limit_});
295- ORT_RETURN_IF_ERROR (context.RunProgram (swiglu));
296- } else {
297- ORT_THROW (" only swiglu is supported for WebGPU." );
298- }
299-
300- status = ApplyMatMulNBits (&fc1_activated, fc2_experts_weights, fc2_scales, nullptr , fc2_experts_bias_optional,
301- K_fc2, N_fc2, block_size_fc2, accuracy_level, expert_weight_bits_, context,
302- &fc2_outputs, expert_idx, &indirect_experts);
303- ORT_RETURN_IF_ERROR (status);
262+ // Step 2: Batched fc1 MatMulNBits with M=k, per-row expert selection.
263+ // A is (1, hidden_size) but dispatched with override_M=k; shader broadcasts A row 0.
264+ TensorShape fc1_output_shape ({static_cast <int64_t >(k), fc1_output_size});
265+ Tensor fc1_outputs = context.CreateGPUTensor (dtype, fc1_output_shape);
266+ status = ApplyMatMulNBits (hidden_state, fc1_experts_weights, fc1_scales, nullptr , fc1_experts_bias_optional,
267+ K_fc1, N_fc1, block_size_fc1, accuracy_level, expert_weight_bits_, context,
268+ &fc1_outputs, 0 , &indirect_experts, /* override_M=*/ k);
269+ ORT_RETURN_IF_ERROR (status);
270+
271+ // Step 3: SwiGLU on all k rows at once
272+ TensorShape fc1_activated_shape ({static_cast <int64_t >(k), moe_params.inter_size });
273+ Tensor fc1_activated = context.CreateGPUTensor (dtype, fc1_activated_shape);
274+ if (is_swiglu) {
275+ SwigLuProgram swiglu;
276+ swiglu
277+ .AddInputs ({{&fc1_outputs, ProgramTensorMetadataDependency::Type, 2 }})
278+ .AddOutput ({&fc1_activated, ProgramTensorMetadataDependency::None})
279+ .SetWorkgroupSize (128 )
280+ .SetDispatchGroupSize (((k * static_cast <uint32_t >(moe_params.inter_size )) + 127 ) / 128 )
281+ .AddUniformVariables ({k,
282+ static_cast <uint32_t >(moe_params.inter_size ),
283+ activation_alpha_,
284+ activation_beta_,
285+ swiglu_limit_});
286+ ORT_RETURN_IF_ERROR (context.RunProgram (swiglu));
287+ } else {
288+ ORT_THROW (" only swiglu is supported for WebGPU." );
289+ }
304290
305- QMoEFinalMix1TokenProgram final_mix;
306- final_mix
307- .AddInputs ({{&fc2_outputs, ProgramTensorMetadataDependency::Type}})
308- .AddInputs ({{&router_values, ProgramTensorMetadataDependency::Type}})
309- .AddInputs ({{&indirect_experts, ProgramTensorMetadataDependency::Type}})
310- .AddOutput ({output_tensor, ProgramTensorMetadataDependency::None})
311- .SetDispatchGroupSize (1 )
312- .AddUniformVariables ({hidden_size, expert_idx});
291+ // Step 4: Batched fc2 MatMulNBits with M=k, per-row expert selection
292+ // fc1_activated already has k rows (one per expert), no override_M needed.
293+ TensorShape fc2_output_shape ({static_cast <int64_t >(k), N_fc2});
294+ Tensor fc2_outputs = context.CreateGPUTensor (dtype, fc2_output_shape);
295+ status = ApplyMatMulNBits (&fc1_activated, fc2_experts_weights, fc2_scales, nullptr , fc2_experts_bias_optional,
296+ K_fc2, N_fc2, block_size_fc2, accuracy_level, expert_weight_bits_, context,
297+ &fc2_outputs, 0 , &indirect_experts, /* override_M=*/ 0 );
298+ ORT_RETURN_IF_ERROR (status);
299+
300+ // Step 5: Fused FinalMix — accumulate all k expert results weighted by router_values
301+ // Dispatch across hidden_size (not k) to avoid race: each thread accumulates all k experts.
302+ const uint32_t mix_wg_size = 256 ;
303+ FusedFinalMix1TokenProgram final_mix;
304+ final_mix
305+ .AddInputs ({{&fc2_outputs, ProgramTensorMetadataDependency::Type}})
306+ .AddInputs ({{&router_values, ProgramTensorMetadataDependency::Type}})
307+ .AddInputs ({{&indirect_experts, ProgramTensorMetadataDependency::Type}})
308+ .AddOutput ({output_tensor, ProgramTensorMetadataDependency::None})
309+ .SetWorkgroupSize (mix_wg_size)
310+ .SetDispatchGroupSize ((hidden_size + mix_wg_size - 1 ) / mix_wg_size)
311+ .AddUniformVariables ({hidden_size, k});
312+ ORT_RETURN_IF_ERROR (context.RunProgram (final_mix));
313313
314- ORT_RETURN_IF_ERROR (context.RunProgram (final_mix));
315- }
316314 return Status::OK ();
317315 }
318316
317+ // Multi-token path: accumulates into output_tensor, need to initialize to zero.
318+ const int total_output_size = (static_cast <int >(input_shape.Size ()) + 3 ) / 4 ;
319+ ZeroTensorProgram zero;
320+ zero
321+ .AddOutput ({output_tensor, ProgramTensorMetadataDependency::Type, ProgramOutput::Flatten, 4 })
322+ .SetDispatchGroupSize ((total_output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE )
323+ .AddUniformVariables ({static_cast <uint32_t >(total_output_size)});
324+ ORT_RETURN_IF_ERROR (context.RunProgram (zero));
325+
319326 // path for num_tokens > 1
320327 // process tokens in chunks of max_tokens to put some cap on memory usage
321328 for (int token_offset = 0 ; token_offset < moe_params.num_rows ; token_offset += max_tokens) {
0 commit comments