Skip to content

Commit 0558f97

Browse files
Add assertions for label attention attributions in tests
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
1 parent 70b79c9 commit 0558f97

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

tests/test_pipeline.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,28 @@ def run_full_pipeline(
197197
explain_with_captum=True,
198198
)
199199

200+
# Test label attention assertions
201+
if label_attention_enabled:
202+
assert predictions["label_attention_attributions"] is not None, (
203+
"Label attention attributions should not be None when label_attention_enabled is True"
204+
)
205+
label_attention_attributions = predictions["label_attention_attributions"]
206+
expected_shape = (
207+
len(sample_text_data), # batch_size
208+
model_params["n_head"], # n_head
209+
model_params["num_classes"], # num_classes
210+
tokenizer.output_dim, # seq_len
211+
)
212+
assert label_attention_attributions.shape == expected_shape, (
213+
f"Label attention attributions shape mismatch. "
214+
f"Expected {expected_shape}, got {label_attention_attributions.shape}"
215+
)
216+
else:
217+
# When label attention is not enabled, the attributions should be None
218+
assert predictions.get("label_attention_attributions") is None, (
219+
"Label attention attributions should be None when label_attention_enabled is False"
220+
)
221+
200222
# Test explainability functions
201223
text_idx = 0
202224
text = sample_text_data[text_idx]

0 commit comments

Comments
 (0)