Skip to content

Commit 747fc0d

Browse files
committed
feat(loras): ability to specify #?scale=0.1 -> cross_attn_kwargs
1 parent 106dc4b commit 747fc0d

1 file changed

Lines changed: 26 additions & 3 deletions

File tree

api/app.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def truncateInputs(inputs: dict):
122122
# last_xformers_memory_efficient_attention = {}
123123
last_attn_procs = None
124124
last_lora_weights = None
125+
cross_attention_kwargs = None
125126

126127

127128
# Inference is ran for every server call
@@ -135,6 +136,7 @@ async def inference(all_inputs: dict, response) -> dict:
135136
global always_normalize_model_id
136137
global last_attn_procs
137138
global last_lora_weights
139+
global cross_attention_kwargs
138140

139141
clearSession()
140142

@@ -379,6 +381,7 @@ def sendStatus():
379381
lora_weights_joined = json.dumps(lora_weights)
380382
if last_lora_weights != lora_weights_joined:
381383
last_lora_weights = lora_weights_joined
384+
cross_attention_kwargs = {}
382385
print("Unloading previous LoRA weights")
383386
pipeline.unload_lora_weights()
384387

@@ -390,6 +393,12 @@ def sendStatus():
390393
storage = Storage(weights, no_raise=True, status=status)
391394
if storage:
392395
storage_query_fname = storage.query.get("fname")
396+
storage_query_scale = (
397+
float(storage.query.get("scale")[0])
398+
if storage.query.get("scale")
399+
else 1
400+
)
401+
cross_attention_kwargs.update({"scale": storage_query_scale})
393402
if storage_query_fname:
394403
fname = storage_query_fname[0]
395404
else:
@@ -413,9 +422,22 @@ def sendStatus():
413422
print("No changes to LoRAs since last call")
414423

415424
# TODO, generalize
416-
cross_attention_kwargs = model_inputs.get("cross_attention_kwargs", None)
417-
if isinstance(cross_attention_kwargs, str):
418-
model_inputs["cross_attention_kwargs"] = json.loads(cross_attention_kwargs)
425+
mi_cross_attention_kwargs = model_inputs.get("cross_attention_kwargs", None)
426+
if mi_cross_attention_kwargs:
427+
model_inputs.pop("cross_attention_kwargs")
428+
if isinstance(mi_cross_attention_kwargs, str):
429+
cross_attention_kwargs.update(json.loads(mi_cross_attention_kwargs))
430+
elif type(mi_cross_attention_kwargs) == dict:
431+
cross_attention_kwargs.update(mi_cross_attention_kwargs)
432+
else:
433+
return {
434+
"$error": {
435+
"code": "INVALID_CROSS_ATTENTION_KWARGS",
436+
"message": "`cross_attention_kwargs` should be a dict or json string",
437+
}
438+
}
439+
440+
print({"cross_attention_kwargs": cross_attention_kwargs})
419441

420442
# Parse out your arguments
421443
# prompt = model_inputs.get("prompt", None)
@@ -555,6 +577,7 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
555577
getattr(pipeline, custom_pipeline_method)
556578
if custom_pipeline_method
557579
else pipeline,
580+
cross_attention_kwargs=cross_attention_kwargs,
558581
callback=callback,
559582
**model_inputs,
560583
)

0 commit comments

Comments
 (0)