Skip to content

Commit 9bd9c41

Browse files
authored
[python] Refactor dicts to static fields to improve performance (#6436)
1 parent fa3c6c0 commit 9bd9c41

1 file changed

Lines changed: 242 additions & 56 deletions

File tree

paimon-python/pypaimon/common/predicate.py

Lines changed: 242 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
# limitations under the License.
1717
################################################################################
1818

19+
from abc import ABC, ABCMeta, abstractmethod
1920
from dataclasses import dataclass
2021
from functools import reduce
2122
from typing import Any, Dict, List, Optional
23+
from typing import ClassVar
2224

2325
import pyarrow
2426
from pyarrow import compute as pyarrow_compute
@@ -35,6 +37,8 @@ class Predicate:
3537
field: Optional[str]
3638
literals: Optional[List[Any]] = None
3739

40+
testers: ClassVar[Dict[str, Any]] = {}
41+
3842
def new_index(self, index: int):
3943
return Predicate(
4044
method=self.method,
@@ -56,26 +60,10 @@ def test(self, record: InternalRow) -> bool:
5660
t = any(p.test(record) for p in self.literals)
5761
return t
5862

59-
dispatch = {
60-
'equal': lambda val, literals: val == literals[0],
61-
'notEqual': lambda val, literals: val != literals[0],
62-
'lessThan': lambda val, literals: val < literals[0],
63-
'lessOrEqual': lambda val, literals: val <= literals[0],
64-
'greaterThan': lambda val, literals: val > literals[0],
65-
'greaterOrEqual': lambda val, literals: val >= literals[0],
66-
'isNull': lambda val, literals: val is None,
67-
'isNotNull': lambda val, literals: val is not None,
68-
'startsWith': lambda val, literals: isinstance(val, str) and val.startswith(literals[0]),
69-
'endsWith': lambda val, literals: isinstance(val, str) and val.endswith(literals[0]),
70-
'contains': lambda val, literals: isinstance(val, str) and literals[0] in val,
71-
'in': lambda val, literals: val in literals,
72-
'notIn': lambda val, literals: val not in literals,
73-
'between': lambda val, literals: literals[0] <= val <= literals[1],
74-
}
75-
func = dispatch.get(self.method)
76-
if func:
77-
field_value = record.get_field(self.index)
78-
return func(field_value, self.literals)
63+
field_value = record.get_field(self.index)
64+
tester = Predicate.testers.get(self.method)
65+
if tester:
66+
return tester.test_by_value(field_value, self.literals)
7967
raise ValueError(f"Unsupported predicate method: {self.method}")
8068

8169
def test_by_simple_stats(self, stat: SimpleStats, row_count: int) -> bool:
@@ -110,25 +98,9 @@ def test_by_stats(self, stat: Dict) -> bool:
11098
# invalid stats, skip validation
11199
return True
112100

113-
dispatch = {
114-
'equal': lambda literals: min_value <= literals[0] <= max_value,
115-
'notEqual': lambda literals: not (min_value == literals[0] == max_value),
116-
'lessThan': lambda literals: literals[0] > min_value,
117-
'lessOrEqual': lambda literals: literals[0] >= min_value,
118-
'greaterThan': lambda literals: literals[0] < max_value,
119-
'greaterOrEqual': lambda literals: literals[0] <= max_value,
120-
'in': lambda literals: any(min_value <= l <= max_value for l in literals),
121-
'notIn': lambda literals: not any(min_value == l == max_value for l in literals),
122-
'between': lambda literals: literals[0] <= max_value and literals[1] >= min_value,
123-
'startsWith': lambda literals: ((isinstance(min_value, str) and isinstance(max_value, str)) and
124-
((min_value.startswith(literals[0]) or min_value < literals[0]) and
125-
(max_value.startswith(literals[0]) or max_value > literals[0]))),
126-
'endsWith': lambda literals: True,
127-
'contains': lambda literals: True,
128-
}
129-
func = dispatch.get(self.method)
130-
if func:
131-
return func(self.literals)
101+
tester = Predicate.testers.get(self.method)
102+
if tester:
103+
return tester.test_by_stats(min_value, max_value, self.literals)
132104
raise ValueError(f"Unsupported predicate method: {self.method}")
133105

134106
def to_arrow(self) -> Any:
@@ -177,22 +149,236 @@ def to_arrow(self) -> Any:
177149
return pyarrow_dataset.field(self.field).is_valid() | pyarrow_dataset.field(self.field).is_null()
178150

179151
field = pyarrow_dataset.field(self.field)
180-
dispatch = {
181-
'equal': lambda literals: field == literals[0],
182-
'notEqual': lambda literals: field != literals[0],
183-
'lessThan': lambda literals: field < literals[0],
184-
'lessOrEqual': lambda literals: field <= literals[0],
185-
'greaterThan': lambda literals: field > literals[0],
186-
'greaterOrEqual': lambda literals: field >= literals[0],
187-
'isNull': lambda literals: field.is_null(),
188-
'isNotNull': lambda literals: field.is_valid(),
189-
'in': lambda literals: field.isin(literals),
190-
'notIn': lambda literals: ~field.isin(literals),
191-
'between': lambda literals: (field >= literals[0]) & (field <= literals[1]),
192-
}
193-
194-
func = dispatch.get(self.method)
195-
if func:
196-
return func(self.literals)
152+
tester = Predicate.testers.get(self.method)
153+
if tester:
154+
return tester.test_by_arrow(field, self.literals)
197155

198156
raise ValueError("Unsupported predicate method: {}".format(self.method))
157+
158+
159+
class RegisterMeta(ABCMeta):
160+
def __init__(cls, name, bases, dct):
161+
super().__init__(name, bases, dct)
162+
if not bool(cls.__abstractmethods__):
163+
Predicate.testers[cls.name] = cls()
164+
165+
166+
class Tester(ABC, metaclass=RegisterMeta):
167+
168+
name = None
169+
170+
@abstractmethod
171+
def test_by_value(self, val, literals) -> bool:
172+
"""
173+
Test based on the specific val and literals.
174+
"""
175+
176+
@abstractmethod
177+
def test_by_stats(self, min_v, max_v, literals) -> bool:
178+
"""
179+
Test based on the specific min_value and max_value and literals.
180+
"""
181+
182+
@abstractmethod
183+
def test_by_arrow(self, val, literals) -> bool:
184+
"""
185+
Test based on the specific arrow value and literals.
186+
"""
187+
188+
189+
class Equal(Tester):
190+
191+
name = 'equal'
192+
193+
def test_by_value(self, val, literals) -> bool:
194+
return val == literals[0]
195+
196+
def test_by_stats(self, min_v, max_v, literals) -> bool:
197+
return min_v <= literals[0] <= max_v
198+
199+
def test_by_arrow(self, val, literals) -> bool:
200+
return val == literals[0]
201+
202+
203+
class NotEqual(Tester):
204+
205+
name = "notEqual"
206+
207+
def test_by_value(self, val, literals) -> bool:
208+
return val != literals[0]
209+
210+
def test_by_stats(self, min_v, max_v, literals) -> bool:
211+
return not (min_v == literals[0] == max_v)
212+
213+
def test_by_arrow(self, val, literals) -> bool:
214+
return val != literals[0]
215+
216+
217+
class LessThan(Tester):
218+
219+
name = "lessThan"
220+
221+
def test_by_value(self, val, literals) -> bool:
222+
return val < literals[0]
223+
224+
def test_by_stats(self, min_v, max_v, literals) -> bool:
225+
return literals[0] > min_v
226+
227+
def test_by_arrow(self, val, literals) -> bool:
228+
return val < literals[0]
229+
230+
231+
class LessOrEqual(Tester):
232+
233+
name = "lessOrEqual"
234+
235+
def test_by_value(self, val, literals) -> bool:
236+
return val <= literals[0]
237+
238+
def test_by_stats(self, min_v, max_v, literals) -> bool:
239+
return literals[0] >= min_v
240+
241+
def test_by_arrow(self, val, literals) -> bool:
242+
return val <= literals[0]
243+
244+
245+
class GreaterThan(Tester):
246+
247+
name = "greaterThan"
248+
249+
def test_by_value(self, val, literals) -> bool:
250+
return val > literals[0]
251+
252+
def test_by_stats(self, min_v, max_v, literals) -> bool:
253+
return literals[0] < max_v
254+
255+
def test_by_arrow(self, val, literals) -> bool:
256+
return val > literals[0]
257+
258+
259+
class GreaterOrEqual(Tester):
260+
261+
name = "greaterOrEqual"
262+
263+
def test_by_value(self, val, literals) -> bool:
264+
return val >= literals[0]
265+
266+
def test_by_stats(self, min_v, max_v, literals) -> bool:
267+
return literals[0] <= max_v
268+
269+
def test_by_arrow(self, val, literals) -> bool:
270+
return val >= literals[0]
271+
272+
273+
class In(Tester):
274+
275+
name = "in"
276+
277+
def test_by_value(self, val, literals) -> bool:
278+
return val in literals
279+
280+
def test_by_stats(self, min_v, max_v, literals) -> bool:
281+
return any(min_v <= l <= max_v for l in literals)
282+
283+
def test_by_arrow(self, val, literals) -> bool:
284+
return val.isin(literals)
285+
286+
287+
class NotIn(Tester):
288+
289+
name = "notIn"
290+
291+
def test_by_value(self, val, literals) -> bool:
292+
return val not in literals
293+
294+
def test_by_stats(self, min_v, max_v, literals) -> bool:
295+
return not any(min_v == l == max_v for l in literals)
296+
297+
def test_by_arrow(self, val, literals) -> bool:
298+
return ~val.isin(literals)
299+
300+
301+
class Between(Tester):
302+
303+
name = "between"
304+
305+
def test_by_value(self, val, literals) -> bool:
306+
return literals[0] <= val <= literals[1]
307+
308+
def test_by_stats(self, min_v, max_v, literals) -> bool:
309+
return literals[0] <= max_v and literals[1] >= min_v
310+
311+
def test_by_arrow(self, val, literals) -> bool:
312+
return (val >= literals[0]) & (val <= literals[1])
313+
314+
315+
class StartsWith(Tester):
316+
317+
name = "startsWith"
318+
319+
def test_by_value(self, val, literals) -> bool:
320+
return isinstance(val, str) and val.startswith(literals[0])
321+
322+
def test_by_stats(self, min_v, max_v, literals) -> bool:
323+
return ((isinstance(min_v, str) and isinstance(max_v, str)) and
324+
((min_v.startswith(literals[0]) or min_v < literals[0]) and
325+
(max_v.startswith(literals[0]) or max_v > literals[0])))
326+
327+
def test_by_arrow(self, val, literals) -> bool:
328+
return True
329+
330+
331+
class EndsWith(Tester):
332+
333+
name = "endsWith"
334+
335+
def test_by_value(self, val, literals) -> bool:
336+
return isinstance(val, str) and val.endswith(literals[0])
337+
338+
def test_by_stats(self, min_v, max_v, literals) -> bool:
339+
return True
340+
341+
def test_by_arrow(self, val, literals) -> bool:
342+
return True
343+
344+
345+
class Contains(Tester):
346+
347+
name = "contains"
348+
349+
def test_by_value(self, val, literals) -> bool:
350+
return isinstance(val, str) and literals[0] in val
351+
352+
def test_by_stats(self, min_v, max_v, literals) -> bool:
353+
return True
354+
355+
def test_by_arrow(self, val, literals) -> bool:
356+
return True
357+
358+
359+
class IsNull(Tester):
360+
361+
name = "isNull"
362+
363+
def test_by_value(self, val, literals) -> bool:
364+
return val is None
365+
366+
def test_by_stats(self, min_v, max_v, literals) -> bool:
367+
return True
368+
369+
def test_by_arrow(self, val, literals) -> bool:
370+
return val.is_null()
371+
372+
373+
class IsNotNull(Tester):
374+
375+
name = "isNotNull"
376+
377+
def test_by_value(self, val, literals) -> bool:
378+
return val is not None
379+
380+
def test_by_stats(self, min_v, max_v, literals) -> bool:
381+
return True
382+
383+
def test_by_arrow(self, val, literals) -> bool:
384+
return val.is_valid()

0 commit comments

Comments
 (0)