55from loadModel import loadModel , MODEL_IDS
66from diffusers import AutoencoderKL , UNet2DConditionModel , DDPMScheduler
77from transformers import CLIPTextModel , CLIPTokenizer
8- from precision import PRECISION , revision_from_precision , torch_dtype_from_precision
98from utils import Storage
109import subprocess
1110from pathlib import Path
1211import shutil
1312from convert_to_diffusers import main as convert_to_diffusers
1413from download_checkpoint import main as download_checkpoint
1514
16- MODEL_ID = os .environ .get ("MODEL_ID" )
17- MODEL_URL = os .environ .get ("MODEL_URL" )
1815USE_DREAMBOOTH = os .environ .get ("USE_DREAMBOOTH" )
1916HF_AUTH_TOKEN = os .environ .get ("HF_AUTH_TOKEN" )
2017RUNTIME_DOWNLOADS = os .environ .get ("RUNTIME_DOWNLOADS" )
18+
2119HOME = os .path .expanduser ("~" )
2220MODELS_DIR = os .path .join (HOME , ".cache" , "diffusers-api" )
2321Path (MODELS_DIR ).mkdir (parents = True , exist_ok = True )
@@ -55,19 +53,17 @@ def download_model(
5553 "hf_model_id" : hf_model_id ,
5654 },
5755 )
58- url = model_url or MODEL_URL
5956 hf_model_id = hf_model_id or model_id
60- model_revision = model_revision or revision_from_precision ()
61- normalized_model_id = id
57+ normalized_model_id = model_id
6258
63- if url != "" :
64- normalized_model_id = normalize_model_id (model_id , model_precision )
59+ if model_url != "" :
60+ normalized_model_id = normalize_model_id (model_id , model_revision )
6561 print ({"normalized_model_id" : normalized_model_id })
66- filename = url .split ("/" ).pop ()
62+ filename = model_url .split ("/" ).pop ()
6763 if not filename :
6864 filename = normalized_model_id + ".tar.zst"
6965 model_file = os .path .join (MODELS_DIR , filename )
70- storage = Storage (url , default_path = normalized_model_id + ".tar.zst" )
66+ storage = Storage (model_url , default_path = normalized_model_id + ".tar.zst" )
7167 exists = storage .file_exists ()
7268 if exists :
7369 storage .download_file (model_file )
@@ -98,16 +94,28 @@ def download_model(
9894 )
9995 else :
10096 print ("Does not exist, let's try find it on huggingface" )
101- print ({"model_precision" : model_precision , "model_revision" : model_revision })
97+ print (
98+ {
99+ "model_precision" : model_precision ,
100+ "model_revision" : model_revision ,
101+ }
102+ )
102103 # This would be quicker to just model.to("cuda") afterwards, but
103104 # this conveniently logs all the timings (and doesn't happen often)
104105 print ("download" )
105106 send ("download" , "start" , {})
106- model = loadModel (hf_model_id , False , precision = model_precision , revision = model_revision ) # download
107+ model = loadModel (
108+ hf_model_id ,
109+ False ,
110+ precision = model_precision ,
111+ revision = model_revision ,
112+ ) # download
107113 send ("download" , "done" , {})
108114
109115 print ("load" )
110- model = loadModel (hf_model_id , True , precision = model_precision , revision = model_revision ) # load
116+ model = loadModel (
117+ hf_model_id , True , precision = model_precision , revision = model_revision
118+ ) # load
111119 # dir = "models--" + model_id.replace("/", "--") + "--dda"
112120 dir = os .path .join (MODELS_DIR , normalized_model_id )
113121 model .save_pretrained (dir , safe_serialization = True )
@@ -137,12 +145,12 @@ def download_model(
137145 return
138146
139147 # do a dry run of loading the huggingface model, which will download weights at build time
140- # For local dev & preview deploys, download all the models (terrible for serverless deploys)
141- if MODEL_ID == "ALL" :
142- for MODEL_I in MODEL_IDS :
143- loadModel ( MODEL_I , False , precision = model_revision )
144- else :
145- loadModel ( normalized_model_id , False , precision = model_revision )
148+ loadModel (
149+ model_id = normalized_model_id ,
150+ load = False ,
151+ precision = model_precision ,
152+ revision = model_revision ,
153+ )
146154
147155 # if USE_DREAMBOOTH:
148156 # Actually we can re-use these from the above loaded model
@@ -164,4 +172,10 @@ def download_model(
164172
165173
166174if __name__ == "__main__" :
167- download_model ("" , MODEL_ID , PRECISION )
175+ download_model (
176+ model_url = os .environ .get ("MODEL_URL" ),
177+ model_id = os .environ .get ("MODEL_ID" ),
178+ hf_model_id = os .environ .get ("HF_MODEL_ID" ),
179+ model_revision = os .environ .get ("MODEL_REVISION" ),
180+ model_precision = os .environ .get ("MODEL_PRECISION" ),
181+ )
0 commit comments