Skip to content

Commit b2ce070

Browse files
Paul Mathewcursoragent
andcommitted
perf(upsert): prune destination scan via df partition-column ranges and project join_cols only
Two complementary optimizations to ``Transaction.upsert`` for tables whose partition spec sources from columns NOT in ``join_cols`` (a common pattern for append-only event logs partitioned by time but keyed by composite IDs): 1. Partition-range augmentation: ``upsert_util.augment_filter_with_partition_ranges`` derives ``[min, max]`` predicates from ``df`` for every partition source column present in the frame and ANDs them into the row filter built by ``create_match_filter``. ``inclusive_projection`` then projects each range through the partition transform at scan plan time, enabling manifest- and file-level pruning that the key-only filter can't trigger. 2. Column-projection for the insert-only path: when ``when_matched_update_all=False`` the consumer loop only reads ``join_cols`` off each destination batch. Passing ``selected_fields=tuple(join_cols)`` to ``DataScan`` lets the parquet reader prune wide non-key columns. The existing ``_projected_field_ids`` auto-union with row-filter columns keeps the partition-range predicate's data accessible. Correctness guards skip the augmentation per-column when the source column is absent from df, entirely null, or partially null (a non-null range predicate would exclude NULL-partition destination rows whose keys may collide with the null-partition source rows). Related to #2138, #2159, #3129. Complementary to (closed-stale) #2943's "coarse match filter" approach: that PR shrinks the row predicate itself; this one adds partition pruning the row predicate can't trigger on its own. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent d339391 commit b2ce070

3 files changed

Lines changed: 795 additions & 10 deletions

File tree

pyiceberg/table/__init__.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -877,12 +877,35 @@ def upsert(
877877
# get list of rows that exist so we don't have to load the entire target table
878878
matched_predicate = upsert_util.create_match_filter(df, join_cols)
879879

880+
# Augment the row filter with [min, max] predicates on any
881+
# partition source column present in ``df``. ``inclusive_projection``
882+
# projects the range through monotonic partition transforms when
883+
# planning the scan, so ``DataScan.plan_files`` can prune the
884+
# destination at the manifest + file level.
885+
matched_predicate = upsert_util.augment_filter_with_partition_ranges(
886+
matched_predicate,
887+
df,
888+
self.table_metadata.schema(),
889+
self.table_metadata.spec(),
890+
)
891+
892+
# When ``when_matched_update_all=False`` the consumer loop below
893+
# only ever reads ``join_cols`` off each destination batch (to
894+
# build the per-batch match filter via
895+
# ``upsert_util.create_match_filter``). Project ``join_cols``
896+
# only so the parquet reader can prune wide non-key columns.
897+
#
898+
# ``when_matched_update_all=True`` falls back to the legacy
899+
# ``("*",)`` projection.
900+
selected_fields: tuple[str, ...] = ("*",) if when_matched_update_all else tuple(join_cols)
901+
880902
# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.
881903

882904
matched_iceberg_record_batches_scan = DataScan(
883905
table_metadata=self.table_metadata,
884906
io=self._table.io,
885907
row_filter=matched_predicate,
908+
selected_fields=selected_fields,
886909
case_sensitive=case_sensitive,
887910
)
888911

@@ -2072,13 +2095,11 @@ def _build_residual_evaluator(self, spec_id: int) -> Callable[[DataFile], Residu
20722095
# The lambda created here is run in multiple threads.
20732096
# So we avoid creating _EvaluatorExpression methods bound to a single
20742097
# shared instance across multiple threads.
2075-
return lambda datafile: (
2076-
residual_evaluator_of(
2077-
spec=spec,
2078-
expr=self.row_filter,
2079-
case_sensitive=self.case_sensitive,
2080-
schema=self.table_metadata.schema(),
2081-
)
2098+
return lambda datafile: residual_evaluator_of(
2099+
spec=spec,
2100+
expr=self.row_filter,
2101+
case_sensitive=self.case_sensitive,
2102+
schema=self.table_metadata.schema(),
20822103
)
20832104

20842105
@staticmethod

pyiceberg/table/upsert_util.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,16 @@
2323

2424
from pyiceberg.expressions import (
2525
AlwaysFalse,
26+
And,
2627
BooleanExpression,
2728
EqualTo,
29+
GreaterThanOrEqual,
2830
In,
31+
LessThanOrEqual,
2932
Or,
3033
)
34+
from pyiceberg.partitioning import PartitionSpec
35+
from pyiceberg.schema import Schema
3136

3237

3338
def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
@@ -53,6 +58,104 @@ def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
5358
return len(df.select(join_cols).group_by(join_cols).aggregate([([], "count_all")]).filter(pc.field("count_all") > 1)) > 0
5459

5560

61+
def augment_filter_with_partition_ranges(
62+
matched_predicate: BooleanExpression,
63+
df: pyarrow_table,
64+
schema: Schema,
65+
spec: PartitionSpec,
66+
) -> BooleanExpression:
67+
"""Return *matched_predicate* AND'd with ``[min, max]`` predicates on partition source columns.
68+
69+
Iceberg's ``inclusive_projection`` projects each range through the
70+
partition transform (``hours``, ``days``, ``months``, ``years``,
71+
``identity``, ``truncate``) when planning the scan, so
72+
``DataScan.plan_files`` can prune manifests and data files that
73+
don't overlap the source's value range. Without this augmentation,
74+
tables whose partition spec sources from columns NOT in
75+
``join_cols`` (a common pattern for append-only event logs
76+
partitioned by time but keyed by composite IDs) fall through to a
77+
full table scan on every upsert because the row filter built from
78+
``join_cols`` alone projects to ``AlwaysTrue`` against the
79+
partition spec.
80+
81+
Bucket and other non-monotonic transforms return ``None`` from
82+
their ``project`` method for inequalities, so the augmentation is
83+
safe — it either prunes or contributes ``AlwaysTrue`` (no harm).
84+
85+
A partition source column is skipped from augmentation when:
86+
87+
- It isn't present on ``df`` (no source value to bound).
88+
- It is entirely null in ``df`` (no meaningful min/max).
89+
- It contains any null in ``df`` (preserving correctness: a
90+
``GreaterThanOrEqual(col, non_null_min)`` predicate would
91+
exclude destination rows whose partition value is ``NULL``,
92+
potentially missing a key match. Without partition pruning
93+
those NULL-partition rows are scanned normally.)
94+
95+
When ``min == max`` for a column, an ``EqualTo`` predicate is
96+
emitted instead of the range pair — tighter, and lets exact
97+
partition pruning fire.
98+
99+
Args:
100+
matched_predicate: The row filter built from ``join_cols``.
101+
df: Source data frame whose values bound the augmentation.
102+
schema: Iceberg schema, used to resolve partition source ids
103+
to column names.
104+
spec: Active partition spec.
105+
106+
Returns:
107+
The augmented predicate, or *matched_predicate* unchanged
108+
when no partition source column qualifies.
109+
"""
110+
if spec.is_unpartitioned():
111+
return matched_predicate
112+
113+
df_columns = set(df.column_names)
114+
augmentations: list[BooleanExpression] = []
115+
116+
# Iterate distinct source columns rather than partition fields —
117+
# multiple partition fields can share a source column (e.g.
118+
# ``bucket(8, id), truncate(4, id)``) but we only need to add the
119+
# source-column range once; ``inclusive_projection`` projects
120+
# through each partition field independently.
121+
seen_source_ids: set[int] = set()
122+
for field in spec.fields:
123+
if field.source_id in seen_source_ids:
124+
continue
125+
seen_source_ids.add(field.source_id)
126+
127+
col_name = schema.find_field(field.source_id).name
128+
if col_name not in df_columns:
129+
continue
130+
131+
col = df[col_name]
132+
if col.null_count > 0:
133+
# Mixing null with a bounded predicate would exclude
134+
# destination rows whose partition value is null,
135+
# potentially missing key matches. Skip pruning rather
136+
# than risk a correctness regression.
137+
continue
138+
139+
col_min = pc.min(col).as_py()
140+
col_max = pc.max(col).as_py()
141+
if col_min is None or col_max is None:
142+
# Defensive — ``null_count == 0`` should imply both bounds
143+
# are non-null, but pyarrow's min/max can still return None
144+
# on empty columns.
145+
continue
146+
147+
if col_min == col_max:
148+
augmentations.append(EqualTo(col_name, col_min))
149+
else:
150+
augmentations.append(GreaterThanOrEqual(col_name, col_min))
151+
augmentations.append(LessThanOrEqual(col_name, col_max))
152+
153+
if not augmentations:
154+
return matched_predicate
155+
156+
return functools.reduce(And, [matched_predicate, *augmentations])
157+
158+
56159
def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table:
57160
"""
58161
Return a table with rows that need to be updated in the target table based on the join columns.

0 commit comments

Comments
 (0)