Skip to content

Commit 1ff499b

Browse files
feat: add label desencoding after prediction
1 parent 9aff671 commit 1ff499b

1 file changed

Lines changed: 14 additions & 8 deletions

File tree

torchTextClassifiers/torchTextClassifiers.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
ModelCheckpoint,
2323
)
2424

25-
from torchTextClassifiers.categorical_value_encoder import ValueEncoder
2625
from torchTextClassifiers.dataset import TextClassificationDataset
2726
from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule
2827
from torchTextClassifiers.model.components import (
@@ -35,6 +34,7 @@
3534
TextEmbedderConfig,
3635
)
3736
from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput
37+
from torchTextClassifiers.value_encoder import ValueEncoder
3838

3939
logger = logging.getLogger(__name__)
4040

@@ -125,7 +125,7 @@ def __init__(
125125
126126
Example:
127127
>>> from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
128-
>>> from torchTextClassifiers.categorical_value_encoder import ValueEncoder, DictEncoder
128+
>>> from torchTextClassifiers.value_encoder import ValueEncoder, DictEncoder
129129
>>> # Build one DictEncoder per categorical feature
130130
>>> encoders = {str(i): DictEncoder({v: j for j, v in enumerate(sorted(set(X_categorical[:, i])))})
131131
... for i in range(X_categorical.shape[1])}
@@ -492,9 +492,9 @@ def _check_categorical_variables(
492492
raise ValueError(
493493
"Raw categorical input encoding is enabled, but no value_encoder was provided. Please provide a ValueEncoder to encode raw categorical values to integers."
494494
)
495-
X[:, 1:] = self.value_encoder.transform(X[:, 1:])
496-
497-
categorical_variables = X[:, 1:].astype(int)
495+
categorical_variables = self.value_encoder.transform(X[:, 1:]).astype(int)
496+
else:
497+
categorical_variables = X[:, 1:].astype(int)
498498

499499
for j in range(num_cat_vars):
500500
max_cat_value = categorical_variables[:, j].max()
@@ -549,6 +549,7 @@ def _check_Y(self, Y, raw_labels: bool) -> np.ndarray:
549549
def predict(
550550
self,
551551
X_test: np.ndarray,
552+
raw_categorical_inputs: bool = True,
552553
top_k=1,
553554
explain_with_label_attention: bool = False,
554555
explain_with_captum=False,
@@ -593,7 +594,7 @@ def predict(
593594
return_offsets_mapping = False
594595
return_word_ids = False
595596

596-
X_test = self._check_X(X_test)
597+
X_test = self._check_X(X_test, raw_categorical_inputs)
597598
text = X_test["text"]
598599
categorical_variables = X_test["categorical_variables"]
599600

@@ -638,7 +639,12 @@ def predict(
638639

639640
label_scores_topk = torch.topk(label_scores, k=top_k, dim=1)
640641

641-
predictions = label_scores_topk.indices # get the top_k most likely predictions
642+
integer_predictions = label_scores_topk.indices # integer class indices (needed for captum)
643+
if self.value_encoder is not None:
644+
predictions = self.value_encoder.inverse_transform_labels(integer_predictions.numpy())
645+
else:
646+
predictions = integer_predictions
647+
642648
confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores
643649

644650
if explain:
@@ -648,7 +654,7 @@ def predict(
648654
for k in range(top_k):
649655
attributions = lig.attribute(
650656
(encoded_text, attention_mask, categorical_vars),
651-
target=torch.Tensor(predictions[:, k]).long(),
657+
target=integer_predictions[:, k],
652658
) # (batch_size, seq_len)
653659
attributions = attributions.sum(dim=-1)
654660
captum_attributions.append(attributions.detach().cpu())

0 commit comments

Comments
 (0)