Skip to content

Commit 36ae572

Browse files
committed
update huggingface example
1 parent 5821050 commit 36ae572

File tree

2 files changed

+39
-13
lines changed

2 files changed

+39
-13
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Generated by Beam SDK
2+
.beamignore
3+
pyproject.toml
4+
.git
5+
.idea
6+
.python-version
7+
.vscode
8+
.venv
9+
venv
10+
__pycache__
11+
.DS_Store
12+
.config
13+
drive/MyDrive
14+
.coverage
15+
.pytest_cache
16+
.ipynb
17+
.ruff_cache
18+
.dockerignore
19+
.ipynb_checkpoints
20+
.env.local
21+
.envrc
22+
**/__pycache__/
23+
**/.pytest_cache/
24+
**/node_modules/
25+
**/.venv/
26+
*.pyc
27+
.next/
28+
.circleci

utils/huggingface_clone/app.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from beam import function, Volume, Image
2-
from transformers import AutoModelForCausalLM, AutoTokenizer
2+
from huggingface_hub import snapshot_download
3+
import os
34

45
@function(
56
app="volume-imports",
67
name="huggingface-clone-model",
7-
secrets=["HUGGINGFACE_TOKEN"],
8+
secrets=["HUGGINGFACE_TOKEN", "HF_TOKEN"],
89
memory="8gb",
10+
gpu="T4",
911
image=Image(
10-
python_packages=["torch","transformers"]
12+
python_packages=["torch", "huggingface_hub"]
1113
),
1214
volumes=[Volume(name="huggingface_models", mount_path="/huggingface_models")]
1315
)
@@ -16,25 +18,21 @@ def handler(*, model_name: str = ""):
1618
raise ValueError("model_name is required")
1719

1820
print(f"Downloading model: {model_name}")
21+
token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN")
1922

2023
try:
21-
# Download model and tokenizer
22-
model = AutoModelForCausalLM.from_pretrained(model_name)
23-
tokenizer = AutoTokenizer.from_pretrained(model_name)
24+
os.makedirs(f"/huggingface_models/{model_name}", exist_ok=True)
2425

25-
# Save to local volume
26-
save_path = f"/huggingface_models/{model_name.replace('/', '_')}"
27-
model.save_pretrained(save_path)
28-
tokenizer.save_pretrained(save_path)
29-
print(f"Model and tokenizer saved to: {save_path}")
26+
path = snapshot_download(repo_id=model_name, local_dir=f"/huggingface_models/{model_name}", token=token)
27+
print(f"Model downloaded to: {path}")
3028

3129
return {
3230
"model_name": model_name,
33-
"saved_path": save_path
31+
"saved_path": path
3432
}
3533
except Exception as e:
3634
print(f"Failed to download model: {str(e)}")
3735
raise Exception(f"Failed to download model: {str(e)}")
3836

3937
if __name__ == "__main__":
40-
handler(model_name="distilbert/distilgpt2")
38+
handler(model_name="tencent/Hunyuan3D-2.1")

0 commit comments

Comments
 (0)