diff --git a/CHANGES.rst b/CHANGES.rst index ea617c8fa..446e78bbe 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -22,6 +22,10 @@ Changes - The row indices of training and testing samples are now also included in the dictionaries produced by :meth:`DataOp.skb.iter_cv_splits`. :pr:`2012` by :user:`Jérôme Dockès `. +- Added a ``metric`` parameter to :func:`fuzzy_join` and :class:`Joiner` to configure + the nearest-neighbor distance used for matching. The metric can be any value + supported by :class:`~sklearn.neighbors.NearestNeighbors` (see its docstring). + :pr:`1861` by :user:`Saba Siddique `. Bugfixes -------- diff --git a/skrub/_fuzzy_join.py b/skrub/_fuzzy_join.py index f10af3c2f..629e3eb73 100644 --- a/skrub/_fuzzy_join.py +++ b/skrub/_fuzzy_join.py @@ -22,6 +22,7 @@ def fuzzy_join( string_encoder=DEFAULT_STRING_ENCODER, add_match_info=False, drop_unmatched=False, + metric="euclidean", ): """Fuzzy (approximate) join. @@ -35,9 +36,9 @@ def fuzzy_join( To identify the best match for each row, values from the matching columns (``left_key`` and ``right_key``) are vectorized, i.e. represented by vectors of - continuous values. Then, the Euclidean distances between these vectors are - computed to find, for each left table row, its nearest neighbor within the - right table. + continuous values. Then, distances between these vectors are computed + (using the specified metric) to find, for each left table row, its nearest + neighbor within the right table. Optionally, a maximum distance threshold, ``max_dist``, can be set. Matches between vectors that are separated by a distance (strictly) greater than @@ -120,6 +121,9 @@ def fuzzy_join( drop_unmatched : bool, default=False Remove rows for which a match was not found in the right table (i.e. for which the nearest neighbor is further than `max_dist`). + metric : str, default='euclidean' + The distance metric to use for nearest neighbor search. + See :class:`~sklearn.neighbors.NearestNeighbors` for all available metrics. Returns ------- @@ -208,6 +212,7 @@ def fuzzy_join( ref_dist=ref_dist, string_encoder=string_encoder, add_match_info=True, + metric=metric, ).fit_transform(left) if drop_unmatched: join = sbd.filter(join, sbd.col(join, "skrub_Joiner_match_accepted")) diff --git a/skrub/_joiner.py b/skrub/_joiner.py index c94569be7..fcc751e26 100644 --- a/skrub/_joiner.py +++ b/skrub/_joiner.py @@ -87,9 +87,9 @@ class Joiner(TransformerMixin, BaseEstimator): To identify the best match for each row, values from the matching columns (`main_key` and `aux_key`) are vectorized, i.e. represented by vectors of - continuous values. Then, the Euclidean distances between these vectors are - computed to find, for each main table row, its nearest neighbor within the - auxiliary table. + continuous values. Then, the distances between these vectors are computed + (using the specified metric) to find, for each main table row, its + nearest neighbor within the auxiliary table. Optionally, a maximum distance threshold, `max_dist`, can be set. Matches between vectors that are separated by a distance (strictly) greater than @@ -168,6 +168,9 @@ class Joiner(TransformerMixin, BaseEstimator): above the threshold. Those values can be helpful for an estimator that uses the joined features, or to inspect the result of the join and set a `max_dist` threshold. + metric : str, default='euclidean' + The distance metric to use for nearest neighbor search. + See :class:`~sklearn.neighbors.NearestNeighbors` for all available metrics. Attributes ---------- @@ -236,6 +239,7 @@ def __init__( ref_dist=DEFAULT_REF_DIST, string_encoder=DEFAULT_STRING_ENCODER, add_match_info=True, + metric="euclidean", ): self.aux_table = aux_table self.key = key @@ -250,6 +254,7 @@ def __init__( else string_encoder ) self.add_match_info = add_match_info + self.metric = metric def _check_max_dist(self): if ( @@ -267,7 +272,7 @@ def _check_ref_dist(self): f"'ref_dist' should be one of {list(_MATCHERS.keys())}. Got" f" {self.ref_dist!r}" ) - self._matching = _MATCHERS[self.ref_dist]() + self._matching = _MATCHERS[self.ref_dist](metric=self.metric) def fit_transform(self, X, y=None): """Fit the instance to the main table. diff --git a/skrub/_matching.py b/skrub/_matching.py index 26791b1c2..9c9933eff 100644 --- a/skrub/_matching.py +++ b/skrub/_matching.py @@ -15,11 +15,20 @@ class Matching(BaseEstimator): nearest match are rescaled. This base class does not apply any rescaling: the reference distance is 1.0. Subclasses can override ``_get_reference_distances`` to modify the rescaling behavior. + + Parameters + ---------- + metric : str, default='euclidean' + The distance metric to use for nearest neighbor search. + See :class:`~sklearn.neighbors.NearestNeighbors` for all available metrics. """ + def __init__(self, metric="euclidean"): + self.metric = metric + def fit(self, aux): self.aux_ = aux - self.neighbors_ = NearestNeighbors(n_neighbors=1).fit(aux) + self.neighbors_ = NearestNeighbors(n_neighbors=1, metric=self.metric).fit(aux) return self def match(self, main, max_dist): @@ -72,9 +81,27 @@ class RandomPairs(Matching): Pairs of (different) rows are sampled randomly from the auxiliary table, and the distance that separates each pair is computed. The reference distance is a percentile (by default 25%) of the resulting distances. + + Parameters + ---------- + percentile : float, default=25.0 + Percentile to use for reference distance. + n_sampled_pairs : int, default=500 + Number of random pairs to sample. + random_state : int or RandomState, default=0 + Random state for sampling. + metric : str, default='euclidean' + The distance metric to use. See Matching class. """ - def __init__(self, percentile=25.0, n_sampled_pairs=500, random_state=0): + def __init__( + self, + percentile=25.0, + n_sampled_pairs=500, + random_state=0, + metric="euclidean", + ): + super().__init__(metric=metric) self.percentile = percentile self.n_sampled_pairs = n_sampled_pairs self.random_state = random_state @@ -121,9 +148,17 @@ class SelfJoinNeighbor(Matching): Instead of the nearest neighbor, another (more distant neighbor) can be chosen by setting ``reference_neighbor``. + + Parameters + ---------- + reference_neighbor : int, default=1 + Which neighbor to use as reference (1 = nearest, 2 = second nearest, etc.) + metric : str, default='euclidean' + The distance metric to use. See Matching class. """ - def __init__(self, reference_neighbor=1): + def __init__(self, reference_neighbor=1, metric="euclidean"): + super().__init__(metric=metric) self.reference_neighbor = reference_neighbor def _get_reference_distances(self, main, indices): @@ -149,9 +184,17 @@ class OtherNeighbor(Matching): Instead of the nearest neighbor, another (more distant neighbor) can be chosen by setting ``reference_neighbor``. + + Parameters + ---------- + reference_neighbor : int, default=1 + Which neighbor to use as reference (1 = nearest, 2 = second nearest, etc.) + metric : str, default='euclidean' + The distance metric to use. See Matching class. """ - def __init__(self, reference_neighbor=1): + def __init__(self, reference_neighbor=1, metric="euclidean"): + super().__init__(metric=metric) self.reference_neighbor = reference_neighbor def _get_reference_distances(self, main, indices): diff --git a/skrub/tests/test_fuzzy_join.py b/skrub/tests/test_fuzzy_join.py index 34aa2397d..6a963576c 100644 --- a/skrub/tests/test_fuzzy_join.py +++ b/skrub/tests/test_fuzzy_join.py @@ -510,3 +510,84 @@ def test_missing_values(df_module): c = fuzzy_join(b, a, left_on="col3", right_on="col1", add_match_info=True) assert ns.shape(c)[0] == len(b) + + +@pytest.mark.parametrize( + ("metric", "expected_match"), + [ + ("euclidean", {"xr": 1.0, "yr": 1.0, "idr": "A"}), + ("manhattan", {"xr": 1.0, "yr": 1.0, "idr": "A"}), + ("cosine", {"xr": 3.0, "yr": 0.0, "idr": "B"}), + ], +) +def test_fuzzy_join_distance_metrics(df_module, metric, expected_match): + """ + Test that each metric produces the expected joined values. + """ + left = df_module.make_dataframe({"x": [1.0], "y": [0.0]}) + right = df_module.make_dataframe( + {"x": [1.0, 3.0], "y": [1.0, 0.0], "id": ["A", "B"]} + ) + + result = fuzzy_join( + left, + right, + on=["x", "y"], + suffix="r", + metric=metric, + ref_dist="no_rescaling", + add_match_info=False, + ) + + expected = df_module.make_dataframe( + { + "x": [1.0], + "y": [0.0], + "xr": [expected_match["xr"]], + "yr": [expected_match["yr"]], + "idr": [expected_match["idr"]], + } + ) + df_module.assert_frame_equal(result, expected) + + +def test_fuzzy_join_distance_metric_changes_match(df_module): + """ + Test that changing metric can change which row is selected as the match. + """ + left = df_module.make_dataframe({"x": [1.0], "y": [0.0]}) + right = df_module.make_dataframe( + {"x": [1.0, 3.0], "y": [1.0, 0.0], "id": ["A", "B"]} + ) + result_euclidean = fuzzy_join( + left, + right, + on=["x", "y"], + suffix="r", + metric="euclidean", + ref_dist="no_rescaling", + add_match_info=False, + ) + result_cosine = fuzzy_join( + left, + right, + on=["x", "y"], + suffix="r", + metric="cosine", + ref_dist="no_rescaling", + add_match_info=False, + ) + assert ns.to_list(ns.col(result_euclidean, "idr")) != ns.to_list( + ns.col(result_cosine, "idr") + ) + + +def test_fuzzy_join_invalid_metric_raises(df_module): + """ + Test that an invalid metric raises an error. + """ + left = df_module.make_dataframe({"A": ["aa"]}) + right = df_module.make_dataframe({"A": ["aa"], "B": [1]}) + + with pytest.raises(ValueError): + fuzzy_join(left, right, on="A", metric="invalid_metric")