Skip to content

Commit adaa7f6

Browse files
committed
feat(downloads): allow HF_MODEL_ID call-arg (defauls to MODEL_ID)
Sometimes you might want the unique model_id you use to differ from the HuggingFace MODEL_ID (user/repo), e.g. maybe the model is constantly updated and you want to create new unique local models from each update. TODO, separate MODEL_REVISION and MODEL_PRECISION. Allow HF_MODEL_ID env variable for builds.
1 parent 6a78229 commit adaa7f6

2 files changed

Lines changed: 7 additions & 3 deletions

File tree

api/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def inference(all_inputs: dict) -> dict:
151151
normalized_model_id = model_id
152152

153153
if RUNTIME_DOWNLOADS:
154+
hf_model_id = call_inputs.get("HF_MODEL_ID", None)
154155
model_precision = call_inputs.get("MODEL_PRECISION", None)
155156
checkpoint_url = call_inputs.get("CHECKPOINT_URL", None)
156157
checkpoint_config_url = call_inputs.get("CHECKPOINT_CONFIG_URL", None)
@@ -173,6 +174,7 @@ def inference(all_inputs: dict) -> dict:
173174
model_revision=model_precision,
174175
checkpoint_url=checkpoint_url,
175176
checkpoint_config_url=checkpoint_config_url,
177+
hf_model_id=hf_model_id,
176178
)
177179
# downloaded_models.update({normalized_model_id: True})
178180
clearPipelines()

api/download.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,19 @@ def download_model(
4343
model_revision=None,
4444
checkpoint_url=None,
4545
checkpoint_config_url=None,
46+
hf_model_id=None,
4647
):
4748
print(
4849
"download_model",
4950
{
5051
"model_url": model_url,
5152
"model_id": model_id,
5253
"model_revision": model_revision,
54+
"hf_model_id": hf_model_id,
5355
},
5456
)
55-
id = model_id or MODEL_ID
5657
url = model_url or MODEL_URL
58+
hf_model_id = hf_model_id or model_id
5759
revision = model_revision or revision_from_precision()
5860
normalized_model_id = id
5961

@@ -100,11 +102,11 @@ def download_model(
100102
# this conveniently logs all the timings (and doesn't happen often)
101103
print("download")
102104
send("download", "start", {})
103-
model = loadModel(model_id, False, precision=model_revision) # download
105+
model = loadModel(hf_model_id, False, precision=model_revision) # download
104106
send("download", "done", {})
105107

106108
print("load")
107-
model = loadModel(model_id, True, precision=model_revision) # load
109+
model = loadModel(hf_model_id, True, precision=model_revision) # load
108110
# dir = "models--" + model_id.replace("/", "--") + "--dda"
109111
dir = os.path.join(MODELS_DIR, normalized_model_id)
110112
model.save_pretrained(dir, safe_serialization=True)

0 commit comments

Comments
 (0)