Skip to content

Commit 4d9f390

Browse files
Merge pull request #3875 from AI-Hypercomputer:nicogrande/improve-gemma4-vllm-perf
PiperOrigin-RevId: 914403579
2 parents 1aca4a7 + 5b0ae2f commit 4d9f390

2 files changed

Lines changed: 15 additions & 1 deletion

File tree

src/maxtext/configs/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,11 @@ class MoEGeneral(BaseModel):
726726
description="Whether to pre-fuse MoE weights (w0 and w1) during initialization. "
727727
"This is useful for inference performance in vllm_rpa mode.",
728728
)
729+
fuse_expert_scales: bool = Field(
730+
False,
731+
description="Whether to fuse the expert scaling factors into the expert weights. "
732+
"This can improve inference performance.",
733+
)
729734

730735

731736
class MoEKernels(BaseModel):

src/maxtext/layers/moe.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,14 @@ def __init__(
539539
else:
540540
self.per_expert_scale = None
541541

542+
# Scale the output projection ahead of time during inference for higher generation throughput.
543+
if (
544+
self.per_expert_scale is not None
545+
and self.config.model_call_mode == "inference"
546+
and self.config.fuse_expert_scales
547+
):
548+
self.wo.value = self.wo.value * self.per_expert_scale.value[:, None, None]
549+
542550
def _maybe_shard_with_logical(self, inputs, logical_name):
543551
return maybe_shard_with_logical(
544552
inputs,
@@ -2242,7 +2250,8 @@ def __call__(
22422250
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
22432251
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
22442252

2245-
if self.per_expert_scale is not None:
2253+
# Only apply per expert scales if we have not fused with the out-projections at init time.
2254+
if self.per_expert_scale is not None and cfg.model_call_mode != "inference" and not cfg.fuse_expert_scales:
22462255
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]
22472256

22482257
if self.wi_0_sparsity_module is not None:

0 commit comments

Comments
 (0)