Skip to content

Commit 519a32d

Browse files
feat(test): add test of all pipeline with different tokenizers
1 parent 4e2ffa5 commit 519a32d

1 file changed

Lines changed: 236 additions & 0 deletions

File tree

tests/test_pipeline.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
5+
from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
6+
from torchTextClassifiers.dataset import TextClassificationDataset
7+
from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule
8+
from torchTextClassifiers.model.components import (
9+
AttentionConfig,
10+
CategoricalVariableNet,
11+
ClassificationHead,
12+
TextEmbedder,
13+
TextEmbedderConfig,
14+
)
15+
from torchTextClassifiers.tokenizers import HuggingFaceTokenizer, NGramTokenizer, WordPieceTokenizer
16+
from torchTextClassifiers.utilities.plot_explainability import (
17+
map_attributions_to_char,
18+
map_attributions_to_word,
19+
plot_attributions_at_char,
20+
plot_attributions_at_word,
21+
)
22+
23+
24+
@pytest.fixture
25+
def sample_data():
26+
"""Fixture providing sample data for all tests."""
27+
sample_text_data = [
28+
"This is a positive example",
29+
"This is a negative example",
30+
"Another positive case",
31+
"Another negative case",
32+
"Good example here",
33+
"Bad example here",
34+
]
35+
categorical_data = np.array([[1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1]]).astype(int)
36+
labels = np.array([1, 0, 1, 0, 1, 5])
37+
38+
return sample_text_data, categorical_data, labels
39+
40+
41+
@pytest.fixture
42+
def model_params():
43+
"""Fixture providing common model parameters."""
44+
return {
45+
"embedding_dim": 96,
46+
"n_layers": 2,
47+
"n_head": 4,
48+
"num_classes": 10,
49+
"categorical_vocab_sizes": [2, 2],
50+
"categorical_embedding_dims": [4, 7],
51+
}
52+
53+
54+
def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params):
55+
"""Helper function to run the complete pipeline for a given tokenizer."""
56+
# Create dataset
57+
dataset = TextClassificationDataset(
58+
texts=sample_text_data,
59+
categorical_variables=categorical_data.tolist(),
60+
tokenizer=tokenizer,
61+
labels=None,
62+
)
63+
64+
dataloader = dataset.create_dataloader(batch_size=4)
65+
batch = next(iter(dataloader))
66+
67+
# Get tokenizer parameters
68+
vocab_size = tokenizer.vocab_size
69+
padding_idx = tokenizer.padding_idx
70+
sequence_len = tokenizer.output_dim
71+
72+
# Create attention config
73+
attention_config = AttentionConfig(
74+
n_layers=model_params["n_layers"],
75+
n_head=model_params["n_head"],
76+
n_kv_head=model_params["n_head"],
77+
sequence_len=sequence_len,
78+
)
79+
80+
# Create text embedder
81+
text_embedder_config = TextEmbedderConfig(
82+
vocab_size=vocab_size,
83+
embedding_dim=model_params["embedding_dim"],
84+
padding_idx=padding_idx,
85+
attention_config=attention_config,
86+
)
87+
88+
text_embedder = TextEmbedder(text_embedder_config=text_embedder_config)
89+
text_embedder.init_weights()
90+
91+
# Create categorical variable net
92+
categorical_var_net = CategoricalVariableNet(
93+
categorical_vocabulary_sizes=model_params["categorical_vocab_sizes"],
94+
categorical_embedding_dims=model_params["categorical_embedding_dims"],
95+
)
96+
97+
# Create classification head
98+
expected_input_dim = model_params["embedding_dim"] + categorical_var_net.output_dim
99+
classification_head = ClassificationHead(
100+
input_dim=expected_input_dim,
101+
num_classes=model_params["num_classes"],
102+
)
103+
104+
# Create model
105+
model = TextClassificationModel(
106+
text_embedder=text_embedder,
107+
categorical_variable_net=categorical_var_net,
108+
classification_head=classification_head,
109+
)
110+
111+
# Test forward pass
112+
model(**batch)
113+
114+
# Create module
115+
module = TextClassificationModule(
116+
model=model,
117+
loss=torch.nn.CrossEntropyLoss(),
118+
optimizer=torch.optim.Adam,
119+
optimizer_params={"lr": 1e-3},
120+
scheduler=None,
121+
scheduler_params=None,
122+
scheduler_interval="epoch",
123+
)
124+
125+
# Test prediction
126+
module.predict_step(batch)
127+
128+
# Prepare data for training
129+
X = np.column_stack([sample_text_data, categorical_data])
130+
Y = labels
131+
132+
# Create model config
133+
model_config = ModelConfig(
134+
embedding_dim=model_params["embedding_dim"],
135+
categorical_vocabulary_sizes=model_params["categorical_vocab_sizes"],
136+
categorical_embedding_dims=model_params["categorical_embedding_dims"],
137+
num_classes=model_params["num_classes"],
138+
attention_config=attention_config,
139+
)
140+
141+
# Create training config
142+
training_config = TrainingConfig(
143+
lr=1e-3,
144+
batch_size=4,
145+
num_epochs=1,
146+
)
147+
148+
# Create classifier
149+
ttc = torchTextClassifiers(
150+
tokenizer=tokenizer,
151+
model_config=model_config,
152+
)
153+
154+
# Train
155+
ttc.train(
156+
X_train=X,
157+
y_train=Y,
158+
X_val=X,
159+
y_val=Y,
160+
training_config=training_config,
161+
)
162+
163+
# Predict with explanations
164+
top_k = 5
165+
predictions = ttc.predict(X, top_k=top_k, explain=True)
166+
167+
# Test explainability functions
168+
text_idx = 0
169+
text = sample_text_data[text_idx]
170+
offsets = predictions["offset_mapping"][text_idx]
171+
attributions = predictions["attributions"][text_idx]
172+
word_ids = predictions["word_ids"][text_idx]
173+
174+
word_attributions = map_attributions_to_word(attributions, word_ids)
175+
char_attributions = map_attributions_to_char(attributions, offsets, text)
176+
177+
# Note: We're not actually plotting in tests, just calling the functions
178+
# to ensure they don't raise errors
179+
plot_attributions_at_char(text, char_attributions)
180+
plot_attributions_at_word(text, word_attributions)
181+
182+
183+
def test_wordpiece_tokenizer(sample_data, model_params):
184+
"""Test the full pipeline with WordPieceTokenizer."""
185+
sample_text_data, categorical_data, labels = sample_data
186+
187+
vocab_size = 100
188+
tokenizer = WordPieceTokenizer(vocab_size, output_dim=50)
189+
tokenizer.train(sample_text_data)
190+
191+
# Check tokenizer works
192+
result = tokenizer.tokenize(sample_text_data)
193+
assert result.input_ids.shape[0] == len(sample_text_data)
194+
195+
# Run full pipeline
196+
run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params)
197+
198+
199+
def test_huggingface_tokenizer(sample_data, model_params):
200+
"""Test the full pipeline with HuggingFaceTokenizer."""
201+
sample_text_data, categorical_data, labels = sample_data
202+
203+
tokenizer = HuggingFaceTokenizer.load_from_pretrained(
204+
"google-bert/bert-base-uncased", output_dim=50
205+
)
206+
207+
# Check tokenizer works
208+
result = tokenizer.tokenize(sample_text_data)
209+
assert result.input_ids.shape[0] == len(sample_text_data)
210+
211+
# Run full pipeline
212+
run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params)
213+
214+
215+
def test_ngram_tokenizer(sample_data, model_params):
216+
"""Test the full pipeline with NGramTokenizer."""
217+
sample_text_data, categorical_data, labels = sample_data
218+
219+
tokenizer = NGramTokenizer(
220+
min_count=3, min_n=2, max_n=5, num_tokens=100, len_word_ngrams=2, output_dim=76
221+
)
222+
tokenizer.train(sample_text_data)
223+
224+
# Check tokenizer works
225+
result = tokenizer.tokenize(
226+
sample_text_data[0], return_offsets_mapping=True, return_word_ids=True
227+
)
228+
assert result.input_ids is not None
229+
230+
# Check batch decode
231+
batch_result = tokenizer.tokenize(sample_text_data)
232+
decoded = tokenizer.batch_decode(batch_result.input_ids.tolist())
233+
assert len(decoded) == len(sample_text_data)
234+
235+
# Run full pipeline
236+
run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params)

0 commit comments

Comments
 (0)