Skip to content

Commit 4508b7d

Browse files
committed
fix: improve the error message provided by deduplicate
1 parent 4c5d4ac commit 4508b7d

2 files changed

Lines changed: 58 additions & 9 deletions

File tree

skrub/_deduplicate.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Implements deduplication based on clustering string distance matrices.
33
"""
44

5+
import warnings
6+
57
import numpy as np
68
import pandas as pd
79
from joblib import Parallel, delayed
@@ -75,7 +77,7 @@ def _guess_clusters(Z, distance_mat, n_jobs=None):
7577
number of clusters that maximize the silhouette score.
7678
"""
7779
max_clusters = Z.shape[0]
78-
n_clusters = np.arange(2, max_clusters)
80+
n_clusters = np.arange(1, max_clusters)
7981
# silhouette score needs a redundant distance matrix
8082
redundant_dist = squareform(distance_mat)
8183
silhouette_scores = Parallel(n_jobs=n_jobs, prefer="processes")(
@@ -136,6 +138,7 @@ def deduplicate(
136138
analyzer="char_wb",
137139
linkage_method="average",
138140
n_jobs=None,
141+
warn=False,
139142
):
140143
"""Deduplicate categorical data by hierarchically clustering similar strings.
141144
@@ -168,6 +171,9 @@ def deduplicate(
168171
average distance between data points in the first and second cluster.
169172
n_jobs : int, default=None
170173
The number of jobs to run in parallel.
174+
warn : bool, default=False
175+
If True, emit a warning when clustering fails (e.g. too few or too
176+
similar entries) and the input is returned unchanged.
171177
172178
Returns
173179
-------
@@ -260,14 +266,23 @@ def deduplicate(
260266
9 white 9 white
261267
"""
262268
unique_words, counts = np.unique(X, return_counts=True)
263-
distance_mat = _compute_ngram_distance(
264-
unique_words, ngram_range=ngram_range, analyzer=analyzer
265-
)
266-
267-
Z = linkage(distance_mat, method=linkage_method, optimal_ordering=True)
268-
if n_clusters is None:
269-
n_clusters = _guess_clusters(Z, distance_mat, n_jobs)
270-
clusters = fcluster(Z, n_clusters, criterion="maxclust")
269+
try:
270+
distance_mat = _compute_ngram_distance(
271+
unique_words, ngram_range=ngram_range, analyzer=analyzer
272+
)
273+
Z = linkage(distance_mat, method=linkage_method, optimal_ordering=True)
274+
if n_clusters is None:
275+
n_clusters = _guess_clusters(Z, distance_mat, n_jobs)
276+
clusters = fcluster(Z, n_clusters, criterion="maxclust")
277+
except Exception:
278+
if warn:
279+
warnings.warn(
280+
"Deduplication could not cluster the data (too few or too similar"
281+
" entries). Returning the input unchanged.",
282+
UserWarning,
283+
stacklevel=2,
284+
)
285+
return list(X)
271286

272287
translation_table = _create_spelling_correction(unique_words, counts, clusters)
273288
unrolled_corrections = translation_table[X]

skrub/tests/test_deduplicate.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import pandas as pd
44
import pytest
5+
import warnings
56
from scipy.cluster.hierarchy import linkage
67
from scipy.spatial.distance import squareform
78
from sklearn.utils._testing import assert_array_equal, skip_if_no_parallel
@@ -145,6 +146,39 @@ def start_call(self):
145146
joblib.register_parallel_backend("testing", DummyBackend)
146147

147148

149+
@pytest.mark.parametrize(
150+
"X",
151+
[
152+
# Too few unique entries for silhouette score (only 2 unique values)
153+
["black", "black", "black", "blac"],
154+
# Too few unique entries (3 unique values, only 1 cluster possible)
155+
["black", "white", "black", "black", "blac"],
156+
# 4 unique values but still not enough for clustering to succeed
157+
["black", "black", "black", "black", "white", "white", "white", "red", "green"],
158+
],
159+
)
160+
def test_deduplicate_failure_returns_input(X):
161+
result = deduplicate(X)
162+
assert isinstance(result, list)
163+
assert result == X
164+
165+
166+
@pytest.mark.parametrize(
167+
"X",
168+
[
169+
["black", "black", "black", "blac"],
170+
["black", "white", "black", "black", "blac"],
171+
],
172+
)
173+
def test_deduplicate_warn(X):
174+
with pytest.warns(UserWarning, match="Returning the input unchanged"):
175+
deduplicate(X, warn=True)
176+
177+
with warnings.catch_warnings():
178+
warnings.simplefilter("error")
179+
deduplicate(X, warn=False) # raises if any warning is emitted
180+
181+
148182
@skip_if_no_parallel
149183
def test_backend_respected():
150184
"""

0 commit comments

Comments
 (0)