Skip to content

Commit 7b22b18

Browse files
committed
Fix NotStartsWith residual evaluation to return correct result
1 parent d101879 commit 7b22b18

2 files changed

Lines changed: 18 additions & 4 deletions

File tree

pyiceberg/expressions/visitors.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,10 +1905,11 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> BooleanEx
19051905
return AlwaysFalse()
19061906

19071907
def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> BooleanExpression:
1908-
if not self.visit_starts_with(term, literal):
1909-
return AlwaysTrue()
1910-
else:
1908+
starts_with_result = self.visit_starts_with(term, literal)
1909+
if isinstance(starts_with_result, AlwaysTrue):
19111910
return AlwaysFalse()
1911+
else:
1912+
return AlwaysTrue()
19121913

19131914
def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression:
19141915
"""

tests/expressions/test_residual_evaluator.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from pyiceberg.schema import Schema
4242
from pyiceberg.transforms import DayTransform, IdentityTransform
4343
from pyiceberg.typedef import Record
44-
from pyiceberg.types import DoubleType, FloatType, IntegerType, NestedField, TimestampType
44+
from pyiceberg.types import DoubleType, FloatType, IntegerType, NestedField, StringType, TimestampType
4545

4646

4747
def test_identity_transform_residual() -> None:
@@ -249,3 +249,16 @@ def test_not_in_timestamp() -> None:
249249
ts_day += 3 # type: ignore
250250
residual = res_eval.residual_for(Record(ts_day))
251251
assert residual == AlwaysTrue()
252+
253+
254+
def test_not_starts_with() -> None:
255+
schema = Schema(NestedField(1, "x", StringType()))
256+
spec = PartitionSpec(PartitionField(1, 1001, IdentityTransform(), "x_part"))
257+
258+
predicate = NotStartsWith("x", "a")
259+
res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema)
260+
261+
assert res_eval.residual_for(Record("bb")) == AlwaysTrue()
262+
assert res_eval.residual_for(Record("abc")) == AlwaysFalse()
263+
assert res_eval.residual_for(Record("a")) == AlwaysFalse()
264+
assert res_eval.residual_for(Record("zoo")) == AlwaysTrue()

0 commit comments

Comments
 (0)