@@ -44,6 +44,7 @@ def download_model(
4444 checkpoint_url = None ,
4545 checkpoint_config_url = None ,
4646 hf_model_id = None ,
47+ model_precision = None ,
4748):
4849 print (
4950 "download_model" ,
@@ -56,11 +57,11 @@ def download_model(
5657 )
5758 url = model_url or MODEL_URL
5859 hf_model_id = hf_model_id or model_id
59- revision = model_revision or revision_from_precision ()
60+ model_revision = model_revision or revision_from_precision ()
6061 normalized_model_id = id
6162
6263 if url != "" :
63- normalized_model_id = normalize_model_id (model_id , model_revision )
64+ normalized_model_id = normalize_model_id (model_id , model_precision )
6465 print ({"normalized_model_id" : normalized_model_id })
6566 filename = url .split ("/" ).pop ()
6667 if not filename :
@@ -97,16 +98,16 @@ def download_model(
9798 )
9899 else :
99100 print ("Does not exist, let's try find it on huggingface" )
100- print ("precision = " , { "model_revision" : model_revision })
101+ print ({ "model_precision" : model_precision , "model_revision" : model_revision })
101102 # This would be quicker to just model.to("cuda") afterwards, but
102103 # this conveniently logs all the timings (and doesn't happen often)
103104 print ("download" )
104105 send ("download" , "start" , {})
105- model = loadModel (hf_model_id , False , precision = model_revision ) # download
106+ model = loadModel (hf_model_id , False , precision = model_precision , revision = model_revision ) # download
106107 send ("download" , "done" , {})
107108
108109 print ("load" )
109- model = loadModel (hf_model_id , True , precision = model_revision ) # load
110+ model = loadModel (hf_model_id , True , precision = model_precision , revision = model_revision ) # load
110111 # dir = "models--" + model_id.replace("/", "--") + "--dda"
111112 dir = os .path .join (MODELS_DIR , normalized_model_id )
112113 model .save_pretrained (dir , safe_serialization = True )
0 commit comments