Skip to content

Commit a66ff80

Browse files
committed
add layer norm weight plus 1
1 parent bcedecd commit a66ff80

3 files changed

Lines changed: 45 additions & 4 deletions

File tree

megatron/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from apex.normalization import MixedFusedRMSNorm as RMSNorm
77
else:
88
from .rmsnorm import RMSNorm
9-
from torch.nn import LayerNorm
9+
from .layer_norm_p1 import LayerNorm1P as LayerNorm
1010

1111
from .distributed import DistributedDataParallel
1212
from .bert_model import BertModel

megatron/model/layer_norm_p1.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import math
2+
import numbers
3+
4+
import torch
5+
import torch.nn as nn
6+
from torch.nn.parameter import Parameter
7+
from torch.nn import init
8+
9+
10+
class LayerNorm1P(torch.nn.Module):
11+
def __init__(self, normalized_shape, eps=1e-5, apply_layernorm_1p=False):
12+
super(LayerNorm1P, self).__init__()
13+
self.eps = eps
14+
self.apply_layernorm_1p = apply_layernorm_1p
15+
16+
if isinstance(normalized_shape, numbers.Integral):
17+
normalized_shape = (normalized_shape,)
18+
self.normalized_shape = torch.Size(normalized_shape)
19+
self.weight = Parameter(torch.Tensor(*normalized_shape))
20+
self.bias = Parameter(torch.Tensor(*normalized_shape))
21+
self.reset_parameters()
22+
23+
def reset_parameters(self):
24+
25+
if self.apply_layernorm_1p:
26+
init.zeros_(self.weight)
27+
init.zeros_(self.bias)
28+
else:
29+
init.ones_(self.weight)
30+
init.zeros_(self.bias)
31+
32+
def forward(self, input):
33+
if self.apply_layernorm_1p:
34+
weight_plus_1 = (self.weight + 1)
35+
output = torch.nn.functional.layer_norm(input, self.normalized_shape, weight_plus_1, self.bias, self.eps)
36+
return output
37+
else:
38+
return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)

megatron/model/transformer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,8 @@ def __init__(self, config,
913913
else:
914914
self.input_layernorm = LayerNorm(
915915
config.hidden_size,
916-
eps=config.layernorm_epsilon)
916+
eps=config.layernorm_epsilon,
917+
apply_layernorm_1p=args.apply_layernorm_1p)
917918
else:
918919
self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
919920
# Self attention.
@@ -939,7 +940,8 @@ def __init__(self, config,
939940
else:
940941
self.post_attention_layernorm = LayerNorm(
941942
config.hidden_size,
942-
eps=config.layernorm_epsilon)
943+
eps=config.layernorm_epsilon,
944+
apply_layernorm_1p=args.apply_layernorm_1p)
943945
else:
944946
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
945947
# Cross attention.
@@ -1762,7 +1764,8 @@ def build_layer(layer_number, n_e):
17621764
else:
17631765
self.final_layernorm = LayerNorm(
17641766
config.hidden_size,
1765-
eps=config.layernorm_epsilon)
1767+
eps=config.layernorm_epsilon,
1768+
apply_layernorm_1p=args.apply_layernorm_1p)
17661769
else:
17671770
self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
17681771

0 commit comments

Comments
 (0)