@@ -130,6 +130,7 @@ def truncateInputs(inputs: dict):
130130
131131last_xformers_memory_efficient_attention = {}
132132last_attn_procs = None
133+ last_lora_weights = None
133134
134135
135136# Inference is ran for every server call
@@ -143,6 +144,7 @@ async def inference(all_inputs: dict, response) -> dict:
143144 global last_xformers_memory_efficient_attention
144145 global always_normalize_model_id
145146 global last_attn_procs
147+ global last_lora_weights
146148
147149 clearSession ()
148150
@@ -244,6 +246,8 @@ def sendStatus():
244246 "loadModel" , "done" , {"startRequestId" : startRequestId }, send_opts
245247 )
246248 last_model_id = normalized_model_id
249+ last_attn_procs = None
250+ last_lora_weights = None
247251 else :
248252 if always_normalize_model_id :
249253 normalized_model_id = always_normalize_model_id
@@ -312,8 +316,13 @@ def sendStatus():
312316 is_url = call_inputs .get ("is_url" , False )
313317 image_decoder = getFromUrl if is_url else decodeBase64Image
314318
319+ # Better to use new lora_weights in next section
315320 attn_procs = call_inputs .get ("attn_procs" , None )
316321 if attn_procs is not last_attn_procs :
322+ print (
323+ "[DEPRECATED] Using `attn_procs` for LoRAs is deprecated. "
324+ + "Please use `lora_weights` instead."
325+ )
317326 last_attn_procs = attn_procs
318327 if attn_procs :
319328 storage = Storage (attn_procs , no_raise = True )
@@ -344,6 +353,40 @@ def sendStatus():
344353 print ("Clearing attn procs" )
345354 pipeline .unet .set_attn_processor (CrossAttnProcessor ())
346355
356+ # Currently we only support a single string, but we should allow
357+ # and array too in anticipation of multi-LoRA support in diffusers
358+ # tracked at https://github.com/huggingface/diffusers/issues/2613.
359+ lora_weights = call_inputs .get ("lora_weights" , None )
360+ if lora_weights is not last_lora_weights :
361+ last_lora_weights = lora_weights
362+ if lora_weights :
363+ pipeline .unet .set_attn_processor (CrossAttnProcessor ())
364+ storage = Storage (lora_weights , no_raise = True )
365+ if storage :
366+ storage_query_fname = storage .query .get ("fname" )
367+ if storage_query_fname :
368+ fname = storage_query_fname [0 ]
369+ else :
370+ hash = sha256 (lora_weights .encode ("utf-8" )).hexdigest ()
371+ fname = "url_" + hash [:7 ] + "--" + storage .url .split ("/" ).pop ()
372+ cache_fname = "lora_weights--" + fname
373+ path = os .path .join (MODELS_DIR , cache_fname )
374+ if not os .path .exists (path ):
375+ storage .download_and_extract (path )
376+ print ("Load lora_weights `" + lora_weights + "` from `" + path + "`" )
377+ pipeline .load_lora_weights (
378+ MODELS_DIR , weight_name = cache_fname , local_files_only = True
379+ )
380+ else :
381+ print ("Loading from huggingface not supported yet: " + lora_weights )
382+ # maybe something like sayakpaul/civitai-light-shadow-lora#lora=l_a_s.s9s?
383+ # lora_model_id = "sayakpaul/civitai-light-shadow-lora"
384+ # lora_filename = "light_and_shadow.safetensors"
385+ # pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
386+ else :
387+ print ("Clearing attn procs" )
388+ pipeline .unet .set_attn_processor (CrossAttnProcessor ())
389+
347390 # TODO, generalize
348391 cross_attention_kwargs = model_inputs .get ("cross_attention_kwargs" , None )
349392 if isinstance (cross_attention_kwargs , str ):
0 commit comments