Skip to content

Commit 2880e1a

Browse files
authored
Merge pull request #414 from CYHSM/qk_norm
Add QK norm to Causal Self Attention Block
2 parents 34786d8 + 5fed18e commit 2880e1a

4 files changed

Lines changed: 96 additions & 10 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ requires-python = ">=3.10,<3.13"
55
description = "Modalities, a PyTorch-native framework for distributed and reproducible foundation model training."
66
readme = "README.md"
77
dependencies = [
8-
"numpy<2.0",
8+
"numpy",
9+
"torch",
910
"packaging",
1011
"tqdm",
1112
"pyyaml",

src/modalities/models/components/layer_norms.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,13 @@ class RMSLayerNorm(nn.Module):
1111
def __init__(self, ndim: int, bias: bool = True, epsilon: float = 1e-5):
1212
"""
1313
Initializes a LayerNorm module.
14-
1514
Args:
1615
ndim (int): The number of dimensions of the input tensor.
1716
bias (bool, optional): If True, adds a learnable bias to the normalized tensor. Defaults to True.
1817
epsilon (float, optional): A small value added to the denominator for numerical stability. Defaults to 1e-5.
19-
2018
Note:
2119
Original paper: https://arxiv.org/pdf/1910.07467.pdf
2220
Source code adopted from https://github.com/facebookresearch/llama/blob/a0a4da8b497c566403941ceec47c2512ecf9dd20/llama/model.py#L34C1-L77C36
23-
2421
Returns:
2522
None
2623
"""
@@ -41,13 +38,10 @@ def _norm(self, x: torch.Tensor) -> torch.Tensor:
4138
def forward(self, x: torch.Tensor) -> torch.Tensor:
4239
"""
4340
Forward pass of the layer normalization module.
44-
4541
Args:
4642
x (torch.Tensor): Input tensor.
47-
4843
Returns:
4944
torch.Tensor: Output tensor after applying layer normalization.
50-
5145
"""
5246
output = self._norm(x.float()).type_as(x)
5347
if self.bias is None:
@@ -97,3 +91,16 @@ class RMSLayerNormConfig(BaseModel):
9791
ndim: Annotated[int, Field(strict=True, ge=1)]
9892
epsilon: Annotated[float, Field(gt=0, default=1e-6)]
9993
bias: Annotated[bool, Field(strict=True, default=True)]
94+
95+
96+
class PytorchRMSLayerNormConfig(BaseModel):
97+
"""
98+
Configuration class for RMSLayerNorm.
99+
100+
Args:
101+
normalized_shape (int): The expected size of the input shape.
102+
eps (float, optional): Small value added to the input to avoid division by zero. Defaults to 1e-5.
103+
"""
104+
105+
normalized_shape: Annotated[int, Field(strict=True, ge=1)]
106+
eps: Annotated[float, Field(strict=True, gt=0, default=1e-5)]

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
from modalities.config.lookup_enum import LookupEnum
1212
from modalities.config.utils import convert_base_model_config_to_dict
13-
from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig
13+
from modalities.models.components.layer_norms import (
14+
LayerNormConfig,
15+
PytorchRMSLayerNormConfig,
16+
RMSLayerNorm,
17+
RMSLayerNormConfig,
18+
)
1419
from modalities.models.model import ActivationType, NNModel, SwiGLU
1520
from modalities.util import parse_enum_by_name
1621

@@ -33,15 +38,17 @@ class LayerNorms(LookupEnum):
3338
Attributes:
3439
RMSNorm: RMSLayerNorm class.
3540
LayerNorm: nn.LayerNorm class.
41+
PyTorchRMSNorm: nn.RMSNorm class.
3642
"""
3743

3844
rms_norm = RMSLayerNorm
3945
layer_norm = nn.LayerNorm
46+
pytorch_rms_norm = nn.RMSNorm
4047

4148

4249
class LayerNormWrapperConfig(BaseModel):
4350
norm_type: LayerNorms
44-
config: LayerNormConfig | RMSLayerNormConfig
51+
config: PytorchRMSLayerNormConfig | RMSLayerNormConfig | LayerNormConfig
4552

4653

4754
class PositionTypes(str, Enum):
@@ -292,6 +299,7 @@ def parse_sharding_strategy_by_name(cls, name):
292299
config: RotaryTransformConfig | IdentityTransformConfig
293300

294301
qkv_transforms: list[QueryKeyValueTransformConfig]
302+
qk_norm_config: Optional[LayerNormWrapperConfig] = None
295303

296304

297305
class GPT2LLMConfig(BaseModel):
@@ -461,6 +469,23 @@ def __init__(
461469
for transform_config in attention_config.qkv_transforms
462470
)
463471

472+
# QK Norm - helpful for models >1B to stabilize training
473+
# Baseline logits w/o qk norm: (Q @ K^T) / sqrt(d_h)
474+
# with geometric form of dot product: (||q_i|| * ||k_j|| * cos(θ_ij)) / sqrt(d_h)
475+
# so if the model wants to increase the distance between logits
476+
# it needs to scale q or k OR adjust the angle between them
477+
# qk norm forces the model to mostly adjust the angle between q and k which stabilizes training
478+
if attention_config.attention_config is not None:
479+
self.q_norm = attention_config.qk_norm_config.norm_type.value(
480+
**dict(attention_config.qk_norm_config.config)
481+
)
482+
self.k_norm = attention_config.qk_norm_config.norm_type.value(
483+
**dict(attention_config.qk_norm_config.config)
484+
)
485+
else:
486+
self.q_norm = None
487+
self.k_norm = None
488+
464489
def projection(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
465490
"""
466491
Applies projections to the input tensor to get queries, keys, and values.
@@ -632,6 +657,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
632657

633658
# q: (B, nh_q, T, hd), k: (B, nh_kv, T, hd), v: (B, nh_kv, T, hd)
634659
q, k, v = CausalSelfAttention.execute_qkv_transforms(q, k, v, self.qkv_transforms, self.n_head_q)
660+
if self.q_norm is not None and self.k_norm is not None:
661+
q = self.q_norm(q)
662+
k = self.k_norm(k)
635663
y = CausalSelfAttention.execute_attention(q, k, v, self.dropout, self.attention_impl) # (B, T, nh_q, hd)
636664
y = y.reshape(B, T, -1) # (B, T, n_embd), re-assemble all head outputs side by side
637665
return self.resid_dropout(self.c_proj(y)) # (B, T, n_embd), output projection

tests/models/test_causal_self_attention.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
import pytest
88
import torch
99

10-
from modalities.models.gpt2.gpt2_model import AttentionConfig, CausalSelfAttention
10+
from modalities.models.gpt2.gpt2_model import (
11+
AttentionConfig,
12+
CausalSelfAttention,
13+
LayerNorms,
14+
LayerNormWrapperConfig,
15+
PytorchRMSLayerNormConfig,
16+
)
1117

1218
torch.manual_seed(0)
1319

@@ -222,3 +228,47 @@ def test_attention_implementation_approximate_equality(
222228
atol=2.5e-3, # default for bfloat16: 1e-5
223229
rtol=0.016, # default for bfloat16: 0.016
224230
)
231+
232+
233+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This test requires 1 GPU.")
234+
@pytest.mark.parametrize(
235+
"n_head_q, n_head_kv, n_embd, attention_impl",
236+
[
237+
(4, 4, 32, "manual"),
238+
(8, 2, 32, "manual"),
239+
(4, 4, 32, "pytorch_flash"),
240+
(8, 2, 32, "pytorch_flash"),
241+
(4, 4, 32, "dao_flash"),
242+
(8, 2, 32, "dao_flash"),
243+
],
244+
)
245+
def test_qk_norm(n_head_q, n_head_kv, n_embd, attention_impl):
246+
batch_size = 2
247+
block_size = 10
248+
head_dim = n_embd // n_head_q
249+
embedding_shape = (batch_size, block_size - 1, n_embd)
250+
embedded_input_seq = _get_random_input_seq(embedding_shape)
251+
252+
attention_config_no_norm = AttentionConfig(qkv_transforms=[], use_qk_norm=False)
253+
attention_config_with_norm = AttentionConfig(
254+
qkv_transforms=[],
255+
use_qk_norm=True,
256+
qk_norm_config=LayerNormWrapperConfig(
257+
norm_type=LayerNorms.pytorch_rms_norm, config=PytorchRMSLayerNormConfig(normalized_shape=head_dim)
258+
),
259+
)
260+
261+
# Create two separate layers with same initial weights
262+
torch.manual_seed(0)
263+
layer_no_norm = _get_random_attention_layer(n_head_q, n_head_kv, n_embd, attention_impl, attention_config_no_norm)
264+
265+
torch.manual_seed(0)
266+
layer_with_norm = _get_random_attention_layer(
267+
n_head_q, n_head_kv, n_embd, attention_impl, attention_config_with_norm
268+
)
269+
270+
output_no_norm = layer_no_norm(embedded_input_seq)
271+
output_with_norm = layer_with_norm(embedded_input_seq)
272+
273+
assert output_no_norm.shape == output_with_norm.shape == embedding_shape
274+
assert not torch.allclose(output_no_norm, output_with_norm, atol=1e-6)

0 commit comments

Comments
 (0)