Skip to content

Commit 997a4f8

Browse files
chore: add new tests, by components
1 parent 997181b commit 997a4f8

1 file changed

Lines changed: 243 additions & 0 deletions

File tree

tests/test_components.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import pytest
2+
import torch
3+
4+
from torchTextClassifiers.model.components import (
5+
AttentionConfig,
6+
CategoricalForwardType,
7+
CategoricalVariableNet,
8+
ClassificationHead,
9+
LabelAttentionConfig,
10+
SentenceEmbedder,
11+
SentenceEmbedderConfig,
12+
TokenEmbedder,
13+
TokenEmbedderConfig,
14+
)
15+
from torchTextClassifiers.model.model import TextClassificationModel
16+
17+
BATCH = 4
18+
SEQ_LEN = 20
19+
EMB_DIM = 16 # divisible by 4 (n_head) and head_dim=4 is even (rotary)
20+
VOCAB_SIZE = 100
21+
PADDING_IDX = 0
22+
NUM_CLASSES = 3
23+
24+
25+
@pytest.fixture
26+
def input_ids():
27+
ids = torch.randint(1, VOCAB_SIZE, (BATCH, SEQ_LEN))
28+
ids[:, -2:] = PADDING_IDX
29+
return ids
30+
31+
32+
@pytest.fixture
33+
def attention_mask(input_ids):
34+
return (input_ids != PADDING_IDX).long()
35+
36+
37+
@pytest.fixture
38+
def token_embeddings():
39+
return torch.randn(BATCH, SEQ_LEN, EMB_DIM)
40+
41+
42+
class TestTokenEmbedder:
43+
def test_no_attention(self, input_ids, attention_mask):
44+
embedder = TokenEmbedder(
45+
TokenEmbedderConfig(
46+
vocab_size=VOCAB_SIZE, embedding_dim=EMB_DIM, padding_idx=PADDING_IDX
47+
)
48+
)
49+
out = embedder(input_ids, attention_mask)
50+
assert out["token_embeddings"].shape == (BATCH, SEQ_LEN, EMB_DIM)
51+
assert out["attention_mask"].shape == (BATCH, SEQ_LEN)
52+
53+
def test_with_attention(self, input_ids, attention_mask):
54+
embedder = TokenEmbedder(
55+
TokenEmbedderConfig(
56+
vocab_size=VOCAB_SIZE,
57+
embedding_dim=EMB_DIM,
58+
padding_idx=PADDING_IDX,
59+
attention_config=AttentionConfig(
60+
n_layers=2, n_head=4, n_kv_head=4, positional_encoding=False
61+
),
62+
)
63+
)
64+
out = embedder(input_ids, attention_mask)
65+
assert out["token_embeddings"].shape == (BATCH, SEQ_LEN, EMB_DIM)
66+
67+
def test_with_rotary_positional_encoding(self, input_ids, attention_mask):
68+
embedder = TokenEmbedder(
69+
TokenEmbedderConfig(
70+
vocab_size=VOCAB_SIZE,
71+
embedding_dim=EMB_DIM,
72+
padding_idx=PADDING_IDX,
73+
attention_config=AttentionConfig(
74+
n_layers=1,
75+
n_head=4,
76+
n_kv_head=4,
77+
positional_encoding=True,
78+
sequence_len=SEQ_LEN,
79+
),
80+
)
81+
)
82+
out = embedder(input_ids, attention_mask)
83+
assert out["token_embeddings"].shape == (BATCH, SEQ_LEN, EMB_DIM)
84+
85+
def test_shape_mismatch_raises(self):
86+
embedder = TokenEmbedder(
87+
TokenEmbedderConfig(
88+
vocab_size=VOCAB_SIZE, embedding_dim=EMB_DIM, padding_idx=PADDING_IDX
89+
)
90+
)
91+
with pytest.raises(ValueError):
92+
embedder(
93+
torch.randint(1, VOCAB_SIZE, (BATCH, SEQ_LEN)),
94+
torch.ones(BATCH, SEQ_LEN + 1, dtype=torch.long),
95+
)
96+
97+
98+
class TestSentenceEmbedder:
99+
@pytest.mark.parametrize("method", ["mean", "first", "last"])
100+
def test_aggregation_methods(self, token_embeddings, attention_mask, method):
101+
embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method=method))
102+
out = embedder(token_embeddings, attention_mask)
103+
assert out["sentence_embedding"].shape == (BATCH, EMB_DIM)
104+
assert out["label_attention_matrix"] is None
105+
106+
def test_label_attention_output_shape(self, token_embeddings, attention_mask):
107+
embedder = SentenceEmbedder(
108+
SentenceEmbedderConfig(
109+
aggregation_method=None,
110+
label_attention_config=LabelAttentionConfig(
111+
n_head=4, num_classes=NUM_CLASSES, embedding_dim=EMB_DIM
112+
),
113+
)
114+
)
115+
out = embedder(token_embeddings, attention_mask)
116+
assert out["sentence_embedding"].shape == (BATCH, NUM_CLASSES, EMB_DIM)
117+
assert out["label_attention_matrix"] is None
118+
119+
def test_label_attention_matrix_returned(self, token_embeddings, attention_mask):
120+
embedder = SentenceEmbedder(
121+
SentenceEmbedderConfig(
122+
aggregation_method=None,
123+
label_attention_config=LabelAttentionConfig(
124+
n_head=4, num_classes=NUM_CLASSES, embedding_dim=EMB_DIM
125+
),
126+
)
127+
)
128+
out = embedder(token_embeddings, attention_mask, return_label_attention_matrix=True)
129+
assert out["label_attention_matrix"].shape == (BATCH, 4, NUM_CLASSES, SEQ_LEN)
130+
131+
def test_none_aggregation_without_label_attention_raises(self):
132+
with pytest.raises(ValueError):
133+
SentenceEmbedder(SentenceEmbedderConfig(aggregation_method=None))
134+
135+
136+
class TestCategoricalVariableNet:
137+
def test_concatenate_all(self):
138+
net = CategoricalVariableNet(
139+
categorical_vocabulary_sizes=[4, 5],
140+
categorical_embedding_dims=[3, 6],
141+
)
142+
assert net.forward_type == CategoricalForwardType.CONCATENATE_ALL
143+
assert net.output_dim == 9
144+
out = net(torch.randint(0, 3, (BATCH, 2)))
145+
assert out.shape == (BATCH, 9)
146+
147+
def test_average_and_concat(self):
148+
net = CategoricalVariableNet(
149+
categorical_vocabulary_sizes=[4, 5],
150+
categorical_embedding_dims=8,
151+
)
152+
assert net.forward_type == CategoricalForwardType.AVERAGE_AND_CONCAT
153+
assert net.output_dim == 8
154+
out = net(torch.randint(0, 3, (BATCH, 2)))
155+
assert out.shape == (BATCH, 8)
156+
157+
def test_sum_to_text(self):
158+
net = CategoricalVariableNet(
159+
categorical_vocabulary_sizes=[4, 5],
160+
categorical_embedding_dims=None,
161+
text_embedding_dim=EMB_DIM,
162+
)
163+
assert net.forward_type == CategoricalForwardType.SUM_TO_TEXT
164+
assert net.output_dim == EMB_DIM
165+
out = net(torch.randint(0, 3, (BATCH, 2)))
166+
assert out.shape == (BATCH, EMB_DIM)
167+
168+
def test_out_of_range_value_raises(self):
169+
net = CategoricalVariableNet(
170+
categorical_vocabulary_sizes=[4, 5],
171+
categorical_embedding_dims=[3, 6],
172+
)
173+
with pytest.raises(ValueError):
174+
net(torch.tensor([[10, 1]] * BATCH)) # first feature value 10 >= vocab 4
175+
176+
177+
class TestTextClassificationModel:
178+
def _token_embedder(self):
179+
return TokenEmbedder(
180+
TokenEmbedderConfig(
181+
vocab_size=VOCAB_SIZE, embedding_dim=EMB_DIM, padding_idx=PADDING_IDX
182+
)
183+
)
184+
185+
def _sentence_embedder(self, label_attention=False):
186+
if label_attention:
187+
return SentenceEmbedder(
188+
SentenceEmbedderConfig(
189+
aggregation_method=None,
190+
label_attention_config=LabelAttentionConfig(
191+
n_head=4, num_classes=NUM_CLASSES, embedding_dim=EMB_DIM
192+
),
193+
)
194+
)
195+
return SentenceEmbedder(SentenceEmbedderConfig(aggregation_method="mean"))
196+
197+
def test_text_only(self, input_ids, attention_mask):
198+
model = TextClassificationModel(
199+
token_embedder=self._token_embedder(),
200+
sentence_embedder=self._sentence_embedder(),
201+
classification_head=ClassificationHead(input_dim=EMB_DIM, num_classes=NUM_CLASSES),
202+
)
203+
logits = model(input_ids, attention_mask, torch.empty(BATCH, 0))
204+
assert logits.shape == (BATCH, NUM_CLASSES)
205+
206+
def test_text_and_categorical(self, input_ids, attention_mask):
207+
cat_net = CategoricalVariableNet(
208+
categorical_vocabulary_sizes=[4, 5],
209+
categorical_embedding_dims=[3, 6],
210+
)
211+
model = TextClassificationModel(
212+
token_embedder=self._token_embedder(),
213+
sentence_embedder=self._sentence_embedder(),
214+
categorical_variable_net=cat_net,
215+
classification_head=ClassificationHead(
216+
input_dim=EMB_DIM + cat_net.output_dim, num_classes=NUM_CLASSES
217+
),
218+
)
219+
logits = model(input_ids, attention_mask, torch.randint(0, 3, (BATCH, 2)))
220+
assert logits.shape == (BATCH, NUM_CLASSES)
221+
222+
def test_label_attention_logits_and_matrix(self, input_ids, attention_mask):
223+
model = TextClassificationModel(
224+
token_embedder=self._token_embedder(),
225+
sentence_embedder=self._sentence_embedder(label_attention=True),
226+
classification_head=ClassificationHead(input_dim=EMB_DIM, num_classes=1),
227+
)
228+
result = model(
229+
input_ids,
230+
attention_mask,
231+
torch.empty(BATCH, 0),
232+
return_label_attention_matrix=True,
233+
)
234+
assert result["logits"].shape == (BATCH, NUM_CLASSES)
235+
assert result["label_attention_matrix"].shape == (BATCH, 4, NUM_CLASSES, SEQ_LEN)
236+
237+
def test_missing_sentence_embedder_raises(self):
238+
with pytest.raises(ValueError):
239+
TextClassificationModel(
240+
token_embedder=self._token_embedder(),
241+
sentence_embedder=None,
242+
classification_head=ClassificationHead(input_dim=EMB_DIM, num_classes=NUM_CLASSES),
243+
)

0 commit comments

Comments
 (0)