Skip to content

Commit 011b98b

Browse files
committed
fix: upsert with null values in join columns
1 parent 78615d2 commit 011b98b

File tree

2 files changed

+76
-13
lines changed

2 files changed

+76
-13
lines changed

pyiceberg/table/upsert_util.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,38 +14,61 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
import functools
18-
import operator
17+
from math import isnan
18+
from typing import Any
1919

2020
import pyarrow as pa
2121
from pyarrow import Table as pyarrow_table
2222
from pyarrow import compute as pc
2323

2424
from pyiceberg.expressions import (
2525
AlwaysFalse,
26+
And,
2627
BooleanExpression,
2728
EqualTo,
2829
In,
30+
IsNaN,
31+
IsNull,
2932
Or,
3033
)
3134

3235

3336
def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
3437
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
38+
filters = []
3539

3640
if len(join_cols) == 1:
37-
return In(join_cols[0], unique_keys[0].to_pylist())
41+
column = join_cols[0]
42+
values = set(unique_keys[0].to_pylist())
43+
44+
if None in values:
45+
filters.append(IsNull(column))
46+
values.remove(None)
47+
48+
if nans := {v for v in values if isinstance(v, float) and isnan(v)}:
49+
filters.append(IsNaN(column))
50+
values -= nans
51+
52+
filters.append(In(column, values))
53+
else:
54+
55+
def equals(column: str, value: Any) -> BooleanExpression:
56+
if value is None:
57+
return IsNull(column)
58+
59+
if isinstance(value, float) and isnan(value):
60+
return IsNaN(column)
61+
62+
return EqualTo(column, value)
63+
64+
filters = [And(*[equals(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()]
65+
66+
if len(filters) == 0:
67+
return AlwaysFalse()
68+
elif len(filters) == 1:
69+
return filters[0]
3870
else:
39-
filters = [
40-
functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()
41-
]
42-
43-
if len(filters) == 0:
44-
return AlwaysFalse()
45-
elif len(filters) == 1:
46-
return filters[0]
47-
else:
48-
return Or(*filters)
71+
return Or(*filters)
4972

5073

5174
def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:

tests/table/test_upsert.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,46 @@ def test_upsert_with_nulls(catalog: Catalog) -> None:
713713
schema=schema,
714714
)
715715

716+
# upsert table with null value
717+
data_with_null = pa.Table.from_pylist(
718+
[
719+
{"foo": None, "bar": 1, "baz": False},
720+
],
721+
schema=schema,
722+
)
723+
upd = table.upsert(data_with_null, join_cols=["foo"])
724+
assert upd.rows_updated == 0
725+
assert upd.rows_inserted == 1
726+
assert table.scan().to_arrow() == pa.Table.from_pylist(
727+
[
728+
{"foo": None, "bar": 1, "baz": False},
729+
{"foo": "apple", "bar": 7, "baz": False},
730+
{"foo": "banana", "bar": None, "baz": False},
731+
],
732+
schema=schema,
733+
)
734+
735+
# upsert table with null and non-null values, in two join columns
736+
data_with_null = pa.Table.from_pylist(
737+
[
738+
{"foo": None, "bar": 1, "baz": True},
739+
{"foo": "lemon", "bar": None, "baz": False},
740+
],
741+
schema=schema,
742+
)
743+
upd = table.upsert(data_with_null, join_cols=["foo", "bar"])
744+
assert upd.rows_updated == 1
745+
assert upd.rows_inserted == 1
746+
assert table.scan().to_arrow() == pa.Table.from_pylist(
747+
[
748+
{"foo": "lemon", "bar": None, "baz": False},
749+
{"foo": None, "bar": 1, "baz": True},
750+
{"foo": "apple", "bar": 7, "baz": False},
751+
{"foo": "banana", "bar": None, "baz": False},
752+
],
753+
schema=schema,
754+
)
755+
716756

717757
def test_transaction(catalog: Catalog) -> None:
718758
"""Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is

0 commit comments

Comments
 (0)