Skip to content

Commit 846d23e

Browse files
timsaucerclaude
andcommitted
Consolidate except_all/except_distinct and intersect/intersect_distinct into single methods with distinct flag
Follows the same pattern as union(distinct=) and union_by_name(distinct=). Also deprecates union_distinct() in favor of union(distinct=True). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1f35ec1 commit 846d23e

File tree

3 files changed

+66
-100
lines changed

3 files changed

+66
-100
lines changed

crates/core/src/dataframe.rs

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -890,17 +890,6 @@ impl PyDataFrame {
890890
Ok(Self::new(new_df))
891891
}
892892

893-
/// Calculate the distinct union of two `DataFrame`s. The
894-
/// two `DataFrame`s must have exactly the same schema
895-
fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
896-
let new_df = self
897-
.df
898-
.as_ref()
899-
.clone()
900-
.union_distinct(py_df.df.as_ref().clone())?;
901-
Ok(Self::new(new_df))
902-
}
903-
904893
#[pyo3(signature = (column, preserve_nulls=true, recursions=None))]
905894
fn unnest_column(
906895
&self,
@@ -935,38 +924,28 @@ impl PyDataFrame {
935924
}
936925

937926
/// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
938-
fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
939-
let new_df = self
940-
.df
941-
.as_ref()
942-
.clone()
943-
.intersect(py_df.df.as_ref().clone())?;
927+
#[pyo3(signature = (py_df, distinct=false))]
928+
fn intersect(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
929+
let base = self.df.as_ref().clone();
930+
let other = py_df.df.as_ref().clone();
931+
let new_df = if distinct {
932+
base.intersect_distinct(other)?
933+
} else {
934+
base.intersect(other)?
935+
};
944936
Ok(Self::new(new_df))
945937
}
946938

947939
/// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
948-
fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
949-
let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
950-
Ok(Self::new(new_df))
951-
}
952-
953-
/// Calculate the set difference with deduplication
954-
fn except_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
955-
let new_df = self
956-
.df
957-
.as_ref()
958-
.clone()
959-
.except_distinct(py_df.df.as_ref().clone())?;
960-
Ok(Self::new(new_df))
961-
}
962-
963-
/// Calculate the intersection with deduplication
964-
fn intersect_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
965-
let new_df = self
966-
.df
967-
.as_ref()
968-
.clone()
969-
.intersect_distinct(py_df.df.as_ref().clone())?;
940+
#[pyo3(signature = (py_df, distinct=false))]
941+
fn except_all(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
942+
let base = self.df.as_ref().clone();
943+
let other = py_df.df.as_ref().clone();
944+
let new_df = if distinct {
945+
base.except_distinct(other)?
946+
} else {
947+
base.except(other)?
948+
};
970949
Ok(Self::new(new_df))
971950
}
972951

python/datafusion/dataframe.py

Lines changed: 34 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,96 +1179,76 @@ def union(self, other: DataFrame, distinct: bool = False) -> DataFrame:
11791179
"""
11801180
return DataFrame(self.df.union(other.df, distinct))
11811181

1182+
@deprecated(
1183+
"union_distinct() is deprecated. Use union(other, distinct=True) instead."
1184+
)
11821185
def union_distinct(self, other: DataFrame) -> DataFrame:
11831186
"""Calculate the distinct union of two :py:class:`DataFrame`.
11841187
1185-
The two :py:class:`DataFrame` must have exactly the same schema.
1186-
Any duplicate rows are discarded.
1187-
1188-
Args:
1189-
other: DataFrame to union with.
1190-
1191-
Returns:
1192-
DataFrame after union.
1188+
See Also:
1189+
:py:meth:`union`
11931190
"""
1194-
return DataFrame(self.df.union_distinct(other.df))
1191+
return self.union(other, distinct=True)
11951192

1196-
def intersect(self, other: DataFrame) -> DataFrame:
1193+
def intersect(self, other: DataFrame, distinct: bool = False) -> DataFrame:
11971194
"""Calculate the intersection of two :py:class:`DataFrame`.
11981195
11991196
The two :py:class:`DataFrame` must have exactly the same schema.
12001197
12011198
Args:
1202-
other: DataFrame to intersect with.
1199+
other: DataFrame to intersect with.
1200+
distinct: If ``True``, duplicate rows are removed from the result.
12031201
12041202
Returns:
12051203
DataFrame after intersection.
1206-
"""
1207-
return DataFrame(self.df.intersect(other.df))
12081204
1209-
def except_all(self, other: DataFrame) -> DataFrame:
1210-
"""Calculate the exception of two :py:class:`DataFrame`.
1205+
Examples:
1206+
Find rows common to both DataFrames:
12111207
1212-
The two :py:class:`DataFrame` must have exactly the same schema.
1208+
>>> ctx = dfn.SessionContext()
1209+
>>> df1 = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]})
1210+
>>> df2 = ctx.from_pydict({"a": [1, 4], "b": [10, 40]})
1211+
>>> df1.intersect(df2).to_pydict()
1212+
{'a': [1], 'b': [10]}
12131213
1214-
Args:
1215-
other: DataFrame to calculate exception with.
1214+
Intersect with deduplication:
12161215
1217-
Returns:
1218-
DataFrame after exception.
1216+
>>> df1 = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 10, 20]})
1217+
>>> df2 = ctx.from_pydict({"a": [1, 1], "b": [10, 10]})
1218+
>>> df1.intersect(df2, distinct=True).to_pydict()
1219+
{'a': [1], 'b': [10]}
12191220
"""
1220-
return DataFrame(self.df.except_all(other.df))
1221+
return DataFrame(self.df.intersect(other.df, distinct))
12211222

1222-
def except_distinct(self, other: DataFrame) -> DataFrame:
1223-
"""Calculate the set difference with deduplication.
1223+
def except_all(self, other: DataFrame, distinct: bool = False) -> DataFrame:
1224+
"""Calculate the set difference of two :py:class:`DataFrame`.
12241225
1225-
Returns rows that are in this DataFrame but not in ``other``,
1226-
removing any duplicates. In contrast, :py:meth:`except_all` preserves
1227-
duplicate rows.
1226+
Returns rows that are in this DataFrame but not in ``other``.
12281227
12291228
The two :py:class:`DataFrame` must have exactly the same schema.
12301229
12311230
Args:
12321231
other: DataFrame to calculate exception with.
1232+
distinct: If ``True``, duplicate rows are removed from the result.
12331233
12341234
Returns:
1235-
DataFrame after set difference with deduplication.
1235+
DataFrame after set difference.
12361236
12371237
Examples:
1238-
Remove rows present in ``df2`` and deduplicate:
1238+
Remove rows present in ``df2``:
12391239
12401240
>>> ctx = dfn.SessionContext()
1241-
>>> df1 = ctx.from_pydict({"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]})
1241+
>>> df1 = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]})
12421242
>>> df2 = ctx.from_pydict({"a": [1, 2], "b": [10, 20]})
1243-
>>> df1.except_distinct(df2).sort("a").to_pydict()
1243+
>>> df1.except_all(df2).sort("a").to_pydict()
12441244
{'a': [3], 'b': [30]}
1245-
"""
1246-
return DataFrame(self.df.except_distinct(other.df))
1247-
1248-
def intersect_distinct(self, other: DataFrame) -> DataFrame:
1249-
"""Calculate the intersection with deduplication.
1250-
1251-
Returns distinct rows that appear in both DataFrames. In contrast,
1252-
:py:meth:`intersect` preserves duplicate rows.
1253-
1254-
The two :py:class:`DataFrame` must have exactly the same schema.
1255-
1256-
Args:
1257-
other: DataFrame to intersect with.
1258-
1259-
Returns:
1260-
DataFrame after intersection with deduplication.
12611245
1262-
Examples:
1263-
Find rows common to both DataFrames:
1246+
Remove rows present in ``df2`` and deduplicate:
12641247
1265-
>>> ctx = dfn.SessionContext()
1266-
>>> df1 = ctx.from_pydict({"a": [1, 2, 3], "b": [10, 20, 30]})
1267-
>>> df2 = ctx.from_pydict({"a": [1, 4], "b": [10, 40]})
1268-
>>> df1.intersect_distinct(df2).to_pydict()
1269-
{'a': [1], 'b': [10]}
1248+
>>> df1.except_all(df2, distinct=True).sort("a").to_pydict()
1249+
{'a': [3], 'b': [30]}
12701250
"""
1271-
return DataFrame(self.df.intersect_distinct(other.df))
1251+
return DataFrame(self.df.except_all(other.df, distinct))
12721252

12731253
def union_by_name(self, other: DataFrame, distinct: bool = False) -> DataFrame:
12741254
"""Union two :py:class:`DataFrame` matching columns by name.

python/tests/test_dataframe.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3573,40 +3573,47 @@ def test_read_parquet_file_sort_order(tmp_path, file_sort_order):
35733573

35743574

35753575
@pytest.mark.parametrize(
3576-
("df1_data", "df2_data", "method", "expected_a", "expected_b"),
3576+
("df1_data", "df2_data", "method", "kwargs", "expected_a", "expected_b"),
35773577
[
35783578
pytest.param(
35793579
{"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]},
35803580
{"a": [1, 2], "b": [10, 20]},
3581-
"except_distinct",
3581+
"except_all",
3582+
{"distinct": True},
35823583
[3],
35833584
[30],
3584-
id="except_distinct: removes matching rows and deduplicates",
3585+
id="except_all(distinct=True): removes matching rows and deduplicates",
35853586
),
35863587
pytest.param(
35873588
{"a": [1, 2, 3, 1], "b": [10, 20, 30, 10]},
35883589
{"a": [1, 4], "b": [10, 40]},
3589-
"intersect_distinct",
3590+
"intersect",
3591+
{"distinct": True},
35903592
[1],
35913593
[10],
3592-
id="intersect_distinct: keeps common rows and deduplicates",
3594+
id="intersect(distinct=True): keeps common rows and deduplicates",
35933595
),
35943596
pytest.param(
35953597
{"a": [1], "b": [10]},
35963598
{"b": [20], "a": [2]}, # reversed column order tests matching by name
35973599
"union_by_name",
3600+
{},
35983601
[1, 2],
35993602
[10, 20],
36003603
id="union_by_name: matches columns by name not position",
36013604
),
36023605
],
36033606
)
3604-
def test_set_operations_distinct(df1_data, df2_data, method, expected_a, expected_b):
3607+
def test_set_operations_distinct(
3608+
df1_data, df2_data, method, kwargs, expected_a, expected_b
3609+
):
36053610
ctx = SessionContext()
36063611
df1 = ctx.from_pydict(df1_data)
36073612
df2 = ctx.from_pydict(df2_data)
36083613
result = (
3609-
getattr(df1, method)(df2).sort(column("a").sort(ascending=True)).collect()[0]
3614+
getattr(df1, method)(df2, **kwargs)
3615+
.sort(column("a").sort(ascending=True))
3616+
.collect()[0]
36103617
)
36113618
assert result.column(0).to_pylist() == expected_a
36123619
assert result.column(1).to_pylist() == expected_b

0 commit comments

Comments
 (0)