Skip to content

Commit 96f5b91

Browse files
MechaCritterclaude
andcommitted
feat(encoders): add save_to_disk/load_from_disk with .encoder files
Encoders can now persist their learned state to a versioned .encoder file (fitted clustering model, PCA model and normalization hyperparameters) and be restored from it via the load_from_disk classmethod. The feature extractor and similarity function are not serialized and are provided again at load time; dimension validation runs on restore. This is the designated replacement for loading pretrained models via the KMeansWeights/GMMWeights enums. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
1 parent 1ec8a2b commit 96f5b91

1 file changed

Lines changed: 99 additions & 0 deletions

File tree

pyvisim/encoders/_base_encoder.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import abc
2+
import pathlib
23
import warnings
34
from collections.abc import Callable, Iterable, Iterator, MutableSequence
45
from enum import Enum
@@ -7,6 +8,7 @@
78

89
import joblib
910
import numpy as np
11+
from sklearn.exceptions import NotFittedError
1012

1113
from .._base_classes import FeatureExtractorBase, SimilarityMetric
1214
from .._config import PICKLE_MODEL_FILES_PATH, setup_logging
@@ -16,6 +18,21 @@
1618

1719
setup_logging()
1820

21+
_ENCODER_FILE_SUFFIX = ".encoder"
22+
_ENCODER_FILE_FORMAT_VERSION = 1
23+
_ENCODER_STATE_KEYS = frozenset(
24+
{
25+
"encoder_class",
26+
"clustering_model",
27+
"pca",
28+
"power_norm_weight",
29+
"norm_order",
30+
"epsilon",
31+
"flatten",
32+
"raise_error_when_pca_incompatible",
33+
}
34+
)
35+
1936

2037
# Helper Functions
2138
def check_desired_output(
@@ -103,6 +120,7 @@ def fallback(vecs1: np.ndarray, vecs2: np.ndarray) -> np.ndarray:
103120

104121

105122
MethodT = TypeVar("MethodT", bound=Callable[..., Any])
123+
_EncoderT = TypeVar("_EncoderT", bound="ImageEncoderBase")
106124

107125

108126
def _tupleize_first_arg(func: MethodT) -> MethodT: # noqa: UP047
@@ -414,6 +432,87 @@ def learn(
414432
print(" - New dimension after PCA reduction:", self._pca.n_components)
415433
self._clustering_model.fit(features)
416434

435+
def save_to_disk(self, path: str | pathlib.Path) -> pathlib.Path:
436+
"""
437+
Saves the learned state of this encoder to a ``.encoder`` file.
438+
439+
The file contains the fitted clustering model, the PCA model (if any)
440+
and the normalization hyperparameters. The feature extractor and the
441+
similarity function are not serialized; provide them again when
442+
calling :meth:`load_from_disk`.
443+
444+
:param path: Target file path. The ``.encoder`` suffix is appended if missing.
445+
:return: The path of the written file.
446+
:raises NotFittedError: If the clustering model is missing or not fitted.
447+
"""
448+
if self._clustering_model is None or not self._clustering_model.is_fitted:
449+
raise NotFittedError(
450+
"Cannot save an encoder whose clustering model is not fitted. "
451+
"Call 'learn' first."
452+
)
453+
path = pathlib.Path(path)
454+
if path.suffix != _ENCODER_FILE_SUFFIX:
455+
path = path.with_name(path.name + _ENCODER_FILE_SUFFIX)
456+
state = {
457+
"format_version": _ENCODER_FILE_FORMAT_VERSION,
458+
"encoder_class": type(self).__name__,
459+
"clustering_model": self._clustering_model,
460+
"pca": self._pca,
461+
"power_norm_weight": self.power_norm_weight,
462+
"norm_order": self.norm_order,
463+
"epsilon": self.epsilon,
464+
"flatten": self.flatten,
465+
"raise_error_when_pca_incompatible": self.raise_error_when_pca_incompatible,
466+
}
467+
joblib.dump(state, path)
468+
return path
469+
470+
@classmethod
471+
def load_from_disk(
472+
cls: type[_EncoderT],
473+
path: str | pathlib.Path,
474+
*,
475+
feature_extractor: FeatureExtractorBase | None = None,
476+
similarity_func: Callable[
477+
[np.ndarray, np.ndarray], np.ndarray
478+
] = cosine_similarity,
479+
) -> _EncoderT:
480+
"""
481+
Loads an encoder previously saved with :meth:`save_to_disk`.
482+
483+
:param path: Path to the ``.encoder`` file.
484+
:param feature_extractor: Feature extractor to use with the loaded
485+
encoder. Defaults to RootSIFT. Its output dimension has to match
486+
the input dimension of the saved PCA or clustering model.
487+
:param similarity_func: Similarity function to use with the loaded encoder.
488+
:return: A ready-to-use encoder instance.
489+
:raises ValueError: If the file is not a valid ``.encoder`` file or
490+
was saved by a different encoder class.
491+
"""
492+
state = joblib.load(path)
493+
if not isinstance(state, dict) or not _ENCODER_STATE_KEYS.issubset(state):
494+
raise ValueError(f"File {path} is not a valid .encoder file.")
495+
if state["encoder_class"] != cls.__name__:
496+
raise ValueError(
497+
f"File {path} was saved by {state['encoder_class']}. "
498+
f"Load it with {state['encoder_class']}.load_from_disk instead."
499+
)
500+
encoder = cls(
501+
feature_extractor=feature_extractor,
502+
similarity_func=similarity_func,
503+
power_norm_weight=state["power_norm_weight"],
504+
norm_order=state["norm_order"],
505+
epsilon=state["epsilon"],
506+
flatten=state["flatten"],
507+
raise_error_when_pca_incompatible=state[
508+
"raise_error_when_pca_incompatible"
509+
],
510+
)
511+
if state["pca"] is not None:
512+
encoder.pca = state["pca"]
513+
encoder.clustering_model = state["clustering_model"]
514+
return encoder
515+
417516
@_tupleize_first_arg
418517
# @lru_cache(maxsize=4)
419518
def generate_encoding_map(

0 commit comments

Comments
 (0)