|
14 | 14 | # KIND, either express or implied. See the License for the |
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | | -import datetime |
18 | 17 | from pathlib import PosixPath |
19 | 18 | from typing import Any |
20 | 19 |
|
|
28 | 27 | from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference |
29 | 28 | from pyiceberg.expressions.literals import LongLiteral |
30 | 29 | from pyiceberg.io.pyarrow import schema_to_pyarrow |
31 | | -from pyiceberg.partitioning import PartitionField, PartitionSpec |
32 | 30 | from pyiceberg.schema import Schema |
33 | 31 | from pyiceberg.table import Table, UpsertResult |
34 | 32 | from pyiceberg.table.snapshots import Operation |
35 | 33 | from pyiceberg.table.upsert_util import create_match_filter |
36 | | -from pyiceberg.transforms import IdentityTransform |
37 | | -from pyiceberg.types import DateType, IntegerType, NestedField, StringType, StructType |
| 34 | +from pyiceberg.types import IntegerType, NestedField, StringType, StructType |
38 | 35 | from tests.catalog.test_base import InMemoryCatalog |
39 | 36 |
|
40 | 37 |
|
@@ -894,174 +891,61 @@ def test_upsert_snapshot_properties(catalog: Catalog) -> None: |
894 | 891 | assert snapshot.summary.additional_properties.get("test_prop") == "test_value" |
895 | 892 |
|
896 | 893 |
|
897 | | -class TestUpsertScanProjection: |
| 894 | +def test_upsert_narrows_destination_scan_projection_to_join_cols( |
| 895 | + catalog: Catalog, |
| 896 | + monkeypatch: pytest.MonkeyPatch, |
| 897 | +) -> None: |
898 | 898 | """``Transaction.upsert`` narrows the destination scan's |
899 | | - ``selected_fields`` to ``join_cols`` when ``when_matched_update_all=False``. |
900 | | -
|
901 | | - Rationale: the insert-on-no-match branch only reads ``join_cols`` |
902 | | - off each destination batch (to feed ``create_match_filter``); every |
903 | | - other column is unused. Projection at the scan boundary lets the |
904 | | - parquet reader prune wide non-key columns at the file level — |
905 | | - significant for tables whose payload column (e.g. a JSON ``log``) |
906 | | - dominates file bytes. ``_projected_field_ids`` auto-unions the |
907 | | - row-filter's column ids back in, so any column referenced by the |
908 | | - join-key predicate is still readable for filter evaluation without |
909 | | - needing to list it explicitly. |
910 | | -
|
911 | | - Falls back to ``("*",)`` when ``when_matched_update_all=True`` |
912 | | - because ``get_rows_to_update`` reads every non-key column off the |
913 | | - destination row to detect value drift — narrowing would break the |
914 | | - no-op-write skip. |
| 899 | + ``selected_fields`` to ``join_cols`` when |
| 900 | + ``when_matched_update_all=False``. |
| 901 | +
|
| 902 | + The insert-on-no-match branch only reads ``join_cols`` from each |
| 903 | + destination batch (to feed ``create_match_filter``), so projection |
| 904 | + at the scan boundary lets the parquet reader skip wide non-key |
| 905 | + columns. The ``("*",)`` fallback on the ``=True`` branch is |
| 906 | + exercised by the rest of this module — ``get_rows_to_update``'s |
| 907 | + value-drift detection would silently break if it ever regressed. |
915 | 908 | """ |
| 909 | + import functools |
916 | 910 |
|
917 | | - @staticmethod |
918 | | - def _build_partitioned_table(catalog: Catalog, identifier: str) -> Table: |
919 | | - _drop_table(catalog, identifier) |
920 | | - schema = Schema( |
921 | | - NestedField(1, "order_id", IntegerType(), required=True), |
922 | | - NestedField(2, "order_date", DateType(), required=True), |
923 | | - NestedField(3, "order_type", StringType(), required=True), |
924 | | - ) |
925 | | - spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date")) |
926 | | - return catalog.create_table(identifier, schema=schema, partition_spec=spec) |
927 | | - |
928 | | - @staticmethod |
929 | | - def _arrow_schema() -> pa.Schema: |
930 | | - return pa.schema( |
931 | | - [ |
932 | | - pa.field("order_id", pa.int32(), nullable=False), |
933 | | - pa.field("order_date", pa.date32(), nullable=False), |
934 | | - pa.field("order_type", pa.string(), nullable=False), |
935 | | - ] |
936 | | - ) |
| 911 | + from pyiceberg.table import DataScan |
937 | 912 |
|
938 | | - def _seed(self, table: Table) -> None: |
939 | | - table.append( |
940 | | - pa.Table.from_pylist( |
941 | | - [ |
942 | | - {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"}, |
943 | | - {"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "A"}, |
944 | | - ], |
945 | | - schema=self._arrow_schema(), |
946 | | - ) |
947 | | - ) |
948 | | - |
949 | | - @pytest.fixture |
950 | | - def captured_scans(self, monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]: |
951 | | - """Spy on ``DataScan.__init__`` to capture every kwargs dict. |
952 | | -
|
953 | | - Lets the tests pin which ``selected_fields`` the upsert path |
954 | | - actually passes — assertions on the surfaced batch schema alone |
955 | | - would miss the case where the underlying projection contract |
956 | | - regresses but the test data happens to have only join_cols |
957 | | - anyway. |
958 | | -
|
959 | | - The spy preserves ``__init__``'s signature via |
960 | | - :func:`functools.wraps` so ``DataScan.update()``'s reflective |
961 | | - ``inspect.signature(type(self).__init__).parameters`` lookup |
962 | | - (used by ``use_ref``) still resolves to the real parameter |
963 | | - names, not the spy's ``**kwargs``. |
964 | | - """ |
965 | | - import functools |
966 | | - |
967 | | - from pyiceberg.table import DataScan |
968 | | - |
969 | | - captured: list[dict[str, Any]] = [] |
970 | | - original_init = DataScan.__init__ |
971 | | - |
972 | | - @functools.wraps(original_init) |
973 | | - def _spy(self: DataScan, *args: Any, **kwargs: Any) -> None: |
974 | | - captured.append(dict(kwargs)) |
975 | | - original_init(self, *args, **kwargs) |
976 | | - |
977 | | - monkeypatch.setattr(DataScan, "__init__", _spy) |
978 | | - return captured |
979 | | - |
980 | | - def test_when_matched_false_projects_join_cols_only(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None: |
981 | | - """The insert-on-no-match branch never reads non-key destination |
982 | | - columns, so the scan must narrow the projection to ``join_cols`` |
983 | | - — saving the parquet reader from materialising wide payload |
984 | | - columns just to be discarded.""" |
985 | | - table = self._build_partitioned_table(catalog, "default.test_upsert_projection_insert_only") |
986 | | - self._seed(table) |
987 | | - upsert_df = pa.Table.from_pylist( |
988 | | - [ |
989 | | - {"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "B"}, |
990 | | - {"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"}, |
991 | | - ], |
992 | | - schema=self._arrow_schema(), |
993 | | - ) |
994 | | - |
995 | | - # Snapshot only the scans constructed during the upsert (the |
996 | | - # seed append above may have created its own). |
997 | | - before = len(captured_scans) |
998 | | - res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=False) |
999 | | - upsert_scans = captured_scans[before:] |
1000 | | - assert res.rows_inserted == 1 |
1001 | | - assert res.rows_updated == 0 |
1002 | | - |
1003 | | - # The upsert constructs one DataScan for the destination match. |
1004 | | - # ``use_ref`` may construct a second DataScan as an inherited |
1005 | | - # copy (via ``self.update``), which carries the same |
1006 | | - # ``selected_fields`` through. Pin both: at least one scan was |
1007 | | - # constructed during the upsert, and every scan that ran |
1008 | | - # carries the narrowed projection. |
1009 | | - assert upsert_scans, "upsert path constructed no DataScan — projection contract regression" |
1010 | | - selected = [s.get("selected_fields") for s in upsert_scans] |
1011 | | - assert all(sf == ("order_id",) for sf in selected), ( |
1012 | | - f"expected every DataScan during upsert to use selected_fields=('order_id',); got {selected}" |
1013 | | - ) |
| 913 | + identifier = "default.test_upsert_narrows_projection" |
| 914 | + _drop_table(catalog, identifier) |
| 915 | + table = catalog.create_table( |
| 916 | + identifier, |
| 917 | + schema=Schema( |
| 918 | + NestedField(1, "id", IntegerType(), required=True), |
| 919 | + NestedField(2, "payload", StringType(), required=True), |
| 920 | + ), |
| 921 | + ) |
| 922 | + arrow_schema = pa.schema([pa.field("id", pa.int32(), nullable=False), pa.field("payload", pa.string(), nullable=False)]) |
| 923 | + table.append(pa.Table.from_pylist([{"id": 1, "payload": "a"}], schema=arrow_schema)) |
1014 | 924 |
|
1015 | | - def test_when_matched_true_keeps_star_projection(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None: |
1016 | | - """The update branch's ``get_rows_to_update`` compares non-key |
1017 | | - columns to detect actual value changes — projecting only |
1018 | | - ``join_cols`` would feed it data with no non-key columns to |
1019 | | - compare and silently turn every match into a write-back. Must |
1020 | | - keep ``("*",)``.""" |
1021 | | - table = self._build_partitioned_table(catalog, "default.test_upsert_projection_update_mode") |
1022 | | - self._seed(table) |
1023 | | - upsert_df = pa.Table.from_pylist( |
1024 | | - [ |
1025 | | - {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "B"}, |
1026 | | - {"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"}, |
1027 | | - ], |
1028 | | - schema=self._arrow_schema(), |
1029 | | - ) |
| 925 | + # Spy on ``DataScan.__init__`` to capture each constructed scan's |
| 926 | + # ``selected_fields``. ``functools.wraps`` preserves the original |
| 927 | + # signature so ``DataScan.update()``'s reflective parameter lookup |
| 928 | + # (used inside ``use_ref``) still resolves correctly. |
| 929 | + captured: list[tuple[str, ...] | None] = [] |
| 930 | + original_init = DataScan.__init__ |
1030 | 931 |
|
1031 | | - before = len(captured_scans) |
1032 | | - res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True) |
1033 | | - upsert_scans = captured_scans[before:] |
1034 | | - assert res.rows_updated == 1 |
1035 | | - assert res.rows_inserted == 1 |
| 932 | + @functools.wraps(original_init) |
| 933 | + def _spy(self: DataScan, *args: Any, **kwargs: Any) -> None: |
| 934 | + original_init(self, *args, **kwargs) |
| 935 | + captured.append(kwargs.get("selected_fields")) |
1036 | 936 |
|
1037 | | - assert upsert_scans, "upsert path constructed no DataScan — projection contract regression" |
1038 | | - selected = [s.get("selected_fields") for s in upsert_scans] |
1039 | | - assert all(sf == ("*",) for sf in selected), ( |
1040 | | - f"expected every DataScan during upsert to keep selected_fields=('*',) for the update branch; got {selected}" |
1041 | | - ) |
| 937 | + monkeypatch.setattr(DataScan, "__init__", _spy) |
1042 | 938 |
|
1043 | | - def test_update_mode_actually_updates_non_key_columns(self, catalog: Catalog) -> None: |
1044 | | - """End-to-end correctness pin: with ``when_matched_update_all=True`` |
1045 | | - the destination scan must read non-key columns so |
1046 | | - ``get_rows_to_update`` can detect ``order_type`` changes. A |
1047 | | - regression that narrows projection unconditionally would skip |
1048 | | - the comparison and silently miss updates whose non-key columns |
1049 | | - differ. |
1050 | | - """ |
1051 | | - identifier = "default.test_upsert_update_mode_correctness" |
1052 | | - table = self._build_partitioned_table(catalog, identifier) |
1053 | | - self._seed(table) |
1054 | | - # Source has the same (order_id, order_date) as one destination |
1055 | | - # row but a different ``order_type``. Update path must detect |
1056 | | - # the non-key change and overwrite. |
1057 | | - upsert_df = pa.Table.from_pylist( |
1058 | | - [{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "CHANGED"}], |
1059 | | - schema=self._arrow_schema(), |
1060 | | - ) |
1061 | | - res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True) |
1062 | | - assert res.rows_updated == 1 |
1063 | | - assert res.rows_inserted == 0 |
| 939 | + table.upsert( |
| 940 | + df=pa.Table.from_pylist( |
| 941 | + [{"id": 1, "payload": "a-new"}, {"id": 2, "payload": "b"}], |
| 942 | + schema=arrow_schema, |
| 943 | + ), |
| 944 | + join_cols=["id"], |
| 945 | + when_matched_update_all=False, |
| 946 | + ) |
1064 | 947 |
|
1065 | | - # Read back: the original 'A' must have been overwritten with 'CHANGED'. |
1066 | | - rows = {r["order_id"]: r for r in table.scan().to_arrow().to_pylist()} |
1067 | | - assert rows[2]["order_type"] == "CHANGED" |
| 948 | + assert captured, "upsert path constructed no DataScan — projection contract regression" |
| 949 | + assert all(sf == ("id",) for sf in captured), ( |
| 950 | + f"expected every DataScan during upsert to use selected_fields=('id',); got {captured}" |
| 951 | + ) |
0 commit comments