Skip to content

Commit 16193aa

Browse files
authored
Merge pull request #14 from MinishLab/add_push
add push_to_hub functionality
2 parents f54b7cd + f3a859e commit 16193aa

2 files changed

Lines changed: 27 additions & 1 deletion

File tree

model2vec/model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from logging import getLogger
44
from pathlib import Path
5+
from tempfile import TemporaryDirectory
56
from typing import Any, Iterator
67

78
import numpy as np
@@ -10,7 +11,7 @@
1011
from torch.nn import EmbeddingBag
1112
from tqdm import tqdm
1213

13-
from model2vec.utils import load_pretrained, save_pretrained
14+
from model2vec.utils import load_pretrained, push_folder_to_hub, save_pretrained
1415

1516
PathLike = Path | str
1617

@@ -154,3 +155,14 @@ def _encode_batch(self, sentences: list[str], max_length: int | None) -> np.ndar
154155
def _batch(sentences: list[str], batch_size: int) -> Iterator[list[str]]:
155156
"""Batch the sentences into equal-sized."""
156157
return (sentences[i : i + batch_size] for i in range(0, len(sentences), batch_size))
158+
159+
def push_to_hub(self, repo_id: str, token: str | None) -> None:
160+
"""
161+
Push the model to the huggingface hub.
162+
163+
:param repo_id: The repo id to push to.
164+
:param token: The huggingface token to use.
165+
"""
166+
with TemporaryDirectory() as temp_dir:
167+
self.save_pretrained(temp_dir)
168+
push_folder_to_hub(Path(temp_dir), repo_id, token)

model2vec/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,17 @@ def load_pretrained(
100100
)
101101

102102
return embeddings, tokenizer, config
103+
104+
105+
def push_folder_to_hub(folder_path: Path, repo_id: str, huggingface_token: str | None) -> None:
106+
"""
107+
Push a model folder to the huggingface hub.
108+
109+
:param folder_path: The path to the folder.
110+
:param repo_id: The repo name.
111+
:param huggingface_token: The huggingface token.
112+
"""
113+
if not huggingface_hub.repo_exists(repo_id=repo_id, token=huggingface_token):
114+
huggingface_hub.create_repo(repo_id, token=huggingface_token)
115+
huggingface_hub.upload_folder(repo_id=repo_id, folder_path=folder_path, token=huggingface_token)
116+
logger.info(f"Pushed model to {repo_id}")

0 commit comments

Comments
 (0)