Skip to content

Commit 0cb839d

Browse files
committed
feat(checkpoints): support #fname query in HTTPStorage
1 parent 0c7a757 commit 0cb839d

File tree

3 files changed

+25
-13
lines changed

3 files changed

+25
-13
lines changed

api/convert_to_diffusers.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@ def main(
2222
checkpoint_url: str,
2323
checkpoint_config_url: str,
2424
checkpoint_args: dict = {},
25+
path=None,
2526
):
26-
fname = CHECKPOINT_DIR + "/" + checkpoint_url.split("/").pop()
27+
if not path:
28+
fname = checkpoint_url.split("/").pop()
29+
path = os.path.join(CHECKPOINT_DIR, fname)
2730

2831
if checkpoint_config_url and checkpoint_config_url != "":
2932
storage = Storage(checkpoint_config_url)
30-
configPath = (
31-
CHECKPOINT_DIR + "/" + checkpoint_url.split("/").pop() + "_config.yaml"
32-
)
33+
configPath = CHECKPOINT_DIR + "/" + path + "_config.yaml"
3334
print(f"Downloading {checkpoint_config_url} to {configPath}...")
3435
storage.download_file(configPath)
3536

@@ -47,7 +48,7 @@ def main(
4748
# "./diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py"
4849
# )
4950

50-
print("Converting " + fname + " to diffusers model " + model_id + "...", flush=True)
51+
print("Converting " + path + " to diffusers model " + model_id + "...", flush=True)
5152

5253
# These are now in main requirements.txt.
5354
# subprocess.run(
@@ -112,11 +113,11 @@ def main(
112113
# our defaults
113114
args.update(
114115
{
115-
"checkpoint_path": fname,
116+
"checkpoint_path": path,
116117
"original_config_file": configPath if checkpoint_config_url else None,
117118
"device": device_id,
118119
"extract_ema": True,
119-
"from_safetensors": "safetensor" in fname.lower(),
120+
"from_safetensors": "safetensor" in path.lower(),
120121
}
121122
)
122123

api/download.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,12 @@ async def download_model(
8080
await asyncio.to_thread(storage.download_and_extract, model_file, model_dir)
8181
else:
8282
if checkpoint_url:
83-
download_checkpoint(checkpoint_url)
83+
path = download_checkpoint(checkpoint_url)
8484
convert_to_diffusers(
8585
model_id=model_id,
8686
checkpoint_url=checkpoint_url,
8787
checkpoint_config_url=checkpoint_config_url,
88+
path=path,
8889
)
8990
else:
9091
print("Does not exist, let's try find it on huggingface")
@@ -138,11 +139,12 @@ async def download_model(
138139

139140
else:
140141
if checkpoint_url:
141-
download_checkpoint(checkpoint_url)
142+
path = download_checkpoint(checkpoint_url)
142143
convert_to_diffusers(
143144
model_id=model_id,
144145
checkpoint_url=checkpoint_url,
145146
checkpoint_config_url=checkpoint_config_url,
147+
path=path,
146148
)
147149
else:
148150
# do a dry run of loading the huggingface model, which will download weights at build time

api/download_checkpoint.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,19 @@
88
def main(checkpoint_url: str):
99
if not os.path.isdir(CHECKPOINT_DIR):
1010
os.makedirs(CHECKPOINT_DIR)
11-
fname = CHECKPOINT_DIR + "/" + checkpoint_url.split("/").pop()
12-
if not os.path.isfile(fname):
13-
storage = Storage(checkpoint_url)
14-
storage.download_file(fname)
11+
12+
storage = Storage(checkpoint_url)
13+
storage_query_fname = storage.query.get("fname")
14+
if storage_query_fname:
15+
fname = storage_query_fname[0]
16+
else:
17+
fname = checkpoint_url.split("/").pop()
18+
path = os.path.join(CHECKPOINT_DIR, fname)
19+
20+
if not os.path.isfile(path):
21+
storage.download_file(path)
22+
23+
return path
1524

1625

1726
if __name__ == "__main__":

0 commit comments

Comments
 (0)