Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ Changes
- The row indices of training and testing samples are now also included in the
dictionaries produced by :meth:`DataOp.skb.iter_cv_splits`. :pr:`2012` by
:user:`Jérôme Dockès <jeromedockes>`.
- Added a ``metric`` parameter to :func:`fuzzy_join` and :class:`Joiner` to configure
the nearest-neighbor distance used for matching. The metric can be any value
supported by :class:`~sklearn.neighbors.NearestNeighbors` (see its docstring).
:pr:`1861` by :user:`Saba Siddique <sabasiddique1>`.

Bugfixes
--------
Expand Down
11 changes: 8 additions & 3 deletions skrub/_fuzzy_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def fuzzy_join(
string_encoder=DEFAULT_STRING_ENCODER,
add_match_info=False,
drop_unmatched=False,
metric="euclidean",
):
"""Fuzzy (approximate) join.

Expand All @@ -35,9 +36,9 @@ def fuzzy_join(

To identify the best match for each row, values from the matching columns
(``left_key`` and ``right_key``) are vectorized, i.e. represented by vectors of
continuous values. Then, the Euclidean distances between these vectors are
computed to find, for each left table row, its nearest neighbor within the
right table.
continuous values. Then, distances between these vectors are computed
(using the specified metric) to find, for each left table row, its nearest
neighbor within the right table.

Optionally, a maximum distance threshold, ``max_dist``, can be set. Matches
between vectors that are separated by a distance (strictly) greater than
Expand Down Expand Up @@ -120,6 +121,9 @@ def fuzzy_join(
drop_unmatched : bool, default=False
Remove rows for which a match was not found in the right table (i.e. for
which the nearest neighbor is further than `max_dist`).
metric : str, default='euclidean'
The distance metric to use for nearest neighbor search.
See :class:`~sklearn.neighbors.NearestNeighbors` for all available metrics.

Returns
-------
Expand Down Expand Up @@ -208,6 +212,7 @@ def fuzzy_join(
ref_dist=ref_dist,
string_encoder=string_encoder,
add_match_info=True,
metric=metric,
).fit_transform(left)
if drop_unmatched:
join = sbd.filter(join, sbd.col(join, "skrub_Joiner_match_accepted"))
Expand Down
13 changes: 9 additions & 4 deletions skrub/_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ class Joiner(TransformerMixin, BaseEstimator):

To identify the best match for each row, values from the matching columns
(`main_key` and `aux_key`) are vectorized, i.e. represented by vectors of
continuous values. Then, the Euclidean distances between these vectors are
computed to find, for each main table row, its nearest neighbor within the
auxiliary table.
continuous values. Then, the distances between these vectors are computed
(using the specified metric) to find, for each main table row, its
nearest neighbor within the auxiliary table.

Optionally, a maximum distance threshold, `max_dist`, can be set. Matches
between vectors that are separated by a distance (strictly) greater than
Expand Down Expand Up @@ -168,6 +168,9 @@ class Joiner(TransformerMixin, BaseEstimator):
above the threshold. Those values can be helpful for an estimator that
uses the joined features, or to inspect the result of the join and set
a `max_dist` threshold.
metric : str, default='euclidean'
The distance metric to use for nearest neighbor search.
See :class:`~sklearn.neighbors.NearestNeighbors` for all available metrics.

Attributes
----------
Expand Down Expand Up @@ -236,6 +239,7 @@ def __init__(
ref_dist=DEFAULT_REF_DIST,
string_encoder=DEFAULT_STRING_ENCODER,
add_match_info=True,
metric="euclidean",
):
self.aux_table = aux_table
self.key = key
Expand All @@ -250,6 +254,7 @@ def __init__(
else string_encoder
)
self.add_match_info = add_match_info
self.metric = metric

def _check_max_dist(self):
if (
Expand All @@ -267,7 +272,7 @@ def _check_ref_dist(self):
f"'ref_dist' should be one of {list(_MATCHERS.keys())}. Got"
f" {self.ref_dist!r}"
)
self._matching = _MATCHERS[self.ref_dist]()
self._matching = _MATCHERS[self.ref_dist](metric=self.metric)

def fit_transform(self, X, y=None):
"""Fit the instance to the main table.
Expand Down
51 changes: 47 additions & 4 deletions skrub/_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,20 @@ class Matching(BaseEstimator):
nearest match are rescaled. This base class does not apply any rescaling:
the reference distance is 1.0. Subclasses can override
``_get_reference_distances`` to modify the rescaling behavior.

Parameters
----------
metric : str, default='euclidean'
The distance metric to use for nearest neighbor search.
See :class:`~sklearn.neighbors.NearestNeighbors` for all available metrics.
"""

def __init__(self, metric="euclidean"):
self.metric = metric

def fit(self, aux):
self.aux_ = aux
self.neighbors_ = NearestNeighbors(n_neighbors=1).fit(aux)
self.neighbors_ = NearestNeighbors(n_neighbors=1, metric=self.metric).fit(aux)
return self

def match(self, main, max_dist):
Expand Down Expand Up @@ -72,9 +81,27 @@ class RandomPairs(Matching):
Pairs of (different) rows are sampled randomly from the auxiliary table,
and the distance that separates each pair is computed. The reference
distance is a percentile (by default 25%) of the resulting distances.

Parameters
----------
percentile : float, default=25.0
Percentile to use for reference distance.
n_sampled_pairs : int, default=500
Number of random pairs to sample.
random_state : int or RandomState, default=0
Random state for sampling.
metric : str, default='euclidean'
The distance metric to use. See Matching class.
"""

def __init__(self, percentile=25.0, n_sampled_pairs=500, random_state=0):
def __init__(
self,
percentile=25.0,
n_sampled_pairs=500,
random_state=0,
metric="euclidean",
):
super().__init__(metric=metric)
self.percentile = percentile
self.n_sampled_pairs = n_sampled_pairs
self.random_state = random_state
Expand Down Expand Up @@ -121,9 +148,17 @@ class SelfJoinNeighbor(Matching):

Instead of the nearest neighbor, another (more distant neighbor) can be
chosen by setting ``reference_neighbor``.

Parameters
----------
reference_neighbor : int, default=1
Which neighbor to use as reference (1 = nearest, 2 = second nearest, etc.)
metric : str, default='euclidean'
The distance metric to use. See Matching class.
"""

def __init__(self, reference_neighbor=1):
def __init__(self, reference_neighbor=1, metric="euclidean"):
super().__init__(metric=metric)
self.reference_neighbor = reference_neighbor

def _get_reference_distances(self, main, indices):
Expand All @@ -149,9 +184,17 @@ class OtherNeighbor(Matching):

Instead of the nearest neighbor, another (more distant neighbor) can be
chosen by setting ``reference_neighbor``.

Parameters
----------
reference_neighbor : int, default=1
Which neighbor to use as reference (1 = nearest, 2 = second nearest, etc.)
metric : str, default='euclidean'
The distance metric to use. See Matching class.
"""

def __init__(self, reference_neighbor=1):
def __init__(self, reference_neighbor=1, metric="euclidean"):
super().__init__(metric=metric)
self.reference_neighbor = reference_neighbor

def _get_reference_distances(self, main, indices):
Expand Down
81 changes: 81 additions & 0 deletions skrub/tests/test_fuzzy_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,84 @@ def test_missing_values(df_module):

c = fuzzy_join(b, a, left_on="col3", right_on="col1", add_match_info=True)
assert ns.shape(c)[0] == len(b)


@pytest.mark.parametrize(
("metric", "expected_match"),
[
("euclidean", {"xr": 1.0, "yr": 1.0, "idr": "A"}),
("manhattan", {"xr": 1.0, "yr": 1.0, "idr": "A"}),
("cosine", {"xr": 3.0, "yr": 0.0, "idr": "B"}),
],
)
def test_fuzzy_join_distance_metrics(df_module, metric, expected_match):
"""
Test that each metric produces the expected joined values.
"""
left = df_module.make_dataframe({"x": [1.0], "y": [0.0]})
right = df_module.make_dataframe(
{"x": [1.0, 3.0], "y": [1.0, 0.0], "id": ["A", "B"]}
)

result = fuzzy_join(
left,
right,
on=["x", "y"],
suffix="r",
metric=metric,
ref_dist="no_rescaling",
add_match_info=False,
)

expected = df_module.make_dataframe(
{
"x": [1.0],
"y": [0.0],
"xr": [expected_match["xr"]],
"yr": [expected_match["yr"]],
"idr": [expected_match["idr"]],
}
)
df_module.assert_frame_equal(result, expected)


def test_fuzzy_join_distance_metric_changes_match(df_module):
"""
Test that changing metric can change which row is selected as the match.
"""
left = df_module.make_dataframe({"x": [1.0], "y": [0.0]})
right = df_module.make_dataframe(
{"x": [1.0, 3.0], "y": [1.0, 0.0], "id": ["A", "B"]}
)
result_euclidean = fuzzy_join(
left,
right,
on=["x", "y"],
suffix="r",
metric="euclidean",
ref_dist="no_rescaling",
add_match_info=False,
)
result_cosine = fuzzy_join(
left,
right,
on=["x", "y"],
suffix="r",
metric="cosine",
ref_dist="no_rescaling",
add_match_info=False,
)
assert ns.to_list(ns.col(result_euclidean, "idr")) != ns.to_list(
ns.col(result_cosine, "idr")
)


def test_fuzzy_join_invalid_metric_raises(df_module):
"""
Test that an invalid metric raises an error.
"""
left = df_module.make_dataframe({"A": ["aa"]})
right = df_module.make_dataframe({"A": ["aa"], "B": [1]})

with pytest.raises(ValueError):
fuzzy_join(left, right, on="A", metric="invalid_metric")