Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,16 @@
from pyiceberg.schema import Schema
from pyiceberg.typedef import EMPTY_DICT, L, LiteralValue, Record, StructProtocol
from pyiceberg.types import (
DateType,
DoubleType,
FloatType,
IcebergType,
IntegerType,
LongType,
NestedField,
PrimitiveType,
StructType,
TimestampNanoType,
TimestampType,
TimestamptzType,
)
Expand All @@ -73,6 +77,17 @@
T = TypeVar("T")


def _from_bytes_with_promotion(field_type: PrimitiveType, b: bytes) -> Any:
Comment thread
rambleraptor marked this conversation as resolved.
Outdated
Comment thread
rambleraptor marked this conversation as resolved.
Outdated
if len(b) == 4:
Comment thread
rambleraptor marked this conversation as resolved.
Outdated
if isinstance(field_type, LongType):
return from_bytes(IntegerType(), b)
elif isinstance(field_type, DoubleType):
return from_bytes(FloatType(), b)
elif isinstance(field_type, (TimestampType, TimestampNanoType)):
return from_bytes(DateType(), b)
return from_bytes(field_type, b)


class BooleanExpressionVisitor(Generic[T], ABC):
@abstractmethod
def visit_true(self) -> T:
Expand Down Expand Up @@ -1242,7 +1257,7 @@ def visit_less_than(self, term: BoundTerm, literal: LiteralValue) -> bool:
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
lower_bound = _from_bytes_with_promotion(field.field_type, lower_bound_bytes)

if self._is_nan(lower_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
Expand All @@ -1264,7 +1279,7 @@ def visit_less_than_or_equal(self, term: BoundTerm, literal: LiteralValue) -> bo
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
lower_bound = _from_bytes_with_promotion(field.field_type, lower_bound_bytes)
if self._is_nan(lower_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
return ROWS_MIGHT_MATCH
Expand All @@ -1285,7 +1300,7 @@ def visit_greater_than(self, term: BoundTerm, literal: LiteralValue) -> bool:
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
upper_bound = _from_bytes_with_promotion(field.field_type, upper_bound_bytes)
if upper_bound <= literal.value:
if self._is_nan(upper_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
Expand All @@ -1306,7 +1321,7 @@ def visit_greater_than_or_equal(self, term: BoundTerm, literal: LiteralValue) ->
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
upper_bound = _from_bytes_with_promotion(field.field_type, upper_bound_bytes)
if upper_bound < literal.value:
if self._is_nan(upper_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
Expand All @@ -1327,7 +1342,7 @@ def visit_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
lower_bound = _from_bytes_with_promotion(field.field_type, lower_bound_bytes)
if self._is_nan(lower_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
return ROWS_MIGHT_MATCH
Expand All @@ -1336,7 +1351,7 @@ def visit_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
return ROWS_CANNOT_MATCH

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
upper_bound = _from_bytes_with_promotion(field.field_type, upper_bound_bytes)
if self._is_nan(upper_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
return ROWS_MIGHT_MATCH
Expand Down Expand Up @@ -1364,22 +1379,22 @@ def visit_in(self, term: BoundTerm, literals: set[L]) -> bool:
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
lower_bound = _from_bytes_with_promotion(field.field_type, lower_bound_bytes)
if self._is_nan(lower_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
return ROWS_MIGHT_MATCH

literals = {lit for lit in literals if lower_bound <= lit} # type: ignore[operator]
literals = {lit for lit in literals if lower_bound <= lit}
if len(literals) == 0:
return ROWS_CANNOT_MATCH

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
upper_bound = _from_bytes_with_promotion(field.field_type, upper_bound_bytes)
# this is different from Java, here NaN is always larger
if self._is_nan(upper_bound):
return ROWS_MIGHT_MATCH

literals = {lit for lit in literals if upper_bound >= lit} # type: ignore[operator]
literals = {lit for lit in literals if upper_bound >= lit}
if len(literals) == 0:
return ROWS_CANNOT_MATCH

Expand All @@ -1404,14 +1419,14 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
len_prefix = len(prefix)

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound = str(from_bytes(field.field_type, lower_bound_bytes))
lower_bound = str(_from_bytes_with_promotion(field.field_type, lower_bound_bytes))

# truncate lower bound so that its length is not greater than the length of prefix
if lower_bound and lower_bound[:len_prefix] > prefix:
return ROWS_CANNOT_MATCH

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound = str(from_bytes(field.field_type, upper_bound_bytes))
upper_bound = str(_from_bytes_with_promotion(field.field_type, upper_bound_bytes))

# truncate upper bound so that its length is not greater than the length of prefix
if upper_bound is not None and upper_bound[:len_prefix] < prefix:
Expand All @@ -1435,8 +1450,8 @@ def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
# not_starts_with will match unless all values must start with the prefix. This happens when
# the lower and upper bounds both start with the prefix.
if (lower_bound_bytes := self.lower_bounds.get(field_id)) and (upper_bound_bytes := self.upper_bounds.get(field_id)):
lower_bound = str(from_bytes(field.field_type, lower_bound_bytes))
upper_bound = str(from_bytes(field.field_type, upper_bound_bytes))
lower_bound = str(_from_bytes_with_promotion(field.field_type, lower_bound_bytes))
upper_bound = str(_from_bytes_with_promotion(field.field_type, upper_bound_bytes))

# if lower is shorter than the prefix then lower doesn't start with the prefix
if len(lower_bound) < len_prefix:
Expand Down
44 changes: 44 additions & 0 deletions tests/expressions/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
FloatType,
IcebergType,
IntegerType,
LongType,
NestedField,
PrimitiveType,
StringType,
Expand Down Expand Up @@ -1463,3 +1464,46 @@ def test_strict_integer_not_in(strict_data_file_schema: Schema, strict_data_file

should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotIn("no_nulls", {"abc", "def"})).eval(strict_data_file_1)
assert not should_read, "Should not match: no_nulls field does not have bounds"


def test_inclusive_metrics_evaluator_with_type_promotion_crash() -> None:
Comment thread
rambleraptor marked this conversation as resolved.
Outdated
# Schema defines 'id' as LongType (evolved state)
schema = Schema(NestedField(1, "id", LongType(), required=True))

# Historical manifest contains 4-byte integer bounds
data_file = DataFile.from_args(
file_path="file_1.parquet",
file_format=FileFormat.PARQUET,
partition={},
record_count=100,
file_size_in_bytes=1024,
lower_bounds={1: to_bytes(IntegerType(), 30)},
upper_bounds={1: to_bytes(IntegerType(), 79)},
)

# Predicate: id > 100
# Decodes 4-byte bounds correctly and skips the file
evaluator_pruning = _InclusiveMetricsEvaluator(schema, GreaterThan("id", 100))
assert not evaluator_pruning.eval(data_file)


def test_inclusive_metrics_evaluator_with_float_to_double_promotion() -> None:
# Schema defines 'val' as DoubleType (evolved state)
schema = Schema(NestedField(1, "val", DoubleType(), required=True))

# Historical manifest contains 4-byte float bounds
data_file = DataFile.from_args(
file_path="file_1.parquet",
file_format=FileFormat.PARQUET,
partition={},
record_count=100,
file_size_in_bytes=1024,
lower_bounds={1: to_bytes(FloatType(), 30.0)},
upper_bounds={1: to_bytes(FloatType(), 79.0)},
)

# Predicate: val < 50.0
evaluator = _InclusiveMetricsEvaluator(schema, LessThan("val", 50.0))

# Should not crash and should correctly identify that the file might match
assert evaluator.eval(data_file)
Loading