Skip to content

Commit 6edc821

Browse files
committed
feat(downloads): allow separate MODEL_REVISION and MODEL_PRECISION
TODO: allow same for builds
1 parent adaa7f6 commit 6edc821

3 files changed

Lines changed: 12 additions & 9 deletions

File tree

api/app.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def inference(all_inputs: dict) -> dict:
152152

153153
if RUNTIME_DOWNLOADS:
154154
hf_model_id = call_inputs.get("HF_MODEL_ID", None)
155+
model_revision = call_inputs.get("MODEL_REVISION", None)
155156
model_precision = call_inputs.get("MODEL_PRECISION", None)
156157
checkpoint_url = call_inputs.get("CHECKPOINT_URL", None)
157158
checkpoint_config_url = call_inputs.get("CHECKPOINT_CONFIG_URL", None)
@@ -171,10 +172,11 @@ def inference(all_inputs: dict) -> dict:
171172
download_model(
172173
model_id=model_id,
173174
model_url=model_url,
174-
model_revision=model_precision,
175+
model_revision=model_revision or model_precision,
175176
checkpoint_url=checkpoint_url,
176177
checkpoint_config_url=checkpoint_config_url,
177178
hf_model_id=hf_model_id,
179+
model_precision=model_precision,
178180
)
179181
# downloaded_models.update({normalized_model_id: True})
180182
clearPipelines()

api/download.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def download_model(
4444
checkpoint_url=None,
4545
checkpoint_config_url=None,
4646
hf_model_id=None,
47+
model_precision=None,
4748
):
4849
print(
4950
"download_model",
@@ -56,11 +57,11 @@ def download_model(
5657
)
5758
url = model_url or MODEL_URL
5859
hf_model_id = hf_model_id or model_id
59-
revision = model_revision or revision_from_precision()
60+
model_revision = model_revision or revision_from_precision()
6061
normalized_model_id = id
6162

6263
if url != "":
63-
normalized_model_id = normalize_model_id(model_id, model_revision)
64+
normalized_model_id = normalize_model_id(model_id, model_precision)
6465
print({"normalized_model_id": normalized_model_id})
6566
filename = url.split("/").pop()
6667
if not filename:
@@ -97,16 +98,16 @@ def download_model(
9798
)
9899
else:
99100
print("Does not exist, let's try find it on huggingface")
100-
print("precision = ", {"model_revision": model_revision})
101+
print({"model_precision": model_precision, "model_revision": model_revision})
101102
# This would be quicker to just model.to("cuda") afterwards, but
102103
# this conveniently logs all the timings (and doesn't happen often)
103104
print("download")
104105
send("download", "start", {})
105-
model = loadModel(hf_model_id, False, precision=model_revision) # download
106+
model = loadModel(hf_model_id, False, precision=model_precision, revision=model_revision) # download
106107
send("download", "done", {})
107108

108109
print("load")
109-
model = loadModel(hf_model_id, True, precision=model_revision) # load
110+
model = loadModel(hf_model_id, True, precision=model_precision, revision=model_revision) # load
110111
# dir = "models--" + model_id.replace("/", "--") + "--dda"
111112
dir = os.path.join(MODELS_DIR, normalized_model_id)
112113
model.save_pretrained(dir, safe_serialization=True)

api/loadModel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
]
2525

2626

27-
def loadModel(model_id: str, load=True, precision=None):
28-
print("loadModel", {"model_id": model_id, "load": load, "precision": precision})
29-
revision = revision_from_precision(precision)
27+
def loadModel(model_id: str, load=True, precision=None, revision=None):
28+
revision = revision or revision_from_precision(precision)
3029
torch_dtype = torch_dtype_from_precision(precision)
30+
print("loadModel", {"model_id": model_id, "load": load, "precision": precision, "revision": revision})
3131
print(
3232
("Loading" if load else "Downloading")
3333
+ " model: "

0 commit comments

Comments
 (0)