Skip to content

Commit 993be12

Browse files
committed
feat(sdxl,compel): Support. AutoPipeline default, safety_check fix
1 parent 747fc0d commit 993be12

5 files changed

Lines changed: 108 additions & 16 deletions

File tree

api/app.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
from threading import Timer
2929
import extras
3030

31+
from diffusers import StableDiffusionXLPipeline
32+
3133
from lib.textual_inversions import handle_textual_inversions
34+
from lib.prompts import prepare_prompts
3235
from lib.vars import (
3336
RUNTIME_DOWNLOADS,
3437
USE_DREAMBOOTH,
@@ -290,7 +293,7 @@ def sendStatus():
290293
if PIPELINE == "ALL":
291294
pipeline_name = call_inputs.get("PIPELINE", None)
292295
if not pipeline_name:
293-
pipeline_name = "StableDiffusionPipeline"
296+
pipeline_name = "AutoPipelineForText2Image"
294297
result["$meta"].update({"PIPELINE": pipeline_name})
295298

296299
pipeline = getPipelineForModel(
@@ -329,7 +332,11 @@ def sendStatus():
329332
}
330333

331334
safety_checker = call_inputs.get("safety_checker", True)
332-
pipeline.safety_checker = model.safety_checker if safety_checker else None
335+
pipeline.safety_checker = (
336+
model.safety_checker
337+
if safety_checker and hasattr(model, "safety_checker")
338+
else None
339+
)
333340
is_url = call_inputs.get("is_url", False)
334341
image_decoder = getFromUrl if is_url else decodeBase64Image
335342

@@ -399,6 +406,8 @@ def sendStatus():
399406
else 1
400407
)
401408
cross_attention_kwargs.update({"scale": storage_query_scale})
409+
# https://github.com/damian0815/compel/issues/42#issuecomment-1656989385
410+
pipeline._lora_scale = storage_query_scale
402411
if storage_query_fname:
403412
fname = storage_query_fname[0]
404413
else:
@@ -569,8 +578,22 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
569578
"inference", step / model_inputs.get("num_inference_steps", 50)
570579
)
571580

581+
is_sdxl = isinstance(model, StableDiffusionXLPipeline)
582+
print("is_sdxl", is_sdxl)
583+
572584
with torch.inference_mode():
573585
custom_pipeline_method = call_inputs.get("custom_pipeline_method", None)
586+
print(
587+
pipeline,
588+
{
589+
"cross_attention_kwargs": cross_attention_kwargs,
590+
"callback": callback,
591+
"**model_inputs": model_inputs,
592+
},
593+
)
594+
595+
if call_inputs.get("compel_prompts", False):
596+
prepare_prompts(pipeline, model_inputs, is_sdxl)
574597

575598
try:
576599
async_pipeline = asyncio.to_thread(
@@ -581,13 +604,13 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
581604
callback=callback,
582605
**model_inputs,
583606
)
584-
if call_inputs.get("PIPELINE") != "StableDiffusionPipeline":
585-
# autocast im2img and inpaint which are broken in 0.4.0, 0.4.1
586-
# still broken in 0.5.1
587-
with autocast(device_id):
588-
images = (await async_pipeline).images
589-
else:
590-
images = (await async_pipeline).images
607+
# if call_inputs.get("PIPELINE") != "StableDiffusionPipeline":
608+
# # autocast im2img and inpaint which are broken in 0.4.0, 0.4.1
609+
# # still broken in 0.5.1
610+
# with autocast(device_id):
611+
# images = (await async_pipeline).images
612+
# else:
613+
images = (await async_pipeline).images
591614

592615
except Exception as err:
593616
return {

api/getPipeline.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,11 @@ def getPipelineForModel(
6565
start = time.time()
6666

6767
if hasattr(diffusers_pipelines, pipeline_name):
68-
if hasattr(model, "components"):
69-
pipeline = getattr(diffusers_pipelines, pipeline_name)(**model.components)
68+
pipeline_class = getattr(diffusers_pipelines, pipeline_name)
69+
if hasattr(pipeline_class, "from_pipe"):
70+
pipeline = pipeline_class.from_pipe(model)
71+
elif hasattr(model, "components"):
72+
pipeline = pipeline_class(**model.components)
7073
else:
7174
pipeline = getattr(diffusers_pipelines, pipeline_name)(
7275
vae=model.vae,

api/lib/prompts.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
2+
3+
4+
def prepare_prompts(pipeline, model_inputs, is_sdxl):
5+
textual_inversion_manager = DiffusersTextualInversionManager(pipeline)
6+
if is_sdxl:
7+
compel = Compel(
8+
tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2],
9+
text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
10+
# diffusers has no ti in sdxl yet
11+
# https://github.com/huggingface/diffusers/issues/4376#issuecomment-1659016141
12+
# textual_inversion_manager=textual_inversion_manager,
13+
truncate_long_prompts=False,
14+
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
15+
requires_pooled=[False, True],
16+
)
17+
conditioning, pooled = compel(model_inputs.get("prompt"))
18+
negative_conditioning, negative_pooled = compel(
19+
model_inputs.get("negative_prompt")
20+
)
21+
[
22+
conditioning,
23+
negative_conditioning,
24+
] = compel.pad_conditioning_tensors_to_same_length(
25+
[conditioning, negative_conditioning]
26+
)
27+
model_inputs.update(
28+
{
29+
"prompt": None,
30+
"negative_prompt": None,
31+
"prompt_embeds": conditioning,
32+
"negative_prompt_embeds": negative_conditioning,
33+
"pooled_prompt_embeds": pooled,
34+
"negative_pooled_prompt_embeds": negative_pooled,
35+
}
36+
)
37+
38+
else:
39+
compel = Compel(
40+
tokenizer=pipeline.tokenizer,
41+
text_encoder=pipeline.text_encoder,
42+
textual_inversion_manager=textual_inversion_manager,
43+
truncate_long_prompts=False,
44+
)
45+
conditioning = compel(model_inputs.get("prompt"))
46+
negative_conditioning = compel(model_inputs.get("negative_prompt"))
47+
[
48+
conditioning,
49+
negative_conditioning,
50+
] = compel.pad_conditioning_tensors_to_same_length(
51+
[conditioning, negative_conditioning]
52+
)
53+
model_inputs.update(
54+
{
55+
"prompt": None,
56+
"negative_prompt": None,
57+
"prompt_embeds": conditioning,
58+
"negative_prompt_embeds": negative_conditioning,
59+
}
60+
)

api/loadModel.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import os
3-
from diffusers import pipelines as _pipelines, StableDiffusionPipeline
3+
from diffusers import pipelines as _pipelines, AutoPipelineForText2Image
44
from getScheduler import getScheduler, DEFAULT_SCHEDULER
55
from precision import torch_dtype_from_precision
66
from device import device
@@ -25,7 +25,14 @@
2525
]
2626

2727

28-
def loadModel(model_id: str, load=True, precision=None, revision=None, send_opts={}):
28+
def loadModel(
29+
model_id: str,
30+
load=True,
31+
precision=None,
32+
revision=None,
33+
send_opts={},
34+
pipeline_class=AutoPipelineForText2Image,
35+
):
2936
torch_dtype = torch_dtype_from_precision(precision)
3037
if revision == "":
3138
revision = None
@@ -46,9 +53,7 @@ def loadModel(model_id: str, load=True, precision=None, revision=None, send_opts
4653
+ (f" ({revision})" if revision else "")
4754
)
4855

49-
pipeline = (
50-
StableDiffusionPipeline if PIPELINE == "ALL" else getattr(_pipelines, PIPELINE)
51-
)
56+
pipeline = pipeline_class if PIPELINE == "ALL" else getattr(_pipelines, PIPELINE)
5257

5358
scheduler = getScheduler(model_id, DEFAULT_SCHEDULER, not load)
5459

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,4 @@ xtarfile[zstd]==0.1.0
6262
bitsandbytes==0.40.2 # released 2023-07-17
6363

6464
invisible-watermark==0.2.0 # released 2023-07-06
65+
compel==2.0.1 # released 2023-07-29

0 commit comments

Comments
 (0)