Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ Changes

Bugfixes
--------
- :meth:`deduplicate` now proceeds even though the clustering of the strings fails,
with a possibility to display the warnings.
:pr:`1996` by :user:`Marie Sacksick <MarieSacksick>`.

Deprecations
------------
Expand Down Expand Up @@ -79,12 +82,14 @@ Changes
- The ``exclude_cols`` of :meth:`DataOp.skb.apply` can now be a DataOp.
:pr:`2050` by :user:`Jérôme Dockès <jeromedockes>`.


Bugfixes
--------
- An error that could arise when calling ``score`` on a ``SkrubLearner`` that
contains an inner transformer that has a ``score`` method has been fixed.
:pr:`2052` by :user:`Jérôme Dockès <jeromedockes>`.


Deprecations
------------
- The parameter ``numeric_dtype`` in the :class:`Cleaner` has been deprecated in
Expand Down
29 changes: 21 additions & 8 deletions skrub/_deduplicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Implements deduplication based on clustering string distance matrices.
"""

import warnings

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
Expand Down Expand Up @@ -136,6 +138,7 @@ def deduplicate(
analyzer="char_wb",
linkage_method="average",
n_jobs=None,
warn=False,
):
"""Deduplicate categorical data by hierarchically clustering similar strings.

Expand Down Expand Up @@ -168,6 +171,9 @@ def deduplicate(
average distance between data points in the first and second cluster.
n_jobs : int, default=None
The number of jobs to run in parallel.
warn : bool, default=False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if the warning is needed, we almost never use warnings in skrub 🤔

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True that it's a bit unusual for skrub. But it seems to me that it would be nice to warn the user that no deduplication was done. And I don't want to add a systematic warning that will be ignored and will push the user to remove all warnings...

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i also wouldn't add the warn parameter. I would say either we consider that there are no duplicates to find and return the original list without warnings, or we consider it is a genuine failure and raise a (informative) exception

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This situation might happen regularly, and I don't want to break pipelines while deduplicate could be used to clean data "in case". I'll remove the warn parameter, but add this behavior in documentation.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with that solution 👍

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think of it, why is it a function and not a transformer? Is there a way for the user to know the clusters or not?
It looks like I can only deduplicate a whole dataset, and not fit on train and apply on a predict.

If True, emit a warning when clustering fails (e.g. too few or too
similar entries) and the input is returned unchanged.

Returns
-------
Expand Down Expand Up @@ -261,14 +267,21 @@ def deduplicate(
9 white 9 white
"""
unique_words, counts = np.unique(X, return_counts=True)
distance_mat = _compute_ngram_distance(
unique_words, ngram_range=ngram_range, analyzer=analyzer
)

Z = linkage(distance_mat, method=linkage_method, optimal_ordering=True)
if n_clusters is None:
n_clusters = _guess_clusters(Z, distance_mat, n_jobs)
clusters = fcluster(Z, n_clusters, criterion="maxclust")
try:
distance_mat = _compute_ngram_distance(
unique_words, ngram_range=ngram_range, analyzer=analyzer
)
Z = linkage(distance_mat, method=linkage_method, optimal_ordering=True)
if n_clusters is None:
n_clusters = _guess_clusters(Z, distance_mat, n_jobs)
clusters = fcluster(Z, n_clusters, criterion="maxclust")
except ValueError:
if warn:
warnings.warn(
"Deduplication could not cluster the data (too few or too similar"
" entries). Returning the input unchanged."
)
return list(X)

translation_table = _create_spelling_correction(unique_words, counts, clusters)
unrolled_corrections = translation_table[X]
Expand Down
35 changes: 35 additions & 0 deletions skrub/tests/test_deduplicate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import joblib
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -145,6 +147,39 @@ def start_call(self):
joblib.register_parallel_backend("testing", DummyBackend)


@pytest.mark.parametrize(
"X",
[
# Too few unique entries for silhouette score (only 2 unique values)
["black", "black", "black", "blac"],
# Too few unique entries (3 unique values, only 1 cluster possible)
["black", "white", "black", "black", "blac"],
# 4 unique values but still not enough for clustering to succeed
["black", "black", "black", "black", "white", "white", "white", "red", "green"],
],
)
def test_deduplicate_failure_returns_input(X):
result = deduplicate(X)
assert isinstance(result, list)
assert result == X


@pytest.mark.parametrize(
"X",
[
["black", "black", "black", "blac"],
["black", "white", "black", "black", "blac"],
],
)
def test_deduplicate_warn(X):
with pytest.warns(UserWarning, match="Returning the input unchanged"):
deduplicate(X, warn=True)

with warnings.catch_warnings():
warnings.simplefilter("error")
deduplicate(X, warn=False) # raises if any warning is emitted


@skip_if_no_parallel
def test_backend_respected():
"""
Expand Down
Loading