Skip to content

Commit 59502a1

Browse files
committed
small fixes, many comments
1 parent 8611ad5 commit 59502a1

2 files changed

Lines changed: 48 additions & 15 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,13 @@ def distill_from_model(
8989
if not all_tokens:
9090
raise ValueError("The vocabulary is empty after preprocessing. Please check your token_remove_pattern.")
9191

92-
# Convert tokens to IDs
93-
token_ids = turn_tokens_into_ids(all_tokens, tokenizer)
94-
9592
# Create the embeddings.
9693
unk_token = tokenizer.special_tokens_map.get("unk_token")
9794
pad_token = tokenizer.special_tokens_map.get("pad_token")
95+
96+
# Convert tokens to IDs
97+
token_ids = turn_tokens_into_ids(all_tokens, tokenizer, unk_token)
98+
9899
embeddings = create_embeddings(
99100
tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
100101
)

model2vec/distill/tokenizer.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import json
44
import logging
55
import re
6-
from string import punctuation
76
from typing import Any
87

98
from tokenizers import Regex, Tokenizer
@@ -267,19 +266,24 @@ def _process_internal_tokens(
267266
# Figure out whether token is a subword or not.
268267
encoded = backend_tokenizer.encode(f" {'a' * 25}", add_special_tokens=False)
269268
first_token, second_token, *_ = encoded.tokens
270-
# Remove the space prefix if exists.
269+
# Isolate the prefix. We can't do first_token[0] because we don't know
270+
# how long the prefix is.
271271
# e.g., "Ġaaaa" -> "Ġ"
272-
word_prefix, _ = first_token.split("a", maxsplit=1)
272+
a_index = 0 if "a" not in first_token else first_token.index("a")
273+
word_prefix = first_token[:a_index]
273274
is_byte_prefix = word_prefix == "Ġ"
274275
second_token = encoded.tokens[1]
275276
# The second token is the first subword token.
276277
# If a tokenizer uses subwords, this token will have been prefixed.
277-
subword_prefix, _ = second_token.split("a", maxsplit=1)
278+
# We don't know how long the prefix is.
279+
a_index = 0 if "a" not in second_token else second_token.index("a")
280+
subword_prefix = second_token[:a_index]
278281

279282
pre_tokenizer: PreTokenizer | None = backend_tokenizer.pre_tokenizer
280283

281284
for token in internal_tokens:
282-
token_object = _create_single_internal_token(
285+
# Create the token objects. If this returns None, it was unsucessful for some reason.
286+
if token_object := _create_single_internal_token(
283287
token=token,
284288
subword_prefix=subword_prefix,
285289
word_prefix=word_prefix,
@@ -288,8 +292,7 @@ def _process_internal_tokens(
288292
token_remove_regex=token_remove_regex,
289293
added_tokens_to_keep=added_tokens_to_keep,
290294
added_tokens_to_remove=added_tokens_to_remove,
291-
)
292-
if token_object:
295+
):
293296
cleaned_internal_tokens.append(token_object)
294297

295298
return cleaned_internal_tokens
@@ -307,23 +310,38 @@ def _create_single_internal_token(
307310
) -> Token | None:
308311
"""Create a token object from a string."""
309312
if token in added_tokens_to_remove:
313+
# We remove any tokens that are added tokens that aren't [UNK] or [PAD].
310314
return None
311315
if token in added_tokens_to_keep:
312-
# Don't put special tokens through the regular motions.
316+
# Don't put added tokens through the regular motions.
313317
return Token(form=token, normalized_form=token, is_subword=False, is_internal=True)
314318
if token_remove_regex and token_remove_regex.match(token):
319+
# If the regex matches, remove the token.
315320
return None
321+
322+
# A token is a subword if there is a subword prefix and the word
323+
# starts with a subword prefix, or if there is a WORD prefix, and the word
324+
# does not start with this prefix. For metaspace tokenizers, for example:
325+
# "doghouse" -> ["_dog", "house"]
326+
# So we can only tell that "house" is a subword by knowing that it is not prefixed
327+
# and word-initial tokens are.
316328
is_subword = False
317329
if subword_prefix:
318330
is_subword = bool(token.startswith(subword_prefix))
319331
if word_prefix:
320332
is_subword = not bool(token.startswith(word_prefix))
321333

334+
# Byte prefixed tokenizers don't need to be checked.
322335
if pre_tokenizer is not None and not is_byte_prefix:
336+
# We need to check the thing without prefixes. If we have a word prefix,
337+
# we need to check tokens that have are subwords. Other way around for subword
338+
# prefixes.
323339
if (subword_prefix and not is_subword) or (word_prefix and is_subword):
340+
# If this is True, the token is unreachable, even though it is a subword token.
324341
if len(pre_tokenizer.pre_tokenize_str(token)) > 1:
325342
return None
326343

344+
# Turn a token into a normalized form for later processing.
327345
normalized_form = _create_normalized_form(token, subword_prefix, word_prefix, is_byte_prefix, is_subword)
328346

329347
return Token(form=token, normalized_form=normalized_form, is_subword=is_subword, is_internal=True)
@@ -345,32 +363,46 @@ def _create_normalized_form(
345363
return f"▁{token}"
346364

347365

348-
def turn_tokens_into_ids(tokens: list[Token], tokenizer: PreTrainedTokenizerFast) -> list[list[int]]:
366+
def turn_tokens_into_ids(tokens: list[Token], tokenizer: PreTrainedTokenizerFast, unk_token: str) -> list[list[int]]:
349367
"""
350368
Convert a list of Token objects to their corresponding token ID sequences.
351369
352370
:param tokens: List of Token objects to convert
353371
:param tokenizer: The tokenizer to use for converting tokens to IDs
372+
:param unk_token: The string form of the unk token.
354373
:return: List of token IDs corresponding to the input tokens
355374
"""
356-
# Implementation will be added later
375+
unk_id = tokenizer.convert_tokens_to_ids(unk_token)
357376
bos, _, eos = tokenizer.encode("a", add_special_tokens=True)
358377
token_ids = []
359378
for token in tokens:
360379
if token.is_internal:
361-
token_ids.append([bos, tokenizer.convert_tokens_to_ids(token.form), eos])
380+
# Careful. Any incorrect tokens will just get `[UNK]``, so this could go horribly wrong
381+
token_id = tokenizer.convert_tokens_to_ids(token.form)
382+
# Explicitly check and warn if `unk_id` appears, but don't crash.
383+
if token_id == unk_id and token.form != unk_token:
384+
logger.warning(f"Token {token.form} was set to unk. This is wrong.")
385+
token_ids.append([bos, token_id, eos])
362386
else:
363387
token_ids.append(tokenizer.encode(token.form))
364388

365389
return token_ids
366390

367391

368392
def _normalize_vocabulary_token(token: str, pre_tokenizer: PreTokenizer) -> str:
369-
# Add prefix space for byte-level tokenizers.
393+
"""Normalize a token that is not in the initial token vocabulary."""
394+
# Add prefix space for byte tokenizers.
370395
prefixed_token = f" {token}"
371396
pretokenized_tokens, offsets = zip(*pre_tokenizer.pre_tokenize_str(prefixed_token))
397+
# The first item is always the start of the token.
372398
new_token = [pretokenized_tokens[0]]
399+
# Loop over the subtokens and offsets.
373400
for t, (s, _) in zip(pretokenized_tokens[1:], offsets[1:]):
401+
# If the character before the subtoken is a space, we have a
402+
# multiword token. e.g., "room for the moon", which is split into
403+
# ["room", "for", "the", "moon"].
404+
# If it doesn't have a space, it is part of a complex multiword token,
405+
# e.g., "chat-gpt", which is split into ["chat", "-", "gpt"].
374406
if prefixed_token[s - 1] == " ":
375407
new_token.append(f" {t}")
376408
else:

0 commit comments

Comments
 (0)