|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +import json |
14 | 15 | from typing import Dict, List, Optional, Union |
15 | 16 |
|
16 | 17 | import safetensors |
17 | 18 | import torch |
18 | 19 | from huggingface_hub.utils import validate_hf_hub_args |
| 20 | +from tokenizers import Tokenizer as TokenizerFast |
19 | 21 | from torch import nn |
20 | 22 |
|
21 | 23 | from ..models.modeling_utils import load_state_dict |
@@ -547,23 +549,39 @@ def unload_textual_inversion( |
547 | 549 | else: |
548 | 550 | last_special_token_id = added_token_id |
549 | 551 |
|
550 | | - # Delete from tokenizer |
551 | | - for token_id, token_to_remove in zip(token_ids, tokens): |
552 | | - del tokenizer._added_tokens_decoder[token_id] |
553 | | - del tokenizer._added_tokens_encoder[token_to_remove] |
554 | | - |
555 | | - # Make all token ids sequential in tokenizer |
556 | | - key_id = 1 |
557 | | - for token_id in tokenizer.added_tokens_decoder: |
558 | | - if token_id > last_special_token_id and token_id > last_special_token_id + key_id: |
559 | | - token = tokenizer._added_tokens_decoder[token_id] |
560 | | - tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token |
| 552 | + # Fast tokenizers (v5+) |
| 553 | + if hasattr(tokenizer, "_tokenizer"): |
| 554 | + # Fast tokenizers: serialize, filter tokens, reload |
| 555 | + tokenizer_json = json.loads(tokenizer._tokenizer.to_str()) |
| 556 | + new_id = last_special_token_id + 1 |
| 557 | + filtered = [] |
| 558 | + for tok in tokenizer_json.get("added_tokens", []): |
| 559 | + if tok.get("content") in set(tokens): |
| 560 | + continue |
| 561 | + if not tok.get("special", False): |
| 562 | + tok["id"] = new_id |
| 563 | + new_id += 1 |
| 564 | + filtered.append(tok) |
| 565 | + tokenizer_json["added_tokens"] = filtered |
| 566 | + tokenizer._tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json)) |
| 567 | + else: |
| 568 | + # Slow tokenizers |
| 569 | + for token_id, token_to_remove in zip(token_ids, tokens): |
561 | 570 | del tokenizer._added_tokens_decoder[token_id] |
562 | | - tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id |
563 | | - key_id += 1 |
564 | | - tokenizer._update_trie() |
565 | | - # set correct total vocab size after removing tokens |
566 | | - tokenizer._update_total_vocab_size() |
| 571 | + del tokenizer._added_tokens_encoder[token_to_remove] |
| 572 | + |
| 573 | + key_id = 1 |
| 574 | + for token_id in list(tokenizer.added_tokens_decoder.keys()): |
| 575 | + if token_id > last_special_token_id and token_id > last_special_token_id + key_id: |
| 576 | + token = tokenizer._added_tokens_decoder[token_id] |
| 577 | + tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token |
| 578 | + del tokenizer._added_tokens_decoder[token_id] |
| 579 | + tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id |
| 580 | + key_id += 1 |
| 581 | + if hasattr(tokenizer, "_update_trie"): |
| 582 | + tokenizer._update_trie() |
| 583 | + if hasattr(tokenizer, "_update_total_vocab_size"): |
| 584 | + tokenizer._update_total_vocab_size() |
567 | 585 |
|
568 | 586 | # Delete from text encoder |
569 | 587 | text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim |
|
0 commit comments