Skip to content

Commit fa3c6c0

Browse files
authored
[Python] clean code for pypaimon (#6433)
1 parent 3d3b097 commit fa3c6c0

5 files changed

Lines changed: 95 additions & 179 deletions

File tree

paimon-python/pypaimon/common/predicate.py

Lines changed: 76 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -50,95 +50,33 @@ def new_literals(self, literals: List[Any]):
5050
literals=literals)
5151

5252
def test(self, record: InternalRow) -> bool:
53-
if self.method == 'equal':
54-
return record.get_field(self.index) == self.literals[0]
55-
elif self.method == 'notEqual':
56-
return record.get_field(self.index) != self.literals[0]
57-
elif self.method == 'lessThan':
58-
return record.get_field(self.index) < self.literals[0]
59-
elif self.method == 'lessOrEqual':
60-
return record.get_field(self.index) <= self.literals[0]
61-
elif self.method == 'greaterThan':
62-
return record.get_field(self.index) > self.literals[0]
63-
elif self.method == 'greaterOrEqual':
64-
return record.get_field(self.index) >= self.literals[0]
65-
elif self.method == 'isNull':
66-
return record.get_field(self.index) is None
67-
elif self.method == 'isNotNull':
68-
return record.get_field(self.index) is not None
69-
elif self.method == 'startsWith':
70-
field_value = record.get_field(self.index)
71-
if not isinstance(field_value, str):
72-
return False
73-
return field_value.startswith(self.literals[0])
74-
elif self.method == 'endsWith':
75-
field_value = record.get_field(self.index)
76-
if not isinstance(field_value, str):
77-
return False
78-
return field_value.endswith(self.literals[0])
79-
elif self.method == 'contains':
80-
field_value = record.get_field(self.index)
81-
if not isinstance(field_value, str):
82-
return False
83-
return self.literals[0] in field_value
84-
elif self.method == 'in':
85-
return record.get_field(self.index) in self.literals
86-
elif self.method == 'notIn':
87-
return record.get_field(self.index) not in self.literals
88-
elif self.method == 'between':
89-
field_value = record.get_field(self.index)
90-
return self.literals[0] <= field_value <= self.literals[1]
91-
elif self.method == 'and':
92-
return all(p.test(record) for p in self.literals)
93-
elif self.method == 'or':
94-
t = any(p.test(record) for p in self.literals)
95-
return t
96-
else:
97-
raise ValueError("Unsupported predicate method: {}".format(self.method))
98-
99-
def test_by_value(self, value: Any) -> bool:
10053
if self.method == 'and':
101-
return all(p.test_by_value(value) for p in self.literals)
54+
return all(p.test(record) for p in self.literals)
10255
if self.method == 'or':
103-
t = any(p.test_by_value(value) for p in self.literals)
56+
t = any(p.test(record) for p in self.literals)
10457
return t
10558

106-
if self.method == 'equal':
107-
return value == self.literals[0]
108-
if self.method == 'notEqual':
109-
return value != self.literals[0]
110-
if self.method == 'lessThan':
111-
return value < self.literals[0]
112-
if self.method == 'lessOrEqual':
113-
return value <= self.literals[0]
114-
if self.method == 'greaterThan':
115-
return value > self.literals[0]
116-
if self.method == 'greaterOrEqual':
117-
return value >= self.literals[0]
118-
if self.method == 'isNull':
119-
return value is None
120-
if self.method == 'isNotNull':
121-
return value is not None
122-
if self.method == 'startsWith':
123-
if not isinstance(value, str):
124-
return False
125-
return value.startswith(self.literals[0])
126-
if self.method == 'endsWith':
127-
if not isinstance(value, str):
128-
return False
129-
return value.endswith(self.literals[0])
130-
if self.method == 'contains':
131-
if not isinstance(value, str):
132-
return False
133-
return self.literals[0] in value
134-
if self.method == 'in':
135-
return value in self.literals
136-
if self.method == 'notIn':
137-
return value not in self.literals
138-
if self.method == 'between':
139-
return self.literals[0] <= value <= self.literals[1]
140-
141-
raise ValueError("Unsupported predicate method: {}".format(self.method))
59+
dispatch = {
60+
'equal': lambda val, literals: val == literals[0],
61+
'notEqual': lambda val, literals: val != literals[0],
62+
'lessThan': lambda val, literals: val < literals[0],
63+
'lessOrEqual': lambda val, literals: val <= literals[0],
64+
'greaterThan': lambda val, literals: val > literals[0],
65+
'greaterOrEqual': lambda val, literals: val >= literals[0],
66+
'isNull': lambda val, literals: val is None,
67+
'isNotNull': lambda val, literals: val is not None,
68+
'startsWith': lambda val, literals: isinstance(val, str) and val.startswith(literals[0]),
69+
'endsWith': lambda val, literals: isinstance(val, str) and val.endswith(literals[0]),
70+
'contains': lambda val, literals: isinstance(val, str) and literals[0] in val,
71+
'in': lambda val, literals: val in literals,
72+
'notIn': lambda val, literals: val not in literals,
73+
'between': lambda val, literals: literals[0] <= val <= literals[1],
74+
}
75+
func = dispatch.get(self.method)
76+
if func:
77+
field_value = record.get_field(self.index)
78+
return func(field_value, self.literals)
79+
raise ValueError(f"Unsupported predicate method: {self.method}")
14280

14381
def test_by_simple_stats(self, stat: SimpleStats, row_count: int) -> bool:
14482
return self.test_by_stats({
@@ -169,66 +107,39 @@ def test_by_stats(self, stat: Dict) -> bool:
169107
max_value = stat["max_values"][self.field]
170108

171109
if min_value is None or max_value is None or (null_count is not None and null_count == row_count):
172-
return False
173-
174-
if self.method == 'equal':
175-
return min_value <= self.literals[0] <= max_value
176-
if self.method == 'notEqual':
177-
return not (min_value == self.literals[0] == max_value)
178-
if self.method == 'lessThan':
179-
return self.literals[0] > min_value
180-
if self.method == 'lessOrEqual':
181-
return self.literals[0] >= min_value
182-
if self.method == 'greaterThan':
183-
return self.literals[0] < max_value
184-
if self.method == 'greaterOrEqual':
185-
return self.literals[0] <= max_value
186-
if self.method == 'startsWith':
187-
if not isinstance(min_value, str) or not isinstance(max_value, str):
188-
raise RuntimeError("startsWith predicate on non-str field")
189-
return ((min_value.startswith(self.literals[0]) or min_value < self.literals[0])
190-
and (max_value.startswith(self.literals[0]) or max_value > self.literals[0]))
191-
if self.method == 'endsWith':
110+
# invalid stats, skip validation
192111
return True
193-
if self.method == 'contains':
194-
return True
195-
if self.method == 'in':
196-
for literal in self.literals:
197-
if min_value <= literal <= max_value:
198-
return True
199-
return False
200-
if self.method == 'notIn':
201-
for literal in self.literals:
202-
if min_value == literal == max_value:
203-
return False
204-
return True
205-
if self.method == 'between':
206-
return self.literals[0] <= max_value and self.literals[1] >= min_value
207-
else:
208-
raise ValueError("Unsupported predicate method: {}".format(self.method))
112+
113+
dispatch = {
114+
'equal': lambda literals: min_value <= literals[0] <= max_value,
115+
'notEqual': lambda literals: not (min_value == literals[0] == max_value),
116+
'lessThan': lambda literals: literals[0] > min_value,
117+
'lessOrEqual': lambda literals: literals[0] >= min_value,
118+
'greaterThan': lambda literals: literals[0] < max_value,
119+
'greaterOrEqual': lambda literals: literals[0] <= max_value,
120+
'in': lambda literals: any(min_value <= l <= max_value for l in literals),
121+
'notIn': lambda literals: not any(min_value == l == max_value for l in literals),
122+
'between': lambda literals: literals[0] <= max_value and literals[1] >= min_value,
123+
'startsWith': lambda literals: ((isinstance(min_value, str) and isinstance(max_value, str)) and
124+
((min_value.startswith(literals[0]) or min_value < literals[0]) and
125+
(max_value.startswith(literals[0]) or max_value > literals[0]))),
126+
'endsWith': lambda literals: True,
127+
'contains': lambda literals: True,
128+
}
129+
func = dispatch.get(self.method)
130+
if func:
131+
return func(self.literals)
132+
raise ValueError(f"Unsupported predicate method: {self.method}")
209133

210134
def to_arrow(self) -> Any:
211-
if self.method == 'equal':
212-
return pyarrow_dataset.field(self.field) == self.literals[0]
213-
elif self.method == 'notEqual':
214-
return pyarrow_dataset.field(self.field) != self.literals[0]
215-
elif self.method == 'lessThan':
216-
return pyarrow_dataset.field(self.field) < self.literals[0]
217-
elif self.method == 'lessOrEqual':
218-
return pyarrow_dataset.field(self.field) <= self.literals[0]
219-
elif self.method == 'greaterThan':
220-
return pyarrow_dataset.field(self.field) > self.literals[0]
221-
elif self.method == 'greaterOrEqual':
222-
return pyarrow_dataset.field(self.field) >= self.literals[0]
223-
elif self.method == 'isNull':
224-
return pyarrow_dataset.field(self.field).is_null()
225-
elif self.method == 'isNotNull':
226-
return pyarrow_dataset.field(self.field).is_valid()
227-
elif self.method == 'in':
228-
return pyarrow_dataset.field(self.field).isin(self.literals)
229-
elif self.method == 'notIn':
230-
return ~pyarrow_dataset.field(self.field).isin(self.literals)
231-
elif self.method == 'startsWith':
135+
if self.method == 'and':
136+
return reduce(lambda x, y: x & y,
137+
[p.to_arrow() for p in self.literals])
138+
if self.method == 'or':
139+
return reduce(lambda x, y: x | y,
140+
[p.to_arrow() for p in self.literals])
141+
142+
if self.method == 'startsWith':
232143
pattern = self.literals[0]
233144
# For PyArrow compatibility - improved approach
234145
try:
@@ -240,7 +151,7 @@ def to_arrow(self) -> Any:
240151
except Exception:
241152
# Fallback to True
242153
return pyarrow_dataset.field(self.field).is_valid() | pyarrow_dataset.field(self.field).is_null()
243-
elif self.method == 'endsWith':
154+
if self.method == 'endsWith':
244155
pattern = self.literals[0]
245156
# For PyArrow compatibility
246157
try:
@@ -252,7 +163,7 @@ def to_arrow(self) -> Any:
252163
except Exception:
253164
# Fallback to True
254165
return pyarrow_dataset.field(self.field).is_valid() | pyarrow_dataset.field(self.field).is_null()
255-
elif self.method == 'contains':
166+
if self.method == 'contains':
256167
pattern = self.literals[0]
257168
# For PyArrow compatibility
258169
try:
@@ -264,14 +175,24 @@ def to_arrow(self) -> Any:
264175
except Exception:
265176
# Fallback to True
266177
return pyarrow_dataset.field(self.field).is_valid() | pyarrow_dataset.field(self.field).is_null()
267-
elif self.method == 'between':
268-
return (pyarrow_dataset.field(self.field) >= self.literals[0]) & \
269-
(pyarrow_dataset.field(self.field) <= self.literals[1])
270-
elif self.method == 'and':
271-
return reduce(lambda x, y: x & y,
272-
[p.to_arrow() for p in self.literals])
273-
elif self.method == 'or':
274-
return reduce(lambda x, y: x | y,
275-
[p.to_arrow() for p in self.literals])
276-
else:
277-
raise ValueError("Unsupported predicate method: {}".format(self.method))
178+
179+
field = pyarrow_dataset.field(self.field)
180+
dispatch = {
181+
'equal': lambda literals: field == literals[0],
182+
'notEqual': lambda literals: field != literals[0],
183+
'lessThan': lambda literals: field < literals[0],
184+
'lessOrEqual': lambda literals: field <= literals[0],
185+
'greaterThan': lambda literals: field > literals[0],
186+
'greaterOrEqual': lambda literals: field >= literals[0],
187+
'isNull': lambda literals: field.is_null(),
188+
'isNotNull': lambda literals: field.is_valid(),
189+
'in': lambda literals: field.isin(literals),
190+
'notIn': lambda literals: ~field.isin(literals),
191+
'between': lambda literals: (field >= literals[0]) & (field <= literals[1]),
192+
}
193+
194+
func = dispatch.get(self.method)
195+
if func:
196+
return func(self.literals)
197+
198+
raise ValueError("Unsupported predicate method: {}".format(self.method))

paimon-python/pypaimon/read/split_read.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919
import os
2020
from abc import ABC, abstractmethod
2121
from functools import partial
22-
from typing import List, Optional, Tuple
22+
from typing import List, Optional, Tuple, Any
2323

2424
from pypaimon.common.core_options import CoreOptions
2525
from pypaimon.common.predicate import Predicate
2626
from pypaimon.manifest.schema.data_file_meta import DataFileMeta
2727
from pypaimon.read.interval_partition import IntervalPartition, SortedRun
2828
from pypaimon.read.partition_info import PartitionInfo
29+
from pypaimon.read.push_down_utils import trim_predicate_by_fields
2930
from pypaimon.read.reader.concat_batch_reader import ConcatBatchReader, ShardBatchReader, MergeAllBatchReader
3031
from pypaimon.read.reader.concat_record_reader import ConcatRecordReader
3132
from pypaimon.read.reader.data_file_batch_reader import DataFileBatchReader
@@ -54,21 +55,31 @@
5455
class SplitRead(ABC):
5556
"""Abstract base class for split reading operations."""
5657

57-
def __init__(self, table, predicate: Optional[Predicate], push_down_predicate,
58-
read_type: List[DataField], split: Split):
58+
def __init__(self, table, predicate: Optional[Predicate], read_type: List[DataField], split: Split):
5959
from pypaimon.table.file_store_table import FileStoreTable
6060

6161
self.table: FileStoreTable = table
6262
self.predicate = predicate
63-
self.push_down_predicate = push_down_predicate
63+
self.push_down_predicate = self._push_down_predicate()
6464
self.split = split
6565
self.value_arity = len(read_type)
6666

67-
self.trimmed_primary_key = [field.name for field in self.table.table_schema.get_trimmed_primary_key_fields()]
67+
self.trimmed_primary_key = self.table.table_schema.get_trimmed_primary_keys()
6868
self.read_fields = read_type
6969
if isinstance(self, MergeFileSplitRead):
7070
self.read_fields = self._create_key_value_fields(read_type)
7171

72+
def _push_down_predicate(self) -> Any:
73+
if self.predicate is None:
74+
return None
75+
elif self.table.is_primary_key_table:
76+
pk_predicate = trim_predicate_by_fields(self.predicate, self.table.primary_keys)
77+
if not pk_predicate:
78+
return None
79+
return pk_predicate.to_arrow()
80+
else:
81+
return self.predicate.to_arrow()
82+
7283
@abstractmethod
7384
def create_reader(self) -> RecordReader:
7485
"""Create a record reader for the given split."""

paimon-python/pypaimon/read/table_read.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
################################################################################
18-
from typing import Any, Iterator, List, Optional
18+
from typing import Iterator, List, Optional
1919

2020
import pandas
2121
import pyarrow
2222

2323
from pypaimon.common.core_options import CoreOptions
2424
from pypaimon.common.predicate import Predicate
25-
from pypaimon.read.push_down_utils import trim_predicate_by_fields
2625
from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
2726
from pypaimon.read.split import Split
2827
from pypaimon.read.split_read import (MergeFileSplitRead, RawFileSplitRead,
@@ -39,7 +38,6 @@ def __init__(self, table, predicate: Optional[Predicate], read_type: List[DataFi
3938

4039
self.table: FileStoreTable = table
4140
self.predicate = predicate
42-
self.push_down_predicate = self._push_down_predicate()
4341
self.read_type = read_type
4442

4543
def to_iterator(self, splits: List[Split]) -> Iterator:
@@ -108,39 +106,25 @@ def to_ray(self, splits: List[Split]) -> "ray.data.dataset.Dataset":
108106

109107
return ray.data.from_arrow(self.to_arrow(splits))
110108

111-
def _push_down_predicate(self) -> Any:
112-
if self.predicate is None:
113-
return None
114-
elif self.table.is_primary_key_table:
115-
pk_predicate = trim_predicate_by_fields(self.predicate, self.table.primary_keys)
116-
if not pk_predicate:
117-
return None
118-
return pk_predicate.to_arrow()
119-
else:
120-
return self.predicate.to_arrow()
121-
122109
def _create_split_read(self, split: Split) -> SplitRead:
123110
if self.table.is_primary_key_table and not split.raw_convertible:
124111
return MergeFileSplitRead(
125112
table=self.table,
126113
predicate=self.predicate,
127-
push_down_predicate=self.push_down_predicate,
128114
read_type=self.read_type,
129115
split=split
130116
)
131117
elif self.table.options.get(CoreOptions.DATA_EVOLUTION_ENABLED, 'false').lower() == 'true':
132118
return DataEvolutionSplitRead(
133119
table=self.table,
134120
predicate=self.predicate,
135-
push_down_predicate=self.push_down_predicate,
136121
read_type=self.read_type,
137122
split=split
138123
)
139124
else:
140125
return RawFileSplitRead(
141126
table=self.table,
142127
predicate=self.predicate,
143-
push_down_predicate=self.push_down_predicate,
144128
read_type=self.read_type,
145129
split=split
146130
)

paimon-python/pypaimon/write/writer/data_blob_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int):
8383
self.blob_column_name = self._get_blob_columns_from_schema()
8484

8585
# Split schema into normal and blob columns
86-
all_column_names = [field.name for field in self.table.table_schema.fields]
86+
all_column_names = self.table.field_names
8787
self.normal_column_names = [col for col in all_column_names if col != self.blob_column_name]
8888
self.write_cols = self.normal_column_names
8989

0 commit comments

Comments
 (0)