Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,12 +877,17 @@ def upsert(
# get list of rows that exist so we don't have to load the entire target table
matched_predicate = upsert_util.create_match_filter(df, join_cols)

# When ``when_matched_update_all=False`` the consumer loop below
# only ever reads ``join_cols`` off each destination batch.
selected_fields: tuple[str, ...] = ("*",) if when_matched_update_all else tuple(join_cols)

# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.

matched_iceberg_record_batches_scan = DataScan(
table_metadata=self.table_metadata,
io=self._table.io,
row_filter=matched_predicate,
selected_fields=selected_fields,
case_sensitive=case_sensitive,
)

Expand Down
179 changes: 178 additions & 1 deletion tests/table/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
from pathlib import PosixPath
from typing import Any

import pyarrow as pa
import pytest
Expand All @@ -26,11 +28,13 @@
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
from pyiceberg.expressions.literals import LongLiteral
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import Table, UpsertResult
from pyiceberg.table.snapshots import Operation
from pyiceberg.table.upsert_util import create_match_filter
from pyiceberg.types import IntegerType, NestedField, StringType, StructType
from pyiceberg.transforms import IdentityTransform
from pyiceberg.types import DateType, IntegerType, NestedField, StringType, StructType
from tests.catalog.test_base import InMemoryCatalog


Expand Down Expand Up @@ -888,3 +892,176 @@ def test_upsert_snapshot_properties(catalog: Catalog) -> None:
for snapshot in snapshots[initial_snapshot_count:]:
assert snapshot.summary is not None
assert snapshot.summary.additional_properties.get("test_prop") == "test_value"


class TestUpsertScanProjection:
Comment thread
paultmathew marked this conversation as resolved.
Outdated
"""``Transaction.upsert`` narrows the destination scan's
``selected_fields`` to ``join_cols`` when ``when_matched_update_all=False``.

Rationale: the insert-on-no-match branch only reads ``join_cols``
off each destination batch (to feed ``create_match_filter``); every
other column is unused. Projection at the scan boundary lets the
parquet reader prune wide non-key columns at the file level —
significant for tables whose payload column (e.g. a JSON ``log``)
dominates file bytes. ``_projected_field_ids`` auto-unions the
row-filter's column ids back in, so any column referenced by the
join-key predicate is still readable for filter evaluation without
needing to list it explicitly.

Falls back to ``("*",)`` when ``when_matched_update_all=True``
because ``get_rows_to_update`` reads every non-key column off the
destination row to detect value drift — narrowing would break the
no-op-write skip.
"""

@staticmethod
def _build_partitioned_table(catalog: Catalog, identifier: str) -> Table:
_drop_table(catalog, identifier)
schema = Schema(
NestedField(1, "order_id", IntegerType(), required=True),
NestedField(2, "order_date", DateType(), required=True),
NestedField(3, "order_type", StringType(), required=True),
)
spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date"))
return catalog.create_table(identifier, schema=schema, partition_spec=spec)

@staticmethod
def _arrow_schema() -> pa.Schema:
return pa.schema(
[
pa.field("order_id", pa.int32(), nullable=False),
pa.field("order_date", pa.date32(), nullable=False),
pa.field("order_type", pa.string(), nullable=False),
]
)

def _seed(self, table: Table) -> None:
table.append(
pa.Table.from_pylist(
[
{"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"},
{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "A"},
],
schema=self._arrow_schema(),
)
)

@pytest.fixture
def captured_scans(self, monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]:
"""Spy on ``DataScan.__init__`` to capture every kwargs dict.

Lets the tests pin which ``selected_fields`` the upsert path
actually passes — assertions on the surfaced batch schema alone
would miss the case where the underlying projection contract
regresses but the test data happens to have only join_cols
anyway.

The spy preserves ``__init__``'s signature via
:func:`functools.wraps` so ``DataScan.update()``'s reflective
``inspect.signature(type(self).__init__).parameters`` lookup
(used by ``use_ref``) still resolves to the real parameter
names, not the spy's ``**kwargs``.
"""
import functools

from pyiceberg.table import DataScan

captured: list[dict[str, Any]] = []
original_init = DataScan.__init__

@functools.wraps(original_init)
def _spy(self: DataScan, *args: Any, **kwargs: Any) -> None:
captured.append(dict(kwargs))
original_init(self, *args, **kwargs)

monkeypatch.setattr(DataScan, "__init__", _spy)
return captured

def test_when_matched_false_projects_join_cols_only(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None:
"""The insert-on-no-match branch never reads non-key destination
columns, so the scan must narrow the projection to ``join_cols``
— saving the parquet reader from materialising wide payload
columns just to be discarded."""
table = self._build_partitioned_table(catalog, "default.test_upsert_projection_insert_only")
self._seed(table)
upsert_df = pa.Table.from_pylist(
[
{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "B"},
{"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"},
],
schema=self._arrow_schema(),
)

# Snapshot only the scans constructed during the upsert (the
# seed append above may have created its own).
before = len(captured_scans)
res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=False)
upsert_scans = captured_scans[before:]
assert res.rows_inserted == 1
assert res.rows_updated == 0

# The upsert constructs one DataScan for the destination match.
# ``use_ref`` may construct a second DataScan as an inherited
# copy (via ``self.update``), which carries the same
# ``selected_fields`` through. Pin both: at least one scan was
# constructed during the upsert, and every scan that ran
# carries the narrowed projection.
assert upsert_scans, "upsert path constructed no DataScan — projection contract regression"
selected = [s.get("selected_fields") for s in upsert_scans]
assert all(sf == ("order_id",) for sf in selected), (
f"expected every DataScan during upsert to use selected_fields=('order_id',); got {selected}"
)

def test_when_matched_true_keeps_star_projection(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None:
"""The update branch's ``get_rows_to_update`` compares non-key
columns to detect actual value changes — projecting only
``join_cols`` would feed it data with no non-key columns to
compare and silently turn every match into a write-back. Must
keep ``("*",)``."""
table = self._build_partitioned_table(catalog, "default.test_upsert_projection_update_mode")
self._seed(table)
upsert_df = pa.Table.from_pylist(
[
{"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "B"},
{"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"},
],
schema=self._arrow_schema(),
)

before = len(captured_scans)
res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True)
upsert_scans = captured_scans[before:]
assert res.rows_updated == 1
assert res.rows_inserted == 1

assert upsert_scans, "upsert path constructed no DataScan — projection contract regression"
selected = [s.get("selected_fields") for s in upsert_scans]
assert all(sf == ("*",) for sf in selected), (
f"expected every DataScan during upsert to keep selected_fields=('*',) for the update branch; got {selected}"
)

def test_update_mode_actually_updates_non_key_columns(self, catalog: Catalog) -> None:
"""End-to-end correctness pin: with ``when_matched_update_all=True``
the destination scan must read non-key columns so
``get_rows_to_update`` can detect ``order_type`` changes. A
regression that narrows projection unconditionally would skip
the comparison and silently miss updates whose non-key columns
differ.
"""
identifier = "default.test_upsert_update_mode_correctness"
table = self._build_partitioned_table(catalog, identifier)
self._seed(table)
# Source has the same (order_id, order_date) as one destination
# row but a different ``order_type``. Update path must detect
# the non-key change and overwrite.
upsert_df = pa.Table.from_pylist(
[{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "CHANGED"}],
schema=self._arrow_schema(),
)
res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True)
assert res.rows_updated == 1
assert res.rows_inserted == 0

# Read back: the original 'A' must have been overwritten with 'CHANGED'.
rows = {r["order_id"]: r for r in table.scan().to_arrow().to_pylist()}
assert rows[2]["order_type"] == "CHANGED"