Skip to content

Commit 823467b

Browse files
fix: add a flag for return_word_ids
aligning with NGramTokenizer
1 parent ab70485 commit 823467b

3 files changed

Lines changed: 16 additions & 4 deletions

File tree

torchTextClassifiers/tokenizers/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ def __init__(
122122
self.output_dim = output_dim # constant context size for all batch
123123

124124
def tokenize(
125-
self, text: Union[str, List[str]], return_offsets_mapping: Optional[bool] = False
125+
self,
126+
text: Union[str, List[str]],
127+
return_offsets_mapping: Optional[bool] = False,
128+
return_word_ids: Optional[bool] = False,
126129
) -> list:
127130
if not self.trained:
128131
raise RuntimeError("Tokenizer must be trained before tokenization.")
@@ -142,11 +145,16 @@ def tokenize(
142145

143146
encoded_text = tokenize_output["input_ids"]
144147

148+
if return_word_ids:
149+
word_ids = np.array([tokenize_output.word_ids(i) for i in range(len(encoded_text))])
150+
else:
151+
word_ids = None
152+
145153
return TokenizerOutput(
146154
input_ids=encoded_text,
147155
attention_mask=tokenize_output["attention_mask"],
148156
offset_mapping=tokenize_output.get("offset_mapping", None),
149-
word_ids=np.array([tokenize_output.word_ids(i) for i in range(len(encoded_text))]),
157+
word_ids=word_ids,
150158
)
151159

152160
@classmethod

torchTextClassifiers/tokenizers/ngram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def __init__(
282282
self.subword_cache = None
283283

284284
self.vocab_size = 3 + self.nwords + self.num_tokens
285-
print("brrrrr ", self.vocab_size)
285+
286286
super().__init__(
287287
vocab_size=self.vocab_size, padding_idx=self.pad_token_id, output_dim=output_dim
288288
)

torchTextClassifiers/torchTextClassifiers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def predict(
460460

461461
if explain:
462462
return_offsets_mapping = True # to be passed to the tokenizer
463+
return_word_ids = True
463464
if self.pytorch_model.text_embedder is None:
464465
raise RuntimeError(
465466
"Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs."
@@ -474,6 +475,7 @@ def predict(
474475
) # initialize a Captum layer gradient integrator
475476
else:
476477
return_offsets_mapping = False
478+
return_word_ids = False
477479

478480
X_test = self._check_X(X_test)
479481
text = X_test["text"]
@@ -482,7 +484,9 @@ def predict(
482484
self.pytorch_model.eval().cpu()
483485

484486
tokenize_output = self.tokenizer.tokenize(
485-
text.tolist(), return_offsets_mapping=return_offsets_mapping
487+
text.tolist(),
488+
return_offsets_mapping=return_offsets_mapping,
489+
return_word_ids=return_word_ids,
486490
)
487491

488492
if not isinstance(tokenize_output, TokenizerOutput):

0 commit comments

Comments
 (0)