Skip to content

Commit 830a45c

Browse files
feat!(tokenizer): ensure output is consistent across al tokenizers
1 parent 6bdb750 commit 830a45c

4 files changed

Lines changed: 79 additions & 15 deletions

File tree

torchTextClassifiers/dataset/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def collate_fn(self, batch):
8181
categorical_tensors = None
8282

8383
return {
84-
"input_ids": tokenize_output["input_ids"],
85-
"attention_mask": tokenize_output["attention_mask"],
84+
"input_ids": tokenize_output.input_ids,
85+
"attention_mask": tokenize_output.attention_mask,
8686
"categorical_vars": categorical_tensors,
8787
"labels": labels_tensor,
8888
}
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from .base import (
22
HAS_HF as HAS_HF,
33
)
4-
from .base import (
5-
BaseTokenizer as BaseTokenizer,
6-
)
4+
from .base import BaseTokenizer as BaseTokenizer
75
from .base import (
86
HuggingFaceTokenizer as HuggingFaceTokenizer,
97
)
8+
from .base import TokenizerOutput as TokenizerOutput
109
from .WordPiece import WordPieceTokenizer as WordPieceTokenizer

torchTextClassifiers/tokenizers/base.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Optional, Union
2+
from dataclasses import asdict, dataclass
3+
from typing import Any, Dict, List, Optional, Union
4+
5+
import numpy as np
6+
import torch
37

48
try:
59
from tokenizers import Tokenizer
@@ -10,6 +14,55 @@
1014
HAS_HF = False
1115

1216

17+
@dataclass
18+
class TokenizerOutput:
19+
input_ids: torch.Tensor # shape: (batch_size, seq_len)
20+
attention_mask: torch.Tensor # shape: (batch_size, seq_len)
21+
offset_mapping: Optional[torch.Tensor] = None # shape: (batch_size, seq_len, 2)
22+
word_ids: Optional[np.ndarray] = None # shape: (batch_size, seq_len)
23+
24+
def to_dict(self) -> Dict[str, Any]:
25+
return asdict(self)
26+
27+
@classmethod
28+
def from_dict(cls, data: Dict[str, Any]) -> "TokenizerOutput":
29+
return cls(**data)
30+
31+
def __post_init__(self):
32+
# --- Basic type checks ---
33+
if not isinstance(self.input_ids, torch.Tensor):
34+
raise TypeError(f"token_ids must be a torch.Tensor, got {type(self.input_ids)}")
35+
if not isinstance(self.attention_mask, torch.Tensor):
36+
raise TypeError(
37+
f"attention_mask must be a torch.Tensor, got {type(self.attention_mask)}"
38+
)
39+
if self.offset_mapping is not None and not isinstance(self.offset_mapping, torch.Tensor):
40+
raise TypeError(
41+
f"offset_mapping must be a torch.Tensor or None, got {type(self.offset_mapping)}"
42+
)
43+
if self.word_ids is not None and not isinstance(self.word_ids, np.ndarray):
44+
raise TypeError(f"word_ids must be a numpy.ndarray or None, got {type(self.word_ids)}")
45+
46+
# --- Shape consistency checks ---
47+
if self.input_ids.shape != self.attention_mask.shape:
48+
raise ValueError(
49+
f"Shape mismatch: token_ids {self.token_ids.shape} and attention_mask {self.attention_mask.shape}"
50+
)
51+
52+
if self.offset_mapping is not None:
53+
expected_shape = (*self.input_ids.shape, 2)
54+
if self.offset_mapping.shape != expected_shape:
55+
raise ValueError(
56+
f"offset_mapping should have shape {expected_shape}, got {self.offset_mapping.shape}"
57+
)
58+
59+
if self.word_ids is not None:
60+
if self.word_ids.shape != self.input_ids.shape:
61+
raise ValueError(
62+
f"word_ids should have shape {self.input_ids.shape}, got {self.word_ids.shape}"
63+
)
64+
65+
1366
class BaseTokenizer(ABC):
1467
def __init__(
1568
self, vocab_size: int, output_vectorized: bool = False, output_dim: Optional[int] = None
@@ -32,7 +85,7 @@ def __init__(
3285
)
3386

3487
@abstractmethod
35-
def tokenize(self, text: Union[str, List[str]]) -> list:
88+
def tokenize(self, text: Union[str, List[str]]) -> TokenizerOutput:
3689
"""Tokenizes the raw input text into a list of tokens."""
3790
pass
3891

@@ -72,14 +125,23 @@ def tokenize(
72125
# Pad to longest sequence if no output_dim is specified
73126
padding = True if self.output_dim is None else "max_length"
74127

75-
return self.tokenizer(
128+
tokenize_output = self.tokenizer(
76129
text,
77130
padding=padding,
78131
return_tensors="pt",
79132
max_length=self.output_dim,
80133
return_offsets_mapping=return_offsets_mapping,
81134
) # method from PreTrainedTokenizerFast
82135

136+
encoded_text = tokenize_output["input_ids"]
137+
138+
return TokenizerOutput(
139+
input_ids=encoded_text,
140+
attention_mask=tokenize_output["attention_mask"],
141+
offset_mapping=tokenize_output.get("offset_mapping", None),
142+
word_ids=np.array([tokenize_output.word_ids(i) for i in range(len(encoded_text))]),
143+
)
144+
83145
@classmethod
84146
def load_from_pretrained(cls, tokenizer_name: str, output_dim: Optional[int] = None):
85147
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

torchTextClassifiers/torchTextClassifiers.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
ClassificationHead,
2929
TextEmbedder,
3030
)
31-
from torchTextClassifiers.tokenizers import BaseTokenizer
31+
from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput
3232

3333
logger = logging.getLogger(__name__)
3434

@@ -474,8 +474,13 @@ def predict(
474474
text.tolist(), return_offsets_mapping=return_offsets_mapping
475475
)
476476

477-
encoded_text = tokenize_output["input_ids"] # (batch_size, seq_len)
478-
attention_mask = tokenize_output["attention_mask"] # (batch_size, seq_len)
477+
if not isinstance(tokenize_output, TokenizerOutput):
478+
raise TypeError(
479+
f"Expected TokenizerOutput, got {type(tokenize_output)} from tokenizer.tokenize method."
480+
)
481+
482+
encoded_text = tokenize_output.input_ids # (batch_size, seq_len)
483+
attention_mask = tokenize_output.attention_mask # (batch_size, seq_len)
479484

480485
if categorical_variables is not None:
481486
categorical_vars = torch.tensor(
@@ -511,10 +516,8 @@ def predict(
511516
"prediction": predictions,
512517
"confidence": confidence,
513518
"attributions": all_attributions,
514-
"offset_mapping": tokenize_output.get("offset_mapping", None),
515-
"word_ids": np.array(
516-
[tokenize_output.word_ids(i) for i in range(len(encoded_text))]
517-
),
519+
"offset_mapping": tokenize_output.offset_mapping,
520+
"word_ids": tokenize_output.word_ids,
518521
}
519522
else:
520523
return {

0 commit comments

Comments
 (0)