Skip to content

Commit 3f1f980

Browse files
committed
fix(x_attn_kwargs): only pass to pipeline if set
1 parent 5f46faa commit 3f1f980

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

api/app.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def sendStatus():
254254
)
255255
# downloaded_models.update({normalized_model_id: True})
256256
clearPipelines()
257+
cross_attention_kwargs = None
257258
if model:
258259
model.to("cpu") # Necessary to avoid a memory leak
259260
await send(
@@ -287,6 +288,7 @@ def sendStatus():
287288
if MODEL_ID == "ALL":
288289
if last_model_id != normalized_model_id:
289290
clearPipelines()
291+
cross_attention_kwargs = None
290292
model = loadModel(normalized_model_id, send_opts=send_opts)
291293
last_model_id = normalized_model_id
292294
else:
@@ -447,8 +449,12 @@ def sendStatus():
447449
if mi_cross_attention_kwargs:
448450
model_inputs.pop("cross_attention_kwargs")
449451
if isinstance(mi_cross_attention_kwargs, str):
452+
if not cross_attention_kwargs:
453+
cross_attention_kwargs = {}
450454
cross_attention_kwargs.update(json.loads(mi_cross_attention_kwargs))
451455
elif type(mi_cross_attention_kwargs) == dict:
456+
if not cross_attention_kwargs:
457+
cross_attention_kwargs = {}
452458
cross_attention_kwargs.update(mi_cross_attention_kwargs)
453459
else:
454460
return {
@@ -459,6 +465,8 @@ def sendStatus():
459465
}
460466

461467
print({"cross_attention_kwargs": cross_attention_kwargs})
468+
if cross_attention_kwargs:
469+
model_inputs.update({"cross_attention_kwargs": cross_attention_kwargs})
462470

463471
# Parse out your arguments
464472
# prompt = model_inputs.get("prompt", None)
@@ -595,14 +603,11 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
595603
or isinstance(model, StableDiffusionXLImg2ImgPipeline)
596604
or isinstance(model, StableDiffusionXLInpaintPipeline)
597605
)
598-
print("is_sdxl", is_sdxl)
599606

600607
with torch.inference_mode():
601608
custom_pipeline_method = call_inputs.get("custom_pipeline_method", None)
602609
print(
603-
pipeline,
604610
{
605-
"cross_attention_kwargs": cross_attention_kwargs,
606611
"callback": callback,
607612
"**model_inputs": model_inputs,
608613
},
@@ -616,7 +621,6 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
616621
getattr(pipeline, custom_pipeline_method)
617622
if custom_pipeline_method
618623
else pipeline,
619-
cross_attention_kwargs=cross_attention_kwargs,
620624
callback=callback,
621625
**model_inputs,
622626
)

0 commit comments

Comments
 (0)