Skip to content

Commit a2572ad

Browse files
fix: fix tests to new value encoder
1 parent 1ff499b commit a2572ad

1 file changed

Lines changed: 9 additions & 8 deletions

File tree

tests/test_pipeline.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from sklearn.preprocessing import LabelEncoder
55

66
from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
7-
from torchTextClassifiers.categorical_value_encoder import CategoricalValueEncoder, DictEncoder
87
from torchTextClassifiers.dataset import TextClassificationDataset
98
from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule
109
from torchTextClassifiers.model.components import (
@@ -16,6 +15,7 @@
1615
TextEmbedderConfig,
1716
)
1817
from torchTextClassifiers.tokenizers import NGramTokenizer
18+
from torchTextClassifiers.value_encoder import DictEncoder, ValueEncoder
1919

2020
try:
2121
from torchTextClassifiers.tokenizers import HuggingFaceTokenizer, WordPieceTokenizer
@@ -88,15 +88,16 @@ def run_full_pipeline(
8888
)
8989
for i in range(n_features)
9090
}
91-
cat_encoder = CategoricalValueEncoder(encoders)
92-
encoded_categorical = cat_encoder.transform(categorical_data)
93-
vocab_sizes = cat_encoder.vocabulary_sizes
9491

9592
# --- Encode string labels to contiguous integers ---
9693
label_encoder = LabelEncoder()
97-
encoded_labels = label_encoder.fit_transform(labels)
94+
label_encoder.fit(labels)
9895
num_classes = len(label_encoder.classes_)
9996

97+
value_encoder = ValueEncoder(label_encoder, encoders)
98+
encoded_categorical = value_encoder.transform(categorical_data)
99+
vocab_sizes = value_encoder.vocabulary_sizes
100+
100101
# --- Direct component test: dataset with already-encoded integers ---
101102
dataset = TextClassificationDataset(
102103
texts=sample_text_data,
@@ -180,7 +181,7 @@ def run_full_pipeline(
180181
# --- Wrapper pipeline with string categorical data ---
181182
# X keeps categorical columns as raw strings; the wrapper encoder handles them.
182183
X = np.column_stack([sample_text_data, categorical_data])
183-
Y = encoded_labels # integer-encoded labels (from LabelEncoder)
184+
Y = labels # raw string labels (encoded by value_encoder)
184185

185186
# Create model config (vocab sizes and num_classes come from the encoders)
186187
model_config = ModelConfig(
@@ -198,11 +199,11 @@ def run_full_pipeline(
198199
num_epochs=1,
199200
)
200201

201-
# Create classifier — pass the fitted categorical encoder
202+
# Create classifier — pass the fitted value encoder
202203
ttc = torchTextClassifiers(
203204
tokenizer=tokenizer,
204205
model_config=model_config,
205-
categorical_encoder=cat_encoder,
206+
value_encoder=value_encoder,
206207
)
207208

208209
# Train with raw string categorical data

0 commit comments

Comments
 (0)