2222 ModelCheckpoint ,
2323)
2424
25- from torchTextClassifiers .categorical_value_encoder import ValueEncoder
2625from torchTextClassifiers .dataset import TextClassificationDataset
2726from torchTextClassifiers .model import TextClassificationModel , TextClassificationModule
2827from torchTextClassifiers .model .components import (
3534 TextEmbedderConfig ,
3635)
3736from torchTextClassifiers .tokenizers import BaseTokenizer , TokenizerOutput
37+ from torchTextClassifiers .value_encoder import ValueEncoder
3838
3939logger = logging .getLogger (__name__ )
4040
@@ -125,7 +125,7 @@ def __init__(
125125
126126 Example:
127127 >>> from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
128- >>> from torchTextClassifiers.categorical_value_encoder import ValueEncoder, DictEncoder
128+ >>> from torchTextClassifiers.value_encoder import ValueEncoder, DictEncoder
129129 >>> # Build one DictEncoder per categorical feature
130130 >>> encoders = {str(i): DictEncoder({v: j for j, v in enumerate(sorted(set(X_categorical[:, i])))})
131131 ... for i in range(X_categorical.shape[1])}
@@ -492,9 +492,9 @@ def _check_categorical_variables(
492492 raise ValueError (
493493 "Raw categorical input encoding is enabled, but no value_encoder was provided. Please provide a ValueEncoder to encode raw categorical values to integers."
494494 )
495- X [:, 1 :] = self .value_encoder .transform (X [:, 1 :])
496-
497- categorical_variables = X [:, 1 :].astype (int )
495+ categorical_variables = self .value_encoder .transform (X [:, 1 :]). astype ( int )
496+ else :
497+ categorical_variables = X [:, 1 :].astype (int )
498498
499499 for j in range (num_cat_vars ):
500500 max_cat_value = categorical_variables [:, j ].max ()
@@ -549,6 +549,7 @@ def _check_Y(self, Y, raw_labels: bool) -> np.ndarray:
549549 def predict (
550550 self ,
551551 X_test : np .ndarray ,
552+ raw_categorical_inputs : bool = True ,
552553 top_k = 1 ,
553554 explain_with_label_attention : bool = False ,
554555 explain_with_captum = False ,
@@ -593,7 +594,7 @@ def predict(
593594 return_offsets_mapping = False
594595 return_word_ids = False
595596
596- X_test = self ._check_X (X_test )
597+ X_test = self ._check_X (X_test , raw_categorical_inputs )
597598 text = X_test ["text" ]
598599 categorical_variables = X_test ["categorical_variables" ]
599600
@@ -638,7 +639,12 @@ def predict(
638639
639640 label_scores_topk = torch .topk (label_scores , k = top_k , dim = 1 )
640641
641- predictions = label_scores_topk .indices # get the top_k most likely predictions
642+ integer_predictions = label_scores_topk .indices # integer class indices (needed for captum)
643+ if self .value_encoder is not None :
644+ predictions = self .value_encoder .inverse_transform_labels (integer_predictions .numpy ())
645+ else :
646+ predictions = integer_predictions
647+
642648 confidence = torch .round (label_scores_topk .values , decimals = 2 ) # and their scores
643649
644650 if explain :
@@ -648,7 +654,7 @@ def predict(
648654 for k in range (top_k ):
649655 attributions = lig .attribute (
650656 (encoded_text , attention_mask , categorical_vars ),
651- target = torch . Tensor ( predictions [:, k ]). long () ,
657+ target = integer_predictions [:, k ],
652658 ) # (batch_size, seq_len)
653659 attributions = attributions .sum (dim = - 1 )
654660 captum_attributions .append (attributions .detach ().cpu ())
0 commit comments