1+ import math
2+ from dataclasses import dataclass
3+ from typing import Optional
4+
15import torch
26from torch import nn
37
8+ from torchTextClassifiers .model .components .attention import AttentionConfig , Block , norm
9+
10+
11+ @dataclass
12+ class TextEmbedderConfig :
13+ vocab_size : int
14+ embedding_dim : int
15+ padding_idx : int
16+ attention_config : Optional [AttentionConfig ] = None
17+
418
519class TextEmbedder (nn .Module ):
6- def __init__ (self , vocab_size : int , embedding_dim : int , padding_idx : int ):
20+ def __init__ (self , text_embedder_config : TextEmbedderConfig ):
721 super ().__init__ ()
822
9- self .vocab_size = vocab_size
10- self .embedding_dim = embedding_dim
11- self .padding_idx = padding_idx
23+ self .config = text_embedder_config
24+
25+ self .attention_config = text_embedder_config .attention_config
26+ if self .attention_config is not None :
27+ self .attention_config .n_embd = text_embedder_config .embedding_dim
28+
29+ self .vocab_size = text_embedder_config .vocab_size
30+ self .embedding_dim = text_embedder_config .embedding_dim
31+ self .padding_idx = text_embedder_config .padding_idx
1232
1333 self .embedding_layer = nn .Embedding (
14- embedding_dim = embedding_dim ,
15- num_embeddings = vocab_size ,
34+ embedding_dim = self . embedding_dim ,
35+ num_embeddings = self . vocab_size ,
1636 padding_idx = self .padding_idx ,
1737 )
1838
39+ if self .attention_config is not None :
40+ self .transformer = nn .ModuleDict (
41+ {
42+ "h" : nn .ModuleList (
43+ [
44+ Block (self .attention_config , layer_idx )
45+ for layer_idx in range (self .attention_config .n_layers )
46+ ]
47+ ),
48+ }
49+ )
50+
51+ head_dim = self .attention_config .n_embd // self .attention_config .n_head
52+
53+ if head_dim * self .attention_config .n_head != self .attention_config .n_embd :
54+ raise ValueError ("embedding_dim must be divisible by n_head." )
55+
56+ if self .attention_config .positional_encoding :
57+ if head_dim % 2 != 0 :
58+ raise ValueError (
59+ "embedding_dim / n_head must be even for rotary positional embeddings."
60+ )
61+
62+ if self .attention_config .sequence_len is None :
63+ raise ValueError (
64+ "sequence_len must be specified in AttentionConfig when positional_encoding is True."
65+ )
66+
67+ self .rotary_seq_len = self .attention_config .sequence_len * 10
68+ cos , sin = self ._precompute_rotary_embeddings (
69+ seq_len = self .rotary_seq_len , head_dim = head_dim
70+ )
71+
72+ self .register_buffer (
73+ "cos" , cos , persistent = False
74+ ) # persistent=False means it's not saved to the checkpoint
75+ self .register_buffer ("sin" , sin , persistent = False )
76+
77+ def init_weights (self ):
78+ self .apply (self ._init_weights )
79+
80+ # zero out c_proj weights in all blocks
81+ if self .attention_config is not None :
82+ for block in self .transformer .h :
83+ torch .nn .init .zeros_ (block .mlp .c_proj .weight )
84+ torch .nn .init .zeros_ (block .attn .c_proj .weight )
85+ # init the rotary embeddings
86+ head_dim = self .attention_config .n_embd // self .attention_config .n_head
87+ cos , sin = self ._precompute_rotary_embeddings (self .rotary_seq_len , head_dim )
88+ self .cos , self .sin = cos , sin
89+ # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
90+ if self .embedding_layer .weight .device .type == "cuda" :
91+ self .embedding_layer .to (dtype = torch .bfloat16 )
92+
93+ def _init_weights (self , module ):
94+ if isinstance (module , nn .Linear ):
95+ # https://arxiv.org/pdf/2310.17813
96+ fan_out = module .weight .size (0 )
97+ fan_in = module .weight .size (1 )
98+ std = 1.0 / math .sqrt (fan_in ) * min (1.0 , math .sqrt (fan_out / fan_in ))
99+ torch .nn .init .normal_ (module .weight , mean = 0.0 , std = std )
100+ if module .bias is not None :
101+ torch .nn .init .zeros_ (module .bias )
102+ elif isinstance (module , nn .Embedding ):
103+ torch .nn .init .normal_ (module .weight , mean = 0.0 , std = 1.0 )
104+
19105 def forward (self , input_ids : torch .Tensor , attention_mask : torch .Tensor ) -> torch .Tensor :
20106 """Converts input token IDs to their corresponding embeddings."""
21107
@@ -36,6 +122,19 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torc
36122 encoded_text
37123 ) # (batch_size, seq_len, embedding_dim)
38124
125+ token_embeddings = norm (token_embeddings )
126+
127+ if self .attention_config is not None :
128+ if self .attention_config .positional_encoding :
129+ cos_sin = self .cos [:, :seq_len ], self .sin [:, :seq_len ]
130+ else :
131+ cos_sin = None
132+
133+ for block in self .transformer .h :
134+ token_embeddings = block (token_embeddings , cos_sin )
135+
136+ token_embeddings = norm (token_embeddings )
137+
39138 text_embedding = self ._get_sentence_embedding (
40139 token_embeddings = token_embeddings , attention_mask = attention_mask
41140 )
@@ -57,9 +156,28 @@ def _get_sentence_embedding(
57156
58157 # average over non-pad token embeddings
59158 # attention mask has 1 for non-pad tokens and 0 for pad token positions
60- # TODO: add attention logic at some point
61159
62160 # mask pad-tokens
161+
162+ if self .attention_config is not None :
163+ if self .attention_config .aggregation_method is not None :
164+ if self .attention_config .aggregation_method == "first" :
165+ return token_embeddings [:, 0 , :]
166+ elif self .attention_config .aggregation_method == "last" :
167+ lengths = attention_mask .sum (dim = 1 ).clamp (min = 1 ) # last non-pad token index + 1
168+ return token_embeddings [
169+ torch .arange (token_embeddings .size (0 )),
170+ lengths - 1 ,
171+ :,
172+ ]
173+ else :
174+ if self .attention_config .aggregation_method != "mean" :
175+ raise ValueError (
176+ f"Unknown aggregation method: { self .attention_config .aggregation_method } . Supported methods are 'mean', 'first', 'last'."
177+ )
178+
179+ assert self .attention_config is None or self .attention_config .aggregation_method == "mean"
180+
63181 mask = attention_mask .unsqueeze (- 1 ).float () # (batch_size, seq_len, 1)
64182 masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
65183
@@ -79,3 +197,24 @@ def __call__(self, *args, **kwargs):
79197 f"(got shape { tuple (out .shape )} )"
80198 )
81199 return out
200+
201+ def _precompute_rotary_embeddings (self , seq_len , head_dim , base = 10000 , device = None ):
202+ # autodetect the device from model embeddings
203+ if device is None :
204+ device = next (self .parameters ()).device
205+
206+ # stride the channels
207+ channel_range = torch .arange (0 , head_dim , 2 , dtype = torch .float32 , device = device )
208+ inv_freq = 1.0 / (base ** (channel_range / head_dim ))
209+ # stride the time steps
210+ t = torch .arange (seq_len , dtype = torch .float32 , device = device )
211+ # calculate the rotation frequencies at each (time, channel) pair
212+ freqs = torch .outer (t , inv_freq )
213+ cos , sin = freqs .cos (), freqs .sin ()
214+ cos , sin = cos .bfloat16 (), sin .bfloat16 () # keep them in bfloat16
215+ cos , sin = (
216+ cos [None , :, None , :],
217+ sin [None , :, None , :],
218+ ) # add batch and head dims for later broadcasting
219+
220+ return cos , sin
0 commit comments