Skip to content

Commit 7a64846

Browse files
committed
feat(loras): use load_lora_weights (works with A1111 files too)
UNFINISHED. Initial implementation works but needs more testing. Also, let's from the get-go support an array of LoRAs (for when diffusers allows multi loras in a future release).
1 parent 4fe13ef commit 7a64846

2 files changed

Lines changed: 139 additions & 0 deletions

File tree

api/app.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def truncateInputs(inputs: dict):
130130

131131
last_xformers_memory_efficient_attention = {}
132132
last_attn_procs = None
133+
last_lora_weights = None
133134

134135

135136
# Inference is ran for every server call
@@ -143,6 +144,7 @@ async def inference(all_inputs: dict, response) -> dict:
143144
global last_xformers_memory_efficient_attention
144145
global always_normalize_model_id
145146
global last_attn_procs
147+
global last_lora_weights
146148

147149
clearSession()
148150

@@ -244,6 +246,8 @@ def sendStatus():
244246
"loadModel", "done", {"startRequestId": startRequestId}, send_opts
245247
)
246248
last_model_id = normalized_model_id
249+
last_attn_procs = None
250+
last_lora_weights = None
247251
else:
248252
if always_normalize_model_id:
249253
normalized_model_id = always_normalize_model_id
@@ -312,8 +316,13 @@ def sendStatus():
312316
is_url = call_inputs.get("is_url", False)
313317
image_decoder = getFromUrl if is_url else decodeBase64Image
314318

319+
# Better to use new lora_weights in next section
315320
attn_procs = call_inputs.get("attn_procs", None)
316321
if attn_procs is not last_attn_procs:
322+
print(
323+
"[DEPRECATED] Using `attn_procs` for LoRAs is deprecated. "
324+
+ "Please use `lora_weights` instead."
325+
)
317326
last_attn_procs = attn_procs
318327
if attn_procs:
319328
storage = Storage(attn_procs, no_raise=True)
@@ -344,6 +353,40 @@ def sendStatus():
344353
print("Clearing attn procs")
345354
pipeline.unet.set_attn_processor(CrossAttnProcessor())
346355

356+
# Currently we only support a single string, but we should allow
357+
# and array too in anticipation of multi-LoRA support in diffusers
358+
# tracked at https://github.com/huggingface/diffusers/issues/2613.
359+
lora_weights = call_inputs.get("lora_weights", None)
360+
if lora_weights is not last_lora_weights:
361+
last_lora_weights = lora_weights
362+
if lora_weights:
363+
pipeline.unet.set_attn_processor(CrossAttnProcessor())
364+
storage = Storage(lora_weights, no_raise=True)
365+
if storage:
366+
storage_query_fname = storage.query.get("fname")
367+
if storage_query_fname:
368+
fname = storage_query_fname[0]
369+
else:
370+
hash = sha256(lora_weights.encode("utf-8")).hexdigest()
371+
fname = "url_" + hash[:7] + "--" + storage.url.split("/").pop()
372+
cache_fname = "lora_weights--" + fname
373+
path = os.path.join(MODELS_DIR, cache_fname)
374+
if not os.path.exists(path):
375+
storage.download_and_extract(path)
376+
print("Load lora_weights `" + lora_weights + "` from `" + path + "`")
377+
pipeline.load_lora_weights(
378+
MODELS_DIR, weight_name=cache_fname, local_files_only=True
379+
)
380+
else:
381+
print("Loading from huggingface not supported yet: " + lora_weights)
382+
# maybe something like sayakpaul/civitai-light-shadow-lora#lora=l_a_s.s9s?
383+
# lora_model_id = "sayakpaul/civitai-light-shadow-lora"
384+
# lora_filename = "light_and_shadow.safetensors"
385+
# pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
386+
else:
387+
print("Clearing attn procs")
388+
pipeline.unet.set_attn_processor(CrossAttnProcessor())
389+
347390
# TODO, generalize
348391
cross_attention_kwargs = model_inputs.get("cross_attention_kwargs", None)
349392
if isinstance(cross_attention_kwargs, str):

tests/integration/test_loras.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import sys
2+
import os
3+
from .lib import getMinio, getDDA
4+
from test import runTest
5+
6+
7+
class TestLoRAs:
8+
def setup_class(self):
9+
print("setup_class")
10+
# self.minio = minio = getMinio("global")
11+
12+
self.dda = dda = getDDA(
13+
# minio=minio
14+
stream_logs=True,
15+
)
16+
print(dda)
17+
18+
self.TEST_ARGS = {"test_url": dda.url}
19+
20+
def teardown_class(self):
21+
print("teardown_class")
22+
# self.minio.stop() - leave global up
23+
self.dda.stop()
24+
25+
if False:
26+
27+
def test_lora_hf_download(self):
28+
"""
29+
Download user/repo from HuggingFace.
30+
"""
31+
# fp32 model is obviously bigger
32+
result = runTest(
33+
"txt2img",
34+
self.TEST_ARGS,
35+
{
36+
"MODEL_ID": "runwayml/stable-diffusion-v1-5",
37+
"MODEL_REVISION": "fp16",
38+
"MODEL_PRECISION": "fp16",
39+
"attn_procs": "patrickvonplaten/lora_dreambooth_dog_example",
40+
},
41+
{
42+
"num_inference_steps": 1,
43+
"prompt": "A picture of a sks dog in a bucket",
44+
"seed": 1,
45+
"cross_attention_kwargs": {"scale": 0.5},
46+
},
47+
)
48+
49+
assert result["image_base64"]
50+
51+
if False:
52+
53+
def test_lora_http_download_pytorch_bin(self):
54+
"""
55+
Download pytroch_lora_weights.bin directly.
56+
"""
57+
result = runTest(
58+
"txt2img",
59+
self.TEST_ARGS,
60+
{
61+
"MODEL_ID": "runwayml/stable-diffusion-v1-5",
62+
"MODEL_REVISION": "fp16",
63+
"MODEL_PRECISION": "fp16",
64+
"attn_procs": "https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/resolve/main/pytorch_lora_weights.bin",
65+
},
66+
{
67+
"num_inference_steps": 1,
68+
"prompt": "A picture of a sks dog in a bucket",
69+
"seed": 1,
70+
"cross_attention_kwargs": {"scale": 0.5},
71+
},
72+
)
73+
74+
assert result["image_base64"]
75+
76+
# These formats are not supported by diffusers yet :(
77+
def test_lora_http_download_civitai_safetensors(self):
78+
result = runTest(
79+
"txt2img",
80+
self.TEST_ARGS,
81+
{
82+
"MODEL_ID": "runwayml/stable-diffusion-v1-5",
83+
"MODEL_REVISION": "fp16",
84+
"MODEL_PRECISION": "fp16",
85+
# https://civitai.com/models/5373/makima-chainsaw-man-lora
86+
"lora_weights": "https://civitai.com/api/download/models/6244#fname=makima_offset.safetensors",
87+
"safety_checker": False,
88+
},
89+
{
90+
"num_inference_steps": 1,
91+
"prompt": "masterpiece, (photorealistic:1.4), best quality, beautiful lighting, (ulzzang-6500:0.5), makima \(chainsaw man\), (red hair)+(long braided hair)+(bangs), yellow eyes, golden eyes, ((ringed eyes)), (white shirt), (necktie), RAW photo, 8k uhd, film grain",
92+
"seed": 1,
93+
},
94+
)
95+
96+
assert result["image_base64"]

0 commit comments

Comments
 (0)