99 AttentionConfig ,
1010 CategoricalVariableNet ,
1111 ClassificationHead ,
12+ LabelAttentionConfig ,
1213 TextEmbedder ,
1314 TextEmbedderConfig ,
1415)
@@ -51,7 +52,14 @@ def model_params():
5152 }
5253
5354
54- def run_full_pipeline (tokenizer , sample_text_data , categorical_data , labels , model_params ):
55+ def run_full_pipeline (
56+ tokenizer ,
57+ sample_text_data ,
58+ categorical_data ,
59+ labels ,
60+ model_params ,
61+ label_attention_enabled : bool = False ,
62+ ):
5563 """Helper function to run the complete pipeline for a given tokenizer."""
5664 # Create dataset
5765 dataset = TextClassificationDataset (
@@ -83,6 +91,15 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
8391 embedding_dim = model_params ["embedding_dim" ],
8492 padding_idx = padding_idx ,
8593 attention_config = attention_config ,
94+ label_attention_config = (
95+ LabelAttentionConfig (
96+ n_head = attention_config .n_head ,
97+ n_kv_head = attention_config .n_kv_head ,
98+ num_classes = model_params ["num_classes" ],
99+ )
100+ if label_attention_enabled
101+ else None
102+ ),
86103 )
87104
88105 text_embedder = TextEmbedder (text_embedder_config = text_embedder_config )
@@ -98,7 +115,7 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
98115 expected_input_dim = model_params ["embedding_dim" ] + categorical_var_net .output_dim
99116 classification_head = ClassificationHead (
100117 input_dim = expected_input_dim ,
101- num_classes = model_params ["num_classes" ],
118+ num_classes = model_params ["num_classes" ] if not label_attention_enabled else 1 ,
102119 )
103120
104121 # Create model
@@ -136,6 +153,15 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
136153 categorical_embedding_dims = model_params ["categorical_embedding_dims" ],
137154 num_classes = model_params ["num_classes" ],
138155 attention_config = attention_config ,
156+ label_attention_config = (
157+ LabelAttentionConfig (
158+ n_head = attention_config .n_head ,
159+ n_kv_head = attention_config .n_kv_head ,
160+ num_classes = model_params ["num_classes" ],
161+ )
162+ if label_attention_enabled
163+ else None
164+ ),
139165 )
140166
141167 # Create training config
@@ -239,3 +265,26 @@ def test_ngram_tokenizer(sample_data, model_params):
239265
240266 # Run full pipeline
241267 run_full_pipeline (tokenizer , sample_text_data , categorical_data , labels , model_params )
268+
269+
270+ def test_label_attention_enabled (sample_data , model_params ):
271+ """Test the full pipeline with label attention enabled (using WordPieceTokenizer)."""
272+ sample_text_data , categorical_data , labels = sample_data
273+
274+ vocab_size = 100
275+ tokenizer = WordPieceTokenizer (vocab_size , output_dim = 50 )
276+ tokenizer .train (sample_text_data )
277+
278+ # Check tokenizer works
279+ result = tokenizer .tokenize (sample_text_data )
280+ assert result .input_ids .shape [0 ] == len (sample_text_data )
281+
282+ # Run full pipeline with label attention enabled
283+ run_full_pipeline (
284+ tokenizer ,
285+ sample_text_data ,
286+ categorical_data ,
287+ labels ,
288+ model_params ,
289+ label_attention_enabled = True ,
290+ )
0 commit comments