Skip to content

Commit 7ff150f

Browse files
committed
Fix string-based starts_with and not_starts_with methods
1 parent d101879 commit 7ff150f

2 files changed

Lines changed: 29 additions & 10 deletions

File tree

pyiceberg/expressions/visitors.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def visit_less_than_or_equal(self, term: BoundTerm, literal: LiteralValue) -> bo
509509

510510
def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
511511
eval_res = term.eval(self.struct)
512-
return eval_res is not None and str(eval_res).startswith(str(literal.value))
512+
return eval_res is not None and eval_res.startswith(literal.value)
513513

514514
def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
515515
return not self.visit_starts_with(term, literal)
@@ -712,7 +712,7 @@ def visit_less_than_or_equal(self, term: BoundTerm, literal: LiteralValue) -> bo
712712
def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
713713
pos = term.ref().accessor.position
714714
field = self.partition_fields[pos]
715-
prefix = str(literal.value)
715+
prefix = literal.value
716716
len_prefix = len(prefix)
717717

718718
if field.lower_bound is None:
@@ -736,7 +736,7 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
736736
def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
737737
pos = term.ref().accessor.position
738738
field = self.partition_fields[pos]
739-
prefix = str(literal.value)
739+
prefix = literal.value
740740
len_prefix = len(prefix)
741741

742742
if field.contains_null or field.lower_bound is None or field.upper_bound is None:
@@ -1408,20 +1408,20 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
14081408
if not isinstance(field.field_type, PrimitiveType):
14091409
raise ValueError(f"Expected PrimitiveType: {field.field_type}")
14101410

1411-
prefix = str(literal.value)
1411+
prefix = literal.value
14121412
len_prefix = len(prefix)
14131413

14141414
lower_bound_bytes = self.lower_bounds.get(field_id)
14151415
if lower_bound_bytes is not None:
1416-
lower_bound = str(from_bytes(field.field_type, lower_bound_bytes))
1416+
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
14171417

14181418
# truncate lower bound so that its length is not greater than the length of prefix
14191419
if lower_bound and lower_bound[:len_prefix] > prefix:
14201420
return ROWS_CANNOT_MATCH
14211421

14221422
upper_bound_bytes = self.upper_bounds.get(field_id)
14231423
if upper_bound_bytes is not None:
1424-
upper_bound = str(from_bytes(field.field_type, upper_bound_bytes))
1424+
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
14251425

14261426
# truncate upper bound so that its length is not greater than the length of prefix
14271427
if upper_bound is not None and upper_bound[:len_prefix] < prefix:
@@ -1439,16 +1439,16 @@ def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
14391439
if not isinstance(field.field_type, PrimitiveType):
14401440
raise ValueError(f"Expected PrimitiveType: {field.field_type}")
14411441

1442-
prefix = str(literal.value)
1442+
prefix = literal.value
14431443
len_prefix = len(prefix)
14441444

14451445
# not_starts_with will match unless all values must start with the prefix. This happens when
14461446
# the lower and upper bounds both start with the prefix.
14471447
lower_bound_bytes = self.lower_bounds.get(field_id)
14481448
upper_bound_bytes = self.upper_bounds.get(field_id)
14491449
if lower_bound_bytes is not None and upper_bound_bytes is not None:
1450-
lower_bound = str(from_bytes(field.field_type, lower_bound_bytes))
1451-
upper_bound = str(from_bytes(field.field_type, upper_bound_bytes))
1450+
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
1451+
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
14521452

14531453
# if lower is shorter than the prefix then lower doesn't start with the prefix
14541454
if len(lower_bound) < len_prefix:
@@ -1899,7 +1899,7 @@ def visit_not_in(self, term: BoundTerm, literals: set[L]) -> BooleanExpression:
18991899

19001900
def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> BooleanExpression:
19011901
eval_res = term.eval(self.struct)
1902-
if eval_res is not None and str(eval_res).startswith(str(literal.value)):
1902+
if eval_res is not None and eval_res.startswith(literal.value):
19031903
return AlwaysTrue()
19041904
else:
19051905
return AlwaysFalse()

tests/expressions/test_visitors.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
from pyiceberg.schema import Accessor, Schema
8181
from pyiceberg.typedef import Record
8282
from pyiceberg.types import (
83+
BinaryType,
8384
BooleanType,
8485
DoubleType,
8586
FloatType,
@@ -1629,6 +1630,24 @@ def test_expression_evaluator_null() -> None:
16291630
assert expression_evaluator(schema, NotStartsWith("a", 1), case_sensitive=True)(struct) is True
16301631

16311632

1633+
def test_expression_evaluator_binary_starts_with() -> None:
1634+
schema = Schema(NestedField(1, "x", BinaryType(), required=False), schema_id=1)
1635+
struct = Record(b"aa")
1636+
assert expression_evaluator(schema, StartsWith("x", b"a"), case_sensitive=True)(struct) is True
1637+
assert expression_evaluator(schema, StartsWith("x", b"aa"), case_sensitive=True)(struct) is True
1638+
assert expression_evaluator(schema, StartsWith("x", b"aaa"), case_sensitive=True)(struct) is False
1639+
assert expression_evaluator(schema, StartsWith("x", b"b"), case_sensitive=True)(struct) is False
1640+
1641+
1642+
def test_expression_evaluator_binary_not_starts_with() -> None:
1643+
schema = Schema(NestedField(1, "x", BinaryType(), required=False), schema_id=1)
1644+
struct = Record(b"aa")
1645+
assert expression_evaluator(schema, NotStartsWith("x", b"a"), case_sensitive=True)(struct) is False
1646+
assert expression_evaluator(schema, NotStartsWith("x", b"aa"), case_sensitive=True)(struct) is False
1647+
assert expression_evaluator(schema, NotStartsWith("x", b"aaa"), case_sensitive=True)(struct) is True
1648+
assert expression_evaluator(schema, NotStartsWith("x", b"b"), case_sensitive=True)(struct) is True
1649+
1650+
16321651
def test_translate_column_names_simple_case(table_schema_simple: Schema) -> None:
16331652
"""Test translate_column_names with matching column names."""
16341653
# Create a bound expression using the original schema

0 commit comments

Comments
 (0)