From 44b81f18739acac1ae8acdd36134eb0091403aec Mon Sep 17 00:00:00 2001 From: stephantul Date: Tue, 6 May 2025 14:48:45 +0200 Subject: [PATCH] feat: add support for other path types --- model2vec/hf_utils.py | 5 ++++- model2vec/model.py | 11 ++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/model2vec/hf_utils.py b/model2vec/hf_utils.py index 1d12f43..99bde32 100644 --- a/model2vec/hf_utils.py +++ b/model2vec/hf_utils.py @@ -122,7 +122,10 @@ def load_pretrained( tokenizer_file = "tokenizer.json" config_name = "config.json" - folder_or_repo_path = Path(folder_or_repo_path) + # This check allows users to pass other things than Path objects, e.g., + # cloudpathlib.Anypath. + if isinstance(folder_or_repo_path, str): + folder_or_repo_path = Path(folder_or_repo_path) local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path diff --git a/model2vec/model.py b/model2vec/model.py index 200b95d..9aae58a 100644 --- a/model2vec/model.py +++ b/model2vec/model.py @@ -465,9 +465,14 @@ def load_local(cls: type[StaticModel], path: PathLike) -> StaticModel: :return: A StaticModel :raises: ValueError if the path is not a directory. """ - path = Path(path) - if not path.is_dir(): - raise ValueError(f"Path {path} is not a directory.") + if isinstance(path, str): + path = Path(path) + + if isinstance(path, Path): + # Only check if we're sure this is a path. + # It could be a cloudpathlib path, or something else. + if not path.is_dir(): + raise ValueError(f"Path {path} is not a directory.") embeddings, tokenizer, config = load_local_model(path)