Skip to content

Commit 9aff671

Browse files
chore: rename file
1 parent 08dc4e3 commit 9aff671

3 files changed

Lines changed: 27 additions & 4 deletions

File tree

torchTextClassifiers/categorical_value_encoder/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .value_encoder import DictEncoder as DictEncoder
2+
from .value_encoder import ValueEncoder as ValueEncoder

torchTextClassifiers/categorical_value_encoder/categorical_value_encoder.py renamed to torchTextClassifiers/value_encoder/value_encoder.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class ValueEncoder:
4242
4343
Initialization:
4444
- label_encoder: A DictEncoder or LabelEncoder instance for encoding labels.
45-
- encoders (optional): A dictionary mapping feature names to DictEncoder or LabelEncoder instances.
45+
- encoders (optional): A dictionary mapping feature names to DictEncoder or
46+
LabelEncoder instances.
4647
4748
Properties:
4849
- vocabulary_sizes: List of vocabulary sizes (number of unique values) for each feature.
@@ -62,7 +63,8 @@ def __init__(
6263

6364
if not isinstance(label_encoder, (DictEncoder, LabelEncoder)):
6465
raise TypeError(
65-
f"label_encoder must be a DictEncoder or LabelEncoder instance, got {type(label_encoder)}"
66+
"label_encoder must be a DictEncoder or LabelEncoder instance, "
67+
f"got {type(label_encoder)}"
6668
)
6769
self.label_encoder = label_encoder
6870

@@ -153,5 +155,26 @@ def transform_labels(self, y_labels: np.ndarray) -> np.ndarray:
153155
"These values were not seen during fitting."
154156
)
155157

158+
def inverse_transform_labels(self, y_encoded: np.ndarray) -> np.ndarray:
159+
"""Decode integer-encoded labels back to original values.
160+
161+
Args:
162+
y_encoded: Array of shape (N,) with integer-encoded labels.
163+
Returns:
164+
Array of shape (N,) with original label values.
165+
Raises:
166+
ValueError: If any encoded label value was not seen during fitting.
167+
"""
168+
169+
if isinstance(self.label_encoder, DictEncoder):
170+
inverse_mapping = self.label_encoder.inverse_mapping
171+
return np.vectorize(inverse_mapping.get)(y_encoded)
172+
elif hasattr(self.label_encoder, "inverse_transform"):
173+
shape = y_encoded.shape
174+
result = self.label_encoder.inverse_transform(y_encoded.ravel())
175+
return result.reshape(shape) if len(shape) > 1 else result
176+
else:
177+
raise TypeError(f"Unsupported label encoder type: {type(self.label_encoder)}")
178+
156179
def __call__(self, array: np.ndarray) -> np.ndarray:
157180
return self.transform(array)

0 commit comments

Comments
 (0)