33from typing import Optional
44
55import torch
6- from torch import nn
6+ import torch .nn as nn
7+ from torch .nn import functional as F
78
89from torchTextClassifiers .model .components .attention import AttentionConfig , Block , norm
910
1011
12+ @dataclass
13+ class LabelAttentionConfig :
14+ n_head : int
15+ n_kv_head : int
16+ num_classes : int
17+
18+
1119@dataclass
1220class TextEmbedderConfig :
1321 vocab_size : int
1422 embedding_dim : int
1523 padding_idx : int
1624 attention_config : Optional [AttentionConfig ] = None
25+ label_attention_config : Optional [LabelAttentionConfig ] = None
1726
1827
1928class TextEmbedder (nn .Module ):
@@ -26,8 +35,9 @@ def __init__(self, text_embedder_config: TextEmbedderConfig):
2635 if isinstance (self .attention_config , dict ):
2736 self .attention_config = AttentionConfig (** self .attention_config )
2837
29- if self .attention_config is not None :
30- self .attention_config .n_embd = text_embedder_config .embedding_dim
38+ self .enable_label_attention = text_embedder_config .label_attention_config is not None
39+ if self .enable_label_attention :
40+ self .label_attention_module = LabelAttentionClassifier (self .config )
3141
3242 self .vocab_size = text_embedder_config .vocab_size
3343 self .embedding_dim = text_embedder_config .embedding_dim
@@ -40,6 +50,7 @@ def __init__(self, text_embedder_config: TextEmbedderConfig):
4050 )
4151
4252 if self .attention_config is not None :
53+ self .attention_config .n_embd = text_embedder_config .embedding_dim
4354 self .transformer = nn .ModuleDict (
4455 {
4556 "h" : nn .ModuleList (
@@ -105,8 +116,23 @@ def _init_weights(self, module):
105116 elif isinstance (module , nn .Embedding ):
106117 torch .nn .init .normal_ (module .weight , mean = 0.0 , std = 1.0 )
107118
108- def forward (self , input_ids : torch .Tensor , attention_mask : torch .Tensor ) -> torch .Tensor :
109- """Converts input token IDs to their corresponding embeddings."""
119+ def forward (
120+ self ,
121+ input_ids : torch .Tensor ,
122+ attention_mask : torch .Tensor ,
123+ return_label_attention_matrix : bool = False ,
124+ ) -> torch .Tensor :
125+ """Converts input token IDs to their corresponding embeddings.
126+
127+ Args:
128+ input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized
129+ attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
130+ return_label_attention_matrix (bool): Whether to return the label attention matrix
131+ Returns:
132+ torch.Tensor: Text embeddings, shape (batch_size, embedding_dim) if self.enable_label_attention is False, else (batch_size, num_labels, embedding_dim)
133+ torch.Tensor: Label attention matrix, shape (batch_size, num_labels, seq_len) if return_label_attention_matrix is True, else None.
134+ Also None if label attention is disabled (even if return_label_attention_matrix is True)
135+ """
110136
111137 encoded_text = input_ids # clearer name
112138 if encoded_text .dtype != torch .long :
@@ -138,14 +164,25 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torc
138164
139165 token_embeddings = norm (token_embeddings )
140166
141- text_embedding = self ._get_sentence_embedding (
142- token_embeddings = token_embeddings , attention_mask = attention_mask
143- )
167+ text_embedding , label_attention_matrix = self ._get_sentence_embedding (
168+ token_embeddings = token_embeddings ,
169+ attention_mask = attention_mask ,
170+ return_label_attention_matrix = return_label_attention_matrix ,
171+ ).values ()
144172
145- return text_embedding
173+ if return_label_attention_matrix :
174+ return (
175+ text_embedding ,
176+ label_attention_matrix ,
177+ ) # label_attention_matrix is None if label attention is disabled
178+ else :
179+ return text_embedding
146180
147181 def _get_sentence_embedding (
148- self , token_embeddings : torch .Tensor , attention_mask : torch .Tensor
182+ self ,
183+ token_embeddings : torch .Tensor ,
184+ attention_mask : torch .Tensor ,
185+ return_label_attention_matrix : bool = False ,
149186 ) -> torch .Tensor :
150187 """
151188 Compute sentence embedding from embedded tokens - "remove" second dimension.
@@ -163,7 +200,7 @@ def _get_sentence_embedding(
163200 # mask pad-tokens
164201
165202 if self .attention_config is not None :
166- if self .attention_config .aggregation_method is not None :
203+ if self .attention_config .aggregation_method is not None : # default is "mean"
167204 if self .attention_config .aggregation_method == "first" :
168205 return token_embeddings [:, 0 , :]
169206 elif self .attention_config .aggregation_method == "last" :
@@ -181,25 +218,29 @@ def _get_sentence_embedding(
181218
182219 assert self .attention_config is None or self .attention_config .aggregation_method == "mean"
183220
184- mask = attention_mask .unsqueeze (- 1 ).float () # (batch_size, seq_len, 1)
185- masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
186-
187- sentence_embedding = masked_embeddings .sum (dim = 1 ) / mask .sum (dim = 1 ).clamp (
188- min = 1.0
189- ) # avoid division by zero
190-
191- sentence_embedding = torch .nan_to_num (sentence_embedding , 0.0 )
192-
193- return sentence_embedding
194-
195- def __call__ (self , * args , ** kwargs ):
196- out = super ().__call__ (* args , ** kwargs )
197- if out .dim () != 2 :
198- raise ValueError (
199- f"Output of { self .__class__ .__name__ } .forward must be 2D "
200- f"(got shape { tuple (out .shape )} )"
221+ if self .enable_label_attention :
222+ label_attention_result = self .label_attention_module (
223+ token_embeddings , compute_attention_matrix = return_label_attention_matrix
201224 )
202- return out
225+ sentence_embedding = label_attention_result [
226+ "sentence_embedding"
227+ ] # (bs, n_labels, d_embed), so classifier needs to be a (d_embed, 1) matrix
228+ label_attention_matrix = label_attention_result ["attention_matrix" ]
229+
230+ else : # sentence embedding = mean of (non-pad) token embeddings
231+ mask = attention_mask .unsqueeze (- 1 ).float () # (batch_size, seq_len, 1)
232+ masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
233+ sentence_embedding = masked_embeddings .sum (dim = 1 ) / mask .sum (dim = 1 ).clamp (
234+ min = 1.0
235+ ) # avoid division by zero
236+
237+ sentence_embedding = torch .nan_to_num (sentence_embedding , 0.0 )
238+ label_attention_matrix = None
239+
240+ return {
241+ "sentence_embedding" : sentence_embedding ,
242+ "label_attention_matrix" : label_attention_matrix ,
243+ }
203244
204245 def _precompute_rotary_embeddings (self , seq_len , head_dim , base = 10000 , device = None ):
205246 # autodetect the device from model embeddings
@@ -221,3 +262,79 @@ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=No
221262 ) # add batch and head dims for later broadcasting
222263
223264 return cos , sin
265+
266+
267+ class LabelAttentionClassifier (nn .Module ):
268+ """
269+ A head for aggregating token embeddings into label-specific sentence embeddings using cross-attention mechanism.
270+ Labels are queries that attend over token embeddings (keys and values) to produce label-specific embeddings.
271+
272+ """
273+
274+ def __init__ (self , config : TextEmbedderConfig ):
275+ super ().__init__ ()
276+
277+ label_attention_config = config .label_attention_config
278+ self .embedding_dim = config .embedding_dim
279+ self .num_classes = label_attention_config .num_classes
280+ self .n_head = label_attention_config .n_head
281+ self .n_kv_head = label_attention_config .n_kv_head
282+ self .enable_gqa = (
283+ self .n_head != self .n_kv_head
284+ ) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
285+ self .head_dim = self .embedding_dim // self .n_head
286+
287+ self .label_embeds = nn .Embedding (self .num_classes , self .embedding_dim )
288+
289+ self .c_q = nn .Linear (self .embedding_dim , self .n_head * self .head_dim , bias = False )
290+ self .c_k = nn .Linear (self .embedding_dim , self .n_kv_head * self .head_dim , bias = False )
291+ self .c_v = nn .Linear (self .embedding_dim , self .n_kv_head * self .head_dim , bias = False )
292+ self .c_proj = nn .Linear (self .embedding_dim , self .embedding_dim , bias = False )
293+
294+ def forward (self , token_embeddings , compute_attention_matrix : Optional [bool ] = False ):
295+ """
296+ Args:
297+ token_embeddings (torch.Tensor), shape (batch, seq_len, d_model): Embedded tokens from the text input.
298+ compute_attention_matrix (bool): Whether to compute and return the attention matrix.
299+ Returns:
300+ dict: {
301+ "sentence_embedding": torch.Tensor, shape (batch, num_classes, d_model): Label-specific sentence embeddings.
302+ "attention_matrix": Optional[torch.Tensor], shape (batch, n_head, num_classes, seq_len): Attention weights if compute_attention_matrix is True, else None.
303+ }
304+
305+ """
306+ B , T , C = token_embeddings .size ()
307+
308+ # 1. Create label indices [0, 1, ..., C-1] for the whole batch
309+ label_indices = torch .arange (self .num_classes ).expand (B , - 1 )
310+
311+ all_label_embeddings = self .label_embeds (
312+ label_indices
313+ ) # Shape: [batch, num_classes, d_model]
314+ all_label_embeddings = norm (all_label_embeddings )
315+
316+ q = self .c_q (all_label_embeddings ).view (B , self .num_classes , self .n_head , self .head_dim )
317+ k = self .c_k (token_embeddings ).view (B , T , self .n_kv_head , self .head_dim )
318+ v = self .c_v (token_embeddings ).view (B , T , self .n_kv_head , self .head_dim )
319+
320+ q , k = norm (q ), norm (k ) # QK norm
321+ q , k , v = (
322+ q .transpose (1 , 2 ),
323+ k .transpose (1 , 2 ),
324+ v .transpose (1 , 2 ),
325+ ) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
326+
327+ y = F .scaled_dot_product_attention (q , k , v , is_causal = False , enable_gqa = self .enable_gqa )
328+
329+ # Re-assemble the heads side by side and project back to residual stream
330+ y = y .transpose (1 , 2 ).contiguous ().view (B , self .num_classes , - 1 ) # (bs, n_labels, d_model)
331+ y = self .c_proj (y )
332+
333+ attention_matrix = None
334+ if compute_attention_matrix :
335+ # size (B, n_head, n_labels, seq_len) - we let the user handle aggregation over heads if desired
336+ attention_matrix = torch .softmax (
337+ torch .matmul (q , k .transpose (- 2 , - 1 )) / (self .head_dim ** 0.5 ), dim = - 1
338+ )
339+
340+ return {"sentence_embedding" : y , "attention_matrix" : attention_matrix }
0 commit comments