Skip to content

Commit 3c0a85a

Browse files
feat(ngram): add return offsets and word_ids + fix output_dim
1 parent 84b118b commit 3c0a85a

3 files changed

Lines changed: 104 additions & 21 deletions

File tree

torchTextClassifiers/tokenizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
HuggingFaceTokenizer as HuggingFaceTokenizer,
77
)
88
from .base import TokenizerOutput as TokenizerOutput
9+
from .ngram import NGramTokenizer as NGramTokenizer
910
from .WordPiece import WordPieceTokenizer as WordPieceTokenizer

torchTextClassifiers/tokenizers/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ def __post_init__(self):
6565

6666
class BaseTokenizer(ABC):
6767
def __init__(
68-
self, vocab_size: int, output_vectorized: bool = False, output_dim: Optional[int] = None
68+
self,
69+
vocab_size: int,
70+
padding_idx: int,
71+
output_vectorized: bool = False,
72+
output_dim: Optional[int] = None,
6973
):
7074
"""
7175
Base class for tokenizers.
@@ -78,6 +82,7 @@ def __init__(
7882
self.vocab_size = vocab_size
7983
self.output_vectorized = output_vectorized
8084
self.output_dim = output_dim
85+
self.padding_idx = padding_idx
8186
if self.output_vectorized:
8287
if output_dim is None:
8388
raise ValueError(

torchTextClassifiers/tokenizers/ngram.py

Lines changed: 97 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import lru_cache
55
from typing import List, Optional, Tuple, Union
66

7+
import numpy as np
78
import torch
89

910
from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput
@@ -113,7 +114,7 @@ def get(self, word: str) -> List[int]:
113114

114115

115116
# ============================================================================
116-
# Vectorized encoding
117+
# Vectorized encoding with optional metadata
117118
# ============================================================================
118119

119120

@@ -124,33 +125,78 @@ def encode_batch_vectorized(
124125
pad_token_id: int,
125126
max_length: Optional[int] = None,
126127
truncation: bool = False,
127-
) -> Tuple[torch.Tensor, torch.Tensor]:
128+
return_offsets_mapping: bool = False,
129+
return_word_ids: bool = False,
130+
force_max_length: bool = False,
131+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[List], Optional[List]]:
128132
"""
129133
Vectorized batch encoding - processes all sentences together.
130-
Returns padded tensors directly.
134+
Returns padded tensors directly, with optional offset mappings and word IDs.
135+
136+
Args:
137+
force_max_length: If True and max_length is set, always return tensors of size max_length
131138
"""
132139
all_ids = []
140+
all_offsets = [] if return_offsets_mapping else None
141+
all_word_ids = [] if return_word_ids else None
133142
max_len = 0
134143

135144
# First pass: encode all sentences
136145
for sentence in sentences:
137146
ids = []
147+
offsets = [] if return_offsets_mapping else None
148+
word_ids = [] if return_word_ids else None
149+
138150
words = sentence.split()
151+
char_offset = 0
152+
153+
for word_idx, word in enumerate(words):
154+
# Find the actual position of this word in the original sentence
155+
word_start = sentence.find(word, char_offset)
156+
word_end = word_start + len(word)
157+
char_offset = word_end
158+
159+
# Get subword tokens for this word
160+
subword_tokens = subword_cache.get(word)
139161

140-
for word in words:
141-
ids.extend(subword_cache.get(word))
162+
for token_id in subword_tokens:
163+
ids.append(token_id)
142164

165+
if return_offsets_mapping:
166+
# All subword tokens of a word map to the word's character span
167+
offsets.append((word_start, word_end))
168+
169+
if return_word_ids:
170+
# All subword tokens of a word get the same word_id
171+
word_ids.append(word_idx)
172+
173+
# Add EOS token
143174
ids.append(eos_token_id)
175+
if return_offsets_mapping:
176+
offsets.append((len(sentence), len(sentence))) # EOS has no span
177+
if return_word_ids:
178+
word_ids.append(None) # EOS is not part of any word
144179

145180
# Truncate if needed
146181
if truncation and max_length and len(ids) > max_length:
147182
ids = ids[:max_length]
183+
if return_offsets_mapping:
184+
offsets = offsets[:max_length]
185+
if return_word_ids:
186+
word_ids = word_ids[:max_length]
148187

149188
all_ids.append(ids)
189+
if return_offsets_mapping:
190+
all_offsets.append(offsets)
191+
if return_word_ids:
192+
all_word_ids.append(word_ids)
150193
max_len = max(max_len, len(ids))
151194

152195
# Determine final sequence length
153-
if max_length and not truncation:
196+
if force_max_length and max_length:
197+
# Always use max_length when force_max_length is True
198+
seq_len = max_length
199+
elif max_length and not truncation:
154200
seq_len = min(max_len, max_length)
155201
elif max_length:
156202
seq_len = max_length
@@ -162,13 +208,22 @@ def encode_batch_vectorized(
162208
input_ids = torch.full((batch_size, seq_len), pad_token_id, dtype=torch.long)
163209
attention_mask = torch.zeros((batch_size, seq_len), dtype=torch.long)
164210

165-
# Fill tensors
211+
# Fill tensors and pad metadata
166212
for i, ids in enumerate(all_ids):
167213
length = min(len(ids), seq_len)
168214
input_ids[i, :length] = torch.tensor(ids[:length], dtype=torch.long)
169215
attention_mask[i, :length] = 1
170216

171-
return input_ids, attention_mask
217+
# Pad offsets and word_ids to match sequence length
218+
if return_offsets_mapping:
219+
# Pad with (0, 0) for padding tokens
220+
all_offsets[i] = all_offsets[i][:length] + [(0, 0)] * (seq_len - length)
221+
222+
if return_word_ids:
223+
# Pad with None for padding tokens
224+
all_word_ids[i] = all_word_ids[i][:length] + [None] * (seq_len - length)
225+
226+
return input_ids, attention_mask, all_offsets, all_word_ids
172227

173228

174229
# ============================================================================
@@ -183,8 +238,7 @@ class NGramTokenizer(BaseTokenizer):
183238
- Vectorized batch encoding
184239
- Cached text normalization
185240
- Direct tensor operations
186-
- No multiprocessing overhead
187-
- No Numba dependency
241+
- Optional offset mapping and word ID tracking
188242
"""
189243

190244
PAD_TOKEN = "[PAD]"
@@ -200,6 +254,7 @@ def __init__(
200254
len_word_ngrams: int,
201255
training_text: Optional[List[str]] = None,
202256
preprocess: bool = True,
257+
output_dim: Optional[int] = None,
203258
**kwargs,
204259
):
205260
if min_n < 2:
@@ -227,9 +282,11 @@ def __init__(
227282
self.subword_cache = None
228283

229284
self.vocab_size = 3 + self.nwords + self.num_tokens
230-
super().__init__(vocab_size=self.vocab_size)
285+
super().__init__(
286+
vocab_size=self.vocab_size, padding_idx=self.pad_token_id, output_dim=output_dim
287+
)
231288

232-
def _build_vocab(self, training_text: List[str]):
289+
def train(self, training_text: List[str]):
233290
"""Build vocabulary from training text."""
234291
word_counts = {}
235292
for sent in training_text:
@@ -261,16 +318,24 @@ def _build_vocab(self, training_text: List[str]):
261318
def tokenize(
262319
self,
263320
text: Union[str, List[str]],
264-
padding: str = "longest",
265-
max_length: Optional[int] = None,
266-
truncation: bool = False,
267321
return_offsets_mapping: bool = False,
268322
return_word_ids: bool = False,
269323
**kwargs,
270324
) -> TokenizerOutput:
271325
"""
272326
Optimized tokenization with vectorized operations.
273-
Note: return_offsets_mapping and return_word_ids removed for speed.
327+
328+
Args:
329+
text: Single string or list of strings to tokenize
330+
padding: Padding strategy ('longest' or 'max_length')
331+
max_length: Maximum sequence length
332+
truncation: Whether to truncate sequences exceeding max_length
333+
return_offsets_mapping: If True, return character offsets for each token
334+
return_word_ids: If True, return word indices for each token
335+
336+
Returns:
337+
TokenizerOutput with input_ids, attention_mask, and optionally
338+
offset_mapping and word_ids
274339
"""
275340
is_single = isinstance(text, str)
276341
if is_single:
@@ -280,21 +345,33 @@ def tokenize(
280345
if self.preprocess:
281346
text = clean_text_feature(text)
282347

348+
if self.output_dim is not None:
349+
max_length = self.output_dim
350+
truncation = True
351+
else:
352+
max_length = None
353+
truncation = False
354+
283355
# Vectorized encoding
284-
input_ids, attention_mask = encode_batch_vectorized(
356+
input_ids, attention_mask, offsets, word_ids = encode_batch_vectorized(
285357
text,
286358
self.subword_cache,
287359
self.eos_token_id,
288360
self.pad_token_id,
289-
max_length=max_length if padding == "max_length" else None,
361+
max_length=max_length,
290362
truncation=truncation,
363+
return_offsets_mapping=return_offsets_mapping,
364+
return_word_ids=return_word_ids,
291365
)
292366

367+
offsets = torch.tensor(offsets) if return_offsets_mapping else None
368+
word_ids = np.array(word_ids) if return_word_ids else None
369+
293370
return TokenizerOutput(
294371
input_ids=input_ids,
295372
attention_mask=attention_mask,
296-
word_ids=None,
297-
offset_mapping=None,
373+
word_ids=word_ids,
374+
offset_mapping=offsets,
298375
)
299376

300377
def decode(

0 commit comments

Comments
 (0)