@@ -50,95 +50,33 @@ def new_literals(self, literals: List[Any]):
5050 literals = literals )
5151
5252 def test (self , record : InternalRow ) -> bool :
53- if self .method == 'equal' :
54- return record .get_field (self .index ) == self .literals [0 ]
55- elif self .method == 'notEqual' :
56- return record .get_field (self .index ) != self .literals [0 ]
57- elif self .method == 'lessThan' :
58- return record .get_field (self .index ) < self .literals [0 ]
59- elif self .method == 'lessOrEqual' :
60- return record .get_field (self .index ) <= self .literals [0 ]
61- elif self .method == 'greaterThan' :
62- return record .get_field (self .index ) > self .literals [0 ]
63- elif self .method == 'greaterOrEqual' :
64- return record .get_field (self .index ) >= self .literals [0 ]
65- elif self .method == 'isNull' :
66- return record .get_field (self .index ) is None
67- elif self .method == 'isNotNull' :
68- return record .get_field (self .index ) is not None
69- elif self .method == 'startsWith' :
70- field_value = record .get_field (self .index )
71- if not isinstance (field_value , str ):
72- return False
73- return field_value .startswith (self .literals [0 ])
74- elif self .method == 'endsWith' :
75- field_value = record .get_field (self .index )
76- if not isinstance (field_value , str ):
77- return False
78- return field_value .endswith (self .literals [0 ])
79- elif self .method == 'contains' :
80- field_value = record .get_field (self .index )
81- if not isinstance (field_value , str ):
82- return False
83- return self .literals [0 ] in field_value
84- elif self .method == 'in' :
85- return record .get_field (self .index ) in self .literals
86- elif self .method == 'notIn' :
87- return record .get_field (self .index ) not in self .literals
88- elif self .method == 'between' :
89- field_value = record .get_field (self .index )
90- return self .literals [0 ] <= field_value <= self .literals [1 ]
91- elif self .method == 'and' :
92- return all (p .test (record ) for p in self .literals )
93- elif self .method == 'or' :
94- t = any (p .test (record ) for p in self .literals )
95- return t
96- else :
97- raise ValueError ("Unsupported predicate method: {}" .format (self .method ))
98-
99- def test_by_value (self , value : Any ) -> bool :
10053 if self .method == 'and' :
101- return all (p .test_by_value ( value ) for p in self .literals )
54+ return all (p .test ( record ) for p in self .literals )
10255 if self .method == 'or' :
103- t = any (p .test_by_value ( value ) for p in self .literals )
56+ t = any (p .test ( record ) for p in self .literals )
10457 return t
10558
106- if self .method == 'equal' :
107- return value == self .literals [0 ]
108- if self .method == 'notEqual' :
109- return value != self .literals [0 ]
110- if self .method == 'lessThan' :
111- return value < self .literals [0 ]
112- if self .method == 'lessOrEqual' :
113- return value <= self .literals [0 ]
114- if self .method == 'greaterThan' :
115- return value > self .literals [0 ]
116- if self .method == 'greaterOrEqual' :
117- return value >= self .literals [0 ]
118- if self .method == 'isNull' :
119- return value is None
120- if self .method == 'isNotNull' :
121- return value is not None
122- if self .method == 'startsWith' :
123- if not isinstance (value , str ):
124- return False
125- return value .startswith (self .literals [0 ])
126- if self .method == 'endsWith' :
127- if not isinstance (value , str ):
128- return False
129- return value .endswith (self .literals [0 ])
130- if self .method == 'contains' :
131- if not isinstance (value , str ):
132- return False
133- return self .literals [0 ] in value
134- if self .method == 'in' :
135- return value in self .literals
136- if self .method == 'notIn' :
137- return value not in self .literals
138- if self .method == 'between' :
139- return self .literals [0 ] <= value <= self .literals [1 ]
140-
141- raise ValueError ("Unsupported predicate method: {}" .format (self .method ))
59+ dispatch = {
60+ 'equal' : lambda val , literals : val == literals [0 ],
61+ 'notEqual' : lambda val , literals : val != literals [0 ],
62+ 'lessThan' : lambda val , literals : val < literals [0 ],
63+ 'lessOrEqual' : lambda val , literals : val <= literals [0 ],
64+ 'greaterThan' : lambda val , literals : val > literals [0 ],
65+ 'greaterOrEqual' : lambda val , literals : val >= literals [0 ],
66+ 'isNull' : lambda val , literals : val is None ,
67+ 'isNotNull' : lambda val , literals : val is not None ,
68+ 'startsWith' : lambda val , literals : isinstance (val , str ) and val .startswith (literals [0 ]),
69+ 'endsWith' : lambda val , literals : isinstance (val , str ) and val .endswith (literals [0 ]),
70+ 'contains' : lambda val , literals : isinstance (val , str ) and literals [0 ] in val ,
71+ 'in' : lambda val , literals : val in literals ,
72+ 'notIn' : lambda val , literals : val not in literals ,
73+ 'between' : lambda val , literals : literals [0 ] <= val <= literals [1 ],
74+ }
75+ func = dispatch .get (self .method )
76+ if func :
77+ field_value = record .get_field (self .index )
78+ return func (field_value , self .literals )
79+ raise ValueError (f"Unsupported predicate method: { self .method } " )
14280
14381 def test_by_simple_stats (self , stat : SimpleStats , row_count : int ) -> bool :
14482 return self .test_by_stats ({
@@ -169,66 +107,39 @@ def test_by_stats(self, stat: Dict) -> bool:
169107 max_value = stat ["max_values" ][self .field ]
170108
171109 if min_value is None or max_value is None or (null_count is not None and null_count == row_count ):
172- return False
173-
174- if self .method == 'equal' :
175- return min_value <= self .literals [0 ] <= max_value
176- if self .method == 'notEqual' :
177- return not (min_value == self .literals [0 ] == max_value )
178- if self .method == 'lessThan' :
179- return self .literals [0 ] > min_value
180- if self .method == 'lessOrEqual' :
181- return self .literals [0 ] >= min_value
182- if self .method == 'greaterThan' :
183- return self .literals [0 ] < max_value
184- if self .method == 'greaterOrEqual' :
185- return self .literals [0 ] <= max_value
186- if self .method == 'startsWith' :
187- if not isinstance (min_value , str ) or not isinstance (max_value , str ):
188- raise RuntimeError ("startsWith predicate on non-str field" )
189- return ((min_value .startswith (self .literals [0 ]) or min_value < self .literals [0 ])
190- and (max_value .startswith (self .literals [0 ]) or max_value > self .literals [0 ]))
191- if self .method == 'endsWith' :
110+ # invalid stats, skip validation
192111 return True
193- if self .method == 'contains' :
194- return True
195- if self .method == 'in' :
196- for literal in self .literals :
197- if min_value <= literal <= max_value :
198- return True
199- return False
200- if self .method == 'notIn' :
201- for literal in self .literals :
202- if min_value == literal == max_value :
203- return False
204- return True
205- if self .method == 'between' :
206- return self .literals [0 ] <= max_value and self .literals [1 ] >= min_value
207- else :
208- raise ValueError ("Unsupported predicate method: {}" .format (self .method ))
112+
113+ dispatch = {
114+ 'equal' : lambda literals : min_value <= literals [0 ] <= max_value ,
115+ 'notEqual' : lambda literals : not (min_value == literals [0 ] == max_value ),
116+ 'lessThan' : lambda literals : literals [0 ] > min_value ,
117+ 'lessOrEqual' : lambda literals : literals [0 ] >= min_value ,
118+ 'greaterThan' : lambda literals : literals [0 ] < max_value ,
119+ 'greaterOrEqual' : lambda literals : literals [0 ] <= max_value ,
120+ 'in' : lambda literals : any (min_value <= l <= max_value for l in literals ),
121+ 'notIn' : lambda literals : not any (min_value == l == max_value for l in literals ),
122+ 'between' : lambda literals : literals [0 ] <= max_value and literals [1 ] >= min_value ,
123+ 'startsWith' : lambda literals : ((isinstance (min_value , str ) and isinstance (max_value , str )) and
124+ ((min_value .startswith (literals [0 ]) or min_value < literals [0 ]) and
125+ (max_value .startswith (literals [0 ]) or max_value > literals [0 ]))),
126+ 'endsWith' : lambda literals : True ,
127+ 'contains' : lambda literals : True ,
128+ }
129+ func = dispatch .get (self .method )
130+ if func :
131+ return func (self .literals )
132+ raise ValueError (f"Unsupported predicate method: { self .method } " )
209133
210134 def to_arrow (self ) -> Any :
211- if self .method == 'equal' :
212- return pyarrow_dataset .field (self .field ) == self .literals [0 ]
213- elif self .method == 'notEqual' :
214- return pyarrow_dataset .field (self .field ) != self .literals [0 ]
215- elif self .method == 'lessThan' :
216- return pyarrow_dataset .field (self .field ) < self .literals [0 ]
217- elif self .method == 'lessOrEqual' :
218- return pyarrow_dataset .field (self .field ) <= self .literals [0 ]
219- elif self .method == 'greaterThan' :
220- return pyarrow_dataset .field (self .field ) > self .literals [0 ]
221- elif self .method == 'greaterOrEqual' :
222- return pyarrow_dataset .field (self .field ) >= self .literals [0 ]
223- elif self .method == 'isNull' :
224- return pyarrow_dataset .field (self .field ).is_null ()
225- elif self .method == 'isNotNull' :
226- return pyarrow_dataset .field (self .field ).is_valid ()
227- elif self .method == 'in' :
228- return pyarrow_dataset .field (self .field ).isin (self .literals )
229- elif self .method == 'notIn' :
230- return ~ pyarrow_dataset .field (self .field ).isin (self .literals )
231- elif self .method == 'startsWith' :
135+ if self .method == 'and' :
136+ return reduce (lambda x , y : x & y ,
137+ [p .to_arrow () for p in self .literals ])
138+ if self .method == 'or' :
139+ return reduce (lambda x , y : x | y ,
140+ [p .to_arrow () for p in self .literals ])
141+
142+ if self .method == 'startsWith' :
232143 pattern = self .literals [0 ]
233144 # For PyArrow compatibility - improved approach
234145 try :
@@ -240,7 +151,7 @@ def to_arrow(self) -> Any:
240151 except Exception :
241152 # Fallback to True
242153 return pyarrow_dataset .field (self .field ).is_valid () | pyarrow_dataset .field (self .field ).is_null ()
243- elif self .method == 'endsWith' :
154+ if self .method == 'endsWith' :
244155 pattern = self .literals [0 ]
245156 # For PyArrow compatibility
246157 try :
@@ -252,7 +163,7 @@ def to_arrow(self) -> Any:
252163 except Exception :
253164 # Fallback to True
254165 return pyarrow_dataset .field (self .field ).is_valid () | pyarrow_dataset .field (self .field ).is_null ()
255- elif self .method == 'contains' :
166+ if self .method == 'contains' :
256167 pattern = self .literals [0 ]
257168 # For PyArrow compatibility
258169 try :
@@ -264,14 +175,24 @@ def to_arrow(self) -> Any:
264175 except Exception :
265176 # Fallback to True
266177 return pyarrow_dataset .field (self .field ).is_valid () | pyarrow_dataset .field (self .field ).is_null ()
267- elif self .method == 'between' :
268- return (pyarrow_dataset .field (self .field ) >= self .literals [0 ]) & \
269- (pyarrow_dataset .field (self .field ) <= self .literals [1 ])
270- elif self .method == 'and' :
271- return reduce (lambda x , y : x & y ,
272- [p .to_arrow () for p in self .literals ])
273- elif self .method == 'or' :
274- return reduce (lambda x , y : x | y ,
275- [p .to_arrow () for p in self .literals ])
276- else :
277- raise ValueError ("Unsupported predicate method: {}" .format (self .method ))
178+
179+ field = pyarrow_dataset .field (self .field )
180+ dispatch = {
181+ 'equal' : lambda literals : field == literals [0 ],
182+ 'notEqual' : lambda literals : field != literals [0 ],
183+ 'lessThan' : lambda literals : field < literals [0 ],
184+ 'lessOrEqual' : lambda literals : field <= literals [0 ],
185+ 'greaterThan' : lambda literals : field > literals [0 ],
186+ 'greaterOrEqual' : lambda literals : field >= literals [0 ],
187+ 'isNull' : lambda literals : field .is_null (),
188+ 'isNotNull' : lambda literals : field .is_valid (),
189+ 'in' : lambda literals : field .isin (literals ),
190+ 'notIn' : lambda literals : ~ field .isin (literals ),
191+ 'between' : lambda literals : (field >= literals [0 ]) & (field <= literals [1 ]),
192+ }
193+
194+ func = dispatch .get (self .method )
195+ if func :
196+ return func (self .literals )
197+
198+ raise ValueError ("Unsupported predicate method: {}" .format (self .method ))
0 commit comments