Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 52 additions & 32 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,6 @@ def upsert(
"""Shorthand API for performing an upsert to an iceberg table.

Args:

df: The input dataframe to upsert with the table's data.
join_cols: Columns to join on, if not provided, it will use the identifier-field-ids.
when_matched_update_all: Bool indicating to update rows that are matched but require an update
Expand All @@ -777,26 +776,37 @@ def upsert(
branch: Branch Reference to run the upsert operation
snapshot_properties: Custom properties to be added to the snapshot summary

To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids
Note:
This method uses null-safe equality for matching rows on join columns, similar to
SQL's <=> operator or Spark's NULL-safe equal. This means:
- NULL values in join columns will match other NULL values
- A row with (key=NULL) in the source will update a row with (key=NULL) in the target

This is equivalent to Spark SQL:
MERGE INTO target USING source ON target.key <=> source.key

If you want standard SQL equality semantics where NULL never matches NULL,
filter out NULL values from the join columns before calling upsert.

Example Use Cases:
Case 1: Both Parameters = True (Full Upsert)
Existing row found → Update it
New row found → Insert it
To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids

Case 2: when_matched_update_all = False, when_not_matched_insert_all = True
Existing row found → Do nothing (no updates)
New row found → Insert it
Example Use Cases:
Case 1: Both Parameters = True (Full Upsert)
Existing row found → Update it
New row found → Insert it

Case 3: when_matched_update_all = True, when_not_matched_insert_all = False
Existing row found → Update it
New row found → Do nothing (no inserts)
Case 2: when_matched_update_all = False, when_not_matched_insert_all = True
Existing row found → Do nothing (no updates)
New row found → Insert it

Case 4: Both Parameters = False (No Merge Effect)
Existing row found → Do nothing
New row found → Do nothing
(Function effectively does nothing)
Case 3: when_matched_update_all = True, when_not_matched_insert_all = False
Existing row found → Update it
New row found → Do nothing (no inserts)

Case 4: Both Parameters = False (No Merge Effect)
Existing row found → Do nothing
New row found → Do nothing
(Function effectively does nothing)

Returns:
An UpsertResult class (contains details of rows updated and inserted)
Expand Down Expand Up @@ -1368,7 +1378,6 @@ def upsert(
"""Shorthand API for performing an upsert to an iceberg table.

Args:

df: The input dataframe to upsert with the table's data.
join_cols: Columns to join on, if not provided, it will use the identifier-field-ids.
when_matched_update_all: Bool indicating to update rows that are matched but require an update
Expand All @@ -1379,26 +1388,37 @@ def upsert(
branch: Branch Reference to run the upsert operation
snapshot_properties: Custom properties to be added to the snapshot summary

To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids
Note:
This method uses null-safe equality for matching rows on join columns, similar to
SQL's <=> operator or Spark's NULL-safe equal. This means:
- NULL values in join columns will match other NULL values
- A row with (key=NULL) in the source will update a row with (key=NULL) in the target

This is equivalent to Spark SQL:
MERGE INTO target USING source ON target.key <=> source.key

If you want standard SQL equality semantics where NULL never matches NULL,
filter out NULL values from the join columns before calling upsert.

Example Use Cases:
Case 1: Both Parameters = True (Full Upsert)
Existing row found → Update it
New row found → Insert it
To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids

Case 2: when_matched_update_all = False, when_not_matched_insert_all = True
Existing row found → Do nothing (no updates)
New row found → Insert it
Example Use Cases:
Case 1: Both Parameters = True (Full Upsert)
Existing row found → Update it
New row found → Insert it

Case 3: when_matched_update_all = True, when_not_matched_insert_all = False
Existing row found → Update it
New row found → Do nothing (no inserts)
Case 2: when_matched_update_all = False, when_not_matched_insert_all = True
Existing row found → Do nothing (no updates)
New row found → Insert it

Case 4: Both Parameters = False (No Merge Effect)
Existing row found → Do nothing
New row found → Do nothing
(Function effectively does nothing)
Case 3: when_matched_update_all = True, when_not_matched_insert_all = False
Existing row found → Update it
New row found → Do nothing (no inserts)

Case 4: Both Parameters = False (No Merge Effect)
Existing row found → Do nothing
New row found → Do nothing
(Function effectively does nothing)

Returns:
An UpsertResult class (contains details of rows updated and inserted)
Expand Down
81 changes: 62 additions & 19 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,67 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import functools
import operator
from math import isnan
from typing import Any

import pyarrow as pa
from pyarrow import Table as pyarrow_table
from pyarrow import compute as pc

from pyiceberg.expressions import (
AlwaysFalse,
And,
BooleanExpression,
EqualTo,
In,
IsNaN,
IsNull,
Or,
)


def _is_nan(value: Any) -> bool:
"""Check if a value is NaN (only applicable to floats)."""
return isinstance(value, float) and isnan(value)


def _null_safe_equals(column: str, value: Any) -> BooleanExpression:
"""Create a null-safe equality expression (like SQL's <=> operator)."""
if value is None:
return IsNull(column)
if _is_nan(value):
return IsNaN(column)
return EqualTo(column, value)


def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
filters: list[BooleanExpression] = []

if len(join_cols) == 1:
return In(join_cols[0], unique_keys[0].to_pylist())
column = join_cols[0]
values = set(unique_keys[0].to_pylist())

# Handle NULL and NaN separately since IN expression doesn't support them
if None in values:
filters.append(IsNull(column))
values.discard(None)

if nans := {v for v in values if _is_nan(v)}:
filters.append(IsNaN(column))
values -= nans

if values:
filters.append(In(column, values))
else:
filters = [
functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()
]
filters = [And(*[_null_safe_equals(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()]

if len(filters) == 0:
return AlwaysFalse()
elif len(filters) == 1:
return filters[0]
else:
return Or(*filters)
if len(filters) == 0:
return AlwaysFalse()
elif len(filters) == 1:
return filters[0]
else:
return Or(*filters)


def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
Expand Down Expand Up @@ -97,16 +126,30 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
# Step 2: Prepare target index with join keys and a marker
target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table))))

# Step 3: Perform an inner join to find which rows from source exist in target
matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
# Step 3: Perform an inner join to find which rows from source exist in target.
# PyArrow's join ignores NULL values (NULL == NULL returns UNKNOWN in SQL semantics).
# We want null-safe equality where NULL == NULL is TRUE, so we fall back to Python when NULLs are present.
has_nulls = any(source_index.column(col).null_count > 0 or target_index.column(col).null_count > 0 for col in join_cols)

if has_nulls:
# Python-based null-safe join
source_keys = {tuple(row[col] for col in join_cols): row[SOURCE_INDEX_COLUMN_NAME] for row in source_index.to_pylist()}
target_keys = {tuple(row[col] for col in join_cols): row[TARGET_INDEX_COLUMN_NAME] for row in target_index.to_pylist()}
matching_indices = [(s, t) for key, s in source_keys.items() if (t := target_keys.get(key)) is not None]
else:
# Fast PyArrow join (no nulls to worry about)
joined = source_index.join(target_index, keys=join_cols, join_type="inner")
matching_indices = list(
zip(
joined[SOURCE_INDEX_COLUMN_NAME].to_pylist(),
joined[TARGET_INDEX_COLUMN_NAME].to_pylist(),
strict=True,
)
)

# Step 4: Compare all rows using Python
to_update_indices = []
for source_idx, target_idx in zip(
matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(),
matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist(),
strict=True,
):
for source_idx, target_idx in matching_indices:
source_row = source_table.slice(source_idx, 1)
target_row = target_table.slice(target_idx, 1)
Comment on lines +129 to 154
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for the null-safe join converts the PyArrow tables to Python lists and dictionaries using to_pylist(). This can be very inefficient and memory-intensive for large tables, potentially leading to performance bottlenecks or out-of-memory errors.

PyArrow's join method supports null-safe equality since version 7.0.0 via the null_matching_behavior='equal' parameter. Using this would be much more performant as it keeps the operations within PyArrow's memory space.

I suggest reverting to the PyArrow join and adding this parameter.

Suggested change
# Step 3: Perform an inner join to find which rows from source exist in target.
# We use a Python-based join instead of PyArrow's join because PyArrow ignores NULL values
# (NULL == NULL returns UNKNOWN in SQL semantics). We want null-safe equality where NULL == NULL is TRUE.
source_keys = {tuple(row[col] for col in join_cols): row[SOURCE_INDEX_COLUMN_NAME] for row in source_index.to_pylist()}
target_keys = {tuple(row[col] for col in join_cols): row[TARGET_INDEX_COLUMN_NAME] for row in target_index.to_pylist()}
matching_indices = [(s, t) for key, s in source_keys.items() if (t := target_keys.get(key)) is not None]
# Step 4: Compare all rows using Python
to_update_indices = []
for source_idx, target_idx in zip(
matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(),
matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist(),
strict=True,
):
for source_idx, target_idx in matching_indices:
source_row = source_table.slice(source_idx, 1)
target_row = target_table.slice(target_idx, 1)
# Step 3: Perform an inner join to find which rows from source exist in target.
# PyArrow's join operator can perform null-safe joins.
matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner", null_matching_behavior="equal")
# Step 4: Compare all rows using Python
to_update_indices = []
for source_idx, target_idx in zip(
matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(),
matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist(),
strict=True,
):
source_row = source_table.slice(source_idx, 1)
target_row = target_table.slice(target_idx, 1)


Expand Down
130 changes: 128 additions & 2 deletions tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
from pyiceberg.expressions.literals import LongLiteral
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, In, IsNaN, IsNull, Or, Reference
from pyiceberg.expressions.literals import DoubleLiteral, LongLiteral
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.schema import Schema
from pyiceberg.table import Table, UpsertResult
Expand Down Expand Up @@ -443,6 +443,82 @@ def test_create_match_filter_single_condition() -> None:
)


def test_create_match_filter_single_column_without_null() -> None:
data = [{"x": 1.0}, {"x": 2.0}, {"x": 3.0}]

schema = pa.schema([pa.field("x", pa.float64())])
table = pa.Table.from_pylist(data, schema=schema)

expr = create_match_filter(table, join_cols=["x"])

assert expr == In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(3.0)})


def test_create_match_filter_single_column_with_null() -> None:
data = [
{"x": 1.0},
{"x": 2.0},
{"x": None},
{"x": 4.0},
{"x": float("nan")},
]
schema = pa.schema([pa.field("x", pa.float64())])
table = pa.Table.from_pylist(data, schema=schema)

expr = create_match_filter(table, join_cols=["x"])

assert expr == Or(
left=IsNull(term=Reference(name="x")),
right=Or(
left=IsNaN(term=Reference(name="x")),
right=In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(4.0)}),
),
)


def test_create_match_filter_multi_column_with_null() -> None:
data = [
{"x": 1.0, "y": 9.0},
{"x": 2.0, "y": None},
{"x": None, "y": 7.0},
{"x": 4.0, "y": float("nan")},
{"x": float("nan"), "y": 0.0},
]
schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())])
table = pa.Table.from_pylist(data, schema=schema)

expr = create_match_filter(table, join_cols=["x", "y"])

assert expr == Or(
left=Or(
left=And(
left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(1.0)),
right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(9.0)),
),
right=And(
left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(2.0)),
right=IsNull(term=Reference(name="y")),
),
),
right=Or(
left=And(
left=IsNull(term=Reference(name="x")),
right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(7.0)),
),
right=Or(
left=And(
left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(4.0)),
right=IsNaN(term=Reference(name="y")),
),
right=And(
left=IsNaN(term=Reference(name="x")),
right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(0.0)),
),
),
),
)


def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_duplicate_rows_in_table"

Expand Down Expand Up @@ -714,6 +790,56 @@ def test_upsert_with_nulls(catalog: Catalog) -> None:
)


def test_upsert_with_nulls_in_join_columns(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_nulls_in_join_columns"
_drop_table(catalog, identifier)

schema = pa.schema(
[
("foo", pa.string()),
("bar", pa.int32()),
("baz", pa.bool_()),
]
)
table = catalog.create_table(identifier, schema)

# upsert table with null value
data_with_null = pa.Table.from_pylist(
[
{"foo": None, "bar": 1, "baz": False},
],
schema=schema,
)
upd = table.upsert(data_with_null, join_cols=["foo"])
assert upd.rows_updated == 0
assert upd.rows_inserted == 1
assert table.scan().to_arrow() == pa.Table.from_pylist(
[
{"foo": None, "bar": 1, "baz": False},
],
schema=schema,
)

# upsert table with null and non-null values, in two join columns
data_with_null = pa.Table.from_pylist(
[
{"foo": None, "bar": 1, "baz": True},
{"foo": "lemon", "bar": None, "baz": False},
],
schema=schema,
)
upd = table.upsert(data_with_null, join_cols=["foo", "bar"])
assert upd.rows_updated == 1
assert upd.rows_inserted == 1
assert table.scan().to_arrow() == pa.Table.from_pylist(
[
{"foo": "lemon", "bar": None, "baz": False},
{"foo": None, "bar": 1, "baz": True},
],
schema=schema,
)


def test_transaction(catalog: Catalog) -> None:
"""Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is
rolled back."""
Expand Down