Skip to content

Commit 2ac39ba

Browse files
itazapsayakpaulgithub-actions[bot]
authored
fast tok update (#13036)
* v5 tok update * ruff * keep pre v5 slow code path * Apply style fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent ef91301 commit 2ac39ba

File tree

1 file changed

+34
-16
lines changed

1 file changed

+34
-16
lines changed

src/diffusers/loaders/textual_inversion.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import json
1415
from typing import Dict, List, Optional, Union
1516

1617
import safetensors
1718
import torch
1819
from huggingface_hub.utils import validate_hf_hub_args
20+
from tokenizers import Tokenizer as TokenizerFast
1921
from torch import nn
2022

2123
from ..models.modeling_utils import load_state_dict
@@ -547,23 +549,39 @@ def unload_textual_inversion(
547549
else:
548550
last_special_token_id = added_token_id
549551

550-
# Delete from tokenizer
551-
for token_id, token_to_remove in zip(token_ids, tokens):
552-
del tokenizer._added_tokens_decoder[token_id]
553-
del tokenizer._added_tokens_encoder[token_to_remove]
554-
555-
# Make all token ids sequential in tokenizer
556-
key_id = 1
557-
for token_id in tokenizer.added_tokens_decoder:
558-
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
559-
token = tokenizer._added_tokens_decoder[token_id]
560-
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
552+
# Fast tokenizers (v5+)
553+
if hasattr(tokenizer, "_tokenizer"):
554+
# Fast tokenizers: serialize, filter tokens, reload
555+
tokenizer_json = json.loads(tokenizer._tokenizer.to_str())
556+
new_id = last_special_token_id + 1
557+
filtered = []
558+
for tok in tokenizer_json.get("added_tokens", []):
559+
if tok.get("content") in set(tokens):
560+
continue
561+
if not tok.get("special", False):
562+
tok["id"] = new_id
563+
new_id += 1
564+
filtered.append(tok)
565+
tokenizer_json["added_tokens"] = filtered
566+
tokenizer._tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json))
567+
else:
568+
# Slow tokenizers
569+
for token_id, token_to_remove in zip(token_ids, tokens):
561570
del tokenizer._added_tokens_decoder[token_id]
562-
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
563-
key_id += 1
564-
tokenizer._update_trie()
565-
# set correct total vocab size after removing tokens
566-
tokenizer._update_total_vocab_size()
571+
del tokenizer._added_tokens_encoder[token_to_remove]
572+
573+
key_id = 1
574+
for token_id in list(tokenizer.added_tokens_decoder.keys()):
575+
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
576+
token = tokenizer._added_tokens_decoder[token_id]
577+
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
578+
del tokenizer._added_tokens_decoder[token_id]
579+
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
580+
key_id += 1
581+
if hasattr(tokenizer, "_update_trie"):
582+
tokenizer._update_trie()
583+
if hasattr(tokenizer, "_update_total_vocab_size"):
584+
tokenizer._update_total_vocab_size()
567585

568586
# Delete from text encoder
569587
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim

0 commit comments

Comments
 (0)