1010
1111from modalities .config .lookup_enum import LookupEnum
1212from modalities .config .utils import convert_base_model_config_to_dict
13- from modalities .models .components .layer_norms import LayerNormConfig , RMSLayerNorm , RMSLayerNormConfig
13+ from modalities .models .components .layer_norms import (
14+ LayerNormConfig ,
15+ PytorchRMSLayerNormConfig ,
16+ RMSLayerNorm ,
17+ RMSLayerNormConfig ,
18+ )
1419from modalities .models .model import ActivationType , NNModel , SwiGLU
1520from modalities .util import parse_enum_by_name
1621
@@ -33,15 +38,17 @@ class LayerNorms(LookupEnum):
3338 Attributes:
3439 RMSNorm: RMSLayerNorm class.
3540 LayerNorm: nn.LayerNorm class.
41+ PyTorchRMSNorm: nn.RMSNorm class.
3642 """
3743
3844 rms_norm = RMSLayerNorm
3945 layer_norm = nn .LayerNorm
46+ pytorch_rms_norm = nn .RMSNorm
4047
4148
4249class LayerNormWrapperConfig (BaseModel ):
4350 norm_type : LayerNorms
44- config : LayerNormConfig | RMSLayerNormConfig
51+ config : PytorchRMSLayerNormConfig | RMSLayerNormConfig | LayerNormConfig
4552
4653
4754class PositionTypes (str , Enum ):
@@ -292,6 +299,7 @@ def parse_sharding_strategy_by_name(cls, name):
292299 config : RotaryTransformConfig | IdentityTransformConfig
293300
294301 qkv_transforms : list [QueryKeyValueTransformConfig ]
302+ qk_norm_config : Optional [LayerNormWrapperConfig ] = None
295303
296304
297305class GPT2LLMConfig (BaseModel ):
@@ -461,6 +469,23 @@ def __init__(
461469 for transform_config in attention_config .qkv_transforms
462470 )
463471
472+ # QK Norm - helpful for models >1B to stabilize training
473+ # Baseline logits w/o qk norm: (Q @ K^T) / sqrt(d_h)
474+ # with geometric form of dot product: (||q_i|| * ||k_j|| * cos(θ_ij)) / sqrt(d_h)
475+ # so if the model wants to increase the distance between logits
476+ # it needs to scale q or k OR adjust the angle between them
477+ # qk norm forces the model to mostly adjust the angle between q and k which stabilizes training
478+ if attention_config .attention_config is not None :
479+ self .q_norm = attention_config .qk_norm_config .norm_type .value (
480+ ** dict (attention_config .qk_norm_config .config )
481+ )
482+ self .k_norm = attention_config .qk_norm_config .norm_type .value (
483+ ** dict (attention_config .qk_norm_config .config )
484+ )
485+ else :
486+ self .q_norm = None
487+ self .k_norm = None
488+
464489 def projection (self , x : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
465490 """
466491 Applies projections to the input tensor to get queries, keys, and values.
@@ -632,6 +657,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
632657
633658 # q: (B, nh_q, T, hd), k: (B, nh_kv, T, hd), v: (B, nh_kv, T, hd)
634659 q , k , v = CausalSelfAttention .execute_qkv_transforms (q , k , v , self .qkv_transforms , self .n_head_q )
660+ if self .q_norm is not None and self .k_norm is not None :
661+ q = self .q_norm (q )
662+ k = self .k_norm (k )
635663 y = CausalSelfAttention .execute_attention (q , k , v , self .dropout , self .attention_impl ) # (B, T, nh_q, hd)
636664 y = y .reshape (B , T , - 1 ) # (B, T, n_embd), re-assemble all head outputs side by side
637665 return self .resid_dropout (self .c_proj (y )) # (B, T, n_embd), output projection
0 commit comments