2121import traceback
2222from precision import MODEL_REVISION , MODEL_PRECISION
2323from device import device , device_id , device_name
24- from diffusers .models .cross_attention import CrossAttnProcessor , LoRACrossAttnProcessor
24+ from diffusers .models .cross_attention import CrossAttnProcessor
25+ from utils import Storage
26+ from hashlib import sha256
27+
2528
2629RUNTIME_DOWNLOADS = os .getenv ("RUNTIME_DOWNLOADS" ) == "1"
2730USE_DREAMBOOTH = os .getenv ("USE_DREAMBOOTH" ) == "1"
@@ -278,6 +281,18 @@ def inference(all_inputs: dict) -> dict:
278281 if attn_procs is not last_attn_procs :
279282 last_attn_procs = attn_procs
280283 if attn_procs :
284+ storage = Storage (attn_procs , no_raise = True )
285+ if storage :
286+ fname = storage .url .split ("/" ).pop ()
287+ hash = sha256 (attn_procs .encode ("utf-8" )).hexdigest ()
288+ if True :
289+ # TODO, way to specify explicit name
290+ path = os .path .join (
291+ MODELS_DIR , "attn_proc--url_" + hash [:7 ] + "--" + fname
292+ )
293+ attn_procs = path
294+ if not os .path .exists (path ):
295+ storage .download_and_extract (path )
281296 print ("Load attn_procs " + attn_procs )
282297 pipeline .unet .load_attn_procs (attn_procs )
283298 else :
@@ -286,7 +301,7 @@ def inference(all_inputs: dict) -> dict:
286301
287302 # TODO, generalize
288303 cross_attention_kwargs = model_inputs .get ("cross_attention_kwargs" , None )
289- if cross_attention_kwargs :
304+ if isinstance ( cross_attention_kwargs , str ) :
290305 model_inputs ["cross_attention_kwargs" ] = json .loads (cross_attention_kwargs )
291306
292307 # Parse out your arguments
0 commit comments