diff --git a/CHANGES.rst b/CHANGES.rst index e04d01357..592bf1b5f 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -29,6 +29,8 @@ New Features - :func:`selectors.has_nulls` now takes a ``proportion`` parameter, which allows selecting columns that have a fraction of null values above the given threshold. :pr:`1881` by :user:`Gabriela Gómez Jiménez `. +- :class:`TextEncoder` has cache support, by allowing to store the values it computes + in a custom parquet file. :pr:`1955` by :user:`Eloi Massoulié `. Changes ------- diff --git a/skrub/_text_encoder.py b/skrub/_text_encoder.py index 005efeb6c..2dbb640cd 100644 --- a/skrub/_text_encoder.py +++ b/skrub/_text_encoder.py @@ -4,15 +4,16 @@ import warnings from pathlib import Path +import polars as pl from sklearn.decomposition import PCA from sklearn.utils.validation import check_is_fitted -from . import _dataframe as sbd -from ._scaling_factor import scaling_factor -from ._single_column_transformer import SingleColumnTransformer -from ._to_str import ToStr -from ._utils import import_optional_dependency, unique_strings -from .datasets._utils import get_data_dir +from skrub import _dataframe as sbd +from skrub._scaling_factor import scaling_factor +from skrub._to_str import ToStr +from skrub._utils import import_optional_dependency, unique_strings +from skrub.core import SingleColumnTransformer +from skrub.datasets._utils import get_data_dir class ModelNotFound(ValueError): @@ -199,6 +200,7 @@ def __init__( store_weights_in_pickle=False, random_state=None, verbose=False, + use_caching=False, ): self.model_name = model_name self.n_components = n_components @@ -209,6 +211,7 @@ def __init__( self.store_weights_in_pickle = store_weights_in_pickle self.random_state = random_state self.verbose = verbose + self.use_caching = use_caching def fit_transform(self, column, y=None): """Fit the TextEncoder from ``column``. @@ -238,7 +241,10 @@ def fit_transform(self, column, y=None): self.input_name_ = sbd.name(column) or "text_enc" - X_out = self._vectorize(column) + if self.use_caching and self.cache_folder is not None: + X_out = self._vectorize_with_cache(column) + else: + X_out = self._vectorize(column) if self.n_components is not None: if (min_shape := min(X_out.shape)) >= self.n_components: @@ -302,7 +308,10 @@ def transform(self, column): raise ValueError("Input column does not contain strings.") column = self.to_str.transform(column) - X_out = self._vectorize(column) + if self.use_caching and self.cache_folder is not None: + X_out = self._vectorize_with_cache(column) + else: + X_out = self._vectorize(column) if hasattr(self, "pca_"): X_out = self.pca_.transform(X_out) @@ -318,6 +327,56 @@ def transform(self, column): return X_out + def _vectorize_with_cache(self, column): + total_values = column.to_frame().rename({column.name: "values"}) + unique_values = column.unique().to_frame().rename({column.name: "values"}) + + if os.path.exists( + os.path.join( + self.cache_folder, + self.model_name.replace("/", "-") + "cached_outputs.parquet", + ) + ): + cached_values = pl.read_parquet( + os.path.join( + self.cache_folder, + self.model_name.replace("/", "-") + "cached_outputs.parquet", + ) + ) # !!!!POLARS DEPENDENT + else: + cached_values = pl.DataFrame( + schema={"values": pl.String} + ) #!!!!!! POLARS DEPENDENT + + to_compute = unique_values.join(cached_values, on="values", how="anti")[ + "values" + ] + print(f"Computing {to_compute.shape[0]}") + + if not to_compute.is_empty(): + V = pl.DataFrame(self._vectorize(to_compute)).with_columns(to_compute) + new_cache = sbd.concat(cached_values, V) + new_cache.write_parquet( + os.path.join(self.cache_folder, "temp_cache.parquet") + ) + os.rename( + os.path.join(self.cache_folder, "temp_cache.parquet"), + os.path.join( + self.cache_folder, + self.model_name.replace("/", "-") + "cached_outputs.parquet", + ), + ) + else: + new_cache = cached_values + + import numpy as np + + X_out = sbd.to_numpy( + total_values.join(new_cache, how="left", on="values", coalesce=True) + )[:, 1:].astype(np.float32) + + return X_out + def _vectorize(self, column): is_null = sbd.to_numpy(sbd.is_null(column)) column = sbd.to_numpy(column) @@ -325,6 +384,7 @@ def _vectorize(self, column): # sentence-transformers deals with converting a torch tensor # to a numpy array, on CPU. + return self._estimator.encode( unique_x, normalize_embeddings=False, @@ -444,3 +504,11 @@ def get_feature_names_out(self, input_features=None): f"{self.input_name_}_{str(i).zfill(num_digits)}" for i in range(self.n_components_) ] + + def flush_cache(self): + os.remove( + os.path.join( + self.cache_folder, + self.model_name.replace("/", "-") + "cached_outputs.parquet", + ) + ) diff --git a/skrub/tests/test_text_encoder.py b/skrub/tests/test_text_encoder.py index c7a6530c0..7125bff45 100644 --- a/skrub/tests/test_text_encoder.py +++ b/skrub/tests/test_text_encoder.py @@ -1,7 +1,9 @@ +import os import pickle import sys import pandas as pd +import polars as pl import pytest from numpy.testing import assert_array_equal from sklearn.base import clone @@ -199,3 +201,37 @@ def test_categorical_features(df_module, encoder): out = encoder.fit(df["categorical"][:4]).transform(df["categorical"][4:]) assert len(sbd.column_names(out)) == 30 + + +@pytest.fixture +def test_cache_size(df_module): + df = sbd.make_dataframe( + { + "id": [1, 2, 3, 4, 5], + "name": ["one", "two", "three", "four", "one"], + "answer": ["yes", "no", "yes", "yes", "perhaps"], + } + ) + encoder = TextEncoder( + model_name="llm_e5-base-v2", + token_env_variable=None, + batch_size=32, + n_components=30, + use_caching=True, + cache_folder=".cache", + ) + + string_cols = ["name", "answer"] + + for col in string_cols: + _ = encoder.fit_transform(df[col], []) + + for col in string_cols: + _ = encoder.transform(df[col]) + + cached_values = pl.read_parquet(os.path.join(".cache", "cached_outputs.parquet")) + expected_cache = pl.DataFrame( + {"Values": ["one", "two", "three", "four", "yes", "no", "perhaps"]} + ) + + assert cached_values["values"] == expected_cache["values"]