File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -197,6 +197,28 @@ def run_full_pipeline(
197197 explain_with_captum = True ,
198198 )
199199
200+ # Test label attention assertions
201+ if label_attention_enabled :
202+ assert predictions ["label_attention_attributions" ] is not None , (
203+ "Label attention attributions should not be None when label_attention_enabled is True"
204+ )
205+ label_attention_attributions = predictions ["label_attention_attributions" ]
206+ expected_shape = (
207+ len (sample_text_data ), # batch_size
208+ model_params ["n_head" ], # n_head
209+ model_params ["num_classes" ], # num_classes
210+ tokenizer .output_dim , # seq_len
211+ )
212+ assert label_attention_attributions .shape == expected_shape , (
213+ f"Label attention attributions shape mismatch. "
214+ f"Expected { expected_shape } , got { label_attention_attributions .shape } "
215+ )
216+ else :
217+ # When label attention is not enabled, the attributions should be None
218+ assert predictions .get ("label_attention_attributions" ) is None , (
219+ "Label attention attributions should be None when label_attention_enabled is False"
220+ )
221+
200222 # Test explainability functions
201223 text_idx = 0
202224 text = sample_text_data [text_idx ]
You can’t perform that action at this time.
0 commit comments