99 AttentionConfig ,
1010 CategoricalVariableNet ,
1111 ClassificationHead ,
12+ LabelAttentionConfig ,
1213 TextEmbedder ,
1314 TextEmbedderConfig ,
1415)
@@ -51,7 +52,14 @@ def model_params():
5152 }
5253
5354
54- def run_full_pipeline (tokenizer , sample_text_data , categorical_data , labels , model_params ):
55+ def run_full_pipeline (
56+ tokenizer ,
57+ sample_text_data ,
58+ categorical_data ,
59+ labels ,
60+ model_params ,
61+ label_attention_enabled : bool = False ,
62+ ):
5563 """Helper function to run the complete pipeline for a given tokenizer."""
5664 # Create dataset
5765 dataset = TextClassificationDataset (
@@ -83,6 +91,14 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
8391 embedding_dim = model_params ["embedding_dim" ],
8492 padding_idx = padding_idx ,
8593 attention_config = attention_config ,
94+ label_attention_config = (
95+ LabelAttentionConfig (
96+ n_head = attention_config .n_head ,
97+ num_classes = model_params ["num_classes" ],
98+ )
99+ if label_attention_enabled
100+ else None
101+ ),
86102 )
87103
88104 text_embedder = TextEmbedder (text_embedder_config = text_embedder_config )
@@ -98,7 +114,7 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
98114 expected_input_dim = model_params ["embedding_dim" ] + categorical_var_net .output_dim
99115 classification_head = ClassificationHead (
100116 input_dim = expected_input_dim ,
101- num_classes = model_params ["num_classes" ],
117+ num_classes = model_params ["num_classes" ] if not label_attention_enabled else 1 ,
102118 )
103119
104120 # Create model
@@ -136,6 +152,14 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
136152 categorical_embedding_dims = model_params ["categorical_embedding_dims" ],
137153 num_classes = model_params ["num_classes" ],
138154 attention_config = attention_config ,
155+ label_attention_config = (
156+ LabelAttentionConfig (
157+ n_head = attention_config .n_head ,
158+ num_classes = model_params ["num_classes" ],
159+ )
160+ if label_attention_enabled
161+ else None
162+ ),
139163 )
140164
141165 # Create training config
@@ -163,13 +187,41 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
163187
164188 # Predict with explanations
165189 top_k = 5
166- predictions = ttc .predict (X , top_k = top_k , explain = True )
190+
191+ predictions = ttc .predict (
192+ X ,
193+ top_k = top_k ,
194+ explain_with_label_attention = label_attention_enabled ,
195+ explain_with_captum = True ,
196+ )
197+
198+ # Test label attention assertions
199+ if label_attention_enabled :
200+ assert (
201+ predictions ["label_attention_attributions" ] is not None
202+ ), "Label attention attributions should not be None when label_attention_enabled is True"
203+ label_attention_attributions = predictions ["label_attention_attributions" ]
204+ expected_shape = (
205+ len (sample_text_data ), # batch_size
206+ model_params ["n_head" ], # n_head
207+ model_params ["num_classes" ], # num_classes
208+ tokenizer .output_dim , # seq_len
209+ )
210+ assert label_attention_attributions .shape == expected_shape , (
211+ f"Label attention attributions shape mismatch. "
212+ f"Expected { expected_shape } , got { label_attention_attributions .shape } "
213+ )
214+ else :
215+ # When label attention is not enabled, the attributions should be None
216+ assert (
217+ predictions .get ("label_attention_attributions" ) is None
218+ ), "Label attention attributions should be None when label_attention_enabled is False"
167219
168220 # Test explainability functions
169221 text_idx = 0
170222 text = sample_text_data [text_idx ]
171223 offsets = predictions ["offset_mapping" ][text_idx ]
172- attributions = predictions ["attributions " ][text_idx ]
224+ attributions = predictions ["captum_attributions " ][text_idx ]
173225 word_ids = predictions ["word_ids" ][text_idx ]
174226
175227 words , word_attributions = map_attributions_to_word (attributions , text , word_ids , offsets )
@@ -239,3 +291,26 @@ def test_ngram_tokenizer(sample_data, model_params):
239291
240292 # Run full pipeline
241293 run_full_pipeline (tokenizer , sample_text_data , categorical_data , labels , model_params )
294+
295+
296+ def test_label_attention_enabled (sample_data , model_params ):
297+ """Test the full pipeline with label attention enabled (using WordPieceTokenizer)."""
298+ sample_text_data , categorical_data , labels = sample_data
299+
300+ vocab_size = 100
301+ tokenizer = WordPieceTokenizer (vocab_size , output_dim = 50 )
302+ tokenizer .train (sample_text_data )
303+
304+ # Check tokenizer works
305+ result = tokenizer .tokenize (sample_text_data )
306+ assert result .input_ids .shape [0 ] == len (sample_text_data )
307+
308+ # Run full pipeline with label attention enabled
309+ run_full_pipeline (
310+ tokenizer ,
311+ sample_text_data ,
312+ categorical_data ,
313+ labels ,
314+ model_params ,
315+ label_attention_enabled = True ,
316+ )
0 commit comments