Skip to content

Commit 22d4507

Browse files
feat: split TextEmbedder into TokenEmbedder vs SequenceEmbedder
TokenEmbedder -> takes tokenized input, outputs. tensor of size (bs, cs, d_embed) SequenceEMbedder -> follows the TokenEmbedder and outputs a (bs, d_embed) tensor
1 parent e0f8f5e commit 22d4507

5 files changed

Lines changed: 204 additions & 173 deletions

File tree

torchTextClassifiers/model/components/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
)
1010
from .classification_head import ClassificationHead as ClassificationHead
1111
from .text_embedder import LabelAttentionConfig as LabelAttentionConfig
12-
from .text_embedder import TextEmbedder as TextEmbedder
13-
from .text_embedder import TextEmbedderConfig as TextEmbedderConfig
12+
from .text_embedder import TokenEmbedder as TokenEmbedder, TokenEmbedderConfig as TokenEmbedderConfig
13+
from .text_embedder import SentenceEmbedder as SentenceEmbedder, SentenceEmbedderConfig as SentenceEmbedderConfig

torchTextClassifiers/model/components/attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class AttentionConfig:
3535
n_kv_head: int
3636
sequence_len: Optional[int] = None
3737
positional_encoding: bool = True
38-
aggregation_method: str = "mean" # or 'last', or 'first'
3938

4039

4140
#### Attention Block #####

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 149 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import math
23
from dataclasses import dataclass
34
from typing import Dict, Optional
@@ -8,47 +9,55 @@
89

910
from torchTextClassifiers.model.components.attention import AttentionConfig, Block, norm
1011

12+
logger = logging.getLogger(__name__)
13+
14+
logging.basicConfig(
15+
level=logging.INFO,
16+
format="%(asctime)s - %(name)s - %(message)s",
17+
datefmt="%Y-%m-%d %H:%M:%S",
18+
handlers=[logging.StreamHandler()],
19+
)
20+
1121

1222
@dataclass
1323
class LabelAttentionConfig:
1424
n_head: int
1525
num_classes: int
26+
embedding_dim: int
1627

1728

1829
@dataclass
19-
class TextEmbedderConfig:
30+
class TokenEmbedderConfig:
2031
vocab_size: int
2132
embedding_dim: int
2233
padding_idx: int
2334
attention_config: Optional[AttentionConfig] = None
35+
36+
37+
@dataclass
38+
class SentenceEmbedderConfig:
39+
aggregation_method: Optional[str] = "mean" # or 'last', or 'first'
2440
label_attention_config: Optional[LabelAttentionConfig] = None
2541

2642

27-
class TextEmbedder(nn.Module):
28-
def __init__(self, text_embedder_config: TextEmbedderConfig):
43+
class TokenEmbedder(nn.Module):
44+
"""
45+
A module that takes tokenized text and outputs dense vector representations (one for each token).
46+
47+
"""
48+
49+
def __init__(self, token_embedder_config: TokenEmbedderConfig):
2950
super().__init__()
3051

31-
self.config = text_embedder_config
52+
self.config = token_embedder_config
3253

33-
self.attention_config = text_embedder_config.attention_config
54+
self.attention_config = token_embedder_config.attention_config
3455
if isinstance(self.attention_config, dict):
3556
self.attention_config = AttentionConfig(**self.attention_config)
3657

37-
# Normalize label_attention_config: allow dicts and convert them to LabelAttentionConfig
38-
self.label_attention_config = text_embedder_config.label_attention_config
39-
if isinstance(self.label_attention_config, dict):
40-
self.label_attention_config = LabelAttentionConfig(**self.label_attention_config)
41-
# Keep self.config in sync so downstream components (e.g., LabelAttentionClassifier)
42-
# always see a LabelAttentionConfig instance rather than a raw dict.
43-
self.config.label_attention_config = self.label_attention_config
44-
45-
self.enable_label_attention = self.label_attention_config is not None
46-
if self.enable_label_attention:
47-
self.label_attention_module = LabelAttentionClassifier(self.config)
48-
49-
self.vocab_size = text_embedder_config.vocab_size
50-
self.embedding_dim = text_embedder_config.embedding_dim
51-
self.padding_idx = text_embedder_config.padding_idx
58+
self.vocab_size = token_embedder_config.vocab_size
59+
self.embedding_dim = token_embedder_config.embedding_dim
60+
self.padding_idx = token_embedder_config.padding_idx
5261

5362
self.embedding_layer = nn.Embedding(
5463
embedding_dim=self.embedding_dim,
@@ -57,7 +66,7 @@ def __init__(self, text_embedder_config: TextEmbedderConfig):
5766
)
5867

5968
if self.attention_config is not None:
60-
self.attention_config.n_embd = text_embedder_config.embedding_dim
69+
self.attention_config.n_embd: int = token_embedder_config.embedding_dim
6170
self.transformer = nn.ModuleDict(
6271
{
6372
"h": nn.ModuleList(
@@ -127,30 +136,7 @@ def forward(
127136
self,
128137
input_ids: torch.Tensor,
129138
attention_mask: torch.Tensor,
130-
return_label_attention_matrix: bool = False,
131139
) -> Dict[str, Optional[torch.Tensor]]:
132-
"""Converts input token IDs to their corresponding embeddings.
133-
134-
Args:
135-
input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized
136-
attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
137-
return_label_attention_matrix (bool): Whether to return the label attention matrix.
138-
139-
Returns:
140-
dict: A dictionary with the following keys:
141-
142-
- "sentence_embedding" (torch.Tensor): Text embeddings of shape
143-
(batch_size, embedding_dim) if ``self.enable_label_attention`` is False,
144-
else (batch_size, num_classes, embedding_dim), where ``num_classes``
145-
is the number of label classes.
146-
147-
- "label_attention_matrix" (Optional[torch.Tensor]): Label attention
148-
matrix of shape (batch_size, n_head, num_classes, seq_len) if
149-
``return_label_attention_matrix`` is True and label attention is
150-
enabled, otherwise ``None``. The dimensions correspond to
151-
(batch_size, attention heads, label classes, sequence length).
152-
"""
153-
154140
encoded_text = input_ids # clearer name
155141
if encoded_text.dtype != torch.long:
156142
encoded_text = encoded_text.to(torch.long)
@@ -181,92 +167,9 @@ def forward(
181167

182168
token_embeddings = norm(token_embeddings)
183169

184-
out = self._get_sentence_embedding(
185-
token_embeddings=token_embeddings,
186-
attention_mask=attention_mask,
187-
return_label_attention_matrix=return_label_attention_matrix,
188-
)
189-
190-
text_embedding = out["sentence_embedding"]
191-
label_attention_matrix = out["label_attention_matrix"]
192-
return {
193-
"sentence_embedding": text_embedding,
194-
"label_attention_matrix": label_attention_matrix,
195-
}
196-
197-
def _get_sentence_embedding(
198-
self,
199-
token_embeddings: torch.Tensor,
200-
attention_mask: torch.Tensor,
201-
return_label_attention_matrix: bool = False,
202-
) -> Dict[str, Optional[torch.Tensor]]:
203-
"""
204-
Compute sentence embedding from embedded tokens - "remove" second dimension.
205-
206-
Args (output from dataset collate_fn):
207-
token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text
208-
attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
209-
return_label_attention_matrix (bool): Whether to compute and return the label attention matrix
210-
Returns:
211-
Dict[str, Optional[torch.Tensor]]: A dictionary containing:
212-
- 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled
213-
- 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None
214-
"""
215-
216-
# average over non-pad token embeddings
217-
# attention mask has 1 for non-pad tokens and 0 for pad token positions
218-
219-
# mask pad-tokens
220-
221-
if self.attention_config is not None:
222-
if self.attention_config.aggregation_method is not None: # default is "mean"
223-
if self.attention_config.aggregation_method == "first":
224-
return {
225-
"sentence_embedding": token_embeddings[:, 0, :],
226-
"label_attention_matrix": None,
227-
}
228-
elif self.attention_config.aggregation_method == "last":
229-
lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
230-
return {
231-
"sentence_embedding": token_embeddings[
232-
torch.arange(token_embeddings.size(0)),
233-
lengths - 1,
234-
:,
235-
],
236-
"label_attention_matrix": None,
237-
}
238-
else:
239-
if self.attention_config.aggregation_method != "mean":
240-
raise ValueError(
241-
f"Unknown aggregation method: {self.attention_config.aggregation_method}. Supported methods are 'mean', 'first', 'last'."
242-
)
243-
244-
assert self.attention_config is None or self.attention_config.aggregation_method == "mean"
245-
246-
if self.enable_label_attention:
247-
label_attention_result = self.label_attention_module(
248-
token_embeddings,
249-
attention_mask=attention_mask,
250-
compute_attention_matrix=return_label_attention_matrix,
251-
)
252-
sentence_embedding = label_attention_result[
253-
"sentence_embedding"
254-
] # (bs, n_labels, d_embed), so classifier needs to be a (d_embed, 1) matrix
255-
label_attention_matrix = label_attention_result["attention_matrix"]
256-
257-
else: # sentence embedding = mean of (non-pad) token embeddings
258-
mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1)
259-
masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
260-
sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp(
261-
min=1.0
262-
) # avoid division by zero
263-
264-
sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0)
265-
label_attention_matrix = None
266-
267170
return {
268-
"sentence_embedding": sentence_embedding,
269-
"label_attention_matrix": label_attention_matrix,
171+
"token_embeddings": token_embeddings,
172+
"attention_mask": attention_mask,
270173
}
271174

272175
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
@@ -291,20 +194,25 @@ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=No
291194
return cos, sin
292195

293196

294-
class LabelAttentionClassifier(nn.Module):
197+
class LabelAttention(nn.Module):
295198
"""
296199
A head for aggregating token embeddings into label-specific sentence embeddings using cross-attention mechanism.
297200
Labels are queries that attend over token embeddings (keys and values) to produce label-specific embeddings.
298201
299202
"""
300203

301-
def __init__(self, config: TextEmbedderConfig):
204+
def __init__(self, label_attention_config: LabelAttentionConfig):
302205
super().__init__()
303206

304-
label_attention_config = config.label_attention_config
305-
self.embedding_dim = config.embedding_dim
207+
if label_attention_config is None:
208+
raise ValueError(
209+
"label_attention_config must be provided to use LabelAttention."
210+
)
211+
212+
self.label_attention_config = label_attention_config
306213
self.num_classes = label_attention_config.num_classes
307214
self.n_head = label_attention_config.n_head
215+
self.embedding_dim = label_attention_config.embedding_dim
308216

309217
# Validate head configuration
310218
self.head_dim = self.embedding_dim // self.n_head
@@ -399,3 +307,111 @@ def forward(
399307
attention_matrix = torch.softmax(attention_scores, dim=-1)
400308

401309
return {"sentence_embedding": y, "attention_matrix": attention_matrix}
310+
311+
312+
class SentenceEmbedder(nn.Module):
313+
def __init__(self, sentence_embedder_config: SentenceEmbedderConfig):
314+
315+
"""
316+
A module to aggregate token embeddings.
317+
318+
Four modes are possible:
319+
- aggregation_method="mean" (default): token embeddings are averaged
320+
- aggregation_method="first": sentence embedding is the first token's embedding (commin in BERT-like models ([CLS] token))
321+
- aggregation_method="last": sentence embedding is the last token's embedding (commin in GPT-like models)
322+
- aggregation_method=None: in that case you need to provide a label attention
323+
"""
324+
325+
self.config
326+
self.label_attention_config = sentence_embedder_config.label_attention_config
327+
self.aggregation_method = sentence_embedder_config.aggregation_method
328+
329+
if isinstance(self.label_attention_config, dict):
330+
self.label_attention_config = LabelAttentionConfig(**self.label_attention_config)
331+
# Keep self.sentence_embedder_config in sync so downstream components (e.g., LabelAttentionClassifier)
332+
# always see a LabelAttentionConfig instance rather than a raw dict.
333+
self.sentence_embedder_config.label_attention_config: LabelAttentionConfig = (
334+
self.label_attention_config
335+
)
336+
337+
if self.label_attention_config is not None:
338+
self.label_attention_module = LabelAttention(
339+
label_attention_config=self.label_attention_config
340+
)
341+
if self.aggregation_method is not None:
342+
logger.info(
343+
"Warning: aggregation_method is ignored when label_attention_config is provided, since label attention produces label-specific sentence embeddings without further aggregation."
344+
)
345+
self.aggregation_method = None # override to avoid confusion
346+
347+
if self.aggregation_method not in (None, "mean", "first", "last"):
348+
raise ValueError(
349+
f"Unsupported aggregation method: {self.aggregation_method}. Supported methods are None, 'mean', 'first', 'last'."
350+
)
351+
if self.aggregation_method is None:
352+
if self.label_attention_config is None:
353+
raise ValueError(
354+
"aggregation_method cannot be None when label_attention_config is not provided, since we need some way to aggregate token embeddings into a sentence embedding. Please specify an aggregation method (e.g., 'mean') or provide a label_attention_config to use label attention for aggregation."
355+
)
356+
357+
def forward(
358+
self,
359+
token_embeddings: torch.Tensor,
360+
attention_mask: torch.Tensor,
361+
return_label_attention_matrix: bool = False,
362+
) -> Dict[str, Optional[torch.Tensor]]:
363+
"""
364+
Compute sentence embedding from embedded tokens - "remove" second dimension.
365+
366+
Args (output from dataset collate_fn):
367+
token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text
368+
attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
369+
return_label_attention_matrix (bool): Whether to compute and return the label attention matrix
370+
Returns:
371+
Dict[str, Optional[torch.Tensor]]: A dictionary containing:
372+
- 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled
373+
- 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None
374+
"""
375+
if self.aggregation_method is not None: # default is "mean"
376+
if self.aggregation_method == "first":
377+
return {
378+
"sentence_embedding": token_embeddings[:, 0, :],
379+
"label_attention_matrix": None,
380+
}
381+
elif self.aggregation_method == "last":
382+
lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
383+
return {
384+
"sentence_embedding": token_embeddings[
385+
torch.arange(token_embeddings.size(0)),
386+
lengths - 1,
387+
:,
388+
],
389+
"label_attention_matrix": None,
390+
}
391+
else: # mean
392+
mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1)
393+
masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
394+
sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp(
395+
min=1.0
396+
) # avoid division by zero
397+
398+
sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0)
399+
return {
400+
"sentence_embedding": sentence_embedding,
401+
"label_attention_matrix": None,
402+
}
403+
404+
else:
405+
label_attention_result = self.label_attention_module(
406+
token_embeddings,
407+
attention_mask=attention_mask,
408+
compute_attention_matrix=return_label_attention_matrix,
409+
)
410+
sentence_embedding = label_attention_result[
411+
"sentence_embedding"
412+
] # (bs, n_labels, d_embed), so classifier needs to be a (d_embed, 1) matrix
413+
label_attention_matrix = label_attention_result["attention_matrix"]
414+
return {
415+
"sentence_embedding": sentence_embedding,
416+
"label_attention_matrix": label_attention_matrix,
417+
}

0 commit comments

Comments
 (0)