diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index cc0d9ff341..2f7c65478a 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -766,7 +766,6 @@ def upsert( """Shorthand API for performing an upsert to an iceberg table. Args: - df: The input dataframe to upsert with the table's data. join_cols: Columns to join on, if not provided, it will use the identifier-field-ids. when_matched_update_all: Bool indicating to update rows that are matched but require an update @@ -777,26 +776,37 @@ def upsert( branch: Branch Reference to run the upsert operation snapshot_properties: Custom properties to be added to the snapshot summary - To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids + Note: + This method uses null-safe equality for matching rows on join columns, similar to + SQL's <=> operator or Spark's NULL-safe equal. This means: + - NULL values in join columns will match other NULL values + - A row with (key=NULL) in the source will update a row with (key=NULL) in the target + + This is equivalent to Spark SQL: + MERGE INTO target USING source ON target.key <=> source.key + + If you want standard SQL equality semantics where NULL never matches NULL, + filter out NULL values from the join columns before calling upsert. - Example Use Cases: - Case 1: Both Parameters = True (Full Upsert) - Existing row found → Update it - New row found → Insert it + To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids - Case 2: when_matched_update_all = False, when_not_matched_insert_all = True - Existing row found → Do nothing (no updates) - New row found → Insert it + Example Use Cases: + Case 1: Both Parameters = True (Full Upsert) + Existing row found → Update it + New row found → Insert it - Case 3: when_matched_update_all = True, when_not_matched_insert_all = False - Existing row found → Update it - New row found → Do nothing (no inserts) + Case 2: when_matched_update_all = False, when_not_matched_insert_all = True + Existing row found → Do nothing (no updates) + New row found → Insert it - Case 4: Both Parameters = False (No Merge Effect) - Existing row found → Do nothing - New row found → Do nothing - (Function effectively does nothing) + Case 3: when_matched_update_all = True, when_not_matched_insert_all = False + Existing row found → Update it + New row found → Do nothing (no inserts) + Case 4: Both Parameters = False (No Merge Effect) + Existing row found → Do nothing + New row found → Do nothing + (Function effectively does nothing) Returns: An UpsertResult class (contains details of rows updated and inserted) @@ -1368,7 +1378,6 @@ def upsert( """Shorthand API for performing an upsert to an iceberg table. Args: - df: The input dataframe to upsert with the table's data. join_cols: Columns to join on, if not provided, it will use the identifier-field-ids. when_matched_update_all: Bool indicating to update rows that are matched but require an update @@ -1379,26 +1388,37 @@ def upsert( branch: Branch Reference to run the upsert operation snapshot_properties: Custom properties to be added to the snapshot summary - To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids + Note: + This method uses null-safe equality for matching rows on join columns, similar to + SQL's <=> operator or Spark's NULL-safe equal. This means: + - NULL values in join columns will match other NULL values + - A row with (key=NULL) in the source will update a row with (key=NULL) in the target + + This is equivalent to Spark SQL: + MERGE INTO target USING source ON target.key <=> source.key + + If you want standard SQL equality semantics where NULL never matches NULL, + filter out NULL values from the join columns before calling upsert. - Example Use Cases: - Case 1: Both Parameters = True (Full Upsert) - Existing row found → Update it - New row found → Insert it + To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids - Case 2: when_matched_update_all = False, when_not_matched_insert_all = True - Existing row found → Do nothing (no updates) - New row found → Insert it + Example Use Cases: + Case 1: Both Parameters = True (Full Upsert) + Existing row found → Update it + New row found → Insert it - Case 3: when_matched_update_all = True, when_not_matched_insert_all = False - Existing row found → Update it - New row found → Do nothing (no inserts) + Case 2: when_matched_update_all = False, when_not_matched_insert_all = True + Existing row found → Do nothing (no updates) + New row found → Insert it - Case 4: Both Parameters = False (No Merge Effect) - Existing row found → Do nothing - New row found → Do nothing - (Function effectively does nothing) + Case 3: when_matched_update_all = True, when_not_matched_insert_all = False + Existing row found → Update it + New row found → Do nothing (no inserts) + Case 4: Both Parameters = False (No Merge Effect) + Existing row found → Do nothing + New row found → Do nothing + (Function effectively does nothing) Returns: An UpsertResult class (contains details of rows updated and inserted) diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 6f32826eb0..ce933f5cde 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import functools -import operator +from math import isnan +from typing import Any import pyarrow as pa from pyarrow import Table as pyarrow_table @@ -23,29 +23,58 @@ from pyiceberg.expressions import ( AlwaysFalse, + And, BooleanExpression, EqualTo, In, + IsNaN, + IsNull, Or, ) +def _is_nan(value: Any) -> bool: + """Check if a value is NaN (only applicable to floats).""" + return isinstance(value, float) and isnan(value) + + +def _null_safe_equals(column: str, value: Any) -> BooleanExpression: + """Create a null-safe equality expression (like SQL's <=> operator).""" + if value is None: + return IsNull(column) + if _is_nan(value): + return IsNaN(column) + return EqualTo(column, value) + + def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: unique_keys = df.select(join_cols).group_by(join_cols).aggregate([]) + filters: list[BooleanExpression] = [] if len(join_cols) == 1: - return In(join_cols[0], unique_keys[0].to_pylist()) + column = join_cols[0] + values = set(unique_keys[0].to_pylist()) + + # Handle NULL and NaN separately since IN expression doesn't support them + if None in values: + filters.append(IsNull(column)) + values.discard(None) + + if nans := {v for v in values if _is_nan(v)}: + filters.append(IsNaN(column)) + values -= nans + + if values: + filters.append(In(column, values)) else: - filters = [ - functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist() - ] + filters = [And(*[_null_safe_equals(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()] - if len(filters) == 0: - return AlwaysFalse() - elif len(filters) == 1: - return filters[0] - else: - return Or(*filters) + if len(filters) == 0: + return AlwaysFalse() + elif len(filters) == 1: + return filters[0] + else: + return Or(*filters) def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool: @@ -97,16 +126,30 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols # Step 2: Prepare target index with join keys and a marker target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table)))) - # Step 3: Perform an inner join to find which rows from source exist in target - matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner") + # Step 3: Perform an inner join to find which rows from source exist in target. + # PyArrow's join ignores NULL values (NULL == NULL returns UNKNOWN in SQL semantics). + # We want null-safe equality where NULL == NULL is TRUE, so we fall back to Python when NULLs are present. + has_nulls = any(source_index.column(col).null_count > 0 or target_index.column(col).null_count > 0 for col in join_cols) + + if has_nulls: + # Python-based null-safe join + source_keys = {tuple(row[col] for col in join_cols): row[SOURCE_INDEX_COLUMN_NAME] for row in source_index.to_pylist()} + target_keys = {tuple(row[col] for col in join_cols): row[TARGET_INDEX_COLUMN_NAME] for row in target_index.to_pylist()} + matching_indices = [(s, t) for key, s in source_keys.items() if (t := target_keys.get(key)) is not None] + else: + # Fast PyArrow join (no nulls to worry about) + joined = source_index.join(target_index, keys=join_cols, join_type="inner") + matching_indices = list( + zip( + joined[SOURCE_INDEX_COLUMN_NAME].to_pylist(), + joined[TARGET_INDEX_COLUMN_NAME].to_pylist(), + strict=True, + ) + ) # Step 4: Compare all rows using Python to_update_indices = [] - for source_idx, target_idx in zip( - matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(), - matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist(), - strict=True, - ): + for source_idx, target_idx in matching_indices: source_row = source_table.slice(source_idx, 1) target_row = target_table.slice(target_idx, 1) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 08f90c6600..599293af50 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -23,8 +23,8 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference -from pyiceberg.expressions.literals import LongLiteral +from pyiceberg.expressions import AlwaysTrue, And, EqualTo, In, IsNaN, IsNull, Or, Reference +from pyiceberg.expressions.literals import DoubleLiteral, LongLiteral from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema from pyiceberg.table import Table, UpsertResult @@ -443,6 +443,82 @@ def test_create_match_filter_single_condition() -> None: ) +def test_create_match_filter_single_column_without_null() -> None: + data = [{"x": 1.0}, {"x": 2.0}, {"x": 3.0}] + + schema = pa.schema([pa.field("x", pa.float64())]) + table = pa.Table.from_pylist(data, schema=schema) + + expr = create_match_filter(table, join_cols=["x"]) + + assert expr == In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(3.0)}) + + +def test_create_match_filter_single_column_with_null() -> None: + data = [ + {"x": 1.0}, + {"x": 2.0}, + {"x": None}, + {"x": 4.0}, + {"x": float("nan")}, + ] + schema = pa.schema([pa.field("x", pa.float64())]) + table = pa.Table.from_pylist(data, schema=schema) + + expr = create_match_filter(table, join_cols=["x"]) + + assert expr == Or( + left=IsNull(term=Reference(name="x")), + right=Or( + left=IsNaN(term=Reference(name="x")), + right=In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(4.0)}), + ), + ) + + +def test_create_match_filter_multi_column_with_null() -> None: + data = [ + {"x": 1.0, "y": 9.0}, + {"x": 2.0, "y": None}, + {"x": None, "y": 7.0}, + {"x": 4.0, "y": float("nan")}, + {"x": float("nan"), "y": 0.0}, + ] + schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())]) + table = pa.Table.from_pylist(data, schema=schema) + + expr = create_match_filter(table, join_cols=["x", "y"]) + + assert expr == Or( + left=Or( + left=And( + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(1.0)), + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(9.0)), + ), + right=And( + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(2.0)), + right=IsNull(term=Reference(name="y")), + ), + ), + right=Or( + left=And( + left=IsNull(term=Reference(name="x")), + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(7.0)), + ), + right=Or( + left=And( + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(4.0)), + right=IsNaN(term=Reference(name="y")), + ), + right=And( + left=IsNaN(term=Reference(name="x")), + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(0.0)), + ), + ), + ), + ) + + def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None: identifier = "default.test_upsert_with_duplicate_rows_in_table" @@ -714,6 +790,56 @@ def test_upsert_with_nulls(catalog: Catalog) -> None: ) +def test_upsert_with_nulls_in_join_columns(catalog: Catalog) -> None: + identifier = "default.test_upsert_with_nulls_in_join_columns" + _drop_table(catalog, identifier) + + schema = pa.schema( + [ + ("foo", pa.string()), + ("bar", pa.int32()), + ("baz", pa.bool_()), + ] + ) + table = catalog.create_table(identifier, schema) + + # upsert table with null value + data_with_null = pa.Table.from_pylist( + [ + {"foo": None, "bar": 1, "baz": False}, + ], + schema=schema, + ) + upd = table.upsert(data_with_null, join_cols=["foo"]) + assert upd.rows_updated == 0 + assert upd.rows_inserted == 1 + assert table.scan().to_arrow() == pa.Table.from_pylist( + [ + {"foo": None, "bar": 1, "baz": False}, + ], + schema=schema, + ) + + # upsert table with null and non-null values, in two join columns + data_with_null = pa.Table.from_pylist( + [ + {"foo": None, "bar": 1, "baz": True}, + {"foo": "lemon", "bar": None, "baz": False}, + ], + schema=schema, + ) + upd = table.upsert(data_with_null, join_cols=["foo", "bar"]) + assert upd.rows_updated == 1 + assert upd.rows_inserted == 1 + assert table.scan().to_arrow() == pa.Table.from_pylist( + [ + {"foo": "lemon", "bar": None, "baz": False}, + {"foo": None, "bar": 1, "baz": True}, + ], + schema=schema, + ) + + def test_transaction(catalog: Catalog) -> None: """Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is rolled back."""