Skip to content

Commit aaa1a35

Browse files
feat: create SequenceEmbedder as a new component
- moved LabelAttention as one SentenceEmbedder type - adapt tests
1 parent 22d4507 commit aaa1a35

7 files changed

Lines changed: 48 additions & 240 deletions

File tree

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def mock_tokenizer():
6363
tokenizer = Mock()
6464
tokenizer.vocab_size = 1000
6565
tokenizer.padding_idx = 1
66+
tokenizer.output_vectorized = False
6667
tokenizer.tokenize = Mock(
6768
return_value={
6869
"input_ids": np.array([[1, 2, 3], [4, 5, 6]]),

tests/test_pipeline.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
CategoricalVariableNet,
1212
ClassificationHead,
1313
LabelAttentionConfig,
14-
TextEmbedder,
15-
TextEmbedderConfig,
14+
SentenceEmbedder,
15+
SentenceEmbedderConfig,
16+
TokenEmbedder,
17+
TokenEmbedderConfig,
1618
)
1719
from torchTextClassifiers.tokenizers import NGramTokenizer
1820
from torchTextClassifiers.value_encoder import DictEncoder, ValueEncoder
@@ -122,24 +124,29 @@ def run_full_pipeline(
122124
sequence_len=sequence_len,
123125
)
124126

125-
# Create text embedder
126-
text_embedder_config = TextEmbedderConfig(
127+
# Create token embedder
128+
token_embedder_config = TokenEmbedderConfig(
127129
vocab_size=vocab_size,
128130
embedding_dim=model_params["embedding_dim"],
129131
padding_idx=padding_idx,
130132
attention_config=attention_config,
133+
)
134+
token_embedder = TokenEmbedder(token_embedder_config=token_embedder_config)
135+
136+
# Create sentence embedder
137+
sentence_embedder_config = SentenceEmbedderConfig(
131138
label_attention_config=(
132139
LabelAttentionConfig(
133140
n_head=attention_config.n_head,
134141
num_classes=num_classes,
142+
embedding_dim=model_params["embedding_dim"],
135143
)
136144
if label_attention_enabled
137145
else None
138146
),
147+
aggregation_method=None if label_attention_enabled else "mean",
139148
)
140-
141-
text_embedder = TextEmbedder(text_embedder_config=text_embedder_config)
142-
text_embedder.init_weights()
149+
sentence_embedder = SentenceEmbedder(sentence_embedder_config=sentence_embedder_config)
143150

144151
# Create categorical variable net (vocab sizes from fitted encoder)
145152
categorical_var_net = CategoricalVariableNet(
@@ -156,7 +163,8 @@ def run_full_pipeline(
156163

157164
# Create model
158165
model = TextClassificationModel(
159-
text_embedder=text_embedder,
166+
token_embedder=token_embedder,
167+
sentence_embedder=sentence_embedder,
160168
categorical_variable_net=categorical_var_net,
161169
classification_head=classification_head,
162170
)

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class TokenEmbedder(nn.Module):
4646
4747
"""
4848

49+
cos: torch.Tensor
50+
sin: torch.Tensor
51+
4952
def __init__(self, token_embedder_config: TokenEmbedderConfig):
5053
super().__init__()
5154

@@ -112,10 +115,7 @@ def init_weights(self):
112115
for block in self.transformer.h:
113116
torch.nn.init.zeros_(block.mlp.c_proj.weight)
114117
torch.nn.init.zeros_(block.attn.c_proj.weight)
115-
# init the rotary embeddings
116-
head_dim = self.attention_config.n_embd // self.attention_config.n_head
117-
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
118-
self.cos, self.sin = cos, sin
118+
119119
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
120120
if self.embedding_layer.weight.device.type == "cuda":
121121
self.embedding_layer.to(dtype=torch.bfloat16)
@@ -205,9 +205,7 @@ def __init__(self, label_attention_config: LabelAttentionConfig):
205205
super().__init__()
206206

207207
if label_attention_config is None:
208-
raise ValueError(
209-
"label_attention_config must be provided to use LabelAttention."
210-
)
208+
raise ValueError("label_attention_config must be provided to use LabelAttention.")
211209

212210
self.label_attention_config = label_attention_config
213211
self.num_classes = label_attention_config.num_classes
@@ -311,7 +309,7 @@ def forward(
311309

312310
class SentenceEmbedder(nn.Module):
313311
def __init__(self, sentence_embedder_config: SentenceEmbedderConfig):
314-
312+
super().__init__()
315313
"""
316314
A module to aggregate token embeddings.
317315
@@ -322,7 +320,7 @@ def __init__(self, sentence_embedder_config: SentenceEmbedderConfig):
322320
- aggregation_method=None: in that case you need to provide a label attention
323321
"""
324322

325-
self.config
323+
self.config = sentence_embedder_config
326324
self.label_attention_config = sentence_embedder_config.label_attention_config
327325
self.aggregation_method = sentence_embedder_config.aggregation_method
328326

torchTextClassifiers/model/model.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
CategoricalForwardType,
1616
CategoricalVariableNet,
1717
ClassificationHead,
18+
SentenceEmbedder,
1819
TokenEmbedder,
19-
SentenceEmbedder
2020
)
2121
from torchTextClassifiers.model.components.attention import norm
2222

@@ -53,20 +53,22 @@ def __init__(
5353
classification_head (ClassificationHead): The classification head module.
5454
token_embedder (Optional[TextEmbedder]): The text embedding module.
5555
If not provided, assumes that input text is already embedded (as tensors) and directly passed to the classification head.
56-
sentence_embedder:
56+
sentence_embedder:
5757
categorical_variable_net (Optional[CategoricalVariableNet]): The categorical variable network module.
5858
If not provided, assumes no categorical variables are used.
5959
"""
6060
super().__init__()
6161

6262
self.token_embedder = token_embedder
63+
self.sentence_embedder = sentence_embedder
6364

6465
if self.token_embedder is not None:
6566
self.token_embedder.init_weights()
6667
if self.sentence_embedder is None:
67-
raise ValueError("You have provided a TokenEmbedder but no SentenceEmbedder: please provide one.")
68-
else:
69-
self.sentence_embedder = sentence_embedder
68+
raise ValueError(
69+
"You have provided a TokenEmbedder but no SentenceEmbedder: please provide one."
70+
)
71+
7072
self.categorical_variable_net = categorical_variable_net
7173
if not self.categorical_variable_net:
7274
logger.info("🔹 No categorical variable network provided; using only text embeddings.")
@@ -76,7 +78,6 @@ def __init__(
7678
self._validate_component_connections()
7779

7880
torch.nn.init.zeros_(self.classification_head.net.weight)
79-
8081

8182
def _validate_component_connections(self):
8283
def _check_text_categorical_connection(self, token_embedder, cat_var_net):
@@ -158,7 +159,9 @@ def forward(
158159
attention_mask=attention_mask,
159160
)
160161
x_token = token_embed_output["token_embeddings"]
161-
sentence_embedding_output = self.sentence_embedder(x_token, attention_mask, return_label_attention_matrix=return_label_attention_matrix)
162+
sentence_embedding_output = self.sentence_embedder(
163+
x_token, attention_mask, return_label_attention_matrix=return_label_attention_matrix
164+
)
162165
x_text = sentence_embedding_output["sentence_embedding"]
163166
if return_label_attention_matrix:
164167
label_attention_matrix = sentence_embedding_output["label_attention_matrix"]

torchTextClassifiers/test copy.py

Lines changed: 0 additions & 107 deletions
This file was deleted.

torchTextClassifiers/test.py

Lines changed: 0 additions & 95 deletions
This file was deleted.

0 commit comments

Comments
 (0)