Skip to content

Commit 6348836

Browse files
committed
feat(attn_procs): initial URL work (see notes)
TODO: * test S3 for single files (should work, http works) * test archives (totally untested, all archive code is new)
1 parent ee2d835 commit 6348836

2 files changed

Lines changed: 58 additions & 11 deletions

File tree

api/app.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
import traceback
2222
from precision import MODEL_REVISION, MODEL_PRECISION
2323
from device import device, device_id, device_name
24-
from diffusers.models.cross_attention import CrossAttnProcessor, LoRACrossAttnProcessor
24+
from diffusers.models.cross_attention import CrossAttnProcessor
25+
from utils import Storage
26+
from hashlib import sha256
27+
2528

2629
RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
2730
USE_DREAMBOOTH = os.getenv("USE_DREAMBOOTH") == "1"
@@ -278,6 +281,18 @@ def inference(all_inputs: dict) -> dict:
278281
if attn_procs is not last_attn_procs:
279282
last_attn_procs = attn_procs
280283
if attn_procs:
284+
storage = Storage(attn_procs, no_raise=True)
285+
if storage:
286+
fname = storage.url.split("/").pop()
287+
hash = sha256(attn_procs.encode("utf-8")).hexdigest()
288+
if True:
289+
# TODO, way to specify explicit name
290+
path = os.path.join(
291+
MODELS_DIR, "attn_proc--url_" + hash[:7] + "--" + fname
292+
)
293+
attn_procs = path
294+
if not os.path.exists(path):
295+
storage.download_and_extract(path)
281296
print("Load attn_procs " + attn_procs)
282297
pipeline.unet.load_attn_procs(attn_procs)
283298
else:
@@ -286,7 +301,7 @@ def inference(all_inputs: dict) -> dict:
286301

287302
# TODO, generalize
288303
cross_attention_kwargs = model_inputs.get("cross_attention_kwargs", None)
289-
if cross_attention_kwargs:
304+
if isinstance(cross_attention_kwargs, str):
290305
model_inputs["cross_attention_kwargs"] = json.loads(cross_attention_kwargs)
291306

292307
# Parse out your arguments

tests/integration/test_attn_procs.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,31 @@
55

66

77
class TestAttnProcs:
8-
def test_hf_download(self):
9-
"""
10-
Make sure when switching models we release VRAM afterwards.
11-
"""
12-
dda = getDDA(
8+
def setup_class(self):
9+
print("setup_class")
10+
# self.minio = minio = getMinio("global")
11+
12+
self.dda = dda = getDDA(
13+
# minio=minio
1314
stream_logs=True,
1415
)
1516
print(dda)
1617

17-
TEST_ARGS = {"test_url": dda.url}
18+
self.TEST_ARGS = {"test_url": dda.url}
1819

19-
mem_usage = list()
20+
def teardown_class(self):
21+
print("teardown_class")
22+
# self.minio.stop() - leave global up
23+
self.dda.stop()
2024

25+
def test_hf_download(self):
26+
"""
27+
Download user/repo from HuggingFace.
28+
"""
2129
# fp32 model is obviously bigger
2230
result = runTest(
2331
"txt2img",
24-
TEST_ARGS,
32+
self.TEST_ARGS,
2533
{
2634
"MODEL_ID": "runwayml/stable-diffusion-v1-5",
2735
"MODEL_REVISION": "fp16",
@@ -37,4 +45,28 @@ def test_hf_download(self):
3745
)
3846

3947
assert result["image_base64"]
40-
dda.stop()
48+
49+
def test_http_download_diffusers_archive(self):
50+
"""
51+
Download user/repo from HuggingFace.
52+
"""
53+
54+
# fp32 model is obviously bigger
55+
result = runTest(
56+
"txt2img",
57+
self.TEST_ARGS,
58+
{
59+
"MODEL_ID": "runwayml/stable-diffusion-v1-5",
60+
"MODEL_REVISION": "fp16",
61+
"MODEL_PRECISION": "fp16",
62+
"attn_procs": "https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/resolve/main/pytorch_lora_weights.bin",
63+
},
64+
{
65+
"num_inference_steps": 1,
66+
"prompt": "A picture of a sks dog in a bucket",
67+
"seed": 1,
68+
"cross_attention_kwargs": {"scale": 0.5},
69+
},
70+
)
71+
72+
assert result["image_base64"]

0 commit comments

Comments
 (0)