Skip to content

Commit 2babd53

Browse files
committed
feat(textualInversion): very early support
1 parent 15331b7 commit 2babd53

6 files changed

Lines changed: 115 additions & 12 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dist/
1414
downloads/
1515
eggs/
1616
.eggs/
17-
lib/
17+
/lib/
1818
lib64/
1919
parts/
2020
sdist/

api/app.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,23 @@
2929
from threading import Timer
3030
import extras
3131

32+
from lib.textual_inversions import handle_textual_inversions
33+
from lib.vars import (
34+
RUNTIME_DOWNLOADS,
35+
USE_DREAMBOOTH,
36+
MODEL_ID,
37+
PIPELINE,
38+
HF_AUTH_TOKEN,
39+
HOME,
40+
MODELS_DIR,
41+
)
3242

33-
RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
34-
USE_DREAMBOOTH = os.getenv("USE_DREAMBOOTH") == "1"
3543
if USE_DREAMBOOTH:
3644
from train_dreambooth import TrainDreamBooth
37-
38-
MODEL_ID = os.environ.get("MODEL_ID")
39-
PIPELINE = os.environ.get("PIPELINE")
40-
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
41-
HOME = os.path.expanduser("~")
42-
MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
43-
4445
print(os.environ.get("USE_PATCHMATCH"))
4546
if os.environ.get("USE_PATCHMATCH") == "1":
4647
from PyPatchMatch import patch_match
4748

48-
4949
torch.set_grad_enabled(False)
5050
always_normalize_model_id = None
5151

@@ -332,6 +332,9 @@ def sendStatus():
332332
is_url = call_inputs.get("is_url", False)
333333
image_decoder = getFromUrl if is_url else decodeBase64Image
334334

335+
textual_inversions = call_inputs.get("textual_inversions", [])
336+
handle_textual_inversions(textual_inversions, model)
337+
335338
# Better to use new lora_weights in next section
336339
attn_procs = call_inputs.get("attn_procs", None)
337340
if attn_procs is not last_attn_procs:
@@ -388,7 +391,7 @@ def sendStatus():
388391
cache_fname = "lora_weights--" + fname
389392
path = os.path.join(MODELS_DIR, cache_fname)
390393
if not os.path.exists(path):
391-
storage.download_and_extract(path)
394+
storage.download_and_extract(path, status=status)
392395
print("Load lora_weights `" + lora_weights + "` from `" + path + "`")
393396
pipeline.load_lora_weights(
394397
MODELS_DIR, weight_name=cache_fname, local_files_only=True

api/lib/__init__.py

Whitespace-only changes.

api/lib/textual_inversions.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import json
2+
import re
3+
import os
4+
from utils import Storage
5+
from .vars import MODELS_DIR
6+
7+
last_textual_inversions = None
8+
last_textual_inversion_model = None
9+
loaded_textual_inversion_tokens = []
10+
11+
tokenRe = re.compile(
12+
r"[#&]{1}fname=(?P<fname>[^\.]+)\.(?:pt|safetensors)(&token=(?P<token>[^&]+))?$"
13+
)
14+
15+
16+
def strMap(str: str):
17+
match = re.search(tokenRe, str)
18+
print(match)
19+
if match:
20+
return match.group("token") or match.group("fname")
21+
22+
23+
def extract_tokens_from_list(textual_inversions: list):
24+
return list(map(strMap, textual_inversions))
25+
26+
27+
def handle_textual_inversions(textual_inversions: list, model):
28+
global last_textual_inversions
29+
global last_textual_inversion_model
30+
global loaded_textual_inversion_tokens
31+
32+
textual_inversions_str = json.dumps(textual_inversions)
33+
if (
34+
textual_inversions_str is not last_textual_inversions
35+
or model is not last_textual_inversion_model
36+
):
37+
if (model is not last_textual_inversion_model):
38+
loaded_textual_inversion_tokens = []
39+
last_textual_inversion_model = model
40+
# print({"textual_inversions": textual_inversions})
41+
# tokens_to_load = extract_tokens_from_list(textual_inversions)
42+
# print({"tokens_loaded": loaded_textual_inversion_tokens})
43+
# print({"tokens_to_load": tokens_to_load})
44+
#
45+
# for token in loaded_textual_inversion_tokens:
46+
# if token not in tokens_to_load:
47+
# print("[TextualInversion] Removing uneeded token: " + token)
48+
# del pipeline.tokenizer.get_vocab()[token]
49+
# # del pipeline.text_encoder.get_input_embeddings().weight.data[token]
50+
# pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
51+
#
52+
# loaded_textual_inversion_tokens = tokens_to_load
53+
54+
last_textual_inversions = textual_inversions_str
55+
for textual_inversion in textual_inversions:
56+
storage = Storage(textual_inversion, no_raise=True)
57+
if storage:
58+
storage_query_fname = storage.query.get("fname")
59+
if storage_query_fname:
60+
fname = storage_query_fname[0]
61+
else:
62+
fname = textual_inversion.split("/").pop()
63+
path = os.path.join(MODELS_DIR, "textual_inversion--" + fname)
64+
if not os.path.exists(path):
65+
storage.download_file(path)
66+
print("Load textual inversion " + path)
67+
token = storage.query.get("token", None)
68+
if token not in loaded_textual_inversion_tokens:
69+
model.load_textual_inversion(
70+
path, token=token, local_files_only=True
71+
)
72+
loaded_textual_inversion_tokens.append(token)
73+
else:
74+
print("Load textual inversion " + textual_inversion)
75+
model.load_textual_inversion(textual_inversion)

api/lib/textual_inversions_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import unittest
2+
from .textual_inversions import extract_tokens_from_list
3+
4+
5+
class TextualInversionsTest(unittest.TestCase):
6+
def test_extract_tokens_query_fname(self):
7+
tis = ["https://civitai.com/api/download/models/106132#fname=4nj0lie.pt"]
8+
tokens = extract_tokens_from_list(tis)
9+
self.assertEqual(tokens[0], "4nj0lie")
10+
11+
def test_extract_tokens_query_token(self):
12+
tis = [
13+
"https://civitai.com/api/download/models/106132#fname=4nj0lie.pt&token=4nj0lie"
14+
]
15+
tokens = extract_tokens_from_list(tis)
16+
self.assertEqual(tokens[0], "4nj0lie")

api/lib/vars.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import os
2+
3+
RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
4+
USE_DREAMBOOTH = os.getenv("USE_DREAMBOOTH") == "1"
5+
MODEL_ID = os.environ.get("MODEL_ID")
6+
PIPELINE = os.environ.get("PIPELINE")
7+
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
8+
HOME = os.path.expanduser("~")
9+
MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")

0 commit comments

Comments
 (0)