Skip to content

Commit 601fa46

Browse files
feat!(label attention): enable label attention
- module and config created to do that - mainly attached the TextEmbedder (it aggregates the token embedding to produce a sentence embedding - instead of naive averaging) - rest of the code has been adapted, especially categorical var handling in TextClassificationModel
1 parent 7033120 commit 601fa46

5 files changed

Lines changed: 179 additions & 36 deletions

File tree

torchTextClassifiers/model/components/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@
88
CategoricalVariableNet as CategoricalVariableNet,
99
)
1010
from .classification_head import ClassificationHead as ClassificationHead
11+
from .text_embedder import LabelAttentionConfig as LabelAttentionConfig
1112
from .text_embedder import TextEmbedder as TextEmbedder
1213
from .text_embedder import TextEmbedderConfig as TextEmbedderConfig

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 146 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,26 @@
33
from typing import Optional
44

55
import torch
6-
from torch import nn
6+
import torch.nn as nn
7+
from torch.nn import functional as F
78

89
from torchTextClassifiers.model.components.attention import AttentionConfig, Block, norm
910

1011

12+
@dataclass
13+
class LabelAttentionConfig:
14+
n_head: int
15+
n_kv_head: int
16+
num_classes: int
17+
18+
1119
@dataclass
1220
class TextEmbedderConfig:
1321
vocab_size: int
1422
embedding_dim: int
1523
padding_idx: int
1624
attention_config: Optional[AttentionConfig] = None
25+
label_attention_config: Optional[LabelAttentionConfig] = None
1726

1827

1928
class TextEmbedder(nn.Module):
@@ -26,8 +35,9 @@ def __init__(self, text_embedder_config: TextEmbedderConfig):
2635
if isinstance(self.attention_config, dict):
2736
self.attention_config = AttentionConfig(**self.attention_config)
2837

29-
if self.attention_config is not None:
30-
self.attention_config.n_embd = text_embedder_config.embedding_dim
38+
self.enable_label_attention = text_embedder_config.label_attention_config is not None
39+
if self.enable_label_attention:
40+
self.label_attention_module = LabelAttentionClassifier(self.config)
3141

3242
self.vocab_size = text_embedder_config.vocab_size
3343
self.embedding_dim = text_embedder_config.embedding_dim
@@ -40,6 +50,7 @@ def __init__(self, text_embedder_config: TextEmbedderConfig):
4050
)
4151

4252
if self.attention_config is not None:
53+
self.attention_config.n_embd = text_embedder_config.embedding_dim
4354
self.transformer = nn.ModuleDict(
4455
{
4556
"h": nn.ModuleList(
@@ -105,8 +116,23 @@ def _init_weights(self, module):
105116
elif isinstance(module, nn.Embedding):
106117
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
107118

108-
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
109-
"""Converts input token IDs to their corresponding embeddings."""
119+
def forward(
120+
self,
121+
input_ids: torch.Tensor,
122+
attention_mask: torch.Tensor,
123+
return_label_attention_matrix: bool = False,
124+
) -> torch.Tensor:
125+
"""Converts input token IDs to their corresponding embeddings.
126+
127+
Args:
128+
input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized
129+
attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
130+
return_label_attention_matrix (bool): Whether to return the label attention matrix
131+
Returns:
132+
torch.Tensor: Text embeddings, shape (batch_size, embedding_dim) if self.enable_label_attention is False, else (batch_size, num_labels, embedding_dim)
133+
torch.Tensor: Label attention matrix, shape (batch_size, num_labels, seq_len) if return_label_attention_matrix is True, else None.
134+
Also None if label attention is disabled (even if return_label_attention_matrix is True)
135+
"""
110136

111137
encoded_text = input_ids # clearer name
112138
if encoded_text.dtype != torch.long:
@@ -138,14 +164,25 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torc
138164

139165
token_embeddings = norm(token_embeddings)
140166

141-
text_embedding = self._get_sentence_embedding(
142-
token_embeddings=token_embeddings, attention_mask=attention_mask
143-
)
167+
text_embedding, label_attention_matrix = self._get_sentence_embedding(
168+
token_embeddings=token_embeddings,
169+
attention_mask=attention_mask,
170+
return_label_attention_matrix=return_label_attention_matrix,
171+
).values()
144172

145-
return text_embedding
173+
if return_label_attention_matrix:
174+
return (
175+
text_embedding,
176+
label_attention_matrix,
177+
) # label_attention_matrix is None if label attention is disabled
178+
else:
179+
return text_embedding
146180

147181
def _get_sentence_embedding(
148-
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
182+
self,
183+
token_embeddings: torch.Tensor,
184+
attention_mask: torch.Tensor,
185+
return_label_attention_matrix: bool = False,
149186
) -> torch.Tensor:
150187
"""
151188
Compute sentence embedding from embedded tokens - "remove" second dimension.
@@ -163,7 +200,7 @@ def _get_sentence_embedding(
163200
# mask pad-tokens
164201

165202
if self.attention_config is not None:
166-
if self.attention_config.aggregation_method is not None:
203+
if self.attention_config.aggregation_method is not None: # default is "mean"
167204
if self.attention_config.aggregation_method == "first":
168205
return token_embeddings[:, 0, :]
169206
elif self.attention_config.aggregation_method == "last":
@@ -181,25 +218,29 @@ def _get_sentence_embedding(
181218

182219
assert self.attention_config is None or self.attention_config.aggregation_method == "mean"
183220

184-
mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1)
185-
masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
186-
187-
sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp(
188-
min=1.0
189-
) # avoid division by zero
190-
191-
sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0)
192-
193-
return sentence_embedding
194-
195-
def __call__(self, *args, **kwargs):
196-
out = super().__call__(*args, **kwargs)
197-
if out.dim() != 2:
198-
raise ValueError(
199-
f"Output of {self.__class__.__name__}.forward must be 2D "
200-
f"(got shape {tuple(out.shape)})"
221+
if self.enable_label_attention:
222+
label_attention_result = self.label_attention_module(
223+
token_embeddings, compute_attention_matrix=return_label_attention_matrix
201224
)
202-
return out
225+
sentence_embedding = label_attention_result[
226+
"sentence_embedding"
227+
] # (bs, n_labels, d_embed), so classifier needs to be a (d_embed, 1) matrix
228+
label_attention_matrix = label_attention_result["attention_matrix"]
229+
230+
else: # sentence embedding = mean of (non-pad) token embeddings
231+
mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1)
232+
masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
233+
sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp(
234+
min=1.0
235+
) # avoid division by zero
236+
237+
sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0)
238+
label_attention_matrix = None
239+
240+
return {
241+
"sentence_embedding": sentence_embedding,
242+
"label_attention_matrix": label_attention_matrix,
243+
}
203244

204245
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
205246
# autodetect the device from model embeddings
@@ -221,3 +262,79 @@ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=No
221262
) # add batch and head dims for later broadcasting
222263

223264
return cos, sin
265+
266+
267+
class LabelAttentionClassifier(nn.Module):
268+
"""
269+
A head for aggregating token embeddings into label-specific sentence embeddings using cross-attention mechanism.
270+
Labels are queries that attend over token embeddings (keys and values) to produce label-specific embeddings.
271+
272+
"""
273+
274+
def __init__(self, config: TextEmbedderConfig):
275+
super().__init__()
276+
277+
label_attention_config = config.label_attention_config
278+
self.embedding_dim = config.embedding_dim
279+
self.num_classes = label_attention_config.num_classes
280+
self.n_head = label_attention_config.n_head
281+
self.n_kv_head = label_attention_config.n_kv_head
282+
self.enable_gqa = (
283+
self.n_head != self.n_kv_head
284+
) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
285+
self.head_dim = self.embedding_dim // self.n_head
286+
287+
self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim)
288+
289+
self.c_q = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False)
290+
self.c_k = nn.Linear(self.embedding_dim, self.n_kv_head * self.head_dim, bias=False)
291+
self.c_v = nn.Linear(self.embedding_dim, self.n_kv_head * self.head_dim, bias=False)
292+
self.c_proj = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
293+
294+
def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = False):
295+
"""
296+
Args:
297+
token_embeddings (torch.Tensor), shape (batch, seq_len, d_model): Embedded tokens from the text input.
298+
compute_attention_matrix (bool): Whether to compute and return the attention matrix.
299+
Returns:
300+
dict: {
301+
"sentence_embedding": torch.Tensor, shape (batch, num_classes, d_model): Label-specific sentence embeddings.
302+
"attention_matrix": Optional[torch.Tensor], shape (batch, n_head, num_classes, seq_len): Attention weights if compute_attention_matrix is True, else None.
303+
}
304+
305+
"""
306+
B, T, C = token_embeddings.size()
307+
308+
# 1. Create label indices [0, 1, ..., C-1] for the whole batch
309+
label_indices = torch.arange(self.num_classes).expand(B, -1)
310+
311+
all_label_embeddings = self.label_embeds(
312+
label_indices
313+
) # Shape: [batch, num_classes, d_model]
314+
all_label_embeddings = norm(all_label_embeddings)
315+
316+
q = self.c_q(all_label_embeddings).view(B, self.num_classes, self.n_head, self.head_dim)
317+
k = self.c_k(token_embeddings).view(B, T, self.n_kv_head, self.head_dim)
318+
v = self.c_v(token_embeddings).view(B, T, self.n_kv_head, self.head_dim)
319+
320+
q, k = norm(q), norm(k) # QK norm
321+
q, k, v = (
322+
q.transpose(1, 2),
323+
k.transpose(1, 2),
324+
v.transpose(1, 2),
325+
) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
326+
327+
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=self.enable_gqa)
328+
329+
# Re-assemble the heads side by side and project back to residual stream
330+
y = y.transpose(1, 2).contiguous().view(B, self.num_classes, -1) # (bs, n_labels, d_model)
331+
y = self.c_proj(y)
332+
333+
attention_matrix = None
334+
if compute_attention_matrix:
335+
# size (B, n_head, n_labels, seq_len) - we let the user handle aggregation over heads if desired
336+
attention_matrix = torch.softmax(
337+
torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5), dim=-1
338+
)
339+
340+
return {"sentence_embedding": y, "attention_matrix": attention_matrix}

torchTextClassifiers/model/lightning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def validation_step(self, batch, batch_idx: int):
102102
targets = batch["labels"]
103103

104104
outputs = self.forward(batch)
105+
105106
loss = self.loss(outputs, targets)
106107
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
107108

torchTextClassifiers/model/model.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
"""FastText model components.
1+
"""TextClassification model components.
22
33
This module contains the PyTorch model, Lightning module, and dataset classes
4-
for FastText classification. Consolidates what was previously in pytorch_model.py,
4+
for TextClassification classification. Consolidates what was previously in pytorch_model.py,
55
lightning_module.py, and dataset.py.
66
"""
77

@@ -17,6 +17,7 @@
1717
ClassificationHead,
1818
TextEmbedder,
1919
)
20+
from torchTextClassifiers.model.components.attention import norm
2021

2122
logger = logging.getLogger(__name__)
2223

@@ -67,8 +68,6 @@ def __init__(
6768

6869
self._validate_component_connections()
6970

70-
self.num_classes = self.classification_head.num_classes
71-
7271
torch.nn.init.zeros_(self.classification_head.net.weight)
7372
if self.text_embedder is not None:
7473
self.text_embedder.init_weights()
@@ -98,6 +97,17 @@ def _check_text_categorical_connection(self, text_embedder, cat_var_net):
9897
raise ValueError(
9998
"Classification head input dimension does not match expected dimension from text embedder and categorical variable net."
10099
)
100+
if self.text_embedder.enable_label_attention:
101+
self.enable_label_attention = True
102+
if self.classification_head.num_classes != 1:
103+
raise ValueError(
104+
"Label attention is enabled. TextEmbedder outputs a (num_classes, embedding_dim) tensor, so the ClassificationHead should have an output dimension of 1."
105+
)
106+
# if enable_label_attention is True, label_attention_config exists - and contains num_classes necessarily
107+
self.num_classes = self.text_embedder.config.label_attention_config.num_classes
108+
else:
109+
self.enable_label_attention = False
110+
self.num_classes = self.classification_head.num_classes
101111
else:
102112
logger.warning(
103113
"⚠️ No text embedder provided; assuming input text is already embedded or vectorized. Take care that the classification head input dimension matches the input text dimension."
@@ -131,21 +141,29 @@ def forward(
131141
if self.categorical_variable_net:
132142
x_cat = self.categorical_variable_net(categorical_vars)
133143

144+
if self.enable_label_attention:
145+
# x_text is (batch_size, num_classes, embedding_dim)
146+
# x_cat is (batch_size, cat_embedding_dim)
147+
# We need to expand x_cat to (batch_size, num_classes, cat_embedding_dim)
148+
# x_cat will be appended to x_text along the last dimension for each class
149+
x_cat = x_cat.unsqueeze(1).expand(-1, self.num_classes, -1)
150+
134151
if (
135152
self.categorical_variable_net.forward_type
136153
== CategoricalForwardType.AVERAGE_AND_CONCAT
137154
or self.categorical_variable_net.forward_type
138155
== CategoricalForwardType.CONCATENATE_ALL
139156
):
140-
x_combined = torch.cat((x_text, x_cat), dim=1)
157+
x_combined = torch.cat((x_text, x_cat), dim=-1)
141158
else:
142159
assert (
143160
self.categorical_variable_net.forward_type == CategoricalForwardType.SUM_TO_TEXT
144161
)
162+
145163
x_combined = x_text + x_cat
146164
else:
147165
x_combined = x_text
148166

149-
logits = self.classification_head(x_combined)
167+
logits = self.classification_head(norm(x_combined)).squeeze(-1)
150168

151169
return logits

torchTextClassifiers/torchTextClassifiers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
CategoricalForwardType,
3030
CategoricalVariableNet,
3131
ClassificationHead,
32+
LabelAttentionConfig,
3233
TextEmbedder,
3334
TextEmbedderConfig,
3435
)
@@ -53,6 +54,7 @@ class ModelConfig:
5354
categorical_embedding_dims: Optional[Union[List[int], int]] = None
5455
num_classes: Optional[int] = None
5556
attention_config: Optional[AttentionConfig] = None
57+
label_attention_config: Optional[LabelAttentionConfig] = None
5658

5759
def to_dict(self) -> Dict[str, Any]:
5860
return asdict(self)
@@ -140,6 +142,7 @@ def __init__(
140142
self.embedding_dim = model_config.embedding_dim
141143
self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes
142144
self.num_classes = model_config.num_classes
145+
self.enable_label_attention = model_config.label_attention_config is not None
143146

144147
if self.tokenizer.output_vectorized:
145148
self.text_embedder = None
@@ -153,6 +156,7 @@ def __init__(
153156
embedding_dim=self.embedding_dim,
154157
padding_idx=tokenizer.padding_idx,
155158
attention_config=model_config.attention_config,
159+
label_attention_config=model_config.label_attention_config,
156160
)
157161
self.text_embedder = TextEmbedder(
158162
text_embedder_config=text_embedder_config,
@@ -174,7 +178,9 @@ def __init__(
174178

175179
self.classification_head = ClassificationHead(
176180
input_dim=classif_head_input_dim,
177-
num_classes=model_config.num_classes,
181+
num_classes=1
182+
if self.enable_label_attention
183+
else model_config.num_classes, # output dim is 1 when using label attention, because embeddings are (num_classes, embedding_dim)
178184
)
179185

180186
self.pytorch_model = TextClassificationModel(

0 commit comments

Comments
 (0)