|
1 | 1 | import abc |
| 2 | +import pathlib |
2 | 3 | import warnings |
3 | 4 | from collections.abc import Callable, Iterable, Iterator, MutableSequence |
4 | 5 | from enum import Enum |
|
7 | 8 |
|
8 | 9 | import joblib |
9 | 10 | import numpy as np |
| 11 | +from sklearn.exceptions import NotFittedError |
10 | 12 |
|
11 | 13 | from .._base_classes import FeatureExtractorBase, SimilarityMetric |
12 | 14 | from .._config import PICKLE_MODEL_FILES_PATH, setup_logging |
|
16 | 18 |
|
17 | 19 | setup_logging() |
18 | 20 |
|
| 21 | +_ENCODER_FILE_SUFFIX = ".encoder" |
| 22 | +_ENCODER_FILE_FORMAT_VERSION = 1 |
| 23 | +_ENCODER_STATE_KEYS = frozenset( |
| 24 | + { |
| 25 | + "encoder_class", |
| 26 | + "clustering_model", |
| 27 | + "pca", |
| 28 | + "power_norm_weight", |
| 29 | + "norm_order", |
| 30 | + "epsilon", |
| 31 | + "flatten", |
| 32 | + "raise_error_when_pca_incompatible", |
| 33 | + } |
| 34 | +) |
| 35 | + |
19 | 36 |
|
20 | 37 | # Helper Functions |
21 | 38 | def check_desired_output( |
@@ -103,6 +120,7 @@ def fallback(vecs1: np.ndarray, vecs2: np.ndarray) -> np.ndarray: |
103 | 120 |
|
104 | 121 |
|
105 | 122 | MethodT = TypeVar("MethodT", bound=Callable[..., Any]) |
| 123 | +_EncoderT = TypeVar("_EncoderT", bound="ImageEncoderBase") |
106 | 124 |
|
107 | 125 |
|
108 | 126 | def _tupleize_first_arg(func: MethodT) -> MethodT: # noqa: UP047 |
@@ -414,6 +432,87 @@ def learn( |
414 | 432 | print(" - New dimension after PCA reduction:", self._pca.n_components) |
415 | 433 | self._clustering_model.fit(features) |
416 | 434 |
|
| 435 | + def save_to_disk(self, path: str | pathlib.Path) -> pathlib.Path: |
| 436 | + """ |
| 437 | + Saves the learned state of this encoder to a ``.encoder`` file. |
| 438 | +
|
| 439 | + The file contains the fitted clustering model, the PCA model (if any) |
| 440 | + and the normalization hyperparameters. The feature extractor and the |
| 441 | + similarity function are not serialized; provide them again when |
| 442 | + calling :meth:`load_from_disk`. |
| 443 | +
|
| 444 | + :param path: Target file path. The ``.encoder`` suffix is appended if missing. |
| 445 | + :return: The path of the written file. |
| 446 | + :raises NotFittedError: If the clustering model is missing or not fitted. |
| 447 | + """ |
| 448 | + if self._clustering_model is None or not self._clustering_model.is_fitted: |
| 449 | + raise NotFittedError( |
| 450 | + "Cannot save an encoder whose clustering model is not fitted. " |
| 451 | + "Call 'learn' first." |
| 452 | + ) |
| 453 | + path = pathlib.Path(path) |
| 454 | + if path.suffix != _ENCODER_FILE_SUFFIX: |
| 455 | + path = path.with_name(path.name + _ENCODER_FILE_SUFFIX) |
| 456 | + state = { |
| 457 | + "format_version": _ENCODER_FILE_FORMAT_VERSION, |
| 458 | + "encoder_class": type(self).__name__, |
| 459 | + "clustering_model": self._clustering_model, |
| 460 | + "pca": self._pca, |
| 461 | + "power_norm_weight": self.power_norm_weight, |
| 462 | + "norm_order": self.norm_order, |
| 463 | + "epsilon": self.epsilon, |
| 464 | + "flatten": self.flatten, |
| 465 | + "raise_error_when_pca_incompatible": self.raise_error_when_pca_incompatible, |
| 466 | + } |
| 467 | + joblib.dump(state, path) |
| 468 | + return path |
| 469 | + |
| 470 | + @classmethod |
| 471 | + def load_from_disk( |
| 472 | + cls: type[_EncoderT], |
| 473 | + path: str | pathlib.Path, |
| 474 | + *, |
| 475 | + feature_extractor: FeatureExtractorBase | None = None, |
| 476 | + similarity_func: Callable[ |
| 477 | + [np.ndarray, np.ndarray], np.ndarray |
| 478 | + ] = cosine_similarity, |
| 479 | + ) -> _EncoderT: |
| 480 | + """ |
| 481 | + Loads an encoder previously saved with :meth:`save_to_disk`. |
| 482 | +
|
| 483 | + :param path: Path to the ``.encoder`` file. |
| 484 | + :param feature_extractor: Feature extractor to use with the loaded |
| 485 | + encoder. Defaults to RootSIFT. Its output dimension has to match |
| 486 | + the input dimension of the saved PCA or clustering model. |
| 487 | + :param similarity_func: Similarity function to use with the loaded encoder. |
| 488 | + :return: A ready-to-use encoder instance. |
| 489 | + :raises ValueError: If the file is not a valid ``.encoder`` file or |
| 490 | + was saved by a different encoder class. |
| 491 | + """ |
| 492 | + state = joblib.load(path) |
| 493 | + if not isinstance(state, dict) or not _ENCODER_STATE_KEYS.issubset(state): |
| 494 | + raise ValueError(f"File {path} is not a valid .encoder file.") |
| 495 | + if state["encoder_class"] != cls.__name__: |
| 496 | + raise ValueError( |
| 497 | + f"File {path} was saved by {state['encoder_class']}. " |
| 498 | + f"Load it with {state['encoder_class']}.load_from_disk instead." |
| 499 | + ) |
| 500 | + encoder = cls( |
| 501 | + feature_extractor=feature_extractor, |
| 502 | + similarity_func=similarity_func, |
| 503 | + power_norm_weight=state["power_norm_weight"], |
| 504 | + norm_order=state["norm_order"], |
| 505 | + epsilon=state["epsilon"], |
| 506 | + flatten=state["flatten"], |
| 507 | + raise_error_when_pca_incompatible=state[ |
| 508 | + "raise_error_when_pca_incompatible" |
| 509 | + ], |
| 510 | + ) |
| 511 | + if state["pca"] is not None: |
| 512 | + encoder.pca = state["pca"] |
| 513 | + encoder.clustering_model = state["clustering_model"] |
| 514 | + return encoder |
| 515 | + |
417 | 516 | @_tupleize_first_arg |
418 | 517 | # @lru_cache(maxsize=4) |
419 | 518 | def generate_encoding_map( |
|
0 commit comments