Skip to content

Commit d67ea01

Browse files
[vLLM IR] rework gemma_rms_norm (vllm-project#39014)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
1 parent 302f30a commit d67ea01

8 files changed

Lines changed: 106 additions & 75 deletions

File tree

tests/kernels/core/test_layernorm.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from tests.kernels.quant_utils import FP8_DTYPE
88
from tests.kernels.utils import opcheck
9-
from vllm.model_executor.layers.layernorm import RMSNorm
9+
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
1010
from vllm.platforms import current_platform
1111
from vllm.utils.torch_utils import set_random_seed
1212

@@ -162,3 +162,31 @@ def test_fused_rms_norm_quant(
162162
atol=1e-3,
163163
rtol=1e-3,
164164
)
165+
166+
167+
@torch.inference_mode()
168+
def test_gemma_rms_norm_mixed_input_weight_dtype(default_vllm_config) -> None:
169+
if not torch.cuda.is_available():
170+
pytest.skip("CUDA required")
171+
172+
device = CUDA_DEVICES[0]
173+
torch.set_default_device(device)
174+
175+
num_tokens, hidden_size = 32, 1024
176+
x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
177+
layer = GemmaRMSNorm(hidden_size, eps=1e-6).to(device=device)
178+
layer.weight.data.normal_(mean=0.0, std=0.1)
179+
180+
# Gemma uses fp32 weight parameter while activations can be bf16.
181+
assert layer.weight.dtype == torch.float32
182+
out = layer(x)
183+
184+
x_fp32 = x.float()
185+
weight_fp32 = layer.weight.data.float() + 1.0
186+
variance = x_fp32.pow(2).mean(dim=-1, keepdim=True)
187+
ref = (x_fp32 * torch.rsqrt(variance + layer.variance_epsilon) * weight_fp32).to(
188+
x.dtype
189+
)
190+
191+
assert out.dtype == x.dtype
192+
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)

vllm/compilation/passes/fusion/allreduce_rms_fusion.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from torch._inductor.pattern_matcher import PatternMatcherPass
1313

1414
import vllm.ir.ops
15+
from vllm.compilation.passes.fusion.rms_quant_fusion import (
16+
_rms_input_weight_dtype_match,
17+
)
1518
from vllm.config import VllmConfig
1619
from vllm.config.utils import Range
1720
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
@@ -320,7 +323,12 @@ def replacement(
320323
return allreduce[3], allreduce[1]
321324

322325
pm.register_replacement(
323-
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
326+
pattern,
327+
replacement,
328+
self.get_inputs(),
329+
pm.fwd_only,
330+
pm_pass,
331+
extra_check=_rms_input_weight_dtype_match,
324332
)
325333

326334

@@ -459,7 +467,12 @@ def replacement(
459467
return allreduce[4], allreduce[1]
460468

461469
pm.register_replacement(
462-
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
470+
pattern,
471+
replacement,
472+
self.get_inputs(),
473+
pm.fwd_only,
474+
pm_pass,
475+
extra_check=_rms_input_weight_dtype_match,
463476
)
464477

465478

@@ -621,7 +634,12 @@ def replacement(
621634
return allreduce[4], allreduce[1], allreduce[5]
622635

623636
pm.register_replacement(
624-
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
637+
pattern,
638+
replacement,
639+
self.get_inputs(),
640+
pm.fwd_only,
641+
pm_pass,
642+
extra_check=_rms_input_weight_dtype_match,
625643
)
626644

627645

vllm/compilation/passes/fusion/rms_quant_fusion.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,22 @@
3838
FP4_DTYPE = torch.uint8
3939

4040

41+
_RMS_NORM_OP = torch.ops.vllm_ir.rms_norm.default
42+
43+
44+
# TODO: extend rmsnorm quant kernels to support mixed input/weight dtypes,
45+
# and remove this check.
46+
def _rms_input_weight_dtype_match(match: pm.Match) -> bool:
47+
"""Prevent fusion when rms_norm input and weight dtypes differ."""
48+
for node in match.nodes:
49+
if node.target == _RMS_NORM_OP:
50+
# rms_norm(x, weight, epsilon, variance_size)
51+
x, weight = node.args[0], node.args[1]
52+
if isinstance(x, fx.Node) and isinstance(weight, fx.Node):
53+
return x.meta["val"].dtype == weight.meta["val"].dtype
54+
return True
55+
56+
4157
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
4258
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
4359

@@ -186,7 +202,14 @@ def replacement(
186202
]
187203
pattern(*inputs)
188204

189-
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
205+
pm.register_replacement(
206+
pattern,
207+
replacement,
208+
inputs,
209+
pm.fwd_only,
210+
pm_pass,
211+
extra_check=_rms_input_weight_dtype_match,
212+
)
190213

191214

192215
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
@@ -249,6 +272,7 @@ def replacement(
249272
inputs,
250273
pm.fwd_only,
251274
pm_pass,
275+
extra_check=_rms_input_weight_dtype_match,
252276
)
253277

254278

@@ -350,6 +374,7 @@ def replacement(
350374
self.rmsnorm_matcher.inputs() + [scale],
351375
pm.fwd_only,
352376
pm_pass,
377+
extra_check=_rms_input_weight_dtype_match,
353378
)
354379

355380

@@ -445,6 +470,7 @@ def replacement(
445470
],
446471
pm.fwd_only,
447472
pm_pass,
473+
extra_check=_rms_input_weight_dtype_match,
448474
)
449475

450476

@@ -503,6 +529,7 @@ def replacement(
503529
],
504530
pm.fwd_only,
505531
pm_pass,
532+
extra_check=_rms_input_weight_dtype_match,
506533
)
507534

508535

@@ -559,6 +586,7 @@ def replacement(
559586
self.rmsnorm_matcher.inputs(),
560587
pm.fwd_only,
561588
pm_pass,
589+
extra_check=_rms_input_weight_dtype_match,
562590
)
563591

564592

vllm/ir/ops/layernorm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def rms_norm(
1616
x_var = x if variance_size is None else x[..., :variance_size]
1717
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
1818
x = x * torch.rsqrt(variance + epsilon)
19-
x = x.to(orig_dtype)
2019
if weight is not None:
21-
x = x * weight
22-
return x
20+
x = x.to(weight.dtype) * weight
21+
return x.to(orig_dtype)

vllm/kernels/aiter_ops.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,11 @@ def is_aiter_found() -> bool:
3636

3737
rms_no_var_16bit_only = (
3838
lambda x, weight, epsilon, variance_size=None: variance_size is None
39-
and x.dtype
40-
in (
41-
torch.float16,
42-
torch.bfloat16,
43-
)
39+
and x.dtype in (torch.float16, torch.bfloat16)
40+
and (weight is None or weight.dtype == x.dtype)
4441
)
45-
"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override."""
42+
"""AITER rms_norm only supports float16 and bfloat16 acts, no var_size override,
43+
and requires weight dtype to match x dtype."""
4644

4745

4846
@ir.ops.rms_norm.register_impl(

vllm/kernels/vllm_c.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
CUDA_ALIKE = current_platform.is_cuda_alike()
1212
"""Most kernels in this file are supported on all CUDA-alike platforms."""
1313

14-
rms_no_var_size = lambda x, weight, epsilon, variance_size=None: variance_size is None
15-
"""vLLM kernel does not support variance_size parameter."""
14+
rms_no_var_size = (
15+
lambda x, weight, epsilon, variance_size=None: variance_size is None
16+
and (weight is None or weight.dtype == x.dtype)
17+
)
18+
"""vLLM kernel requires no variance_size override and matching input/weight dtype."""
1619

1720

1821
@ir.ops.rms_norm.register_impl(

vllm/kernels/xpu_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ def is_xpu_kernels_found() -> bool:
1818
XPU_KERNELS_SUPPORTED = is_xpu_kernels_found()
1919
"""Kernels in this file are supported if vLLM XPU kernels are installed."""
2020

21-
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None
21+
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None and (
22+
weight is None or weight.dtype == x.dtype
23+
)
2224

2325

2426
@ir.ops.rms_norm.register_impl(

vllm/model_executor/layers/layernorm.py

Lines changed: 13 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -376,77 +376,32 @@ def __init__(
376376
self.weight = nn.Parameter(torch.zeros(hidden_size))
377377
self.variance_epsilon = eps
378378

379-
@staticmethod
380-
def _forward_static_no_residual(
381-
weight: torch.Tensor,
382-
variance_epsilon: float,
383-
x: torch.Tensor,
384-
) -> torch.Tensor:
385-
"""PyTorch-native implementation equivalent to forward() without residual."""
386-
orig_dtype = x.dtype
387-
x = x.float()
388-
variance = x.pow(2).mean(dim=-1, keepdim=True)
389-
x = x * torch.rsqrt(variance + variance_epsilon)
390-
x = x * (1.0 + weight.float())
391-
x = x.to(orig_dtype)
392-
return x
393-
394-
@staticmethod
395-
def _forward_static_with_residual(
396-
weight: torch.Tensor,
397-
variance_epsilon: float,
398-
x: torch.Tensor,
399-
residual: torch.Tensor,
400-
) -> tuple[torch.Tensor, torch.Tensor]:
401-
"""PyTorch-native implementation equivalent to forward() with residual."""
402-
orig_dtype = x.dtype
403-
x = (
404-
x.float() + residual.float()
405-
if orig_dtype == torch.float16
406-
else x + residual
407-
)
408-
residual = x
409-
410-
x = x.float()
411-
variance = x.pow(2).mean(dim=-1, keepdim=True)
412-
x = x * torch.rsqrt(variance + variance_epsilon)
413-
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
414-
# See https://github.com/huggingface/transformers/pull/29402
415-
x = x * (1.0 + weight.float())
416-
x = x.to(orig_dtype)
417-
return x, residual
418-
419379
def forward_native(
420380
self,
421381
x: torch.Tensor,
422382
residual: torch.Tensor | None = None,
423383
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
424384
"""PyTorch-native implementation equivalent to forward()."""
425-
if residual is None:
426-
return self._forward_static_no_residual(
427-
self.weight.data, self.variance_epsilon, x
428-
)
429-
else:
430-
return self._forward_static_with_residual(
431-
self.weight.data, self.variance_epsilon, x, residual
385+
orig_dtype = x.dtype
386+
weight = self.weight.data.float() + 1.0
387+
if residual is not None:
388+
x = (
389+
x.float() + residual.float()
390+
if orig_dtype == torch.float16
391+
else x + residual
432392
)
393+
residual = x
394+
# ir.ops.rms_norm handles fp32 upcast internally
395+
out = ir.ops.rms_norm(x, weight, self.variance_epsilon)
396+
return (
397+
out.to(orig_dtype) if residual is None else (out.to(orig_dtype), residual)
398+
)
433399

434400
def forward_cuda(
435401
self,
436402
x: torch.Tensor,
437403
residual: torch.Tensor | None = None,
438404
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
439-
if torch.compiler.is_compiling():
440-
return self.forward_native(x, residual)
441-
442-
if not getattr(self, "_is_compiled", False):
443-
self._forward_static_no_residual = torch.compile( # type: ignore
444-
self._forward_static_no_residual
445-
)
446-
self._forward_static_with_residual = torch.compile( # type: ignore
447-
self._forward_static_with_residual
448-
)
449-
self._is_compiled = True
450405
return self.forward_native(x, residual)
451406

452407

0 commit comments

Comments
 (0)