-
Notifications
You must be signed in to change notification settings - Fork 981
Expand file tree
/
Copy pathgemma4_norm.py
More file actions
50 lines (38 loc) · 1.74 KB
/
gemma4_norm.py
File metadata and controls
50 lines (38 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# pyre-unsafe
# LICENSE file in the root directory of this source tree.
"""Gemma 4 RMSNorm — self-contained re-implementation.
Numerically identical to ``transformers.models.gemma4.modeling_gemma4.Gemma4RMSNorm``
(same float32 upcast and ``pow(mean_squared, -0.5)`` normalization), but
without the transformers import so this module is exportable and dep-light.
"""
from functools import partial
import torch
from torch import nn
class RMSNorm(nn.Module):
"""Gemma4 RMSNorm: ``y = (x / rms(x)) * weight``, computed in float32.
Unlike Gemma 2/3 (``(1 + weight)``) Gemma 4 multiplies by ``weight`` directly.
Pass ``with_scale=False`` for the v-norm and the (unused-here) router norm,
which omit the learnable weight entirely.
"""
def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
super().__init__()
self.eps = eps
self.with_scale = with_scale
if with_scale:
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor) -> torch.Tensor:
# Match transformers' use of pow(mean_squared, -0.5) over rsqrt;
# the comment there cites Torch/JAX compiler differences.
mean_squared = x.pow(2).mean(-1, keepdim=True) + self.eps
return x * torch.pow(mean_squared, -0.5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
normed = self._norm(x.float())
if self.with_scale:
normed = normed * self.weight.float()
return normed.type_as(x)
# V-norm in attention uses RMSNorm without learnable weight.
RMSNormNoWeight = partial(RMSNorm, with_scale=False)