Skip to content

Commit 67b5f43

Browse files
Paul Mathewcursoragent
andcommitted
test(upsert): collapse projection coverage to a single function test
Reviewer feedback: ``TestUpsertScanProjection`` was excessive for what's effectively one boolean projection contract. Replaced the class (helpers, fixture, two tests) with one function-level test that pins the narrow ``join_cols`` projection on the ``when_matched_update_all=False`` branch. The ``("*",)`` fallback on the ``=True`` branch is covered transitively by every other upsert test in this module — ``get_rows_to_update``'s value-drift detection would surface any regression that narrows it. Dropped: ``test_when_matched_true_keeps_star_projection``, ``test_update_mode_actually_updates_non_key_columns``, and unused imports (``datetime``, ``PartitionField``, ``PartitionSpec``, ``IdentityTransform``, ``DateType``). Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent b81990d commit 67b5f43

1 file changed

Lines changed: 50 additions & 166 deletions

File tree

tests/table/test_upsert.py

Lines changed: 50 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
import datetime
1817
from pathlib import PosixPath
1918
from typing import Any
2019

@@ -28,13 +27,11 @@
2827
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
2928
from pyiceberg.expressions.literals import LongLiteral
3029
from pyiceberg.io.pyarrow import schema_to_pyarrow
31-
from pyiceberg.partitioning import PartitionField, PartitionSpec
3230
from pyiceberg.schema import Schema
3331
from pyiceberg.table import Table, UpsertResult
3432
from pyiceberg.table.snapshots import Operation
3533
from pyiceberg.table.upsert_util import create_match_filter
36-
from pyiceberg.transforms import IdentityTransform
37-
from pyiceberg.types import DateType, IntegerType, NestedField, StringType, StructType
34+
from pyiceberg.types import IntegerType, NestedField, StringType, StructType
3835
from tests.catalog.test_base import InMemoryCatalog
3936

4037

@@ -894,174 +891,61 @@ def test_upsert_snapshot_properties(catalog: Catalog) -> None:
894891
assert snapshot.summary.additional_properties.get("test_prop") == "test_value"
895892

896893

897-
class TestUpsertScanProjection:
894+
def test_upsert_narrows_destination_scan_projection_to_join_cols(
895+
catalog: Catalog,
896+
monkeypatch: pytest.MonkeyPatch,
897+
) -> None:
898898
"""``Transaction.upsert`` narrows the destination scan's
899-
``selected_fields`` to ``join_cols`` when ``when_matched_update_all=False``.
900-
901-
Rationale: the insert-on-no-match branch only reads ``join_cols``
902-
off each destination batch (to feed ``create_match_filter``); every
903-
other column is unused. Projection at the scan boundary lets the
904-
parquet reader prune wide non-key columns at the file level —
905-
significant for tables whose payload column (e.g. a JSON ``log``)
906-
dominates file bytes. ``_projected_field_ids`` auto-unions the
907-
row-filter's column ids back in, so any column referenced by the
908-
join-key predicate is still readable for filter evaluation without
909-
needing to list it explicitly.
910-
911-
Falls back to ``("*",)`` when ``when_matched_update_all=True``
912-
because ``get_rows_to_update`` reads every non-key column off the
913-
destination row to detect value drift — narrowing would break the
914-
no-op-write skip.
899+
``selected_fields`` to ``join_cols`` when
900+
``when_matched_update_all=False``.
901+
902+
The insert-on-no-match branch only reads ``join_cols`` from each
903+
destination batch (to feed ``create_match_filter``), so projection
904+
at the scan boundary lets the parquet reader skip wide non-key
905+
columns. The ``("*",)`` fallback on the ``=True`` branch is
906+
exercised by the rest of this module — ``get_rows_to_update``'s
907+
value-drift detection would silently break if it ever regressed.
915908
"""
909+
import functools
916910

917-
@staticmethod
918-
def _build_partitioned_table(catalog: Catalog, identifier: str) -> Table:
919-
_drop_table(catalog, identifier)
920-
schema = Schema(
921-
NestedField(1, "order_id", IntegerType(), required=True),
922-
NestedField(2, "order_date", DateType(), required=True),
923-
NestedField(3, "order_type", StringType(), required=True),
924-
)
925-
spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date"))
926-
return catalog.create_table(identifier, schema=schema, partition_spec=spec)
927-
928-
@staticmethod
929-
def _arrow_schema() -> pa.Schema:
930-
return pa.schema(
931-
[
932-
pa.field("order_id", pa.int32(), nullable=False),
933-
pa.field("order_date", pa.date32(), nullable=False),
934-
pa.field("order_type", pa.string(), nullable=False),
935-
]
936-
)
911+
from pyiceberg.table import DataScan
937912

938-
def _seed(self, table: Table) -> None:
939-
table.append(
940-
pa.Table.from_pylist(
941-
[
942-
{"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"},
943-
{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "A"},
944-
],
945-
schema=self._arrow_schema(),
946-
)
947-
)
948-
949-
@pytest.fixture
950-
def captured_scans(self, monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]:
951-
"""Spy on ``DataScan.__init__`` to capture every kwargs dict.
952-
953-
Lets the tests pin which ``selected_fields`` the upsert path
954-
actually passes — assertions on the surfaced batch schema alone
955-
would miss the case where the underlying projection contract
956-
regresses but the test data happens to have only join_cols
957-
anyway.
958-
959-
The spy preserves ``__init__``'s signature via
960-
:func:`functools.wraps` so ``DataScan.update()``'s reflective
961-
``inspect.signature(type(self).__init__).parameters`` lookup
962-
(used by ``use_ref``) still resolves to the real parameter
963-
names, not the spy's ``**kwargs``.
964-
"""
965-
import functools
966-
967-
from pyiceberg.table import DataScan
968-
969-
captured: list[dict[str, Any]] = []
970-
original_init = DataScan.__init__
971-
972-
@functools.wraps(original_init)
973-
def _spy(self: DataScan, *args: Any, **kwargs: Any) -> None:
974-
captured.append(dict(kwargs))
975-
original_init(self, *args, **kwargs)
976-
977-
monkeypatch.setattr(DataScan, "__init__", _spy)
978-
return captured
979-
980-
def test_when_matched_false_projects_join_cols_only(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None:
981-
"""The insert-on-no-match branch never reads non-key destination
982-
columns, so the scan must narrow the projection to ``join_cols``
983-
— saving the parquet reader from materialising wide payload
984-
columns just to be discarded."""
985-
table = self._build_partitioned_table(catalog, "default.test_upsert_projection_insert_only")
986-
self._seed(table)
987-
upsert_df = pa.Table.from_pylist(
988-
[
989-
{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "B"},
990-
{"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"},
991-
],
992-
schema=self._arrow_schema(),
993-
)
994-
995-
# Snapshot only the scans constructed during the upsert (the
996-
# seed append above may have created its own).
997-
before = len(captured_scans)
998-
res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=False)
999-
upsert_scans = captured_scans[before:]
1000-
assert res.rows_inserted == 1
1001-
assert res.rows_updated == 0
1002-
1003-
# The upsert constructs one DataScan for the destination match.
1004-
# ``use_ref`` may construct a second DataScan as an inherited
1005-
# copy (via ``self.update``), which carries the same
1006-
# ``selected_fields`` through. Pin both: at least one scan was
1007-
# constructed during the upsert, and every scan that ran
1008-
# carries the narrowed projection.
1009-
assert upsert_scans, "upsert path constructed no DataScan — projection contract regression"
1010-
selected = [s.get("selected_fields") for s in upsert_scans]
1011-
assert all(sf == ("order_id",) for sf in selected), (
1012-
f"expected every DataScan during upsert to use selected_fields=('order_id',); got {selected}"
1013-
)
913+
identifier = "default.test_upsert_narrows_projection"
914+
_drop_table(catalog, identifier)
915+
table = catalog.create_table(
916+
identifier,
917+
schema=Schema(
918+
NestedField(1, "id", IntegerType(), required=True),
919+
NestedField(2, "payload", StringType(), required=True),
920+
),
921+
)
922+
arrow_schema = pa.schema([pa.field("id", pa.int32(), nullable=False), pa.field("payload", pa.string(), nullable=False)])
923+
table.append(pa.Table.from_pylist([{"id": 1, "payload": "a"}], schema=arrow_schema))
1014924

1015-
def test_when_matched_true_keeps_star_projection(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None:
1016-
"""The update branch's ``get_rows_to_update`` compares non-key
1017-
columns to detect actual value changes — projecting only
1018-
``join_cols`` would feed it data with no non-key columns to
1019-
compare and silently turn every match into a write-back. Must
1020-
keep ``("*",)``."""
1021-
table = self._build_partitioned_table(catalog, "default.test_upsert_projection_update_mode")
1022-
self._seed(table)
1023-
upsert_df = pa.Table.from_pylist(
1024-
[
1025-
{"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "B"},
1026-
{"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"},
1027-
],
1028-
schema=self._arrow_schema(),
1029-
)
925+
# Spy on ``DataScan.__init__`` to capture each constructed scan's
926+
# ``selected_fields``. ``functools.wraps`` preserves the original
927+
# signature so ``DataScan.update()``'s reflective parameter lookup
928+
# (used inside ``use_ref``) still resolves correctly.
929+
captured: list[tuple[str, ...] | None] = []
930+
original_init = DataScan.__init__
1030931

1031-
before = len(captured_scans)
1032-
res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True)
1033-
upsert_scans = captured_scans[before:]
1034-
assert res.rows_updated == 1
1035-
assert res.rows_inserted == 1
932+
@functools.wraps(original_init)
933+
def _spy(self: DataScan, *args: Any, **kwargs: Any) -> None:
934+
original_init(self, *args, **kwargs)
935+
captured.append(kwargs.get("selected_fields"))
1036936

1037-
assert upsert_scans, "upsert path constructed no DataScan — projection contract regression"
1038-
selected = [s.get("selected_fields") for s in upsert_scans]
1039-
assert all(sf == ("*",) for sf in selected), (
1040-
f"expected every DataScan during upsert to keep selected_fields=('*',) for the update branch; got {selected}"
1041-
)
937+
monkeypatch.setattr(DataScan, "__init__", _spy)
1042938

1043-
def test_update_mode_actually_updates_non_key_columns(self, catalog: Catalog) -> None:
1044-
"""End-to-end correctness pin: with ``when_matched_update_all=True``
1045-
the destination scan must read non-key columns so
1046-
``get_rows_to_update`` can detect ``order_type`` changes. A
1047-
regression that narrows projection unconditionally would skip
1048-
the comparison and silently miss updates whose non-key columns
1049-
differ.
1050-
"""
1051-
identifier = "default.test_upsert_update_mode_correctness"
1052-
table = self._build_partitioned_table(catalog, identifier)
1053-
self._seed(table)
1054-
# Source has the same (order_id, order_date) as one destination
1055-
# row but a different ``order_type``. Update path must detect
1056-
# the non-key change and overwrite.
1057-
upsert_df = pa.Table.from_pylist(
1058-
[{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "CHANGED"}],
1059-
schema=self._arrow_schema(),
1060-
)
1061-
res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True)
1062-
assert res.rows_updated == 1
1063-
assert res.rows_inserted == 0
939+
table.upsert(
940+
df=pa.Table.from_pylist(
941+
[{"id": 1, "payload": "a-new"}, {"id": 2, "payload": "b"}],
942+
schema=arrow_schema,
943+
),
944+
join_cols=["id"],
945+
when_matched_update_all=False,
946+
)
1064947

1065-
# Read back: the original 'A' must have been overwritten with 'CHANGED'.
1066-
rows = {r["order_id"]: r for r in table.scan().to_arrow().to_pylist()}
1067-
assert rows[2]["order_type"] == "CHANGED"
948+
assert captured, "upsert path constructed no DataScan — projection contract regression"
949+
assert all(sf == ("id",) for sf in captured), (
950+
f"expected every DataScan during upsert to use selected_fields=('id',); got {captured}"
951+
)

0 commit comments

Comments
 (0)