From 48afb264e4cde60f0f1f8ddd29a49310ed44648c Mon Sep 17 00:00:00 2001 From: Eloi Massoulie Date: Mon, 9 Mar 2026 14:10:57 +0100 Subject: [PATCH 1/3] Adding and corresponding test --- skrub/_text_encoder.py | 84 +++++++++++++++++++++++++++++--- skrub/tests/test_text_encoder.py | 34 +++++++++++++ 2 files changed, 110 insertions(+), 8 deletions(-) diff --git a/skrub/_text_encoder.py b/skrub/_text_encoder.py index 005efeb6c..a9dbf3de1 100644 --- a/skrub/_text_encoder.py +++ b/skrub/_text_encoder.py @@ -4,15 +4,16 @@ import warnings from pathlib import Path +import polars as pl from sklearn.decomposition import PCA from sklearn.utils.validation import check_is_fitted -from . import _dataframe as sbd -from ._scaling_factor import scaling_factor -from ._single_column_transformer import SingleColumnTransformer -from ._to_str import ToStr -from ._utils import import_optional_dependency, unique_strings -from .datasets._utils import get_data_dir +from skrub import _dataframe as sbd +from skrub._apply_to_cols import SingleColumnTransformer +from skrub._scaling_factor import scaling_factor +from skrub._to_str import ToStr +from skrub._utils import import_optional_dependency, unique_strings +from skrub.datasets._utils import get_data_dir class ModelNotFound(ValueError): @@ -199,6 +200,7 @@ def __init__( store_weights_in_pickle=False, random_state=None, verbose=False, + use_caching=False, ): self.model_name = model_name self.n_components = n_components @@ -209,6 +211,7 @@ def __init__( self.store_weights_in_pickle = store_weights_in_pickle self.random_state = random_state self.verbose = verbose + self.use_caching = use_caching def fit_transform(self, column, y=None): """Fit the TextEncoder from ``column``. @@ -238,7 +241,10 @@ def fit_transform(self, column, y=None): self.input_name_ = sbd.name(column) or "text_enc" - X_out = self._vectorize(column) + if self.use_caching and self.cache_folder is not None: + X_out = self._vectorize_with_cache(column) + else: + X_out = self._vectorize(column) if self.n_components is not None: if (min_shape := min(X_out.shape)) >= self.n_components: @@ -302,7 +308,10 @@ def transform(self, column): raise ValueError("Input column does not contain strings.") column = self.to_str.transform(column) - X_out = self._vectorize(column) + if self.use_caching and self.cache_folder is not None: + X_out = self._vectorize_with_cache(column) + else: + X_out = self._vectorize(column) if hasattr(self, "pca_"): X_out = self.pca_.transform(X_out) @@ -318,6 +327,56 @@ def transform(self, column): return X_out + def _vectorize_with_cache(self, column): + total_values = column.to_frame().rename({column.name: "values"}) + unique_values = column.unique().to_frame().rename({column.name: "values"}) + + if os.path.exists( + os.path.join( + self.cache_folder, + self.model_name.replace("/", "-") + "cached_outputs.parquet", + ) + ): + cached_values = pl.read_parquet( + os.path.join( + self.cache_folder, + self.model_name.replace("/", "-") + "cached_outputs.parquet", + ) + ) # !!!!POLARS DEPENDENT + else: + cached_values = pl.DataFrame( + schema={"values": pl.String} + ) #!!!!!! POLARS DEPENDENT + + to_compute = unique_values.join(cached_values, on="values", how="anti")[ + "values" + ] + print(f"Computing {to_compute.shape[0]}") + + if not to_compute.is_empty(): + V = pl.DataFrame(self._vectorize(to_compute)).with_columns(to_compute) + new_cache = sbd.concat(cached_values, V) + new_cache.write_parquet( + os.path.join(self.cache_folder, "temp_cache.parquet") + ) + os.rename( + os.path.join(self.cache_folder, "temp_cache.parquet"), + os.path.join( + self.cache_folder, + self.model_name.replace("/", "-") + "cached_outputs.parquet", + ), + ) + else: + new_cache = cached_values + + import numpy as np + + X_out = sbd.to_numpy( + total_values.join(new_cache, how="left", on="values", coalesce=True) + )[:, 1:].astype(np.float32) + + return X_out + def _vectorize(self, column): is_null = sbd.to_numpy(sbd.is_null(column)) column = sbd.to_numpy(column) @@ -325,6 +384,7 @@ def _vectorize(self, column): # sentence-transformers deals with converting a torch tensor # to a numpy array, on CPU. + return self._estimator.encode( unique_x, normalize_embeddings=False, @@ -444,3 +504,11 @@ def get_feature_names_out(self, input_features=None): f"{self.input_name_}_{str(i).zfill(num_digits)}" for i in range(self.n_components_) ] + + def flush_cache(self): + os.remove( + os.path.join( + self.cache_folder, + self.model_name.replace("/", "-") + "cached_outputs.parquet", + ) + ) diff --git a/skrub/tests/test_text_encoder.py b/skrub/tests/test_text_encoder.py index c7a6530c0..f9c80ba44 100644 --- a/skrub/tests/test_text_encoder.py +++ b/skrub/tests/test_text_encoder.py @@ -199,3 +199,37 @@ def test_categorical_features(df_module, encoder): out = encoder.fit(df["categorical"][:4]).transform(df["categorical"][4:]) assert len(sbd.column_names(out)) == 30 + + +@pytest.fixture +def test_cache_size(df_module): + df = sbd.make_dataframe( + { + "id": [1, 2, 3, 4, 5], + "name": ["one", "two", "three", "four", "one"], + "answer": ["yes", "no", "yes", "yes", "perhaps"], + } + ) + encoder = TextEncoder( + model_name="llm_e5-base-v2", + token_env_variable=None, + batch_size=32, + n_components=30, + use_caching=True, + cache_folder=".cache", + ) + + string_cols = ["name", "answer"] + + for col in string_cols: + _ = encoder.fit_transform(df[col], []) + + for col in string_cols: + _ = encoder.transform(df[col]) + + cached_values = pl.read_parquet(os.path.join(".cache", "cached_outputs.parquet")) + expected_cache = pl.DataFrame( + {"Values": ["one", "two", "three", "four", "yes", "no", "perhaps"]} + ) + + assert cached_values["values"] == expected_cache["values"] From d1701bfde4c504a81c12cf336d0b77293de388d5 Mon Sep 17 00:00:00 2001 From: Eloi Massoulie Date: Tue, 10 Mar 2026 17:24:05 +0100 Subject: [PATCH 2/3] Changelog --- CHANGES.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index dbf169549..e145ffd2a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -26,6 +26,8 @@ New Features some more attributes for inspection by scikit-learn: ``__sklearn_tags__``, ``classes_``, ``_estimator_type``. :pr:`1931` by :user:`Jérôme Dockès `. +- :class:`TextEncoder` has cache support, by allowing to store the values it computes + in a custom parquet file. :pr:`1955` by :user:`Eloi Massoulié `. Changes ------- From 5dbc66f39b8c9782f144c82df37f515868a92f59 Mon Sep 17 00:00:00 2001 From: Eloi Massoulie Date: Wed, 11 Mar 2026 14:15:48 +0100 Subject: [PATCH 3/3] Fixed imports --- skrub/_text_encoder.py | 2 +- skrub/tests/test_text_encoder.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/skrub/_text_encoder.py b/skrub/_text_encoder.py index a9dbf3de1..2dbb640cd 100644 --- a/skrub/_text_encoder.py +++ b/skrub/_text_encoder.py @@ -9,10 +9,10 @@ from sklearn.utils.validation import check_is_fitted from skrub import _dataframe as sbd -from skrub._apply_to_cols import SingleColumnTransformer from skrub._scaling_factor import scaling_factor from skrub._to_str import ToStr from skrub._utils import import_optional_dependency, unique_strings +from skrub.core import SingleColumnTransformer from skrub.datasets._utils import get_data_dir diff --git a/skrub/tests/test_text_encoder.py b/skrub/tests/test_text_encoder.py index f9c80ba44..7125bff45 100644 --- a/skrub/tests/test_text_encoder.py +++ b/skrub/tests/test_text_encoder.py @@ -1,7 +1,9 @@ +import os import pickle import sys import pandas as pd +import polars as pl import pytest from numpy.testing import assert_array_equal from sklearn.base import clone