2828from threading import Timer
2929import extras
3030
31+ from diffusers import StableDiffusionXLPipeline
32+
3133from lib .textual_inversions import handle_textual_inversions
34+ from lib .prompts import prepare_prompts
3235from lib .vars import (
3336 RUNTIME_DOWNLOADS ,
3437 USE_DREAMBOOTH ,
@@ -290,7 +293,7 @@ def sendStatus():
290293 if PIPELINE == "ALL" :
291294 pipeline_name = call_inputs .get ("PIPELINE" , None )
292295 if not pipeline_name :
293- pipeline_name = "StableDiffusionPipeline "
296+ pipeline_name = "AutoPipelineForText2Image "
294297 result ["$meta" ].update ({"PIPELINE" : pipeline_name })
295298
296299 pipeline = getPipelineForModel (
@@ -329,7 +332,11 @@ def sendStatus():
329332 }
330333
331334 safety_checker = call_inputs .get ("safety_checker" , True )
332- pipeline .safety_checker = model .safety_checker if safety_checker else None
335+ pipeline .safety_checker = (
336+ model .safety_checker
337+ if safety_checker and hasattr (model , "safety_checker" )
338+ else None
339+ )
333340 is_url = call_inputs .get ("is_url" , False )
334341 image_decoder = getFromUrl if is_url else decodeBase64Image
335342
@@ -399,6 +406,8 @@ def sendStatus():
399406 else 1
400407 )
401408 cross_attention_kwargs .update ({"scale" : storage_query_scale })
409+ # https://github.com/damian0815/compel/issues/42#issuecomment-1656989385
410+ pipeline ._lora_scale = storage_query_scale
402411 if storage_query_fname :
403412 fname = storage_query_fname [0 ]
404413 else :
@@ -569,8 +578,22 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
569578 "inference" , step / model_inputs .get ("num_inference_steps" , 50 )
570579 )
571580
581+ is_sdxl = isinstance (model , StableDiffusionXLPipeline )
582+ print ("is_sdxl" , is_sdxl )
583+
572584 with torch .inference_mode ():
573585 custom_pipeline_method = call_inputs .get ("custom_pipeline_method" , None )
586+ print (
587+ pipeline ,
588+ {
589+ "cross_attention_kwargs" : cross_attention_kwargs ,
590+ "callback" : callback ,
591+ "**model_inputs" : model_inputs ,
592+ },
593+ )
594+
595+ if call_inputs .get ("compel_prompts" , False ):
596+ prepare_prompts (pipeline , model_inputs , is_sdxl )
574597
575598 try :
576599 async_pipeline = asyncio .to_thread (
@@ -581,13 +604,13 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
581604 callback = callback ,
582605 ** model_inputs ,
583606 )
584- if call_inputs .get ("PIPELINE" ) != "StableDiffusionPipeline" :
585- # autocast im2img and inpaint which are broken in 0.4.0, 0.4.1
586- # still broken in 0.5.1
587- with autocast (device_id ):
588- images = (await async_pipeline ).images
589- else :
590- images = (await async_pipeline ).images
607+ # if call_inputs.get("PIPELINE") != "StableDiffusionPipeline":
608+ # # autocast im2img and inpaint which are broken in 0.4.0, 0.4.1
609+ # # still broken in 0.5.1
610+ # with autocast(device_id):
611+ # images = (await async_pipeline).images
612+ # else:
613+ images = (await async_pipeline ).images
591614
592615 except Exception as err :
593616 return {
0 commit comments