44from functools import lru_cache
55from typing import List , Optional , Tuple , Union
66
7+ import numpy as np
78import torch
89
910from torchTextClassifiers .tokenizers import BaseTokenizer , TokenizerOutput
@@ -113,7 +114,7 @@ def get(self, word: str) -> List[int]:
113114
114115
115116# ============================================================================
116- # Vectorized encoding
117+ # Vectorized encoding with optional metadata
117118# ============================================================================
118119
119120
@@ -124,33 +125,78 @@ def encode_batch_vectorized(
124125 pad_token_id : int ,
125126 max_length : Optional [int ] = None ,
126127 truncation : bool = False ,
127- ) -> Tuple [torch .Tensor , torch .Tensor ]:
128+ return_offsets_mapping : bool = False ,
129+ return_word_ids : bool = False ,
130+ force_max_length : bool = False ,
131+ ) -> Tuple [torch .Tensor , torch .Tensor , Optional [List ], Optional [List ]]:
128132 """
129133 Vectorized batch encoding - processes all sentences together.
130- Returns padded tensors directly.
134+ Returns padded tensors directly, with optional offset mappings and word IDs.
135+
136+ Args:
137+ force_max_length: If True and max_length is set, always return tensors of size max_length
131138 """
132139 all_ids = []
140+ all_offsets = [] if return_offsets_mapping else None
141+ all_word_ids = [] if return_word_ids else None
133142 max_len = 0
134143
135144 # First pass: encode all sentences
136145 for sentence in sentences :
137146 ids = []
147+ offsets = [] if return_offsets_mapping else None
148+ word_ids = [] if return_word_ids else None
149+
138150 words = sentence .split ()
151+ char_offset = 0
152+
153+ for word_idx , word in enumerate (words ):
154+ # Find the actual position of this word in the original sentence
155+ word_start = sentence .find (word , char_offset )
156+ word_end = word_start + len (word )
157+ char_offset = word_end
158+
159+ # Get subword tokens for this word
160+ subword_tokens = subword_cache .get (word )
139161
140- for word in words :
141- ids .extend ( subword_cache . get ( word ) )
162+ for token_id in subword_tokens :
163+ ids .append ( token_id )
142164
165+ if return_offsets_mapping :
166+ # All subword tokens of a word map to the word's character span
167+ offsets .append ((word_start , word_end ))
168+
169+ if return_word_ids :
170+ # All subword tokens of a word get the same word_id
171+ word_ids .append (word_idx )
172+
173+ # Add EOS token
143174 ids .append (eos_token_id )
175+ if return_offsets_mapping :
176+ offsets .append ((len (sentence ), len (sentence ))) # EOS has no span
177+ if return_word_ids :
178+ word_ids .append (None ) # EOS is not part of any word
144179
145180 # Truncate if needed
146181 if truncation and max_length and len (ids ) > max_length :
147182 ids = ids [:max_length ]
183+ if return_offsets_mapping :
184+ offsets = offsets [:max_length ]
185+ if return_word_ids :
186+ word_ids = word_ids [:max_length ]
148187
149188 all_ids .append (ids )
189+ if return_offsets_mapping :
190+ all_offsets .append (offsets )
191+ if return_word_ids :
192+ all_word_ids .append (word_ids )
150193 max_len = max (max_len , len (ids ))
151194
152195 # Determine final sequence length
153- if max_length and not truncation :
196+ if force_max_length and max_length :
197+ # Always use max_length when force_max_length is True
198+ seq_len = max_length
199+ elif max_length and not truncation :
154200 seq_len = min (max_len , max_length )
155201 elif max_length :
156202 seq_len = max_length
@@ -162,13 +208,22 @@ def encode_batch_vectorized(
162208 input_ids = torch .full ((batch_size , seq_len ), pad_token_id , dtype = torch .long )
163209 attention_mask = torch .zeros ((batch_size , seq_len ), dtype = torch .long )
164210
165- # Fill tensors
211+ # Fill tensors and pad metadata
166212 for i , ids in enumerate (all_ids ):
167213 length = min (len (ids ), seq_len )
168214 input_ids [i , :length ] = torch .tensor (ids [:length ], dtype = torch .long )
169215 attention_mask [i , :length ] = 1
170216
171- return input_ids , attention_mask
217+ # Pad offsets and word_ids to match sequence length
218+ if return_offsets_mapping :
219+ # Pad with (0, 0) for padding tokens
220+ all_offsets [i ] = all_offsets [i ][:length ] + [(0 , 0 )] * (seq_len - length )
221+
222+ if return_word_ids :
223+ # Pad with None for padding tokens
224+ all_word_ids [i ] = all_word_ids [i ][:length ] + [None ] * (seq_len - length )
225+
226+ return input_ids , attention_mask , all_offsets , all_word_ids
172227
173228
174229# ============================================================================
@@ -183,8 +238,7 @@ class NGramTokenizer(BaseTokenizer):
183238 - Vectorized batch encoding
184239 - Cached text normalization
185240 - Direct tensor operations
186- - No multiprocessing overhead
187- - No Numba dependency
241+ - Optional offset mapping and word ID tracking
188242 """
189243
190244 PAD_TOKEN = "[PAD]"
@@ -200,6 +254,7 @@ def __init__(
200254 len_word_ngrams : int ,
201255 training_text : Optional [List [str ]] = None ,
202256 preprocess : bool = True ,
257+ output_dim : Optional [int ] = None ,
203258 ** kwargs ,
204259 ):
205260 if min_n < 2 :
@@ -227,9 +282,11 @@ def __init__(
227282 self .subword_cache = None
228283
229284 self .vocab_size = 3 + self .nwords + self .num_tokens
230- super ().__init__ (vocab_size = self .vocab_size )
285+ super ().__init__ (
286+ vocab_size = self .vocab_size , padding_idx = self .pad_token_id , output_dim = output_dim
287+ )
231288
232- def _build_vocab (self , training_text : List [str ]):
289+ def train (self , training_text : List [str ]):
233290 """Build vocabulary from training text."""
234291 word_counts = {}
235292 for sent in training_text :
@@ -261,16 +318,24 @@ def _build_vocab(self, training_text: List[str]):
261318 def tokenize (
262319 self ,
263320 text : Union [str , List [str ]],
264- padding : str = "longest" ,
265- max_length : Optional [int ] = None ,
266- truncation : bool = False ,
267321 return_offsets_mapping : bool = False ,
268322 return_word_ids : bool = False ,
269323 ** kwargs ,
270324 ) -> TokenizerOutput :
271325 """
272326 Optimized tokenization with vectorized operations.
273- Note: return_offsets_mapping and return_word_ids removed for speed.
327+
328+ Args:
329+ text: Single string or list of strings to tokenize
330+ padding: Padding strategy ('longest' or 'max_length')
331+ max_length: Maximum sequence length
332+ truncation: Whether to truncate sequences exceeding max_length
333+ return_offsets_mapping: If True, return character offsets for each token
334+ return_word_ids: If True, return word indices for each token
335+
336+ Returns:
337+ TokenizerOutput with input_ids, attention_mask, and optionally
338+ offset_mapping and word_ids
274339 """
275340 is_single = isinstance (text , str )
276341 if is_single :
@@ -280,21 +345,33 @@ def tokenize(
280345 if self .preprocess :
281346 text = clean_text_feature (text )
282347
348+ if self .output_dim is not None :
349+ max_length = self .output_dim
350+ truncation = True
351+ else :
352+ max_length = None
353+ truncation = False
354+
283355 # Vectorized encoding
284- input_ids , attention_mask = encode_batch_vectorized (
356+ input_ids , attention_mask , offsets , word_ids = encode_batch_vectorized (
285357 text ,
286358 self .subword_cache ,
287359 self .eos_token_id ,
288360 self .pad_token_id ,
289- max_length = max_length if padding == "max_length" else None ,
361+ max_length = max_length ,
290362 truncation = truncation ,
363+ return_offsets_mapping = return_offsets_mapping ,
364+ return_word_ids = return_word_ids ,
291365 )
292366
367+ offsets = torch .tensor (offsets ) if return_offsets_mapping else None
368+ word_ids = np .array (word_ids ) if return_word_ids else None
369+
293370 return TokenizerOutput (
294371 input_ids = input_ids ,
295372 attention_mask = attention_mask ,
296- word_ids = None ,
297- offset_mapping = None ,
373+ word_ids = word_ids ,
374+ offset_mapping = offsets ,
298375 )
299376
300377 def decode (
0 commit comments