11import numpy as np
22import pytest
33import torch
4+ from sklearn .preprocessing import LabelEncoder
45
56from torchTextClassifiers import ModelConfig , TrainingConfig , torchTextClassifiers
7+ from torchTextClassifiers .categorical_value_encoder import CategoricalValueEncoder , DictEncoder
68from torchTextClassifiers .dataset import TextClassificationDataset
79from torchTextClassifiers .model import TextClassificationModel , TextClassificationModule
810from torchTextClassifiers .model .components import (
1315 TextEmbedder ,
1416 TextEmbedderConfig ,
1517)
16- from torchTextClassifiers .tokenizers import HuggingFaceTokenizer , NGramTokenizer , WordPieceTokenizer
18+ from torchTextClassifiers .tokenizers import NGramTokenizer
19+
20+ try :
21+ from torchTextClassifiers .tokenizers import HuggingFaceTokenizer , WordPieceTokenizer
22+ except ImportError :
23+ pass
24+
1725from torchTextClassifiers .utilities .plot_explainability import (
1826 map_attributions_to_char ,
1927 map_attributions_to_word ,
@@ -33,21 +41,31 @@ def sample_data():
3341 "Good example here" ,
3442 "Bad example here" ,
3543 ]
36- categorical_data = np .array ([[1 , 0 ], [0 , 1 ], [1 , 0 ], [0 , 1 ], [1 , 0 ], [0 , 1 ]]).astype (int )
37- labels = np .array ([1 , 0 , 1 , 0 , 1 , 5 ])
44+ # String categorical variables — two features, two unique values each
45+ categorical_data = np .array (
46+ [
47+ ["cat" , "red" ],
48+ ["dog" , "blue" ],
49+ ["cat" , "red" ],
50+ ["dog" , "blue" ],
51+ ["cat" , "red" ],
52+ ["dog" , "blue" ],
53+ ]
54+ )
55+ # String labels
56+ labels = np .array (["positive" , "negative" , "positive" , "negative" , "positive" , "negative" ])
3857
3958 return sample_text_data , categorical_data , labels
4059
4160
4261@pytest .fixture
4362def model_params ():
44- """Fixture providing common model parameters."""
63+ """Fixture providing common model parameters (class count and vocab sizes are
64+ derived from data at runtime inside run_full_pipeline)."""
4565 return {
4666 "embedding_dim" : 96 ,
4767 "n_layers" : 2 ,
4868 "n_head" : 4 ,
49- "num_classes" : 10 ,
50- "categorical_vocab_sizes" : [2 , 2 ],
5169 "categorical_embedding_dims" : [4 , 7 ],
5270 }
5371
@@ -61,10 +79,28 @@ def run_full_pipeline(
6179 label_attention_enabled : bool = False ,
6280):
6381 """Helper function to run the complete pipeline for a given tokenizer."""
64- # Create dataset
82+
83+ # --- Encode categorical variables (string → int) ---
84+ n_features = categorical_data .shape [1 ]
85+ encoders = {
86+ str (i ): DictEncoder (
87+ {v : j for j , v in enumerate (sorted (set (categorical_data [:, i ].tolist ())))}
88+ )
89+ for i in range (n_features )
90+ }
91+ cat_encoder = CategoricalValueEncoder (encoders )
92+ encoded_categorical = cat_encoder .transform (categorical_data )
93+ vocab_sizes = cat_encoder .vocabulary_sizes
94+
95+ # --- Encode string labels to contiguous integers ---
96+ label_encoder = LabelEncoder ()
97+ encoded_labels = label_encoder .fit_transform (labels )
98+ num_classes = len (label_encoder .classes_ )
99+
100+ # --- Direct component test: dataset with already-encoded integers ---
65101 dataset = TextClassificationDataset (
66102 texts = sample_text_data ,
67- categorical_variables = categorical_data .tolist (),
103+ categorical_variables = encoded_categorical .tolist (),
68104 tokenizer = tokenizer ,
69105 labels = None ,
70106 )
@@ -94,7 +130,7 @@ def run_full_pipeline(
94130 label_attention_config = (
95131 LabelAttentionConfig (
96132 n_head = attention_config .n_head ,
97- num_classes = model_params [ " num_classes" ] ,
133+ num_classes = num_classes ,
98134 )
99135 if label_attention_enabled
100136 else None
@@ -104,17 +140,17 @@ def run_full_pipeline(
104140 text_embedder = TextEmbedder (text_embedder_config = text_embedder_config )
105141 text_embedder .init_weights ()
106142
107- # Create categorical variable net
143+ # Create categorical variable net (vocab sizes from fitted encoder)
108144 categorical_var_net = CategoricalVariableNet (
109- categorical_vocabulary_sizes = model_params [ "categorical_vocab_sizes" ] ,
145+ categorical_vocabulary_sizes = vocab_sizes ,
110146 categorical_embedding_dims = model_params ["categorical_embedding_dims" ],
111147 )
112148
113149 # Create classification head
114150 expected_input_dim = model_params ["embedding_dim" ] + categorical_var_net .output_dim
115151 classification_head = ClassificationHead (
116152 input_dim = expected_input_dim ,
117- num_classes = model_params [ " num_classes" ] if not label_attention_enabled else 1 ,
153+ num_classes = num_classes if not label_attention_enabled else 1 ,
118154 )
119155
120156 # Create model
@@ -141,45 +177,47 @@ def run_full_pipeline(
141177 # Test prediction
142178 module .predict_step (batch )
143179
144- # Prepare data for training
180+ # --- Wrapper pipeline with string categorical data ---
181+ # X keeps categorical columns as raw strings; the wrapper encoder handles them.
145182 X = np .column_stack ([sample_text_data , categorical_data ])
146- Y = labels
183+ Y = encoded_labels # integer-encoded labels (from LabelEncoder)
147184
148- # Create model config
185+ # Create model config (vocab sizes and num_classes come from the encoders)
149186 model_config = ModelConfig (
150187 embedding_dim = model_params ["embedding_dim" ],
151- categorical_vocabulary_sizes = model_params [ "categorical_vocab_sizes" ] ,
188+ categorical_vocabulary_sizes = vocab_sizes ,
152189 categorical_embedding_dims = model_params ["categorical_embedding_dims" ],
153- num_classes = model_params [ " num_classes" ] ,
190+ num_classes = num_classes ,
154191 attention_config = attention_config ,
155192 n_heads_label_attention = attention_config .n_head ,
156193 )
157194
158- # Create training config
159195 training_config = TrainingConfig (
160196 lr = 1e-3 ,
161197 batch_size = 4 ,
162198 num_epochs = 1 ,
163199 )
164200
165- # Create classifier
201+ # Create classifier — pass the fitted categorical encoder
166202 ttc = torchTextClassifiers (
167203 tokenizer = tokenizer ,
168204 model_config = model_config ,
205+ categorical_encoder = cat_encoder ,
169206 )
170207
171- # Train
208+ # Train with raw string categorical data
172209 ttc .train (
173210 X_train = X ,
174211 y_train = Y ,
175212 X_val = X ,
176213 y_val = Y ,
177214 training_config = training_config ,
178215 )
179- ttc .load (ttc .save_path ) # test load
216+ assert ttc .save_path is not None
217+ ttc .load (ttc .save_path ) # test load (encoder is also saved/restored)
180218
181219 # Predict with explanations
182- top_k = 5
220+ top_k = min ( 5 , num_classes )
183221
184222 predictions = ttc .predict (
185223 X ,
@@ -197,15 +235,14 @@ def run_full_pipeline(
197235 expected_shape = (
198236 len (sample_text_data ), # batch_size
199237 model_params ["n_head" ], # n_head
200- model_params [ " num_classes" ] , # num_classes
238+ num_classes , # num_classes (derived from label encoder)
201239 tokenizer .output_dim , # seq_len
202240 )
203241 assert label_attention_attributions .shape == expected_shape , (
204242 f"Label attention attributions shape mismatch. "
205243 f"Expected { expected_shape } , got { label_attention_attributions .shape } "
206244 )
207245 else :
208- # When label attention is not enabled, the attributions should be None
209246 assert (
210247 predictions .get ("label_attention_attributions" ) is None
211248 ), "Label attention attributions should be None when label_attention_enabled is False"
@@ -220,8 +257,6 @@ def run_full_pipeline(
220257 words , word_attributions = map_attributions_to_word (attributions , text , word_ids , offsets )
221258 char_attributions = map_attributions_to_char (attributions , offsets , text )
222259
223- # Note: We're not actually plotting in tests, just calling the functions
224- # to ensure they don't raise errors
225260 plot_attributions_at_char (text , char_attributions )
226261 plot_attributions_at_word (
227262 text = text ,
@@ -238,11 +273,9 @@ def test_wordpiece_tokenizer(sample_data, model_params):
238273 tokenizer = WordPieceTokenizer (vocab_size , output_dim = 50 )
239274 tokenizer .train (sample_text_data )
240275
241- # Check tokenizer works
242276 result = tokenizer .tokenize (sample_text_data )
243277 assert result .input_ids .shape [0 ] == len (sample_text_data )
244278
245- # Run full pipeline
246279 run_full_pipeline (tokenizer , sample_text_data , categorical_data , labels , model_params )
247280
248281
@@ -254,11 +287,9 @@ def test_huggingface_tokenizer(sample_data, model_params):
254287 "google-bert/bert-base-uncased" , output_dim = 50
255288 )
256289
257- # Check tokenizer works
258290 result = tokenizer .tokenize (sample_text_data )
259291 assert result .input_ids .shape [0 ] == len (sample_text_data )
260292
261- # Run full pipeline
262293 run_full_pipeline (tokenizer , sample_text_data , categorical_data , labels , model_params )
263294
264295
@@ -271,18 +302,15 @@ def test_ngram_tokenizer(sample_data, model_params):
271302 )
272303 tokenizer .train (sample_text_data )
273304
274- # Check tokenizer works
275305 result = tokenizer .tokenize (
276306 sample_text_data [0 ], return_offsets_mapping = True , return_word_ids = True
277307 )
278308 assert result .input_ids is not None
279309
280- # Check batch decode
281310 batch_result = tokenizer .tokenize (sample_text_data )
282311 decoded = tokenizer .batch_decode (batch_result .input_ids .tolist ())
283312 assert len (decoded ) == len (sample_text_data )
284313
285- # Run full pipeline
286314 run_full_pipeline (tokenizer , sample_text_data , categorical_data , labels , model_params )
287315
288316
@@ -294,11 +322,9 @@ def test_label_attention_enabled(sample_data, model_params):
294322 tokenizer = WordPieceTokenizer (vocab_size , output_dim = 50 )
295323 tokenizer .train (sample_text_data )
296324
297- # Check tokenizer works
298325 result = tokenizer .tokenize (sample_text_data )
299326 assert result .input_ids .shape [0 ] == len (sample_text_data )
300327
301- # Run full pipeline with label attention enabled
302328 run_full_pipeline (
303329 tokenizer ,
304330 sample_text_data ,
0 commit comments