Skip to content

Commit 28b54d6

Browse files
author
Yingjian Wu
committed
perf: build partition filter with balanced tree to avoid RecursionError
1 parent 939a6e5 commit 28b54d6

2 files changed

Lines changed: 29 additions & 15 deletions

File tree

pyiceberg/table/__init__.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -366,21 +366,19 @@ def _build_partition_predicate(
366366
Returns:
367367
A predicate matching any of the input partition records.
368368
"""
369-
partition_fields = [schema.find_field(field.source_id).name for field in spec.fields]
370-
371-
expr: BooleanExpression = AlwaysFalse()
372-
for partition_record in partition_records:
373-
match_partition_expression: BooleanExpression = AlwaysTrue()
374-
375-
for pos, partition_field in enumerate(partition_fields):
376-
predicate = (
377-
EqualTo(Reference(partition_field), partition_record[pos])
378-
if partition_record[pos] is not None
379-
else IsNull(Reference(partition_field))
380-
)
381-
match_partition_expression = And(match_partition_expression, predicate)
382-
expr = Or(expr, match_partition_expression)
383-
return expr
369+
partition_fields = [schema.find_field(f.source_id).name for f in spec.fields]
370+
if not partition_records or not partition_fields:
371+
return AlwaysFalse()
372+
373+
def _match(record: Record) -> BooleanExpression:
374+
parts: list[BooleanExpression] = [
375+
EqualTo(Reference(name), record[pos]) if record[pos] is not None else IsNull(Reference(name))
376+
for pos, name in enumerate(partition_fields)
377+
]
378+
return And(*parts) if len(parts) > 1 else parts[0]
379+
380+
per_record = [_match(r) for r in partition_records]
381+
return Or(*per_record) if len(per_record) > 1 else per_record[0]
384382

385383
def _append_snapshot_producer(
386384
self, snapshot_properties: dict[str, str], branch: str | None = MAIN_BRANCH

tests/table/test_init.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,3 +1753,19 @@ def test_check_uuid_passes_when_match(table_v2: Table, example_table_metadata_v2
17531753
new_metadata = TableMetadataV2(**example_table_metadata_v2)
17541754
# Should not raise with same uuid
17551755
Table._check_uuid(table_v2.metadata, new_metadata)
1756+
1757+
1758+
def test_build_large_partition_predicate(table_v2: Table) -> None:
1759+
"""A left-folded Or chain over 5000 records would be depth-5000 and crash bind()
1760+
(Python's default recursion limit is ~1000). The balanced tree has depth ~14."""
1761+
from pyiceberg.expressions.visitors import bind
1762+
from pyiceberg.typedef import Record
1763+
1764+
with table_v2.transaction() as tx:
1765+
expr = tx._build_partition_predicate(
1766+
partition_records={Record(i) for i in range(5000)},
1767+
spec=table_v2.metadata.spec(),
1768+
schema=table_v2.metadata.schema(),
1769+
)
1770+
1771+
bind(table_v2.metadata.schema(), expr, case_sensitive=True)

0 commit comments

Comments
 (0)