Skip to content

Commit f474c1b

Browse files
Ayush PatelAyush Patel
authored andcommitted
feat: Add merge() to Table and Transaction
Atomic delete-insert merge by join columns using per-column In filters for file pruning and in-memory anti-join for row-level correctness, committed as a single OVERWRITE snapshot. Unlike upsert(), does not enforce uniqueness on source or target.
1 parent a8a577f commit f474c1b

3 files changed

Lines changed: 725 additions & 0 deletions

File tree

pyiceberg/table/__init__.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,112 @@ def upsert(
885885

886886
return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)
887887

888+
def merge(
889+
self,
890+
df: pa.Table,
891+
join_cols: List[str],
892+
snapshot_properties: Dict[str, str] = EMPTY_DICT,
893+
branch: Optional[str] = MAIN_BRANCH,
894+
) -> None:
895+
"""Atomic delete-insert merge by join columns.
896+
897+
Deletes all target rows matching the source data's join column
898+
values and inserts the source rows, all in a single OVERWRITE
899+
snapshot.
900+
901+
Uses per-column ``In`` filters for file pruning (O(sum of
902+
cardinalities) instead of O(product)), then an in-memory
903+
anti-join for row-level correctness.
904+
905+
Unlike ``upsert()``, does not enforce uniqueness on source or
906+
target.
907+
908+
Args:
909+
df: The Arrow dataframe containing replacement rows.
910+
join_cols: Columns used to match source rows against target rows.
911+
snapshot_properties: Custom properties to be added to the snapshot summary.
912+
branch: Branch reference to run the operation.
913+
"""
914+
try:
915+
import pyarrow as pa
916+
import pyarrow.compute as pc
917+
except ModuleNotFoundError as e:
918+
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
919+
920+
import functools
921+
922+
from pyiceberg.expressions import In
923+
from pyiceberg.io.pyarrow import ArrowScan, _check_pyarrow_schema_compatible, _dataframe_to_data_files
924+
925+
if not isinstance(df, pa.Table):
926+
raise ValueError(f"Expected PyArrow table, got: {df}")
927+
928+
if not join_cols:
929+
raise ValueError("join_cols must be a non-empty list of column names.")
930+
931+
missing = set(join_cols) - set(df.column_names)
932+
if missing:
933+
raise ValueError(f"join_cols not found in source data: {missing}")
934+
935+
if df.num_rows == 0:
936+
return
937+
938+
downcast_ns = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
939+
_check_pyarrow_schema_compatible(
940+
self.table_metadata.schema(),
941+
provided_schema=df.schema,
942+
downcast_ns_timestamp_to_us=downcast_ns,
943+
format_version=self.table_metadata.format_version,
944+
)
945+
946+
# Step 1: Build per-column In filters for file pruning.
947+
# O(sum of cardinalities) instead of O(product).
948+
# Over-approximates the match set, which is fine - row-level
949+
# correctness is enforced by the anti-join in step 3.
950+
in_filters = [In(col, pc.unique(df[col]).to_pylist()) for col in join_cols]
951+
candidate_filter: BooleanExpression = functools.reduce(And, in_filters)
952+
953+
# Step 2: Find candidate files via manifest pruning.
954+
scan = self._scan(row_filter=candidate_filter, case_sensitive=True)
955+
if branch is not None and branch in self.table_metadata.refs:
956+
scan = scan.use_ref(branch)
957+
tasks = list(scan.plan_files())
958+
959+
if not tasks:
960+
# No files overlap - just append.
961+
self.append(df, snapshot_properties=snapshot_properties, branch=branch)
962+
return
963+
964+
# Step 3: Read ALL rows from candidate files, anti-join to keep
965+
# non-matching rows. The candidate_filter was only for file
966+
# pruning - row-level correctness comes from the anti-join.
967+
arrow_scan = ArrowScan(
968+
self.table_metadata,
969+
self._table.io,
970+
projected_schema=self.table_metadata.schema(),
971+
row_filter=ALWAYS_TRUE,
972+
case_sensitive=True,
973+
)
974+
target_data = arrow_scan.to_table(tasks)
975+
source_keys = df.select(join_cols)
976+
kept_rows = target_data.join(source_keys, keys=join_cols, join_type="left anti")
977+
new_content = pa.concat_tables([kept_rows, df], promote_options="default")
978+
979+
# Step 4: Atomic single-snapshot commit.
980+
# Delete old files, append rewritten content.
981+
with self.update_snapshot(
982+
snapshot_properties=snapshot_properties, branch=branch
983+
).overwrite() as overwrite_op:
984+
for task in tasks:
985+
overwrite_op.delete_data_file(task.file)
986+
for data_file in _dataframe_to_data_files(
987+
table_metadata=self.table_metadata,
988+
df=new_content,
989+
io=self._table.io,
990+
write_uuid=overwrite_op.commit_uuid,
991+
):
992+
overwrite_op.append_data_file(data_file)
993+
888994
def add_files(
889995
self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True
890996
) -> None:
@@ -1415,6 +1521,31 @@ def upsert(
14151521
branch=branch,
14161522
)
14171523

1524+
def merge(
1525+
self,
1526+
df: pa.Table,
1527+
join_cols: List[str],
1528+
snapshot_properties: Dict[str, str] = EMPTY_DICT,
1529+
branch: Optional[str] = MAIN_BRANCH,
1530+
) -> None:
1531+
"""Overwrite rows matching the key columns with the provided data.
1532+
1533+
Unlike ``upsert()``, does not enforce uniqueness on source or target.
1534+
1535+
Args:
1536+
df: The Arrow dataframe containing replacement rows.
1537+
join_cols: Columns used to match source rows against target rows.
1538+
snapshot_properties: Custom properties to be added to the snapshot summary.
1539+
branch: Branch reference to run the operation.
1540+
"""
1541+
with self.transaction() as tx:
1542+
tx.merge(
1543+
df=df,
1544+
join_cols=join_cols,
1545+
snapshot_properties=snapshot_properties,
1546+
branch=branch,
1547+
)
1548+
14181549
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None:
14191550
"""
14201551
Shorthand API for appending a PyArrow table to the table.
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""End-to-end benchmark: merge() vs create_match_filter + overwrite().
18+
19+
Compares the full write path (filter construction + file I/O + snapshot commit)
20+
between the new merge() implementation and the previous approach.
21+
22+
Usage:
23+
poetry run pytest tests/benchmark/test_merge_filter.py -v -s -m benchmark
24+
"""
25+
26+
import gc
27+
import itertools
28+
import timeit
29+
import tracemalloc
30+
from pathlib import PosixPath
31+
32+
import pyarrow as pa
33+
import pytest
34+
35+
from pyiceberg.catalog import Catalog
36+
from pyiceberg.exceptions import NoSuchTableError
37+
from pyiceberg.schema import Schema
38+
from pyiceberg.table.upsert_util import create_match_filter
39+
from pyiceberg.types import IntegerType, NestedField, StringType
40+
from tests.catalog.test_base import InMemoryCatalog
41+
42+
43+
def _make_schema(col_cardinalities: dict[str, int]) -> Schema:
44+
fields = []
45+
for i, col in enumerate(col_cardinalities):
46+
field_type = IntegerType() if col == "date_id" else StringType()
47+
fields.append(NestedField(i + 1, col, field_type, required=True))
48+
fields.append(NestedField(len(col_cardinalities) + 1, "v", IntegerType(), required=True))
49+
return Schema(*fields)
50+
51+
52+
def _build_table(col_cardinalities: dict[str, int]) -> tuple[pa.Table, list[str], Schema]:
53+
from pyiceberg.io.pyarrow import schema_to_pyarrow
54+
55+
schema = _make_schema(col_cardinalities)
56+
arrow_schema = schema_to_pyarrow(schema)
57+
58+
vals = []
59+
for col, card in col_cardinalities.items():
60+
if col == "date_id":
61+
vals.append(list(range(20260101, 20260101 + card)))
62+
else:
63+
vals.append([f"{col}_{i}" for i in range(card)])
64+
combos = list(itertools.product(*vals))
65+
data = {col: [c[i] for c in combos] for i, col in enumerate(col_cardinalities)}
66+
data["v"] = list(range(len(combos)))
67+
return pa.table(data, schema=arrow_schema), list(col_cardinalities.keys()), schema
68+
69+
70+
def _fresh_table(catalog: Catalog, name: str, schema: Schema, data: pa.Table) -> None:
71+
ident = f"default.{name}"
72+
try:
73+
catalog.drop_table(ident)
74+
except NoSuchTableError:
75+
pass
76+
tbl = catalog.create_table(ident, schema=schema)
77+
tbl.append(data)
78+
return tbl
79+
80+
81+
def _measure(fn, runs: int = 3) -> tuple[float, int]:
82+
"""Returns (avg_seconds, peak_memory_bytes)."""
83+
times = []
84+
peak = 0
85+
for _ in range(runs):
86+
gc.collect()
87+
tracemalloc.start()
88+
t0 = timeit.default_timer()
89+
fn()
90+
times.append(timeit.default_timer() - t0)
91+
_, p = tracemalloc.get_traced_memory()
92+
tracemalloc.stop()
93+
peak = max(peak, p)
94+
return sum(times) / len(times), peak
95+
96+
97+
def _fmt(secs: float, mem_bytes: int) -> str:
98+
mem = f"{mem_bytes / 1024:.0f} KB" if mem_bytes < 1048576 else f"{mem_bytes / 1048576:.1f} MB"
99+
return f"{secs * 1000:.0f} ms, peak {mem}"
100+
101+
102+
COLS = {"date_id": 252, "account": 100} # 25,200 target rows
103+
104+
105+
@pytest.mark.benchmark
106+
@pytest.mark.parametrize("n_source", [100, 5000])
107+
def test_e2e_merge(n_source: int, tmp_path: PosixPath) -> None:
108+
"""End-to-end merge(): per-column In + anti-join + single OVERWRITE snapshot."""
109+
target_data, join_cols, schema = _build_table(COLS)
110+
111+
catalog = InMemoryCatalog("bench", warehouse=str(tmp_path))
112+
catalog.create_namespace("default")
113+
tbl = _fresh_table(catalog, f"merge_{n_source}", schema, target_data)
114+
115+
source_dict = {col: target_data.column(col).to_pylist()[:n_source] for col in target_data.column_names}
116+
source_dict["v"] = [x + 9000 for x in source_dict["v"]]
117+
source = pa.table(source_dict, schema=target_data.schema)
118+
119+
avg, peak = _measure(lambda: tbl.merge(source, join_cols=join_cols))
120+
print(f"\n merge(): {target_data.num_rows:,} target, {n_source:,} source -> {_fmt(avg, peak)}")
121+
122+
123+
@pytest.mark.benchmark
124+
def test_e2e_overwrite_100src(tmp_path: PosixPath) -> None:
125+
"""End-to-end overwrite() with 100 source rows."""
126+
target_data, join_cols, schema = _build_table(COLS)
127+
128+
catalog = InMemoryCatalog("bench", warehouse=str(tmp_path))
129+
catalog.create_namespace("default")
130+
tbl = _fresh_table(catalog, "overwrite_100", schema, target_data)
131+
132+
source_dict = {col: target_data.column(col).to_pylist()[:100] for col in target_data.column_names}
133+
source_dict["v"] = [x + 9000 for x in source_dict["v"]]
134+
source = pa.table(source_dict, schema=target_data.schema)
135+
136+
avg, peak = _measure(lambda: (
137+
tbl.overwrite(source, overwrite_filter=create_match_filter(source, join_cols))
138+
))
139+
print(f"\n overwrite(): {target_data.num_rows:,} target, 100 source -> {_fmt(avg, peak)}")
140+
141+
142+
@pytest.mark.benchmark
143+
def test_e2e_overwrite_5ksrc_filter_only(tmp_path: PosixPath) -> None:
144+
"""At 5,000 source rows, just constructing the filter takes seconds.
145+
146+
We only measure filter construction here because the full overwrite()
147+
with a 20,000-node expression tree causes process termination during
148+
manifest evaluation.
149+
"""
150+
target_data, join_cols, schema = _build_table(COLS)
151+
152+
source_dict = {col: target_data.column(col).to_pylist()[:5000] for col in target_data.column_names}
153+
source_dict["v"] = [x + 9000 for x in source_dict["v"]]
154+
source = pa.table(source_dict, schema=target_data.schema)
155+
156+
avg, peak = _measure(lambda: create_match_filter(source, join_cols), runs=1)
157+
print(f"\n create_match_filter only (no overwrite): 5,000 source -> {_fmt(avg, peak)}")

0 commit comments

Comments
 (0)