|
| 1 | +import json |
| 2 | +import re |
| 3 | +import os |
| 4 | +from utils import Storage |
| 5 | +from .vars import MODELS_DIR |
| 6 | + |
| 7 | +last_textual_inversions = None |
| 8 | +last_textual_inversion_model = None |
| 9 | +loaded_textual_inversion_tokens = [] |
| 10 | + |
| 11 | +tokenRe = re.compile( |
| 12 | + r"[#&]{1}fname=(?P<fname>[^\.]+)\.(?:pt|safetensors)(&token=(?P<token>[^&]+))?$" |
| 13 | +) |
| 14 | + |
| 15 | + |
| 16 | +def strMap(str: str): |
| 17 | + match = re.search(tokenRe, str) |
| 18 | + print(match) |
| 19 | + if match: |
| 20 | + return match.group("token") or match.group("fname") |
| 21 | + |
| 22 | + |
| 23 | +def extract_tokens_from_list(textual_inversions: list): |
| 24 | + return list(map(strMap, textual_inversions)) |
| 25 | + |
| 26 | + |
| 27 | +def handle_textual_inversions(textual_inversions: list, model): |
| 28 | + global last_textual_inversions |
| 29 | + global last_textual_inversion_model |
| 30 | + global loaded_textual_inversion_tokens |
| 31 | + |
| 32 | + textual_inversions_str = json.dumps(textual_inversions) |
| 33 | + if ( |
| 34 | + textual_inversions_str is not last_textual_inversions |
| 35 | + or model is not last_textual_inversion_model |
| 36 | + ): |
| 37 | + if (model is not last_textual_inversion_model): |
| 38 | + loaded_textual_inversion_tokens = [] |
| 39 | + last_textual_inversion_model = model |
| 40 | + # print({"textual_inversions": textual_inversions}) |
| 41 | + # tokens_to_load = extract_tokens_from_list(textual_inversions) |
| 42 | + # print({"tokens_loaded": loaded_textual_inversion_tokens}) |
| 43 | + # print({"tokens_to_load": tokens_to_load}) |
| 44 | + # |
| 45 | + # for token in loaded_textual_inversion_tokens: |
| 46 | + # if token not in tokens_to_load: |
| 47 | + # print("[TextualInversion] Removing uneeded token: " + token) |
| 48 | + # del pipeline.tokenizer.get_vocab()[token] |
| 49 | + # # del pipeline.text_encoder.get_input_embeddings().weight.data[token] |
| 50 | + # pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer)) |
| 51 | + # |
| 52 | + # loaded_textual_inversion_tokens = tokens_to_load |
| 53 | + |
| 54 | + last_textual_inversions = textual_inversions_str |
| 55 | + for textual_inversion in textual_inversions: |
| 56 | + storage = Storage(textual_inversion, no_raise=True) |
| 57 | + if storage: |
| 58 | + storage_query_fname = storage.query.get("fname") |
| 59 | + if storage_query_fname: |
| 60 | + fname = storage_query_fname[0] |
| 61 | + else: |
| 62 | + fname = textual_inversion.split("/").pop() |
| 63 | + path = os.path.join(MODELS_DIR, "textual_inversion--" + fname) |
| 64 | + if not os.path.exists(path): |
| 65 | + storage.download_file(path) |
| 66 | + print("Load textual inversion " + path) |
| 67 | + token = storage.query.get("token", None) |
| 68 | + if token not in loaded_textual_inversion_tokens: |
| 69 | + model.load_textual_inversion( |
| 70 | + path, token=token, local_files_only=True |
| 71 | + ) |
| 72 | + loaded_textual_inversion_tokens.append(token) |
| 73 | + else: |
| 74 | + print("Load textual inversion " + textual_inversion) |
| 75 | + model.load_textual_inversion(textual_inversion) |
0 commit comments