Skip to content

Commit e0ac877

Browse files
authored
[python] Filter manifest files by partition predicate in scan (#6419)
1 parent b0d139a commit e0ac877

10 files changed

Lines changed: 262 additions & 24 deletions

File tree

paimon-python/pypaimon/common/predicate.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pyarrow import compute as pyarrow_compute
2525
from pyarrow import dataset as pyarrow_dataset
2626

27+
from pypaimon.manifest.schema.simple_stats import SimpleStats
2728
from pypaimon.table.row.internal_row import InternalRow
2829

2930

@@ -34,6 +35,20 @@ class Predicate:
3435
field: Optional[str]
3536
literals: Optional[List[Any]] = None
3637

38+
def new_index(self, index: int):
39+
return Predicate(
40+
method=self.method,
41+
index=index,
42+
field=self.field,
43+
literals=self.literals)
44+
45+
def new_literals(self, literals: List[Any]):
46+
return Predicate(
47+
method=self.method,
48+
index=self.index,
49+
field=self.field,
50+
literals=literals)
51+
3752
def test(self, record: InternalRow) -> bool:
3853
if self.method == 'equal':
3954
return record.get_field(self.index) == self.literals[0]
@@ -125,6 +140,16 @@ def test_by_value(self, value: Any) -> bool:
125140

126141
raise ValueError("Unsupported predicate method: {}".format(self.method))
127142

143+
def test_by_simple_stats(self, stat: SimpleStats, row_count: int) -> bool:
144+
return self.test_by_stats({
145+
"min_values": stat.min_values.to_dict(),
146+
"max_values": stat.max_values.to_dict(),
147+
"null_counts": {
148+
stat.min_values.fields[i].name: stat.null_counts[i] for i in range(len(stat.min_values.fields))
149+
},
150+
"row_count": row_count,
151+
})
152+
128153
def test_by_stats(self, stat: Dict) -> bool:
129154
if self.method == 'and':
130155
return all(p.test_by_stats(stat) for p in self.literals)

paimon-python/pypaimon/read/push_down_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,50 @@
2121
from pypaimon.common.predicate import Predicate
2222

2323

24+
def to_partition_predicate(input_predicate: 'Predicate', all_fields: List[str], partition_keys: List[str]):
25+
if not input_predicate or not partition_keys:
26+
return None
27+
28+
predicates: list['Predicate'] = _split_and(input_predicate)
29+
predicates = [element for element in predicates if _get_all_fields(element).issubset(partition_keys)]
30+
new_predicate = Predicate(
31+
method='and',
32+
index=None,
33+
field=None,
34+
literals=predicates
35+
)
36+
37+
part_to_index = {element: idx for idx, element in enumerate(partition_keys)}
38+
mapping: Dict[int, int] = {
39+
i: part_to_index.get(all_fields[i], -1)
40+
for i in range(len(all_fields))
41+
}
42+
43+
return _change_index(new_predicate, mapping)
44+
45+
46+
def _split_and(input_predicate: 'Predicate'):
47+
if not input_predicate:
48+
return list()
49+
50+
if input_predicate.method == 'and':
51+
return list(input_predicate.literals)
52+
53+
return [input_predicate]
54+
55+
56+
def _change_index(input_predicate: 'Predicate', mapping: Dict[int, int]):
57+
if not input_predicate:
58+
return None
59+
60+
if input_predicate.method == 'and' or input_predicate.method == 'or':
61+
predicates: list['Predicate'] = input_predicate.literals
62+
new_predicates = [_change_index(element, mapping) for element in predicates]
63+
return input_predicate.new_literals(new_predicates)
64+
65+
return input_predicate.new_index(mapping[input_predicate.index])
66+
67+
2468
def extract_predicate_to_list(result: list, input_predicate: 'Predicate', keys: List[str]):
2569
if not input_predicate or not keys:
2670
return

paimon-python/pypaimon/read/read_builder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def new_scan(self) -> TableScan:
5252
return TableScan(
5353
table=self.table,
5454
predicate=self._predicate,
55-
limit=self._limit,
56-
read_type=self.read_type()
55+
limit=self._limit
5756
)
5857

5958
def new_read(self) -> TableRead:

paimon-python/pypaimon/read/scanner/full_starting_scanner.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,25 @@
2525
from pypaimon.manifest.manifest_list_manager import ManifestListManager
2626
from pypaimon.manifest.schema.data_file_meta import DataFileMeta
2727
from pypaimon.manifest.schema.manifest_entry import ManifestEntry
28+
from pypaimon.manifest.schema.manifest_file_meta import ManifestFileMeta
2829
from pypaimon.read.interval_partition import IntervalPartition, SortedRun
2930
from pypaimon.read.plan import Plan
3031
from pypaimon.read.push_down_utils import (extract_predicate_to_dict,
31-
extract_predicate_to_list)
32+
extract_predicate_to_list,
33+
to_partition_predicate)
3234
from pypaimon.read.scanner.starting_scanner import StartingScanner
3335
from pypaimon.read.split import Split
34-
from pypaimon.schema.data_types import DataField
3536
from pypaimon.snapshot.snapshot_manager import SnapshotManager
3637
from pypaimon.table.bucket_mode import BucketMode
3738

3839

3940
class FullStartingScanner(StartingScanner):
40-
def __init__(self, table, predicate: Optional[Predicate], limit: Optional[int], read_type: List[DataField]):
41+
def __init__(self, table, predicate: Optional[Predicate], limit: Optional[int]):
4142
from pypaimon.table.file_store_table import FileStoreTable
4243

4344
self.table: FileStoreTable = table
4445
self.predicate = predicate
4546
self.limit = limit
46-
self.read_type = read_type
4747

4848
self.snapshot_manager = SnapshotManager(table)
4949
self.manifest_list_manager = ManifestListManager(table)
@@ -82,15 +82,26 @@ def scan(self) -> Plan:
8282
splits = self._apply_push_down_limit(splits)
8383
return Plan(splits)
8484

85-
def plan_files(self) -> List[ManifestEntry]:
85+
def _read_manifest_files(self) -> List[ManifestFileMeta]:
8686
latest_snapshot = self.snapshot_manager.get_latest_snapshot()
8787
if not latest_snapshot:
8888
return []
8989
manifest_files = self.manifest_list_manager.read_all(latest_snapshot)
90+
partition_predicate = to_partition_predicate(self.predicate, self.table.field_names, self.table.partition_keys)
91+
92+
def test_predicate(file: ManifestFileMeta) -> bool:
93+
if not partition_predicate:
94+
return True
95+
return partition_predicate.test_by_simple_stats(
96+
file.partition_stats,
97+
file.num_added_files + file.num_deleted_files)
9098

99+
return [file for file in manifest_files if test_predicate(file)]
100+
101+
def plan_files(self) -> List[ManifestEntry]:
102+
manifest_files = self._read_manifest_files()
91103
deleted_entries = set()
92104
added_entries = []
93-
# TODO: filter manifest files by predicate
94105
for manifest_file in manifest_files:
95106
manifest_entries = self.manifest_file_manager.read(manifest_file.file_name,
96107
lambda row: self._bucket_filter(row))

paimon-python/pypaimon/read/scanner/incremental_starting_scanner.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020
from pypaimon.common.predicate import Predicate
2121
from pypaimon.manifest.schema.manifest_entry import ManifestEntry
2222
from pypaimon.read.scanner.full_starting_scanner import FullStartingScanner
23-
from pypaimon.schema.data_types import DataField
2423
from pypaimon.snapshot.snapshot_manager import SnapshotManager
2524

2625

2726
class IncrementalStartingScanner(FullStartingScanner):
2827
def __init__(self, table, predicate: Optional[Predicate], limit: Optional[int],
29-
read_type: List[DataField], start: int, end: int):
30-
super().__init__(table, predicate, limit, read_type)
28+
start: int, end: int):
29+
super().__init__(table, predicate, limit)
3130
self.startingSnapshotId = start
3231
self.endingSnapshotId = end
3332

@@ -55,8 +54,7 @@ def plan_files(self) -> List[ManifestEntry]:
5554

5655
@staticmethod
5756
def between_timestamps(table, predicate: Optional[Predicate], limit: Optional[int],
58-
read_type: List[DataField], start_timestamp: int,
59-
end_timestamp: int) -> 'IncrementalStartingScanner':
57+
start_timestamp: int, end_timestamp: int) -> 'IncrementalStartingScanner':
6058
"""
6159
Create an IncrementalStartingScanner for snapshots between two timestamps.
6260
"""
@@ -74,4 +72,4 @@ def between_timestamps(table, predicate: Optional[Predicate], limit: Optional[in
7472
latest_snapshot = snapshot_manager.get_latest_snapshot()
7573
end_id = end_snapshot.id if end_snapshot else (latest_snapshot.id if latest_snapshot else -1)
7674

77-
return IncrementalStartingScanner(table, predicate, limit, read_type, start_id, end_id)
75+
return IncrementalStartingScanner(table, predicate, limit, start_id, end_id)

paimon-python/pypaimon/read/table_scan.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# limitations under the License.
1717
################################################################################
1818

19-
from typing import List, Optional
19+
from typing import Optional
2020

2121
from pypaimon.common.core_options import CoreOptions
2222
from pypaimon.common.predicate import Predicate
@@ -27,21 +27,18 @@
2727
from pypaimon.read.scanner.incremental_starting_scanner import \
2828
IncrementalStartingScanner
2929
from pypaimon.read.scanner.starting_scanner import StartingScanner
30-
from pypaimon.schema.data_types import DataField
3130
from pypaimon.snapshot.snapshot_manager import SnapshotManager
3231

3332

3433
class TableScan:
3534
"""Implementation of TableScan for native Python reading."""
3635

37-
def __init__(self, table, predicate: Optional[Predicate], limit: Optional[int],
38-
read_type: List[DataField]):
36+
def __init__(self, table, predicate: Optional[Predicate], limit: Optional[int]):
3937
from pypaimon.table.file_store_table import FileStoreTable
4038

4139
self.table: FileStoreTable = table
4240
self.predicate = predicate
4341
self.limit = limit
44-
self.read_type = read_type
4542
self.starting_scanner = self._create_starting_scanner()
4643

4744
def plan(self) -> Plan:
@@ -67,10 +64,9 @@ def _create_starting_scanner(self) -> Optional[StartingScanner]:
6764
if (start_timestamp == end_timestamp or start_timestamp > latest_snapshot.time_millis
6865
or end_timestamp < earliest_snapshot.time_millis):
6966
return EmptyStartingScanner()
70-
return IncrementalStartingScanner.between_timestamps(self.table, self.predicate, self.limit, self.read_type,
71-
start_timestamp,
72-
end_timestamp)
73-
return FullStartingScanner(self.table, self.predicate, self.limit, self.read_type)
67+
return IncrementalStartingScanner.between_timestamps(self.table, self.predicate, self.limit,
68+
start_timestamp, end_timestamp)
69+
return FullStartingScanner(self.table, self.predicate, self.limit)
7470

7571
def with_shard(self, idx_of_this_subtask, number_of_para_subtasks) -> 'TableScan':
7672
self.starting_scanner.with_shard(idx_of_this_subtask, number_of_para_subtasks)

paimon-python/pypaimon/table/file_store_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(self, file_io: FileIO, identifier: Identifier, table_path: Path,
4646

4747
self.table_schema = table_schema
4848
self.fields = table_schema.fields
49+
self.field_names = [field.name for field in table_schema.fields]
4950
self.field_dict = {field.name: field for field in self.fields}
5051
self.primary_keys = table_schema.primary_keys
5152
self.partition_keys = table_schema.partition_keys
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
################################################################################
18+
19+
import os
20+
import shutil
21+
import tempfile
22+
import unittest
23+
24+
import pyarrow as pa
25+
26+
from pypaimon import CatalogFactory
27+
from pypaimon import Schema
28+
from pypaimon.read.split import Split
29+
30+
31+
class ReaderPredicateTest(unittest.TestCase):
32+
@classmethod
33+
def setUpClass(cls):
34+
cls.tempdir = tempfile.mkdtemp()
35+
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
36+
cls.catalog = CatalogFactory.create({
37+
'warehouse': cls.warehouse
38+
})
39+
cls.catalog.create_database('default', False)
40+
41+
cls.pa_schema = pa.schema([
42+
('a', pa.int64()),
43+
('pt', pa.int64())
44+
])
45+
schema = Schema.from_pyarrow_schema(cls.pa_schema, partition_keys=['pt'])
46+
cls.catalog.create_table('default.test_reader_predicate', schema, False)
47+
cls.table = cls.catalog.get_table('default.test_reader_predicate')
48+
49+
data1 = pa.Table.from_pydict({
50+
'a': [1, 2],
51+
'pt': [1001, 1002]}, schema=cls.pa_schema)
52+
write_builder = cls.table.new_batch_write_builder()
53+
table_write = write_builder.new_write()
54+
table_commit = write_builder.new_commit()
55+
table_write.write_arrow(data1)
56+
table_commit.commit(table_write.prepare_commit())
57+
table_write.close()
58+
table_commit.close()
59+
60+
data2 = pa.Table.from_pydict({
61+
'a': [3, 4],
62+
'pt': [1003, 1004]}, schema=cls.pa_schema)
63+
write_builder = cls.table.new_batch_write_builder()
64+
table_write = write_builder.new_write()
65+
table_commit = write_builder.new_commit()
66+
table_write.write_arrow(data2)
67+
table_commit.commit(table_write.prepare_commit())
68+
table_write.close()
69+
table_commit.close()
70+
71+
@classmethod
72+
def tearDownClass(cls):
73+
shutil.rmtree(cls.tempdir, ignore_errors=True)
74+
75+
def test_partition_predicate(self):
76+
predicate_builder = self.table.new_read_builder().new_predicate_builder()
77+
predicate = predicate_builder.equal('pt', 1003)
78+
read_builder = self.table.new_read_builder()
79+
read_builder.with_filter(predicate)
80+
splits: list[Split] = read_builder.new_scan().plan().splits()
81+
self.assertEqual(len(splits), 1)
82+
self.assertEqual(splits[0].partition.to_dict().get("pt"), 1003)

0 commit comments

Comments
 (0)