Skip to content

Commit ea9065e

Browse files
Ayush PatelAyush Patel
authored andcommitted
feat: Add merge() to Table and Transaction
Atomic delete-insert merge by join columns using per-column In filters for file pruning and in-memory anti-join for row-level correctness, committed as a single OVERWRITE snapshot. Unlike upsert(), does not enforce uniqueness on source or target.
1 parent a8a577f commit ea9065e

3 files changed

Lines changed: 1177 additions & 0 deletions

File tree

pyiceberg/table/__init__.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,14 @@ class UpsertResult:
178178
rows_inserted: int = 0
179179

180180

181+
@dataclass()
182+
class MergeResult:
183+
"""Summary of the merge operation."""
184+
185+
rows_deleted: int = 0
186+
rows_inserted: int = 0
187+
188+
181189
class TableProperties:
182190
PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes"
183191
PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024 # 128 MB
@@ -885,6 +893,126 @@ def upsert(
885893

886894
return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)
887895

896+
def merge(
897+
self,
898+
df: pa.Table,
899+
join_cols: List[str],
900+
snapshot_properties: Dict[str, str] = EMPTY_DICT,
901+
branch: Optional[str] = MAIN_BRANCH,
902+
check_duplicate_keys: bool = False,
903+
) -> MergeResult:
904+
"""Atomic delete-insert merge by join columns.
905+
906+
Deletes all target rows matching the source data's join column
907+
values and inserts the source rows, all in a single OVERWRITE
908+
snapshot.
909+
910+
Uses per-column ``In`` filters for file pruning (O(sum of
911+
cardinalities) instead of O(product)), then an in-memory
912+
anti-join for row-level correctness.
913+
914+
Unlike ``upsert()``, does not enforce uniqueness on source or
915+
target by default.
916+
917+
Args:
918+
df: The Arrow dataframe containing replacement rows.
919+
join_cols: Columns used to match source rows against target rows.
920+
snapshot_properties: Custom properties to be added to the snapshot summary.
921+
branch: Branch reference to run the operation.
922+
check_duplicate_keys: If True, raise ValueError when the source
923+
data contains duplicate key tuples based on the join columns.
924+
This is a data quality guard, not a correctness requirement -
925+
merge() produces correct results with duplicate keys.
926+
927+
Returns:
928+
A MergeResult with row counts (deleted from target, inserted from source).
929+
"""
930+
try:
931+
import pyarrow as pa
932+
import pyarrow.compute as pc
933+
except ModuleNotFoundError as e:
934+
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
935+
936+
import functools
937+
938+
from pyiceberg.expressions import In
939+
from pyiceberg.io.pyarrow import ArrowScan, _check_pyarrow_schema_compatible, _dataframe_to_data_files
940+
from pyiceberg.table import upsert_util
941+
942+
if not isinstance(df, pa.Table):
943+
raise ValueError(f"Expected PyArrow table, got: {df}")
944+
945+
if not join_cols:
946+
raise ValueError("join_cols must be a non-empty list of column names.")
947+
948+
missing = set(join_cols) - set(df.column_names)
949+
if missing:
950+
raise ValueError(f"join_cols not found in source data: {missing}")
951+
952+
if df.num_rows == 0:
953+
return MergeResult()
954+
955+
if check_duplicate_keys and upsert_util.has_duplicate_rows(df, join_cols):
956+
raise ValueError("Duplicate rows found in source data based on join columns.")
957+
958+
downcast_ns = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
959+
_check_pyarrow_schema_compatible(
960+
self.table_metadata.schema(),
961+
provided_schema=df.schema,
962+
downcast_ns_timestamp_to_us=downcast_ns,
963+
format_version=self.table_metadata.format_version,
964+
)
965+
966+
# Step 1: Build per-column In filters for file pruning.
967+
# O(sum of cardinalities) instead of O(product).
968+
# Over-approximates the match set, which is fine - row-level
969+
# correctness is enforced by the anti-join in step 3.
970+
in_filters: list[BooleanExpression] = [In(col, pc.unique(df[col]).to_pylist()) for col in join_cols]
971+
candidate_filter: BooleanExpression = functools.reduce(And, in_filters)
972+
973+
# Step 2: Find candidate files via manifest pruning.
974+
scan = self._scan(row_filter=candidate_filter, case_sensitive=True)
975+
if branch is not None and branch in self.table_metadata.refs:
976+
scan = scan.use_ref(branch)
977+
tasks = list(scan.plan_files())
978+
979+
if not tasks:
980+
# No files overlap - just append.
981+
self.append(df, snapshot_properties=snapshot_properties, branch=branch)
982+
return MergeResult(rows_inserted=df.num_rows)
983+
984+
# Step 3: Read ALL rows from candidate files, anti-join to keep
985+
# non-matching rows. The candidate_filter was only for file
986+
# pruning - row-level correctness comes from the anti-join.
987+
arrow_scan = ArrowScan(
988+
self.table_metadata,
989+
self._table.io,
990+
projected_schema=self.table_metadata.schema(),
991+
row_filter=ALWAYS_TRUE,
992+
case_sensitive=True,
993+
)
994+
target_data = arrow_scan.to_table(tasks)
995+
source_keys = df.select(join_cols)
996+
997+
kept_rows = target_data.join(source_keys, keys=join_cols, join_type="left anti")
998+
rows_deleted = target_data.num_rows - kept_rows.num_rows
999+
new_content = pa.concat_tables([kept_rows, df], promote_options="default")
1000+
1001+
# Step 4: Atomic single-snapshot commit.
1002+
# Delete old files, append rewritten content.
1003+
with self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch).overwrite() as overwrite_op:
1004+
for task in tasks:
1005+
overwrite_op.delete_data_file(task.file)
1006+
for data_file in _dataframe_to_data_files(
1007+
table_metadata=self.table_metadata,
1008+
df=new_content,
1009+
io=self._table.io,
1010+
write_uuid=overwrite_op.commit_uuid,
1011+
):
1012+
overwrite_op.append_data_file(data_file)
1013+
1014+
return MergeResult(rows_deleted=rows_deleted, rows_inserted=df.num_rows)
1015+
8881016
def add_files(
8891017
self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True
8901018
) -> None:
@@ -1415,6 +1543,40 @@ def upsert(
14151543
branch=branch,
14161544
)
14171545

1546+
def merge(
1547+
self,
1548+
df: pa.Table,
1549+
join_cols: List[str],
1550+
snapshot_properties: Dict[str, str] = EMPTY_DICT,
1551+
branch: Optional[str] = MAIN_BRANCH,
1552+
check_duplicate_keys: bool = False,
1553+
) -> MergeResult:
1554+
"""Atomic delete-insert merge by join columns.
1555+
1556+
Unlike ``upsert()``, does not enforce uniqueness on source or
1557+
target by default.
1558+
1559+
Args:
1560+
df: The Arrow dataframe containing replacement rows.
1561+
join_cols: Columns used to match source rows against target rows.
1562+
snapshot_properties: Custom properties to be added to the snapshot summary.
1563+
branch: Branch reference to run the operation.
1564+
check_duplicate_keys: If True, raise ValueError when the source
1565+
data contains duplicate key tuples based on the join columns.
1566+
This is a data quality guard, not a correctness requirement.
1567+
1568+
Returns:
1569+
A MergeResult with row counts (deleted from target, inserted from source).
1570+
"""
1571+
with self.transaction() as tx:
1572+
return tx.merge(
1573+
df=df,
1574+
join_cols=join_cols,
1575+
snapshot_properties=snapshot_properties,
1576+
branch=branch,
1577+
check_duplicate_keys=check_duplicate_keys,
1578+
)
1579+
14181580
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None:
14191581
"""
14201582
Shorthand API for appending a PyArrow table to the table.
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""End-to-end benchmark: merge() vs create_match_filter + overwrite().
18+
19+
Compares the full write path (filter construction + file I/O + snapshot commit)
20+
between the new merge() implementation and the previous approach.
21+
22+
Usage:
23+
poetry run pytest tests/benchmark/test_merge_filter.py -v -s -m benchmark
24+
"""
25+
26+
import gc
27+
import itertools
28+
import timeit
29+
import tracemalloc
30+
from pathlib import PosixPath
31+
from typing import Any, Callable
32+
33+
import pyarrow as pa
34+
import pytest
35+
36+
from pyiceberg.catalog import Catalog
37+
from pyiceberg.exceptions import NoSuchTableError
38+
from pyiceberg.schema import Schema
39+
from pyiceberg.table import Table
40+
from pyiceberg.table.upsert_util import create_match_filter
41+
from pyiceberg.types import IntegerType, NestedField, StringType
42+
from tests.catalog.test_base import InMemoryCatalog
43+
44+
45+
def _make_schema(col_cardinalities: dict[str, int]) -> Schema:
46+
fields = []
47+
for i, col in enumerate(col_cardinalities):
48+
field_type = IntegerType() if col == "date_id" else StringType()
49+
fields.append(NestedField(i + 1, col, field_type, required=True))
50+
fields.append(NestedField(len(col_cardinalities) + 1, "v", IntegerType(), required=True))
51+
return Schema(*fields)
52+
53+
54+
def _build_table(col_cardinalities: dict[str, int]) -> tuple[pa.Table, list[str], Schema]:
55+
from pyiceberg.io.pyarrow import schema_to_pyarrow
56+
57+
schema = _make_schema(col_cardinalities)
58+
arrow_schema = schema_to_pyarrow(schema)
59+
60+
vals: list[list[Any]] = []
61+
for col, card in col_cardinalities.items():
62+
if col == "date_id":
63+
vals.append(list(range(20260101, 20260101 + card)))
64+
else:
65+
vals.append([f"{col}_{i}" for i in range(card)])
66+
combos = list(itertools.product(*vals))
67+
data = {col: [c[i] for c in combos] for i, col in enumerate(col_cardinalities)}
68+
data["v"] = list(range(len(combos)))
69+
return pa.table(data, schema=arrow_schema), list(col_cardinalities.keys()), schema
70+
71+
72+
def _fresh_table(catalog: Catalog, name: str, schema: Schema, data: pa.Table) -> Table:
73+
ident = f"default.{name}"
74+
try:
75+
catalog.drop_table(ident)
76+
except NoSuchTableError:
77+
pass
78+
tbl = catalog.create_table(ident, schema=schema)
79+
tbl.append(data)
80+
return tbl
81+
82+
83+
def _measure(fn: Callable[[], Any], runs: int = 3) -> tuple[float, int]:
84+
"""Returns (avg_seconds, peak_memory_bytes)."""
85+
times = []
86+
peak = 0
87+
for _ in range(runs):
88+
gc.collect()
89+
tracemalloc.start()
90+
t0 = timeit.default_timer()
91+
fn()
92+
times.append(timeit.default_timer() - t0)
93+
_, p = tracemalloc.get_traced_memory()
94+
tracemalloc.stop()
95+
peak = max(peak, p)
96+
return sum(times) / len(times), peak
97+
98+
99+
def _fmt(secs: float, mem_bytes: int) -> str:
100+
mem = f"{mem_bytes / 1024:.0f} KB" if mem_bytes < 1048576 else f"{mem_bytes / 1048576:.1f} MB"
101+
return f"{secs * 1000:.0f} ms, peak {mem}"
102+
103+
104+
COLS = {"date_id": 252, "account": 100} # 25,200 target rows
105+
106+
107+
@pytest.mark.benchmark
108+
@pytest.mark.parametrize("n_source", [100, 5000])
109+
def test_e2e_merge(n_source: int, tmp_path: PosixPath) -> None:
110+
"""End-to-end merge(): per-column In + anti-join + single OVERWRITE snapshot."""
111+
target_data, join_cols, schema = _build_table(COLS)
112+
113+
catalog = InMemoryCatalog("bench", warehouse=str(tmp_path))
114+
catalog.create_namespace("default")
115+
tbl = _fresh_table(catalog, f"merge_{n_source}", schema, target_data)
116+
117+
source_dict = {col: target_data.column(col).to_pylist()[:n_source] for col in target_data.column_names}
118+
source_dict["v"] = [x + 9000 for x in source_dict["v"]]
119+
source = pa.table(source_dict, schema=target_data.schema)
120+
121+
avg, peak = _measure(lambda: tbl.merge(source, join_cols=join_cols))
122+
print(f"\n merge(): {target_data.num_rows:,} target, {n_source:,} source -> {_fmt(avg, peak)}")
123+
124+
125+
@pytest.mark.benchmark
126+
def test_e2e_overwrite_100src(tmp_path: PosixPath) -> None:
127+
"""End-to-end overwrite() with 100 source rows."""
128+
target_data, join_cols, schema = _build_table(COLS)
129+
130+
catalog = InMemoryCatalog("bench", warehouse=str(tmp_path))
131+
catalog.create_namespace("default")
132+
tbl = _fresh_table(catalog, "overwrite_100", schema, target_data)
133+
134+
source_dict = {col: target_data.column(col).to_pylist()[:100] for col in target_data.column_names}
135+
source_dict["v"] = [x + 9000 for x in source_dict["v"]]
136+
source = pa.table(source_dict, schema=target_data.schema)
137+
138+
avg, peak = _measure(lambda: (tbl.overwrite(source, overwrite_filter=create_match_filter(source, join_cols))))
139+
print(f"\n overwrite(): {target_data.num_rows:,} target, 100 source -> {_fmt(avg, peak)}")
140+
141+
142+
@pytest.mark.benchmark
143+
def test_e2e_overwrite_5ksrc_filter_only(tmp_path: PosixPath) -> None:
144+
"""At 5,000 source rows, just constructing the filter takes seconds.
145+
146+
We only measure filter construction here because the full overwrite()
147+
with a 20,000-node expression tree causes process termination during
148+
manifest evaluation.
149+
"""
150+
target_data, join_cols, schema = _build_table(COLS)
151+
152+
source_dict = {col: target_data.column(col).to_pylist()[:5000] for col in target_data.column_names}
153+
source_dict["v"] = [x + 9000 for x in source_dict["v"]]
154+
source = pa.table(source_dict, schema=target_data.schema)
155+
156+
avg, peak = _measure(lambda: create_match_filter(source, join_cols), runs=1)
157+
print(f"\n create_match_filter only (no overwrite): 5,000 source -> {_fmt(avg, peak)}")

0 commit comments

Comments
 (0)