Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ New Features
- :func:`selectors.has_nulls` now takes a ``proportion`` parameter, which allows
selecting columns that have a fraction of null values above the given threshold.
:pr:`1881` by :user:`Gabriela Gómez Jiménez <gabrielapgomezji>`.
- :class:`TextEncoder` has cache support, by allowing to store the values it computes
in a custom parquet file. :pr:`1955` by :user:`Eloi Massoulié <emassoulie>`.

Changes
-------
Expand Down
84 changes: 76 additions & 8 deletions skrub/_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
import warnings
from pathlib import Path

import polars as pl
from sklearn.decomposition import PCA
from sklearn.utils.validation import check_is_fitted

from . import _dataframe as sbd
from ._scaling_factor import scaling_factor
from ._single_column_transformer import SingleColumnTransformer
from ._to_str import ToStr
from ._utils import import_optional_dependency, unique_strings
from .datasets._utils import get_data_dir
from skrub import _dataframe as sbd
from skrub._scaling_factor import scaling_factor
from skrub._to_str import ToStr
from skrub._utils import import_optional_dependency, unique_strings
from skrub.core import SingleColumnTransformer
from skrub.datasets._utils import get_data_dir


class ModelNotFound(ValueError):
Expand Down Expand Up @@ -199,6 +200,7 @@ def __init__(
store_weights_in_pickle=False,
random_state=None,
verbose=False,
use_caching=False,
):
self.model_name = model_name
self.n_components = n_components
Expand All @@ -209,6 +211,7 @@ def __init__(
self.store_weights_in_pickle = store_weights_in_pickle
self.random_state = random_state
self.verbose = verbose
self.use_caching = use_caching

def fit_transform(self, column, y=None):
"""Fit the TextEncoder from ``column``.
Expand Down Expand Up @@ -238,7 +241,10 @@ def fit_transform(self, column, y=None):

self.input_name_ = sbd.name(column) or "text_enc"

X_out = self._vectorize(column)
if self.use_caching and self.cache_folder is not None:
X_out = self._vectorize_with_cache(column)
else:
X_out = self._vectorize(column)

if self.n_components is not None:
if (min_shape := min(X_out.shape)) >= self.n_components:
Expand Down Expand Up @@ -302,7 +308,10 @@ def transform(self, column):
raise ValueError("Input column does not contain strings.")
column = self.to_str.transform(column)

X_out = self._vectorize(column)
if self.use_caching and self.cache_folder is not None:
X_out = self._vectorize_with_cache(column)
else:
X_out = self._vectorize(column)

if hasattr(self, "pca_"):
X_out = self.pca_.transform(X_out)
Expand All @@ -318,13 +327,64 @@ def transform(self, column):

return X_out

def _vectorize_with_cache(self, column):
total_values = column.to_frame().rename({column.name: "values"})
unique_values = column.unique().to_frame().rename({column.name: "values"})

if os.path.exists(
os.path.join(
self.cache_folder,
self.model_name.replace("/", "-") + "cached_outputs.parquet",
)
):
cached_values = pl.read_parquet(
os.path.join(
self.cache_folder,
self.model_name.replace("/", "-") + "cached_outputs.parquet",
)
) # !!!!POLARS DEPENDENT
else:
cached_values = pl.DataFrame(
schema={"values": pl.String}
) #!!!!!! POLARS DEPENDENT

to_compute = unique_values.join(cached_values, on="values", how="anti")[
"values"
]
print(f"Computing {to_compute.shape[0]}")

if not to_compute.is_empty():
V = pl.DataFrame(self._vectorize(to_compute)).with_columns(to_compute)
new_cache = sbd.concat(cached_values, V)
new_cache.write_parquet(
os.path.join(self.cache_folder, "temp_cache.parquet")
)
os.rename(
os.path.join(self.cache_folder, "temp_cache.parquet"),
os.path.join(
self.cache_folder,
self.model_name.replace("/", "-") + "cached_outputs.parquet",
),
)
else:
new_cache = cached_values

import numpy as np

X_out = sbd.to_numpy(
total_values.join(new_cache, how="left", on="values", coalesce=True)
)[:, 1:].astype(np.float32)

return X_out

def _vectorize(self, column):
is_null = sbd.to_numpy(sbd.is_null(column))
column = sbd.to_numpy(column)
unique_x, indices_x = unique_strings(column, is_null)

# sentence-transformers deals with converting a torch tensor
# to a numpy array, on CPU.

return self._estimator.encode(
unique_x,
normalize_embeddings=False,
Expand Down Expand Up @@ -444,3 +504,11 @@ def get_feature_names_out(self, input_features=None):
f"{self.input_name_}_{str(i).zfill(num_digits)}"
for i in range(self.n_components_)
]

def flush_cache(self):
os.remove(
os.path.join(
self.cache_folder,
self.model_name.replace("/", "-") + "cached_outputs.parquet",
)
)
36 changes: 36 additions & 0 deletions skrub/tests/test_text_encoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import pickle
import sys

import pandas as pd
import polars as pl
import pytest
from numpy.testing import assert_array_equal
from sklearn.base import clone
Expand Down Expand Up @@ -199,3 +201,37 @@ def test_categorical_features(df_module, encoder):

out = encoder.fit(df["categorical"][:4]).transform(df["categorical"][4:])
assert len(sbd.column_names(out)) == 30


@pytest.fixture
def test_cache_size(df_module):
df = sbd.make_dataframe(
{
"id": [1, 2, 3, 4, 5],
"name": ["one", "two", "three", "four", "one"],
"answer": ["yes", "no", "yes", "yes", "perhaps"],
}
)
encoder = TextEncoder(
model_name="llm_e5-base-v2",
token_env_variable=None,
batch_size=32,
n_components=30,
use_caching=True,
cache_folder=".cache",
)

string_cols = ["name", "answer"]

for col in string_cols:
_ = encoder.fit_transform(df[col], [])

for col in string_cols:
_ = encoder.transform(df[col])

cached_values = pl.read_parquet(os.path.join(".cache", "cached_outputs.parquet"))
expected_cache = pl.DataFrame(
{"Values": ["one", "two", "three", "four", "yes", "no", "perhaps"]}
)

assert cached_values["values"] == expected_cache["values"]
Loading