-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathtest_pipeline.py
More file actions
344 lines (287 loc) · 11.1 KB
/
test_pipeline.py
File metadata and controls
344 lines (287 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
import numpy as np
import pytest
import torch
from sklearn.preprocessing import LabelEncoder
from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
from torchTextClassifiers.dataset import TextClassificationDataset
from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule
from torchTextClassifiers.model.components import (
AttentionConfig,
CategoricalVariableNet,
ClassificationHead,
LabelAttentionConfig,
SentenceEmbedder,
SentenceEmbedderConfig,
TokenEmbedder,
TokenEmbedderConfig,
)
from torchTextClassifiers.tokenizers import NGramTokenizer
from torchTextClassifiers.value_encoder import DictEncoder, ValueEncoder
try:
from torchTextClassifiers.tokenizers import HuggingFaceTokenizer, WordPieceTokenizer
except ImportError:
pass
from torchTextClassifiers.utilities.plot_explainability import (
map_attributions_to_char,
map_attributions_to_word,
plot_attributions_at_char,
plot_attributions_at_word,
)
@pytest.fixture
def sample_data():
"""Fixture providing sample data for all tests."""
sample_text_data = [
"This is a positive example",
"This is a negative example",
"Another positive case",
"Another negative case",
"Good example here",
"Bad example here",
]
# String categorical variables — two features, two unique values each
categorical_data = np.array(
[
["cat", "red"],
["dog", "blue"],
["cat", "red"],
["dog", "blue"],
["cat", "red"],
["dog", "blue"],
]
)
# String labels
labels = np.array(["positive", "negative", "positive", "negative", "positive", "negative"])
return sample_text_data, categorical_data, labels
@pytest.fixture
def model_params():
"""Fixture providing common model parameters (class count and vocab sizes are
derived from data at runtime inside run_full_pipeline)."""
return {
"embedding_dim": 96,
"n_layers": 2,
"n_head": 4,
"categorical_embedding_dims": [4, 7],
}
def run_full_pipeline(
tokenizer,
sample_text_data,
categorical_data,
labels,
model_params,
label_attention_enabled: bool = False,
):
"""Helper function to run the complete pipeline for a given tokenizer."""
# --- Encode categorical variables (string → int) ---
n_features = categorical_data.shape[1]
encoders = {
str(i): DictEncoder(
{v: j for j, v in enumerate(sorted(set(categorical_data[:, i].tolist())))}
)
for i in range(n_features)
}
# --- Encode string labels to contiguous integers ---
label_encoder = LabelEncoder()
label_encoder.fit(labels)
num_classes = len(label_encoder.classes_)
value_encoder = ValueEncoder(label_encoder, encoders)
encoded_categorical = value_encoder.transform(categorical_data)
vocab_sizes = value_encoder.vocabulary_sizes
# --- Direct component test: dataset with already-encoded integers ---
dataset = TextClassificationDataset(
texts=sample_text_data,
categorical_variables=encoded_categorical.tolist(),
tokenizer=tokenizer,
labels=None,
)
dataloader = dataset.create_dataloader(batch_size=4)
batch = next(iter(dataloader))
# Get tokenizer parameters
vocab_size = tokenizer.vocab_size
padding_idx = tokenizer.padding_idx
sequence_len = tokenizer.output_dim
# Create attention config
attention_config = AttentionConfig(
n_layers=model_params["n_layers"],
n_head=model_params["n_head"],
n_kv_head=model_params["n_head"],
sequence_len=sequence_len,
)
# Create token embedder
token_embedder_config = TokenEmbedderConfig(
vocab_size=vocab_size,
embedding_dim=model_params["embedding_dim"],
padding_idx=padding_idx,
attention_config=attention_config,
)
token_embedder = TokenEmbedder(token_embedder_config=token_embedder_config)
# Create sentence embedder
sentence_embedder_config = SentenceEmbedderConfig(
label_attention_config=(
LabelAttentionConfig(
n_head=attention_config.n_head,
num_classes=num_classes,
embedding_dim=model_params["embedding_dim"],
)
if label_attention_enabled
else None
),
aggregation_method=None if label_attention_enabled else "mean",
)
sentence_embedder = SentenceEmbedder(sentence_embedder_config=sentence_embedder_config)
# Create categorical variable net (vocab sizes from fitted encoder)
categorical_var_net = CategoricalVariableNet(
categorical_vocabulary_sizes=vocab_sizes,
categorical_embedding_dims=model_params["categorical_embedding_dims"],
)
# Create classification head
expected_input_dim = model_params["embedding_dim"] + categorical_var_net.output_dim
classification_head = ClassificationHead(
input_dim=expected_input_dim,
num_classes=num_classes if not label_attention_enabled else 1,
)
# Create model
model = TextClassificationModel(
token_embedder=token_embedder,
sentence_embedder=sentence_embedder,
categorical_variable_net=categorical_var_net,
classification_head=classification_head,
)
# Test forward pass
model(**batch)
# Create module
module = TextClassificationModule(
model=model,
loss=torch.nn.CrossEntropyLoss(),
optimizer=torch.optim.Adam,
optimizer_params={"lr": 1e-3},
scheduler=None,
scheduler_params=None,
scheduler_interval="epoch",
)
# Test prediction
module.predict_step(batch)
# --- Wrapper pipeline with string categorical data ---
# X keeps categorical columns as raw strings; the wrapper encoder handles them.
X = np.column_stack([sample_text_data, categorical_data])
Y = labels # raw string labels (encoded by value_encoder)
# Create model config (vocab sizes and num_classes come from the encoders)
model_config = ModelConfig(
embedding_dim=model_params["embedding_dim"],
categorical_vocabulary_sizes=vocab_sizes,
categorical_embedding_dims=model_params["categorical_embedding_dims"],
num_classes=num_classes,
attention_config=attention_config,
n_heads_label_attention=attention_config.n_head,
)
training_config = TrainingConfig(
lr=1e-3,
batch_size=4,
num_epochs=1,
)
# Create classifier — pass the fitted value encoder
ttc = torchTextClassifiers(
tokenizer=tokenizer,
model_config=model_config,
value_encoder=value_encoder,
)
# Train with raw string categorical data
ttc.train(
X_train=X,
y_train=Y,
X_val=X,
y_val=Y,
training_config=training_config,
)
assert ttc.save_path is not None
ttc.load(ttc.save_path) # test load (encoder is also saved/restored)
# Predict with explanations
top_k = min(5, num_classes)
predictions = ttc.predict(
X,
top_k=top_k,
explain_with_label_attention=label_attention_enabled,
explain_with_captum=True,
)
# Test label attention assertions
if label_attention_enabled:
assert (
predictions["label_attention_attributions"] is not None
), "Label attention attributions should not be None when label_attention_enabled is True"
label_attention_attributions = predictions["label_attention_attributions"]
expected_shape = (
len(sample_text_data), # batch_size
model_params["n_head"], # n_head
num_classes, # num_classes (derived from label encoder)
tokenizer.output_dim, # seq_len
)
assert label_attention_attributions.shape == expected_shape, (
f"Label attention attributions shape mismatch. "
f"Expected {expected_shape}, got {label_attention_attributions.shape}"
)
else:
assert (
predictions.get("label_attention_attributions") is None
), "Label attention attributions should be None when label_attention_enabled is False"
# Test explainability functions
text_idx = 0
text = sample_text_data[text_idx]
offsets = predictions["offset_mapping"][text_idx]
attributions = predictions["captum_attributions"][text_idx]
word_ids = predictions["word_ids"][text_idx]
words, word_attributions = map_attributions_to_word(attributions, text, word_ids, offsets)
char_attributions = map_attributions_to_char(attributions, offsets, text)
plot_attributions_at_char(text, char_attributions)
plot_attributions_at_word(
text=text,
words=words.values(),
attributions_per_word=word_attributions,
)
def test_wordpiece_tokenizer(sample_data, model_params):
"""Test the full pipeline with WordPieceTokenizer."""
sample_text_data, categorical_data, labels = sample_data
vocab_size = 100
tokenizer = WordPieceTokenizer(vocab_size, output_dim=50)
tokenizer.train(sample_text_data)
result = tokenizer.tokenize(sample_text_data)
assert result.input_ids.shape[0] == len(sample_text_data)
run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params)
def test_huggingface_tokenizer(sample_data, model_params):
"""Test the full pipeline with HuggingFaceTokenizer."""
sample_text_data, categorical_data, labels = sample_data
tokenizer = HuggingFaceTokenizer.load_from_pretrained(
"google-bert/bert-base-uncased", output_dim=50
)
result = tokenizer.tokenize(sample_text_data)
assert result.input_ids.shape[0] == len(sample_text_data)
run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params)
def test_ngram_tokenizer(sample_data, model_params):
"""Test the full pipeline with NGramTokenizer."""
sample_text_data, categorical_data, labels = sample_data
tokenizer = NGramTokenizer(
min_count=3, min_n=2, max_n=5, num_tokens=100, len_word_ngrams=2, output_dim=76
)
tokenizer.train(sample_text_data)
result = tokenizer.tokenize(
sample_text_data[0], return_offsets_mapping=True, return_word_ids=True
)
assert result.input_ids is not None
batch_result = tokenizer.tokenize(sample_text_data)
decoded = tokenizer.batch_decode(batch_result.input_ids.tolist())
assert len(decoded) == len(sample_text_data)
run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params)
def test_label_attention_enabled(sample_data, model_params):
"""Test the full pipeline with label attention enabled (using WordPieceTokenizer)."""
sample_text_data, categorical_data, labels = sample_data
vocab_size = 100
tokenizer = WordPieceTokenizer(vocab_size, output_dim=50)
tokenizer.train(sample_text_data)
result = tokenizer.tokenize(sample_text_data)
assert result.input_ids.shape[0] == len(sample_text_data)
run_full_pipeline(
tokenizer,
sample_text_data,
categorical_data,
labels,
model_params,
label_attention_enabled=True,
)