@@ -122,6 +122,7 @@ def truncateInputs(inputs: dict):
122122# last_xformers_memory_efficient_attention = {}
123123last_attn_procs = None
124124last_lora_weights = None
125+ cross_attention_kwargs = None
125126
126127
127128# Inference is ran for every server call
@@ -135,6 +136,7 @@ async def inference(all_inputs: dict, response) -> dict:
135136 global always_normalize_model_id
136137 global last_attn_procs
137138 global last_lora_weights
139+ global cross_attention_kwargs
138140
139141 clearSession ()
140142
@@ -379,6 +381,7 @@ def sendStatus():
379381 lora_weights_joined = json .dumps (lora_weights )
380382 if last_lora_weights != lora_weights_joined :
381383 last_lora_weights = lora_weights_joined
384+ cross_attention_kwargs = {}
382385 print ("Unloading previous LoRA weights" )
383386 pipeline .unload_lora_weights ()
384387
@@ -390,6 +393,12 @@ def sendStatus():
390393 storage = Storage (weights , no_raise = True , status = status )
391394 if storage :
392395 storage_query_fname = storage .query .get ("fname" )
396+ storage_query_scale = (
397+ float (storage .query .get ("scale" )[0 ])
398+ if storage .query .get ("scale" )
399+ else 1
400+ )
401+ cross_attention_kwargs .update ({"scale" : storage_query_scale })
393402 if storage_query_fname :
394403 fname = storage_query_fname [0 ]
395404 else :
@@ -413,9 +422,22 @@ def sendStatus():
413422 print ("No changes to LoRAs since last call" )
414423
415424 # TODO, generalize
416- cross_attention_kwargs = model_inputs .get ("cross_attention_kwargs" , None )
417- if isinstance (cross_attention_kwargs , str ):
418- model_inputs ["cross_attention_kwargs" ] = json .loads (cross_attention_kwargs )
425+ mi_cross_attention_kwargs = model_inputs .get ("cross_attention_kwargs" , None )
426+ if mi_cross_attention_kwargs :
427+ model_inputs .pop ("cross_attention_kwargs" )
428+ if isinstance (mi_cross_attention_kwargs , str ):
429+ cross_attention_kwargs .update (json .loads (mi_cross_attention_kwargs ))
430+ elif type (mi_cross_attention_kwargs ) == dict :
431+ cross_attention_kwargs .update (mi_cross_attention_kwargs )
432+ else :
433+ return {
434+ "$error" : {
435+ "code" : "INVALID_CROSS_ATTENTION_KWARGS" ,
436+ "message" : "`cross_attention_kwargs` should be a dict or json string" ,
437+ }
438+ }
439+
440+ print ({"cross_attention_kwargs" : cross_attention_kwargs })
419441
420442 # Parse out your arguments
421443 # prompt = model_inputs.get("prompt", None)
@@ -555,6 +577,7 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
555577 getattr (pipeline , custom_pipeline_method )
556578 if custom_pipeline_method
557579 else pipeline ,
580+ cross_attention_kwargs = cross_attention_kwargs ,
558581 callback = callback ,
559582 ** model_inputs ,
560583 )
0 commit comments