Skip to content

Commit fa9dd16

Browse files
committed
feat(build): separate MODEL_REVISION, MODEL_PRECISION, HF_MODEL_ID
The above already worked for runtime downloads, but can now be used as build-args to download the image at build time and include in your image. SOME VERY IMPORTANT NOTES: 1) MODEL_REVISION no longer defaults to MODEL_PRECISION, you need to specify it separately (however, still defaults to "fp16" in Dockerfile). You'll get a warning if you specify MODEL_PRECISION without _REVISION, to help in the most common case of the old behaviour. 2) build-arg PRECISION still works but has been deprecated for MODEL_PRECISION (which was already the call-arg name). 3) normalized_model_id uses MODEL_REVISION, so for the early birds already using the "cloud cache" on S3, your filenames might no longer match in some cases.
1 parent 0f37a4e commit fa9dd16

File tree

6 files changed

+79
-40
lines changed

6 files changed

+79
-40
lines changed

api/app.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import requests
2020
from download import download_model, normalize_model_id
2121
import traceback
22-
from precision import PRECISION
22+
from precision import MODEL_REVISION, MODEL_PRECISION
2323

2424
RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
2525
USE_DREAMBOOTH = os.getenv("USE_DREAMBOOTH") == "1"
@@ -72,14 +72,19 @@ def init():
7272

7373
if not RUNTIME_DOWNLOADS:
7474
# Uh doesn't this break non-cached images? TODO... IMAGE_CACHE
75-
normalized_model_id = normalize_model_id(MODEL_ID, PRECISION)
75+
normalized_model_id = normalize_model_id(MODEL_ID, MODEL_REVISION)
7676
model_dir = os.path.join(MODELS_DIR, normalized_model_id)
7777
if os.path.isdir(model_dir):
7878
always_normalize_model_id = model_dir
7979
else:
8080
normalized_model_id = MODEL_ID
8181

82-
model = loadModel(model_dir, True, PRECISION)
82+
model = loadModel(
83+
model_id = model_dir,
84+
load=True,
85+
precision=MODEL_PRECISION,
86+
revision=MODEL_REVISION,
87+
)
8388
else:
8489
model = None
8590

@@ -156,7 +161,7 @@ def inference(all_inputs: dict) -> dict:
156161
model_precision = call_inputs.get("MODEL_PRECISION", None)
157162
checkpoint_url = call_inputs.get("CHECKPOINT_URL", None)
158163
checkpoint_config_url = call_inputs.get("CHECKPOINT_CONFIG_URL", None)
159-
normalized_model_id = normalize_model_id(model_id, model_precision)
164+
normalized_model_id = normalize_model_id(model_id, model_revision)
160165
model_dir = os.path.join(MODELS_DIR, normalized_model_id)
161166
if last_model_id != normalized_model_id:
162167
# if not downloaded_models.get(normalized_model_id, None):
@@ -172,7 +177,7 @@ def inference(all_inputs: dict) -> dict:
172177
download_model(
173178
model_id=model_id,
174179
model_url=model_url,
175-
model_revision=model_revision or model_precision,
180+
model_revision=model_revision,
176181
checkpoint_url=checkpoint_url,
177182
checkpoint_config_url=checkpoint_config_url,
178183
hf_model_id=hf_model_id,
@@ -182,7 +187,7 @@ def inference(all_inputs: dict) -> dict:
182187
clearPipelines()
183188
if model:
184189
model.to("cpu") # Necessary to avoid a memory leak
185-
model = loadModel(normalized_model_id, True, model_precision)
190+
model = loadModel(model_id=normalized_model_id, load=True, precision=model_precision)
186191
last_model_id = normalized_model_id
187192
else:
188193
if always_normalize_model_id:

api/download.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,17 @@
55
from loadModel import loadModel, MODEL_IDS
66
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
77
from transformers import CLIPTextModel, CLIPTokenizer
8-
from precision import PRECISION, revision_from_precision, torch_dtype_from_precision
98
from utils import Storage
109
import subprocess
1110
from pathlib import Path
1211
import shutil
1312
from convert_to_diffusers import main as convert_to_diffusers
1413
from download_checkpoint import main as download_checkpoint
1514

16-
MODEL_ID = os.environ.get("MODEL_ID")
17-
MODEL_URL = os.environ.get("MODEL_URL")
1815
USE_DREAMBOOTH = os.environ.get("USE_DREAMBOOTH")
1916
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
2017
RUNTIME_DOWNLOADS = os.environ.get("RUNTIME_DOWNLOADS")
18+
2119
HOME = os.path.expanduser("~")
2220
MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
2321
Path(MODELS_DIR).mkdir(parents=True, exist_ok=True)
@@ -55,19 +53,17 @@ def download_model(
5553
"hf_model_id": hf_model_id,
5654
},
5755
)
58-
url = model_url or MODEL_URL
5956
hf_model_id = hf_model_id or model_id
60-
model_revision = model_revision or revision_from_precision()
61-
normalized_model_id = id
57+
normalized_model_id = model_id
6258

63-
if url != "":
64-
normalized_model_id = normalize_model_id(model_id, model_precision)
59+
if model_url != "":
60+
normalized_model_id = normalize_model_id(model_id, model_revision)
6561
print({"normalized_model_id": normalized_model_id})
66-
filename = url.split("/").pop()
62+
filename = model_url.split("/").pop()
6763
if not filename:
6864
filename = normalized_model_id + ".tar.zst"
6965
model_file = os.path.join(MODELS_DIR, filename)
70-
storage = Storage(url, default_path=normalized_model_id + ".tar.zst")
66+
storage = Storage(model_url, default_path=normalized_model_id + ".tar.zst")
7167
exists = storage.file_exists()
7268
if exists:
7369
storage.download_file(model_file)
@@ -98,16 +94,28 @@ def download_model(
9894
)
9995
else:
10096
print("Does not exist, let's try find it on huggingface")
101-
print({"model_precision": model_precision, "model_revision": model_revision})
97+
print(
98+
{
99+
"model_precision": model_precision,
100+
"model_revision": model_revision,
101+
}
102+
)
102103
# This would be quicker to just model.to("cuda") afterwards, but
103104
# this conveniently logs all the timings (and doesn't happen often)
104105
print("download")
105106
send("download", "start", {})
106-
model = loadModel(hf_model_id, False, precision=model_precision, revision=model_revision) # download
107+
model = loadModel(
108+
hf_model_id,
109+
False,
110+
precision=model_precision,
111+
revision=model_revision,
112+
) # download
107113
send("download", "done", {})
108114

109115
print("load")
110-
model = loadModel(hf_model_id, True, precision=model_precision, revision=model_revision) # load
116+
model = loadModel(
117+
hf_model_id, True, precision=model_precision, revision=model_revision
118+
) # load
111119
# dir = "models--" + model_id.replace("/", "--") + "--dda"
112120
dir = os.path.join(MODELS_DIR, normalized_model_id)
113121
model.save_pretrained(dir, safe_serialization=True)
@@ -137,12 +145,12 @@ def download_model(
137145
return
138146

139147
# do a dry run of loading the huggingface model, which will download weights at build time
140-
# For local dev & preview deploys, download all the models (terrible for serverless deploys)
141-
if MODEL_ID == "ALL":
142-
for MODEL_I in MODEL_IDS:
143-
loadModel(MODEL_I, False, precision=model_revision)
144-
else:
145-
loadModel(normalized_model_id, False, precision=model_revision)
148+
loadModel(
149+
model_id=normalized_model_id,
150+
load=False,
151+
precision=model_precision,
152+
revision=model_revision,
153+
)
146154

147155
# if USE_DREAMBOOTH:
148156
# Actually we can re-use these from the above loaded model
@@ -164,4 +172,10 @@ def download_model(
164172

165173

166174
if __name__ == "__main__":
167-
download_model("", MODEL_ID, PRECISION)
175+
download_model(
176+
model_url=os.environ.get("MODEL_URL"),
177+
model_id=os.environ.get("MODEL_ID"),
178+
hf_model_id=os.environ.get("HF_MODEL_ID"),
179+
model_revision=os.environ.get("MODEL_REVISION"),
180+
model_precision=os.environ.get("MODEL_PRECISION"),
181+
)

api/getPipeline.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
DiffusionPipeline,
55
pipelines as diffusers_pipelines,
66
)
7-
from precision import revision, torch_dtype
87

98
HOME = os.path.expanduser("~")
109
MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
@@ -83,8 +82,8 @@ def getPipelineForModel(pipeline_name: str, model, model_id):
8382

8483
pipeline = DiffusionPipeline.from_pretrained(
8584
model_dir or model_id,
86-
revision=revision,
87-
torch_dtype=torch_dtype,
85+
# revision=revision,
86+
# torch_dtype=torch_dtype,
8887
custom_pipeline="./diffusers/examples/community/" + pipeline_name + ".py",
8988
local_files_only=True,
9089
**model.components,

api/loadModel.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from diffusers import pipelines as _pipelines, StableDiffusionPipeline
44
from getScheduler import getScheduler, DEFAULT_SCHEDULER
5-
from precision import revision_from_precision, torch_dtype_from_precision
5+
from precision import torch_dtype_from_precision
66
import time
77

88
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
@@ -25,9 +25,16 @@
2525

2626

2727
def loadModel(model_id: str, load=True, precision=None, revision=None):
28-
revision = revision or revision_from_precision(precision)
2928
torch_dtype = torch_dtype_from_precision(precision)
30-
print("loadModel", {"model_id": model_id, "load": load, "precision": precision, "revision": revision})
29+
print(
30+
"loadModel",
31+
{
32+
"model_id": model_id,
33+
"load": load,
34+
"precision": precision,
35+
"revision": revision,
36+
},
37+
)
3138
print(
3239
("Loading" if load else "Downloading")
3340
+ " model: "

api/precision.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,32 @@
11
import os
22
import torch
33

4-
PRECISION = os.getenv("PRECISION")
4+
DEPRECATED_PRECISION = os.getenv("PRECISION")
5+
MODEL_PRECISION = os.getenv("MODEL_PRECISION") or DEPRECATED_PRECISION
6+
MODEL_REVISION = os.getenv("MODEL_REVISION")
57

6-
revision = None if PRECISION == "" else PRECISION
7-
torch_dtype = None if PRECISION == "" else torch.float16
8+
if DEPRECATED_PRECISION:
9+
print("Warning: PRECISION variable been deprecated and renamed MODEL_PRECISION")
10+
print("Your setup still works but in a future release, this will throw an error")
811

12+
if MODEL_PRECISION and not MODEL_REVISION:
13+
print("Warning: we no longer default to MODEL_REVISION=MODEL_PRECISION, please")
14+
print(f'explicitly set MODEL_REVISION="{MODEL_PRECISION}" if that\'s what you')
15+
print("want.")
916

10-
def revision_from_precision(precision=PRECISION):
11-
return precision if precision else None
1217

18+
def revision_from_precision(precision=MODEL_PRECISION):
19+
# return precision if precision else None
20+
raise Exception("revision_from_precision no longer supported")
1321

14-
def torch_dtype_from_precision(precision=PRECISION):
22+
23+
def torch_dtype_from_precision(precision=MODEL_PRECISION):
24+
if precision == "fp16":
25+
return torch.float16
26+
return None
27+
28+
29+
def torch_dtype_from_precision(precision=MODEL_PRECISION):
1530
if precision == "fp16":
1631
return torch.float16
1732
return None

api/train_dreambooth.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from transformers import CLIPTextModel, CLIPTokenizer
3838

3939
# DDA
40-
from precision import revision, torch_dtype
4140
from send import send, get_now
4241
from utils import Storage
4342
import subprocess
@@ -55,7 +54,7 @@ def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs):
5554
params = {
5655
# Defaults
5756
"pretrained_model_name_or_path": model_id, # DDA, TODO
58-
"revision": revision, # DDA, was: None
57+
"revision": None,
5958
"tokenizer_name": None,
6059
"instance_data_dir": "instance_data_dir", # DDA TODO
6160
"class_data_dir": "class_data_dir", # DDA, was: None,

0 commit comments

Comments
 (0)