1+ import logging
12import math
23from dataclasses import dataclass
34from typing import Dict , Optional
89
910from torchTextClassifiers .model .components .attention import AttentionConfig , Block , norm
1011
12+ logger = logging .getLogger (__name__ )
13+
14+ logging .basicConfig (
15+ level = logging .INFO ,
16+ format = "%(asctime)s - %(name)s - %(message)s" ,
17+ datefmt = "%Y-%m-%d %H:%M:%S" ,
18+ handlers = [logging .StreamHandler ()],
19+ )
20+
1121
1222@dataclass
1323class LabelAttentionConfig :
1424 n_head : int
1525 num_classes : int
26+ embedding_dim : int
1627
1728
1829@dataclass
19- class TextEmbedderConfig :
30+ class TokenEmbedderConfig :
2031 vocab_size : int
2132 embedding_dim : int
2233 padding_idx : int
2334 attention_config : Optional [AttentionConfig ] = None
35+
36+
37+ @dataclass
38+ class SentenceEmbedderConfig :
39+ aggregation_method : Optional [str ] = "mean" # or 'last', or 'first'
2440 label_attention_config : Optional [LabelAttentionConfig ] = None
2541
2642
27- class TextEmbedder (nn .Module ):
28- def __init__ (self , text_embedder_config : TextEmbedderConfig ):
43+ class TokenEmbedder (nn .Module ):
44+ """
45+ A module that takes tokenized text and outputs dense vector representations (one for each token).
46+
47+ """
48+
49+ def __init__ (self , token_embedder_config : TokenEmbedderConfig ):
2950 super ().__init__ ()
3051
31- self .config = text_embedder_config
52+ self .config = token_embedder_config
3253
33- self .attention_config = text_embedder_config .attention_config
54+ self .attention_config = token_embedder_config .attention_config
3455 if isinstance (self .attention_config , dict ):
3556 self .attention_config = AttentionConfig (** self .attention_config )
3657
37- # Normalize label_attention_config: allow dicts and convert them to LabelAttentionConfig
38- self .label_attention_config = text_embedder_config .label_attention_config
39- if isinstance (self .label_attention_config , dict ):
40- self .label_attention_config = LabelAttentionConfig (** self .label_attention_config )
41- # Keep self.config in sync so downstream components (e.g., LabelAttentionClassifier)
42- # always see a LabelAttentionConfig instance rather than a raw dict.
43- self .config .label_attention_config = self .label_attention_config
44-
45- self .enable_label_attention = self .label_attention_config is not None
46- if self .enable_label_attention :
47- self .label_attention_module = LabelAttentionClassifier (self .config )
48-
49- self .vocab_size = text_embedder_config .vocab_size
50- self .embedding_dim = text_embedder_config .embedding_dim
51- self .padding_idx = text_embedder_config .padding_idx
58+ self .vocab_size = token_embedder_config .vocab_size
59+ self .embedding_dim = token_embedder_config .embedding_dim
60+ self .padding_idx = token_embedder_config .padding_idx
5261
5362 self .embedding_layer = nn .Embedding (
5463 embedding_dim = self .embedding_dim ,
@@ -57,7 +66,7 @@ def __init__(self, text_embedder_config: TextEmbedderConfig):
5766 )
5867
5968 if self .attention_config is not None :
60- self .attention_config .n_embd = text_embedder_config .embedding_dim
69+ self .attention_config .n_embd : int = token_embedder_config .embedding_dim
6170 self .transformer = nn .ModuleDict (
6271 {
6372 "h" : nn .ModuleList (
@@ -127,30 +136,7 @@ def forward(
127136 self ,
128137 input_ids : torch .Tensor ,
129138 attention_mask : torch .Tensor ,
130- return_label_attention_matrix : bool = False ,
131139 ) -> Dict [str , Optional [torch .Tensor ]]:
132- """Converts input token IDs to their corresponding embeddings.
133-
134- Args:
135- input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized
136- attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
137- return_label_attention_matrix (bool): Whether to return the label attention matrix.
138-
139- Returns:
140- dict: A dictionary with the following keys:
141-
142- - "sentence_embedding" (torch.Tensor): Text embeddings of shape
143- (batch_size, embedding_dim) if ``self.enable_label_attention`` is False,
144- else (batch_size, num_classes, embedding_dim), where ``num_classes``
145- is the number of label classes.
146-
147- - "label_attention_matrix" (Optional[torch.Tensor]): Label attention
148- matrix of shape (batch_size, n_head, num_classes, seq_len) if
149- ``return_label_attention_matrix`` is True and label attention is
150- enabled, otherwise ``None``. The dimensions correspond to
151- (batch_size, attention heads, label classes, sequence length).
152- """
153-
154140 encoded_text = input_ids # clearer name
155141 if encoded_text .dtype != torch .long :
156142 encoded_text = encoded_text .to (torch .long )
@@ -181,92 +167,9 @@ def forward(
181167
182168 token_embeddings = norm (token_embeddings )
183169
184- out = self ._get_sentence_embedding (
185- token_embeddings = token_embeddings ,
186- attention_mask = attention_mask ,
187- return_label_attention_matrix = return_label_attention_matrix ,
188- )
189-
190- text_embedding = out ["sentence_embedding" ]
191- label_attention_matrix = out ["label_attention_matrix" ]
192- return {
193- "sentence_embedding" : text_embedding ,
194- "label_attention_matrix" : label_attention_matrix ,
195- }
196-
197- def _get_sentence_embedding (
198- self ,
199- token_embeddings : torch .Tensor ,
200- attention_mask : torch .Tensor ,
201- return_label_attention_matrix : bool = False ,
202- ) -> Dict [str , Optional [torch .Tensor ]]:
203- """
204- Compute sentence embedding from embedded tokens - "remove" second dimension.
205-
206- Args (output from dataset collate_fn):
207- token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text
208- attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
209- return_label_attention_matrix (bool): Whether to compute and return the label attention matrix
210- Returns:
211- Dict[str, Optional[torch.Tensor]]: A dictionary containing:
212- - 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled
213- - 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None
214- """
215-
216- # average over non-pad token embeddings
217- # attention mask has 1 for non-pad tokens and 0 for pad token positions
218-
219- # mask pad-tokens
220-
221- if self .attention_config is not None :
222- if self .attention_config .aggregation_method is not None : # default is "mean"
223- if self .attention_config .aggregation_method == "first" :
224- return {
225- "sentence_embedding" : token_embeddings [:, 0 , :],
226- "label_attention_matrix" : None ,
227- }
228- elif self .attention_config .aggregation_method == "last" :
229- lengths = attention_mask .sum (dim = 1 ).clamp (min = 1 ) # last non-pad token index + 1
230- return {
231- "sentence_embedding" : token_embeddings [
232- torch .arange (token_embeddings .size (0 )),
233- lengths - 1 ,
234- :,
235- ],
236- "label_attention_matrix" : None ,
237- }
238- else :
239- if self .attention_config .aggregation_method != "mean" :
240- raise ValueError (
241- f"Unknown aggregation method: { self .attention_config .aggregation_method } . Supported methods are 'mean', 'first', 'last'."
242- )
243-
244- assert self .attention_config is None or self .attention_config .aggregation_method == "mean"
245-
246- if self .enable_label_attention :
247- label_attention_result = self .label_attention_module (
248- token_embeddings ,
249- attention_mask = attention_mask ,
250- compute_attention_matrix = return_label_attention_matrix ,
251- )
252- sentence_embedding = label_attention_result [
253- "sentence_embedding"
254- ] # (bs, n_labels, d_embed), so classifier needs to be a (d_embed, 1) matrix
255- label_attention_matrix = label_attention_result ["attention_matrix" ]
256-
257- else : # sentence embedding = mean of (non-pad) token embeddings
258- mask = attention_mask .unsqueeze (- 1 ).float () # (batch_size, seq_len, 1)
259- masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
260- sentence_embedding = masked_embeddings .sum (dim = 1 ) / mask .sum (dim = 1 ).clamp (
261- min = 1.0
262- ) # avoid division by zero
263-
264- sentence_embedding = torch .nan_to_num (sentence_embedding , 0.0 )
265- label_attention_matrix = None
266-
267170 return {
268- "sentence_embedding " : sentence_embedding ,
269- "label_attention_matrix " : label_attention_matrix ,
171+ "token_embeddings " : token_embeddings ,
172+ "attention_mask " : attention_mask ,
270173 }
271174
272175 def _precompute_rotary_embeddings (self , seq_len , head_dim , base = 10000 , device = None ):
@@ -291,20 +194,25 @@ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=No
291194 return cos , sin
292195
293196
294- class LabelAttentionClassifier (nn .Module ):
197+ class LabelAttention (nn .Module ):
295198 """
296199 A head for aggregating token embeddings into label-specific sentence embeddings using cross-attention mechanism.
297200 Labels are queries that attend over token embeddings (keys and values) to produce label-specific embeddings.
298201
299202 """
300203
301- def __init__ (self , config : TextEmbedderConfig ):
204+ def __init__ (self , label_attention_config : LabelAttentionConfig ):
302205 super ().__init__ ()
303206
304- label_attention_config = config .label_attention_config
305- self .embedding_dim = config .embedding_dim
207+ if label_attention_config is None :
208+ raise ValueError (
209+ "label_attention_config must be provided to use LabelAttention."
210+ )
211+
212+ self .label_attention_config = label_attention_config
306213 self .num_classes = label_attention_config .num_classes
307214 self .n_head = label_attention_config .n_head
215+ self .embedding_dim = label_attention_config .embedding_dim
308216
309217 # Validate head configuration
310218 self .head_dim = self .embedding_dim // self .n_head
@@ -399,3 +307,111 @@ def forward(
399307 attention_matrix = torch .softmax (attention_scores , dim = - 1 )
400308
401309 return {"sentence_embedding" : y , "attention_matrix" : attention_matrix }
310+
311+
312+ class SentenceEmbedder (nn .Module ):
313+ def __init__ (self , sentence_embedder_config : SentenceEmbedderConfig ):
314+
315+ """
316+ A module to aggregate token embeddings.
317+
318+ Four modes are possible:
319+ - aggregation_method="mean" (default): token embeddings are averaged
320+ - aggregation_method="first": sentence embedding is the first token's embedding (commin in BERT-like models ([CLS] token))
321+ - aggregation_method="last": sentence embedding is the last token's embedding (commin in GPT-like models)
322+ - aggregation_method=None: in that case you need to provide a label attention
323+ """
324+
325+ self .config
326+ self .label_attention_config = sentence_embedder_config .label_attention_config
327+ self .aggregation_method = sentence_embedder_config .aggregation_method
328+
329+ if isinstance (self .label_attention_config , dict ):
330+ self .label_attention_config = LabelAttentionConfig (** self .label_attention_config )
331+ # Keep self.sentence_embedder_config in sync so downstream components (e.g., LabelAttentionClassifier)
332+ # always see a LabelAttentionConfig instance rather than a raw dict.
333+ self .sentence_embedder_config .label_attention_config : LabelAttentionConfig = (
334+ self .label_attention_config
335+ )
336+
337+ if self .label_attention_config is not None :
338+ self .label_attention_module = LabelAttention (
339+ label_attention_config = self .label_attention_config
340+ )
341+ if self .aggregation_method is not None :
342+ logger .info (
343+ "Warning: aggregation_method is ignored when label_attention_config is provided, since label attention produces label-specific sentence embeddings without further aggregation."
344+ )
345+ self .aggregation_method = None # override to avoid confusion
346+
347+ if self .aggregation_method not in (None , "mean" , "first" , "last" ):
348+ raise ValueError (
349+ f"Unsupported aggregation method: { self .aggregation_method } . Supported methods are None, 'mean', 'first', 'last'."
350+ )
351+ if self .aggregation_method is None :
352+ if self .label_attention_config is None :
353+ raise ValueError (
354+ "aggregation_method cannot be None when label_attention_config is not provided, since we need some way to aggregate token embeddings into a sentence embedding. Please specify an aggregation method (e.g., 'mean') or provide a label_attention_config to use label attention for aggregation."
355+ )
356+
357+ def forward (
358+ self ,
359+ token_embeddings : torch .Tensor ,
360+ attention_mask : torch .Tensor ,
361+ return_label_attention_matrix : bool = False ,
362+ ) -> Dict [str , Optional [torch .Tensor ]]:
363+ """
364+ Compute sentence embedding from embedded tokens - "remove" second dimension.
365+
366+ Args (output from dataset collate_fn):
367+ token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text
368+ attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
369+ return_label_attention_matrix (bool): Whether to compute and return the label attention matrix
370+ Returns:
371+ Dict[str, Optional[torch.Tensor]]: A dictionary containing:
372+ - 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled
373+ - 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None
374+ """
375+ if self .aggregation_method is not None : # default is "mean"
376+ if self .aggregation_method == "first" :
377+ return {
378+ "sentence_embedding" : token_embeddings [:, 0 , :],
379+ "label_attention_matrix" : None ,
380+ }
381+ elif self .aggregation_method == "last" :
382+ lengths = attention_mask .sum (dim = 1 ).clamp (min = 1 ) # last non-pad token index + 1
383+ return {
384+ "sentence_embedding" : token_embeddings [
385+ torch .arange (token_embeddings .size (0 )),
386+ lengths - 1 ,
387+ :,
388+ ],
389+ "label_attention_matrix" : None ,
390+ }
391+ else : # mean
392+ mask = attention_mask .unsqueeze (- 1 ).float () # (batch_size, seq_len, 1)
393+ masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
394+ sentence_embedding = masked_embeddings .sum (dim = 1 ) / mask .sum (dim = 1 ).clamp (
395+ min = 1.0
396+ ) # avoid division by zero
397+
398+ sentence_embedding = torch .nan_to_num (sentence_embedding , 0.0 )
399+ return {
400+ "sentence_embedding" : sentence_embedding ,
401+ "label_attention_matrix" : None ,
402+ }
403+
404+ else :
405+ label_attention_result = self .label_attention_module (
406+ token_embeddings ,
407+ attention_mask = attention_mask ,
408+ compute_attention_matrix = return_label_attention_matrix ,
409+ )
410+ sentence_embedding = label_attention_result [
411+ "sentence_embedding"
412+ ] # (bs, n_labels, d_embed), so classifier needs to be a (d_embed, 1) matrix
413+ label_attention_matrix = label_attention_result ["attention_matrix" ]
414+ return {
415+ "sentence_embedding" : sentence_embedding ,
416+ "label_attention_matrix" : label_attention_matrix ,
417+ }
0 commit comments