1616# limitations under the License.
1717################################################################################
1818
19+ from abc import ABC , ABCMeta , abstractmethod
1920from dataclasses import dataclass
2021from functools import reduce
2122from typing import Any , Dict , List , Optional
23+ from typing import ClassVar
2224
2325import pyarrow
2426from pyarrow import compute as pyarrow_compute
@@ -35,6 +37,8 @@ class Predicate:
3537 field : Optional [str ]
3638 literals : Optional [List [Any ]] = None
3739
40+ testers : ClassVar [Dict [str , Any ]] = {}
41+
3842 def new_index (self , index : int ):
3943 return Predicate (
4044 method = self .method ,
@@ -56,26 +60,10 @@ def test(self, record: InternalRow) -> bool:
5660 t = any (p .test (record ) for p in self .literals )
5761 return t
5862
59- dispatch = {
60- 'equal' : lambda val , literals : val == literals [0 ],
61- 'notEqual' : lambda val , literals : val != literals [0 ],
62- 'lessThan' : lambda val , literals : val < literals [0 ],
63- 'lessOrEqual' : lambda val , literals : val <= literals [0 ],
64- 'greaterThan' : lambda val , literals : val > literals [0 ],
65- 'greaterOrEqual' : lambda val , literals : val >= literals [0 ],
66- 'isNull' : lambda val , literals : val is None ,
67- 'isNotNull' : lambda val , literals : val is not None ,
68- 'startsWith' : lambda val , literals : isinstance (val , str ) and val .startswith (literals [0 ]),
69- 'endsWith' : lambda val , literals : isinstance (val , str ) and val .endswith (literals [0 ]),
70- 'contains' : lambda val , literals : isinstance (val , str ) and literals [0 ] in val ,
71- 'in' : lambda val , literals : val in literals ,
72- 'notIn' : lambda val , literals : val not in literals ,
73- 'between' : lambda val , literals : literals [0 ] <= val <= literals [1 ],
74- }
75- func = dispatch .get (self .method )
76- if func :
77- field_value = record .get_field (self .index )
78- return func (field_value , self .literals )
63+ field_value = record .get_field (self .index )
64+ tester = Predicate .testers .get (self .method )
65+ if tester :
66+ return tester .test_by_value (field_value , self .literals )
7967 raise ValueError (f"Unsupported predicate method: { self .method } " )
8068
8169 def test_by_simple_stats (self , stat : SimpleStats , row_count : int ) -> bool :
@@ -110,25 +98,9 @@ def test_by_stats(self, stat: Dict) -> bool:
11098 # invalid stats, skip validation
11199 return True
112100
113- dispatch = {
114- 'equal' : lambda literals : min_value <= literals [0 ] <= max_value ,
115- 'notEqual' : lambda literals : not (min_value == literals [0 ] == max_value ),
116- 'lessThan' : lambda literals : literals [0 ] > min_value ,
117- 'lessOrEqual' : lambda literals : literals [0 ] >= min_value ,
118- 'greaterThan' : lambda literals : literals [0 ] < max_value ,
119- 'greaterOrEqual' : lambda literals : literals [0 ] <= max_value ,
120- 'in' : lambda literals : any (min_value <= l <= max_value for l in literals ),
121- 'notIn' : lambda literals : not any (min_value == l == max_value for l in literals ),
122- 'between' : lambda literals : literals [0 ] <= max_value and literals [1 ] >= min_value ,
123- 'startsWith' : lambda literals : ((isinstance (min_value , str ) and isinstance (max_value , str )) and
124- ((min_value .startswith (literals [0 ]) or min_value < literals [0 ]) and
125- (max_value .startswith (literals [0 ]) or max_value > literals [0 ]))),
126- 'endsWith' : lambda literals : True ,
127- 'contains' : lambda literals : True ,
128- }
129- func = dispatch .get (self .method )
130- if func :
131- return func (self .literals )
101+ tester = Predicate .testers .get (self .method )
102+ if tester :
103+ return tester .test_by_stats (min_value , max_value , self .literals )
132104 raise ValueError (f"Unsupported predicate method: { self .method } " )
133105
134106 def to_arrow (self ) -> Any :
@@ -177,22 +149,236 @@ def to_arrow(self) -> Any:
177149 return pyarrow_dataset .field (self .field ).is_valid () | pyarrow_dataset .field (self .field ).is_null ()
178150
179151 field = pyarrow_dataset .field (self .field )
180- dispatch = {
181- 'equal' : lambda literals : field == literals [0 ],
182- 'notEqual' : lambda literals : field != literals [0 ],
183- 'lessThan' : lambda literals : field < literals [0 ],
184- 'lessOrEqual' : lambda literals : field <= literals [0 ],
185- 'greaterThan' : lambda literals : field > literals [0 ],
186- 'greaterOrEqual' : lambda literals : field >= literals [0 ],
187- 'isNull' : lambda literals : field .is_null (),
188- 'isNotNull' : lambda literals : field .is_valid (),
189- 'in' : lambda literals : field .isin (literals ),
190- 'notIn' : lambda literals : ~ field .isin (literals ),
191- 'between' : lambda literals : (field >= literals [0 ]) & (field <= literals [1 ]),
192- }
193-
194- func = dispatch .get (self .method )
195- if func :
196- return func (self .literals )
152+ tester = Predicate .testers .get (self .method )
153+ if tester :
154+ return tester .test_by_arrow (field , self .literals )
197155
198156 raise ValueError ("Unsupported predicate method: {}" .format (self .method ))
157+
158+
159+ class RegisterMeta (ABCMeta ):
160+ def __init__ (cls , name , bases , dct ):
161+ super ().__init__ (name , bases , dct )
162+ if not bool (cls .__abstractmethods__ ):
163+ Predicate .testers [cls .name ] = cls ()
164+
165+
166+ class Tester (ABC , metaclass = RegisterMeta ):
167+
168+ name = None
169+
170+ @abstractmethod
171+ def test_by_value (self , val , literals ) -> bool :
172+ """
173+ Test based on the specific val and literals.
174+ """
175+
176+ @abstractmethod
177+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
178+ """
179+ Test based on the specific min_value and max_value and literals.
180+ """
181+
182+ @abstractmethod
183+ def test_by_arrow (self , val , literals ) -> bool :
184+ """
185+ Test based on the specific arrow value and literals.
186+ """
187+
188+
189+ class Equal (Tester ):
190+
191+ name = 'equal'
192+
193+ def test_by_value (self , val , literals ) -> bool :
194+ return val == literals [0 ]
195+
196+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
197+ return min_v <= literals [0 ] <= max_v
198+
199+ def test_by_arrow (self , val , literals ) -> bool :
200+ return val == literals [0 ]
201+
202+
203+ class NotEqual (Tester ):
204+
205+ name = "notEqual"
206+
207+ def test_by_value (self , val , literals ) -> bool :
208+ return val != literals [0 ]
209+
210+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
211+ return not (min_v == literals [0 ] == max_v )
212+
213+ def test_by_arrow (self , val , literals ) -> bool :
214+ return val != literals [0 ]
215+
216+
217+ class LessThan (Tester ):
218+
219+ name = "lessThan"
220+
221+ def test_by_value (self , val , literals ) -> bool :
222+ return val < literals [0 ]
223+
224+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
225+ return literals [0 ] > min_v
226+
227+ def test_by_arrow (self , val , literals ) -> bool :
228+ return val < literals [0 ]
229+
230+
231+ class LessOrEqual (Tester ):
232+
233+ name = "lessOrEqual"
234+
235+ def test_by_value (self , val , literals ) -> bool :
236+ return val <= literals [0 ]
237+
238+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
239+ return literals [0 ] >= min_v
240+
241+ def test_by_arrow (self , val , literals ) -> bool :
242+ return val <= literals [0 ]
243+
244+
245+ class GreaterThan (Tester ):
246+
247+ name = "greaterThan"
248+
249+ def test_by_value (self , val , literals ) -> bool :
250+ return val > literals [0 ]
251+
252+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
253+ return literals [0 ] < max_v
254+
255+ def test_by_arrow (self , val , literals ) -> bool :
256+ return val > literals [0 ]
257+
258+
259+ class GreaterOrEqual (Tester ):
260+
261+ name = "greaterOrEqual"
262+
263+ def test_by_value (self , val , literals ) -> bool :
264+ return val >= literals [0 ]
265+
266+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
267+ return literals [0 ] <= max_v
268+
269+ def test_by_arrow (self , val , literals ) -> bool :
270+ return val >= literals [0 ]
271+
272+
273+ class In (Tester ):
274+
275+ name = "in"
276+
277+ def test_by_value (self , val , literals ) -> bool :
278+ return val in literals
279+
280+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
281+ return any (min_v <= l <= max_v for l in literals )
282+
283+ def test_by_arrow (self , val , literals ) -> bool :
284+ return val .isin (literals )
285+
286+
287+ class NotIn (Tester ):
288+
289+ name = "notIn"
290+
291+ def test_by_value (self , val , literals ) -> bool :
292+ return val not in literals
293+
294+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
295+ return not any (min_v == l == max_v for l in literals )
296+
297+ def test_by_arrow (self , val , literals ) -> bool :
298+ return ~ val .isin (literals )
299+
300+
301+ class Between (Tester ):
302+
303+ name = "between"
304+
305+ def test_by_value (self , val , literals ) -> bool :
306+ return literals [0 ] <= val <= literals [1 ]
307+
308+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
309+ return literals [0 ] <= max_v and literals [1 ] >= min_v
310+
311+ def test_by_arrow (self , val , literals ) -> bool :
312+ return (val >= literals [0 ]) & (val <= literals [1 ])
313+
314+
315+ class StartsWith (Tester ):
316+
317+ name = "startsWith"
318+
319+ def test_by_value (self , val , literals ) -> bool :
320+ return isinstance (val , str ) and val .startswith (literals [0 ])
321+
322+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
323+ return ((isinstance (min_v , str ) and isinstance (max_v , str )) and
324+ ((min_v .startswith (literals [0 ]) or min_v < literals [0 ]) and
325+ (max_v .startswith (literals [0 ]) or max_v > literals [0 ])))
326+
327+ def test_by_arrow (self , val , literals ) -> bool :
328+ return True
329+
330+
331+ class EndsWith (Tester ):
332+
333+ name = "endsWith"
334+
335+ def test_by_value (self , val , literals ) -> bool :
336+ return isinstance (val , str ) and val .endswith (literals [0 ])
337+
338+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
339+ return True
340+
341+ def test_by_arrow (self , val , literals ) -> bool :
342+ return True
343+
344+
345+ class Contains (Tester ):
346+
347+ name = "contains"
348+
349+ def test_by_value (self , val , literals ) -> bool :
350+ return isinstance (val , str ) and literals [0 ] in val
351+
352+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
353+ return True
354+
355+ def test_by_arrow (self , val , literals ) -> bool :
356+ return True
357+
358+
359+ class IsNull (Tester ):
360+
361+ name = "isNull"
362+
363+ def test_by_value (self , val , literals ) -> bool :
364+ return val is None
365+
366+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
367+ return True
368+
369+ def test_by_arrow (self , val , literals ) -> bool :
370+ return val .is_null ()
371+
372+
373+ class IsNotNull (Tester ):
374+
375+ name = "isNotNull"
376+
377+ def test_by_value (self , val , literals ) -> bool :
378+ return val is not None
379+
380+ def test_by_stats (self , min_v , max_v , literals ) -> bool :
381+ return True
382+
383+ def test_by_arrow (self , val , literals ) -> bool :
384+ return val .is_valid ()
0 commit comments