File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -726,6 +726,11 @@ class MoEGeneral(BaseModel):
726726 description = "Whether to pre-fuse MoE weights (w0 and w1) during initialization. "
727727 "This is useful for inference performance in vllm_rpa mode." ,
728728 )
729+ fuse_expert_scales : bool = Field (
730+ False ,
731+ description = "Whether to fuse the expert scaling factors into the expert weights. "
732+ "This can improve inference performance." ,
733+ )
729734
730735
731736class MoEKernels (BaseModel ):
Original file line number Diff line number Diff line change @@ -539,6 +539,14 @@ def __init__(
539539 else :
540540 self .per_expert_scale = None
541541
542+ # Scale the output projection ahead of time during inference for higher generation throughput.
543+ if (
544+ self .per_expert_scale is not None
545+ and self .config .model_call_mode == "inference"
546+ and self .config .fuse_expert_scales
547+ ):
548+ self .wo .value = self .wo .value * self .per_expert_scale .value [:, None , None ]
549+
542550 def _maybe_shard_with_logical (self , inputs , logical_name ):
543551 return maybe_shard_with_logical (
544552 inputs ,
@@ -2242,7 +2250,8 @@ def __call__(
22422250 w0_kernel = jnp .asarray (self .wi_0 [...], self .dtype )
22432251 w1_kernel = jnp .asarray (self .wi_1 [...], self .dtype )
22442252
2245- if self .per_expert_scale is not None :
2253+ # Only apply per expert scales if we have not fused with the out-projections at init time.
2254+ if self .per_expert_scale is not None and cfg .model_call_mode != "inference" and not cfg .fuse_expert_scales :
22462255 wo_kernel = wo_kernel * jnp .asarray (self .per_expert_scale [...], self .dtype )[:, None , None ]
22472256
22482257 if self .wi_0_sparsity_module is not None :
You can’t perform that action at this time.
0 commit comments