Skip to content

Commit 4c69a68

Browse files
authored
feat: add subfolder loading (#218)
1 parent 6731674 commit 4c69a68

2 files changed

Lines changed: 50 additions & 28 deletions

File tree

model2vec/hf_utils.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def save_pretrained(
2323
tokenizer: Tokenizer,
2424
config: dict[str, Any],
2525
create_model_card: bool = True,
26+
subfolder: str | None = None,
2627
**kwargs: Any,
2728
) -> None:
2829
"""
@@ -33,8 +34,10 @@ def save_pretrained(
3334
:param tokenizer: The tokenizer.
3435
:param config: A metadata config.
3536
:param create_model_card: Whether to create a model card.
37+
:param subfolder: The subfolder to save the model in.
3638
:param **kwargs: Any additional arguments.
3739
"""
40+
folder_path = folder_path / subfolder if subfolder else folder_path
3841
folder_path.mkdir(exist_ok=True, parents=True)
3942
save_file({"embeddings": embeddings}, folder_path / "model.safetensors")
4043
tokenizer.save(str(folder_path / "tokenizer.json"))
@@ -92,14 +95,18 @@ def _create_model_card(
9295

9396

9497
def load_pretrained(
95-
folder_or_repo_path: str | Path, token: str | None = None, from_sentence_transformers: bool = False
98+
folder_or_repo_path: str | Path,
99+
subfolder: str | None = None,
100+
token: str | None = None,
101+
from_sentence_transformers: bool = False,
96102
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
97103
"""
98104
Loads a pretrained model from a folder.
99105
100106
:param folder_or_repo_path: The folder or repo path to load from.
101107
- If this is a local path, we will load from the local path.
102108
- If the local path is not found, we will attempt to load from the huggingface hub.
109+
:param subfolder: The subfolder to load from.
103110
:param token: The huggingface token to use.
104111
:param from_sentence_transformers: Whether to load the model from a sentence transformers model.
105112
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
@@ -116,36 +123,47 @@ def load_pretrained(
116123
config_name = "config.json"
117124

118125
folder_or_repo_path = Path(folder_or_repo_path)
119-
if folder_or_repo_path.exists():
120-
embeddings_path = folder_or_repo_path / model_file
126+
127+
local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path
128+
129+
if local_folder.exists():
130+
embeddings_path = local_folder / model_file
121131
if not embeddings_path.exists():
122-
raise FileNotFoundError(f"Embeddings file does not exist in {folder_or_repo_path}")
132+
raise FileNotFoundError(f"Embeddings file does not exist in {local_folder}")
123133

124-
config_path = folder_or_repo_path / config_name
134+
config_path = local_folder / config_name
125135
if not config_path.exists():
126-
raise FileNotFoundError(f"Config file does not exist in {folder_or_repo_path}")
136+
raise FileNotFoundError(f"Config file does not exist in {local_folder}")
127137

128-
tokenizer_path = folder_or_repo_path / tokenizer_file
138+
tokenizer_path = local_folder / tokenizer_file
129139
if not tokenizer_path.exists():
130-
raise FileNotFoundError(f"Tokenizer file does not exist in {folder_or_repo_path}")
140+
raise FileNotFoundError(f"Tokenizer file does not exist in {local_folder}")
131141

132142
# README is optional, so this is a bit finicky.
133-
readme_path = folder_or_repo_path / "README.md"
143+
readme_path = local_folder / "README.md"
134144
metadata = _get_metadata_from_readme(readme_path)
135145

136146
else:
137147
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
138-
embeddings_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), model_file, token=token)
148+
embeddings_path = huggingface_hub.hf_hub_download(
149+
folder_or_repo_path.as_posix(), model_file, token=token, subfolder=subfolder
150+
)
139151

140152
try:
141-
readme_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "README.md", token=token)
153+
readme_path = huggingface_hub.hf_hub_download(
154+
folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder
155+
)
142156
metadata = _get_metadata_from_readme(Path(readme_path))
143157
except huggingface_hub.utils.EntryNotFoundError:
144158
logger.info("No README found in the model folder. No model card loaded.")
145159
metadata = {}
146160

147-
config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), config_name, token=token)
148-
tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), tokenizer_file, token=token)
161+
config_path = huggingface_hub.hf_hub_download(
162+
folder_or_repo_path.as_posix(), config_name, token=token, subfolder=subfolder
163+
)
164+
tokenizer_path = huggingface_hub.hf_hub_download(
165+
folder_or_repo_path.as_posix(), tokenizer_file, token=token, subfolder=subfolder
166+
)
149167

150168
opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
151169
if from_sentence_transformers:
@@ -176,11 +194,15 @@ def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:
176194
return data
177195

178196

179-
def push_folder_to_hub(folder_path: Path, repo_id: str, private: bool, token: str | None) -> None:
197+
def push_folder_to_hub(
198+
folder_path: Path, subfolder: str | None, repo_id: str, private: bool, token: str | None
199+
) -> None:
180200
"""
181201
Push a model folder to the huggingface hub, including model card.
182202
183203
:param folder_path: The path to the folder.
204+
:param subfolder: The subfolder to push to.
205+
If None, the folder will be pushed to the root of the repo.
184206
:param repo_id: The repo name.
185207
:param private: Whether the repo is private.
186208
:param token: The huggingface token.
@@ -189,15 +211,6 @@ def push_folder_to_hub(folder_path: Path, repo_id: str, private: bool, token: st
189211
huggingface_hub.create_repo(repo_id, token=token, private=private)
190212

191213
# Push model card and all model files to the Hugging Face hub
192-
huggingface_hub.upload_folder(repo_id=repo_id, folder_path=folder_path, token=token)
193-
194-
# Check if the model card exists, and push it if available
195-
model_card_path = folder_path / "README.md"
196-
if model_card_path.exists():
197-
card = ModelCard.load(model_card_path)
198-
card.push_to_hub(repo_id=repo_id, token=token)
199-
logger.info(f"Pushed model card to {repo_id}")
200-
else:
201-
logger.warning(f"Model card README.md not found in {folder_path}. Skipping model card upload.")
214+
huggingface_hub.upload_folder(repo_id=repo_id, folder_path=folder_path, token=token, path_in_repo=subfolder)
202215

203216
logger.info(f"Pushed model to {repo_id}")

model2vec/model.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,13 @@ def normalize(self, value: bool) -> None:
9696
)
9797
self.config["normalize"] = value
9898

99-
def save_pretrained(self, path: PathLike, model_name: str | None = None) -> None:
99+
def save_pretrained(self, path: PathLike, model_name: str | None = None, subfolder: str | None = None) -> None:
100100
"""
101101
Save the pretrained model.
102102
103103
:param path: The path to save to.
104104
:param model_name: The model name to use in the Model Card.
105+
:param subfolder: The subfolder to save to.
105106
"""
106107
from model2vec.hf_utils import save_pretrained
107108

@@ -113,6 +114,7 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None) -> None
113114
base_model_name=self.base_model_name,
114115
language=self.language,
115116
model_name=model_name,
117+
subfolder=subfolder,
116118
)
117119

118120
def tokenize(self, sentences: list[str], max_length: int | None = None) -> list[list[int]]:
@@ -151,6 +153,7 @@ def from_pretrained(
151153
path: PathLike,
152154
token: str | None = None,
153155
normalize: bool | None = None,
156+
subfolder: str | None = None,
154157
quantize_to: str | DType | None = None,
155158
dimensionality: int | None = None,
156159
) -> StaticModel:
@@ -162,6 +165,7 @@ def from_pretrained(
162165
:param path: The path to load your static model from.
163166
:param token: The huggingface token to use.
164167
:param normalize: Whether to normalize the embeddings.
168+
:param subfolder: The subfolder to load from.
165169
:param quantize_to: The dtype to quantize the model to. If None, no quantization is done.
166170
If a string is passed, it is converted to a DType.
167171
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
@@ -172,7 +176,9 @@ def from_pretrained(
172176
"""
173177
from model2vec.hf_utils import load_pretrained
174178

175-
embeddings, tokenizer, config, metadata = load_pretrained(path, token=token, from_sentence_transformers=False)
179+
embeddings, tokenizer, config, metadata = load_pretrained(
180+
path, token=token, from_sentence_transformers=False, subfolder=subfolder
181+
)
176182

177183
if quantize_to is not None:
178184
quantize_to = DType(quantize_to)
@@ -375,7 +381,9 @@ def _batch(sentences: list[str], batch_size: int) -> Iterator[list[str]]:
375381
"""Batch the sentences into equal-sized."""
376382
return (sentences[i : i + batch_size] for i in range(0, len(sentences), batch_size))
377383

378-
def push_to_hub(self, repo_id: str, private: bool = False, token: str | None = None) -> None:
384+
def push_to_hub(
385+
self, repo_id: str, private: bool = False, token: str | None = None, subfolder: str | None = None
386+
) -> None:
379387
"""
380388
Push the model to the huggingface hub.
381389
@@ -385,12 +393,13 @@ def push_to_hub(self, repo_id: str, private: bool = False, token: str | None = N
385393
:param private: Whether the repo, if created is set to private.
386394
If the repo already exists, this doesn't change the visibility.
387395
:param token: The huggingface token to use.
396+
:param subfolder: The subfolder to push to.
388397
"""
389398
from model2vec.hf_utils import push_folder_to_hub
390399

391400
with TemporaryDirectory() as temp_dir:
392401
self.save_pretrained(temp_dir, model_name=repo_id)
393-
push_folder_to_hub(Path(temp_dir), repo_id, private, token)
402+
push_folder_to_hub(Path(temp_dir), subfolder=subfolder, repo_id=repo_id, private=private, token=token)
394403

395404
@classmethod
396405
def load_local(cls: type[StaticModel], path: PathLike) -> StaticModel:

0 commit comments

Comments
 (0)