@@ -254,6 +254,7 @@ def sendStatus():
254254 )
255255 # downloaded_models.update({normalized_model_id: True})
256256 clearPipelines ()
257+ cross_attention_kwargs = None
257258 if model :
258259 model .to ("cpu" ) # Necessary to avoid a memory leak
259260 await send (
@@ -287,6 +288,7 @@ def sendStatus():
287288 if MODEL_ID == "ALL" :
288289 if last_model_id != normalized_model_id :
289290 clearPipelines ()
291+ cross_attention_kwargs = None
290292 model = loadModel (normalized_model_id , send_opts = send_opts )
291293 last_model_id = normalized_model_id
292294 else :
@@ -447,8 +449,12 @@ def sendStatus():
447449 if mi_cross_attention_kwargs :
448450 model_inputs .pop ("cross_attention_kwargs" )
449451 if isinstance (mi_cross_attention_kwargs , str ):
452+ if not cross_attention_kwargs :
453+ cross_attention_kwargs = {}
450454 cross_attention_kwargs .update (json .loads (mi_cross_attention_kwargs ))
451455 elif type (mi_cross_attention_kwargs ) == dict :
456+ if not cross_attention_kwargs :
457+ cross_attention_kwargs = {}
452458 cross_attention_kwargs .update (mi_cross_attention_kwargs )
453459 else :
454460 return {
@@ -459,6 +465,8 @@ def sendStatus():
459465 }
460466
461467 print ({"cross_attention_kwargs" : cross_attention_kwargs })
468+ if cross_attention_kwargs :
469+ model_inputs .update ({"cross_attention_kwargs" : cross_attention_kwargs })
462470
463471 # Parse out your arguments
464472 # prompt = model_inputs.get("prompt", None)
@@ -595,14 +603,11 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
595603 or isinstance (model , StableDiffusionXLImg2ImgPipeline )
596604 or isinstance (model , StableDiffusionXLInpaintPipeline )
597605 )
598- print ("is_sdxl" , is_sdxl )
599606
600607 with torch .inference_mode ():
601608 custom_pipeline_method = call_inputs .get ("custom_pipeline_method" , None )
602609 print (
603- pipeline ,
604610 {
605- "cross_attention_kwargs" : cross_attention_kwargs ,
606611 "callback" : callback ,
607612 "**model_inputs" : model_inputs ,
608613 },
@@ -616,7 +621,6 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
616621 getattr (pipeline , custom_pipeline_method )
617622 if custom_pipeline_method
618623 else pipeline ,
619- cross_attention_kwargs = cross_attention_kwargs ,
620624 callback = callback ,
621625 ** model_inputs ,
622626 )
0 commit comments