11from abc import ABC , abstractmethod
2- from typing import List , Optional , Union
2+ from dataclasses import asdict , dataclass
3+ from typing import Any , Dict , List , Optional , Union
4+
5+ import numpy as np
6+ import torch
37
48try :
59 from tokenizers import Tokenizer
1014 HAS_HF = False
1115
1216
17+ @dataclass
18+ class TokenizerOutput :
19+ input_ids : torch .Tensor # shape: (batch_size, seq_len)
20+ attention_mask : torch .Tensor # shape: (batch_size, seq_len)
21+ offset_mapping : Optional [torch .Tensor ] = None # shape: (batch_size, seq_len, 2)
22+ word_ids : Optional [np .ndarray ] = None # shape: (batch_size, seq_len)
23+
24+ def to_dict (self ) -> Dict [str , Any ]:
25+ return asdict (self )
26+
27+ @classmethod
28+ def from_dict (cls , data : Dict [str , Any ]) -> "TokenizerOutput" :
29+ return cls (** data )
30+
31+ def __post_init__ (self ):
32+ # --- Basic type checks ---
33+ if not isinstance (self .input_ids , torch .Tensor ):
34+ raise TypeError (f"token_ids must be a torch.Tensor, got { type (self .input_ids )} " )
35+ if not isinstance (self .attention_mask , torch .Tensor ):
36+ raise TypeError (
37+ f"attention_mask must be a torch.Tensor, got { type (self .attention_mask )} "
38+ )
39+ if self .offset_mapping is not None and not isinstance (self .offset_mapping , torch .Tensor ):
40+ raise TypeError (
41+ f"offset_mapping must be a torch.Tensor or None, got { type (self .offset_mapping )} "
42+ )
43+ if self .word_ids is not None and not isinstance (self .word_ids , np .ndarray ):
44+ raise TypeError (f"word_ids must be a numpy.ndarray or None, got { type (self .word_ids )} " )
45+
46+ # --- Shape consistency checks ---
47+ if self .input_ids .shape != self .attention_mask .shape :
48+ raise ValueError (
49+ f"Shape mismatch: token_ids { self .token_ids .shape } and attention_mask { self .attention_mask .shape } "
50+ )
51+
52+ if self .offset_mapping is not None :
53+ expected_shape = (* self .input_ids .shape , 2 )
54+ if self .offset_mapping .shape != expected_shape :
55+ raise ValueError (
56+ f"offset_mapping should have shape { expected_shape } , got { self .offset_mapping .shape } "
57+ )
58+
59+ if self .word_ids is not None :
60+ if self .word_ids .shape != self .input_ids .shape :
61+ raise ValueError (
62+ f"word_ids should have shape { self .input_ids .shape } , got { self .word_ids .shape } "
63+ )
64+
65+
1366class BaseTokenizer (ABC ):
1467 def __init__ (
1568 self , vocab_size : int , output_vectorized : bool = False , output_dim : Optional [int ] = None
@@ -32,7 +85,7 @@ def __init__(
3285 )
3386
3487 @abstractmethod
35- def tokenize (self , text : Union [str , List [str ]]) -> list :
88+ def tokenize (self , text : Union [str , List [str ]]) -> TokenizerOutput :
3689 """Tokenizes the raw input text into a list of tokens."""
3790 pass
3891
@@ -72,14 +125,23 @@ def tokenize(
72125 # Pad to longest sequence if no output_dim is specified
73126 padding = True if self .output_dim is None else "max_length"
74127
75- return self .tokenizer (
128+ tokenize_output = self .tokenizer (
76129 text ,
77130 padding = padding ,
78131 return_tensors = "pt" ,
79132 max_length = self .output_dim ,
80133 return_offsets_mapping = return_offsets_mapping ,
81134 ) # method from PreTrainedTokenizerFast
82135
136+ encoded_text = tokenize_output ["input_ids" ]
137+
138+ return TokenizerOutput (
139+ input_ids = encoded_text ,
140+ attention_mask = tokenize_output ["attention_mask" ],
141+ offset_mapping = tokenize_output .get ("offset_mapping" , None ),
142+ word_ids = np .array ([tokenize_output .word_ids (i ) for i in range (len (encoded_text ))]),
143+ )
144+
83145 @classmethod
84146 def load_from_pretrained (cls , tokenizer_name : str , output_dim : Optional [int ] = None ):
85147 tokenizer = AutoTokenizer .from_pretrained (tokenizer_name )
0 commit comments