Skip to content

Commit 9d9bb42

Browse files
Merge branch 'main' into renovate/astral-sh-uv-0.x
2 parents 0824cbd + e71ab0f commit 9d9bb42

7 files changed

Lines changed: 408 additions & 75 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,4 @@ example_files/
183183
_site/
184184
.quarto/
185185
**/*.quarto_ipynb
186+
my_ttc/

tests/test_pipeline.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AttentionConfig,
1010
CategoricalVariableNet,
1111
ClassificationHead,
12+
LabelAttentionConfig,
1213
TextEmbedder,
1314
TextEmbedderConfig,
1415
)
@@ -51,7 +52,14 @@ def model_params():
5152
}
5253

5354

54-
def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params):
55+
def run_full_pipeline(
56+
tokenizer,
57+
sample_text_data,
58+
categorical_data,
59+
labels,
60+
model_params,
61+
label_attention_enabled: bool = False,
62+
):
5563
"""Helper function to run the complete pipeline for a given tokenizer."""
5664
# Create dataset
5765
dataset = TextClassificationDataset(
@@ -83,6 +91,14 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
8391
embedding_dim=model_params["embedding_dim"],
8492
padding_idx=padding_idx,
8593
attention_config=attention_config,
94+
label_attention_config=(
95+
LabelAttentionConfig(
96+
n_head=attention_config.n_head,
97+
num_classes=model_params["num_classes"],
98+
)
99+
if label_attention_enabled
100+
else None
101+
),
86102
)
87103

88104
text_embedder = TextEmbedder(text_embedder_config=text_embedder_config)
@@ -98,7 +114,7 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
98114
expected_input_dim = model_params["embedding_dim"] + categorical_var_net.output_dim
99115
classification_head = ClassificationHead(
100116
input_dim=expected_input_dim,
101-
num_classes=model_params["num_classes"],
117+
num_classes=model_params["num_classes"] if not label_attention_enabled else 1,
102118
)
103119

104120
# Create model
@@ -136,6 +152,14 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
136152
categorical_embedding_dims=model_params["categorical_embedding_dims"],
137153
num_classes=model_params["num_classes"],
138154
attention_config=attention_config,
155+
label_attention_config=(
156+
LabelAttentionConfig(
157+
n_head=attention_config.n_head,
158+
num_classes=model_params["num_classes"],
159+
)
160+
if label_attention_enabled
161+
else None
162+
),
139163
)
140164

141165
# Create training config
@@ -163,13 +187,41 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
163187

164188
# Predict with explanations
165189
top_k = 5
166-
predictions = ttc.predict(X, top_k=top_k, explain=True)
190+
191+
predictions = ttc.predict(
192+
X,
193+
top_k=top_k,
194+
explain_with_label_attention=label_attention_enabled,
195+
explain_with_captum=True,
196+
)
197+
198+
# Test label attention assertions
199+
if label_attention_enabled:
200+
assert (
201+
predictions["label_attention_attributions"] is not None
202+
), "Label attention attributions should not be None when label_attention_enabled is True"
203+
label_attention_attributions = predictions["label_attention_attributions"]
204+
expected_shape = (
205+
len(sample_text_data), # batch_size
206+
model_params["n_head"], # n_head
207+
model_params["num_classes"], # num_classes
208+
tokenizer.output_dim, # seq_len
209+
)
210+
assert label_attention_attributions.shape == expected_shape, (
211+
f"Label attention attributions shape mismatch. "
212+
f"Expected {expected_shape}, got {label_attention_attributions.shape}"
213+
)
214+
else:
215+
# When label attention is not enabled, the attributions should be None
216+
assert (
217+
predictions.get("label_attention_attributions") is None
218+
), "Label attention attributions should be None when label_attention_enabled is False"
167219

168220
# Test explainability functions
169221
text_idx = 0
170222
text = sample_text_data[text_idx]
171223
offsets = predictions["offset_mapping"][text_idx]
172-
attributions = predictions["attributions"][text_idx]
224+
attributions = predictions["captum_attributions"][text_idx]
173225
word_ids = predictions["word_ids"][text_idx]
174226

175227
words, word_attributions = map_attributions_to_word(attributions, text, word_ids, offsets)
@@ -239,3 +291,26 @@ def test_ngram_tokenizer(sample_data, model_params):
239291

240292
# Run full pipeline
241293
run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, model_params)
294+
295+
296+
def test_label_attention_enabled(sample_data, model_params):
297+
"""Test the full pipeline with label attention enabled (using WordPieceTokenizer)."""
298+
sample_text_data, categorical_data, labels = sample_data
299+
300+
vocab_size = 100
301+
tokenizer = WordPieceTokenizer(vocab_size, output_dim=50)
302+
tokenizer.train(sample_text_data)
303+
304+
# Check tokenizer works
305+
result = tokenizer.tokenize(sample_text_data)
306+
assert result.input_ids.shape[0] == len(sample_text_data)
307+
308+
# Run full pipeline with label attention enabled
309+
run_full_pipeline(
310+
tokenizer,
311+
sample_text_data,
312+
categorical_data,
313+
labels,
314+
model_params,
315+
label_attention_enabled=True,
316+
)

torchTextClassifiers/model/components/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@
88
CategoricalVariableNet as CategoricalVariableNet,
99
)
1010
from .classification_head import ClassificationHead as ClassificationHead
11+
from .text_embedder import LabelAttentionConfig as LabelAttentionConfig
1112
from .text_embedder import TextEmbedder as TextEmbedder
1213
from .text_embedder import TextEmbedderConfig as TextEmbedderConfig

0 commit comments

Comments
 (0)