Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
DoubleType,
FloatType,
IcebergType,
IntegerType,
LongType,
NestedField,
PrimitiveType,
StructType,
Expand All @@ -73,6 +75,18 @@
T = TypeVar("T")


def _from_bytes_with_promotion(field_type: PrimitiveType, b: bytes) -> Any:
Comment thread
rambleraptor marked this conversation as resolved.
Outdated
Comment thread
rambleraptor marked this conversation as resolved.
Outdated
# Integer, Float, Date are 4 bytes
# Long, Double, Timestamps are 8 bytes
# If we have 4 bytes, we may have to handle type promotion.
if len(b) == 4:
Comment thread
rambleraptor marked this conversation as resolved.
Outdated
if isinstance(field_type, LongType):
return from_bytes(IntegerType(), b)
elif isinstance(field_type, DoubleType):
return from_bytes(FloatType(), b)
return from_bytes(field_type, b)


class BooleanExpressionVisitor(Generic[T], ABC):
@abstractmethod
def visit_true(self) -> T:
Expand Down Expand Up @@ -540,7 +554,7 @@ def visit_or(self, left_result: bool, right_result: bool) -> bool:
def _from_byte_buffer(field_type: IcebergType, val: bytes) -> Any:
if not isinstance(field_type, PrimitiveType):
raise ValueError(f"Expected a PrimitiveType, got: {type(field_type)}")
return from_bytes(field_type, val)
return _from_bytes_with_promotion(field_type, val)


class _ManifestEvalVisitor(BoundBooleanExpressionVisitor[bool]):
Expand Down Expand Up @@ -1242,7 +1256,7 @@ def visit_less_than(self, term: BoundTerm, literal: LiteralValue) -> bool:
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
lower_bound = _from_bytes_with_promotion(field.field_type, lower_bound_bytes)

if self._is_nan(lower_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
Expand All @@ -1264,7 +1278,7 @@ def visit_less_than_or_equal(self, term: BoundTerm, literal: LiteralValue) -> bo
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
lower_bound = _from_bytes_with_promotion(field.field_type, lower_bound_bytes)
if self._is_nan(lower_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
return ROWS_MIGHT_MATCH
Expand All @@ -1285,7 +1299,7 @@ def visit_greater_than(self, term: BoundTerm, literal: LiteralValue) -> bool:
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
upper_bound = _from_bytes_with_promotion(field.field_type, upper_bound_bytes)
if upper_bound <= literal.value:
if self._is_nan(upper_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
Expand All @@ -1306,7 +1320,7 @@ def visit_greater_than_or_equal(self, term: BoundTerm, literal: LiteralValue) ->
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
upper_bound = _from_bytes_with_promotion(field.field_type, upper_bound_bytes)
if upper_bound < literal.value:
if self._is_nan(upper_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
Expand All @@ -1327,7 +1341,7 @@ def visit_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
lower_bound = _from_bytes_with_promotion(field.field_type, lower_bound_bytes)
if self._is_nan(lower_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
return ROWS_MIGHT_MATCH
Expand All @@ -1336,7 +1350,7 @@ def visit_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
return ROWS_CANNOT_MATCH

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
upper_bound = _from_bytes_with_promotion(field.field_type, upper_bound_bytes)
if self._is_nan(upper_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
return ROWS_MIGHT_MATCH
Expand Down Expand Up @@ -1364,22 +1378,22 @@ def visit_in(self, term: BoundTerm, literals: set[L]) -> bool:
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
lower_bound = _from_bytes_with_promotion(field.field_type, lower_bound_bytes)
if self._is_nan(lower_bound):
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
return ROWS_MIGHT_MATCH

literals = {lit for lit in literals if lower_bound <= lit} # type: ignore[operator]
literals = {lit for lit in literals if lower_bound <= lit}
if len(literals) == 0:
return ROWS_CANNOT_MATCH

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
upper_bound = _from_bytes_with_promotion(field.field_type, upper_bound_bytes)
# this is different from Java, here NaN is always larger
if self._is_nan(upper_bound):
return ROWS_MIGHT_MATCH

literals = {lit for lit in literals if upper_bound >= lit} # type: ignore[operator]
literals = {lit for lit in literals if upper_bound >= lit}
if len(literals) == 0:
return ROWS_CANNOT_MATCH

Expand All @@ -1404,14 +1418,14 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
len_prefix = len(prefix)

if lower_bound_bytes := self.lower_bounds.get(field_id):
lower_bound = str(from_bytes(field.field_type, lower_bound_bytes))
lower_bound = str(_from_bytes_with_promotion(field.field_type, lower_bound_bytes))

# truncate lower bound so that its length is not greater than the length of prefix
if lower_bound and lower_bound[:len_prefix] > prefix:
return ROWS_CANNOT_MATCH

if upper_bound_bytes := self.upper_bounds.get(field_id):
upper_bound = str(from_bytes(field.field_type, upper_bound_bytes))
upper_bound = str(_from_bytes_with_promotion(field.field_type, upper_bound_bytes))

# truncate upper bound so that its length is not greater than the length of prefix
if upper_bound is not None and upper_bound[:len_prefix] < prefix:
Expand All @@ -1435,8 +1449,8 @@ def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
# not_starts_with will match unless all values must start with the prefix. This happens when
# the lower and upper bounds both start with the prefix.
if (lower_bound_bytes := self.lower_bounds.get(field_id)) and (upper_bound_bytes := self.upper_bounds.get(field_id)):
lower_bound = str(from_bytes(field.field_type, lower_bound_bytes))
upper_bound = str(from_bytes(field.field_type, upper_bound_bytes))
lower_bound = str(_from_bytes_with_promotion(field.field_type, lower_bound_bytes))
upper_bound = str(_from_bytes_with_promotion(field.field_type, upper_bound_bytes))

# if lower is shorter than the prefix then lower doesn't start with the prefix
if len(lower_bound) < len_prefix:
Expand Down
94 changes: 93 additions & 1 deletion tests/expressions/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,14 @@
Or,
StartsWith,
)
from pyiceberg.expressions.visitors import _InclusiveMetricsEvaluator, _StrictMetricsEvaluator
from pyiceberg.expressions.visitors import (
ROWS_CANNOT_MATCH,
ROWS_MIGHT_MATCH,
ROWS_MIGHT_NOT_MATCH,
ROWS_MUST_MATCH,
_InclusiveMetricsEvaluator,
_StrictMetricsEvaluator,
)
from pyiceberg.manifest import DataFile, FileFormat
from pyiceberg.schema import Schema
from pyiceberg.typedef import Record
Expand All @@ -50,6 +57,7 @@
FloatType,
IcebergType,
IntegerType,
LongType,
NestedField,
PrimitiveType,
StringType,
Expand Down Expand Up @@ -1463,3 +1471,87 @@ def test_strict_integer_not_in(strict_data_file_schema: Schema, strict_data_file

should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotIn("no_nulls", {"abc", "def"})).eval(strict_data_file_1)
assert not should_read, "Should not match: no_nulls field does not have bounds"


@pytest.mark.parametrize(
"file_type, evolved_type, lower_bound, upper_bound, op, lit, expected",
[
# Int -> Long
(IntegerType(), LongType(), 30, 79, GreaterThan, 100, ROWS_CANNOT_MATCH),
(IntegerType(), LongType(), 30, 79, LessThan, 50, ROWS_MIGHT_MATCH),
# Float -> Double
(FloatType(), DoubleType(), 30.0, 79.0, GreaterThan, 100.0, ROWS_CANNOT_MATCH),
(FloatType(), DoubleType(), 30.0, 79.0, LessThan, 50.0, ROWS_MIGHT_MATCH),
],
)
def test_inclusive_metrics_evaluator_with_type_promotion(
Comment thread
rambleraptor marked this conversation as resolved.
Outdated
file_type: PrimitiveType,
evolved_type: PrimitiveType,
lower_bound: Any,
upper_bound: Any,
op: Any,
lit: Any,
expected: bool,
) -> None:
# Schema defines 'col' with evolved state
Comment thread
rambleraptor marked this conversation as resolved.
Outdated
schema = Schema(NestedField(1, "col", evolved_type, required=True))

# Historical manifest contains file_type bounds
data_file = DataFile.from_args(
file_path="file_1.parquet",
file_format=FileFormat.PARQUET,
partition={},
record_count=100,
file_size_in_bytes=1024,
lower_bounds={1: to_bytes(file_type, lower_bound)},
upper_bounds={1: to_bytes(file_type, upper_bound)},
)

# Predicate refers to 'col'
evaluator = _InclusiveMetricsEvaluator(schema, op("col", lit))
assert evaluator.eval(data_file) == expected


@pytest.mark.parametrize(
"file_type, evolved_type, lower_bound, upper_bound, op, lit, expected",
[
# Int -> Long
(IntegerType(), LongType(), 30, 79, GreaterThan, 20, ROWS_MUST_MATCH),
(IntegerType(), LongType(), 30, 79, GreaterThan, 100, ROWS_MIGHT_NOT_MATCH),
(IntegerType(), LongType(), 30, 79, LessThan, 100, ROWS_MUST_MATCH),
(IntegerType(), LongType(), 30, 79, LessThan, 20, ROWS_MIGHT_NOT_MATCH),
# Float -> Double
(FloatType(), DoubleType(), 30.0, 79.0, GreaterThan, 20.0, ROWS_MUST_MATCH),
(FloatType(), DoubleType(), 30.0, 79.0, GreaterThan, 100.0, ROWS_MIGHT_NOT_MATCH),
(FloatType(), DoubleType(), 30.0, 79.0, LessThan, 100.0, ROWS_MUST_MATCH),
(FloatType(), DoubleType(), 30.0, 79.0, LessThan, 20.0, ROWS_MIGHT_NOT_MATCH),
],
)
def test_strict_metrics_evaluator_with_type_promotion(
file_type: PrimitiveType,
evolved_type: PrimitiveType,
lower_bound: Any,
upper_bound: Any,
op: Any,
lit: Any,
expected: bool,
) -> None:
# Schema defines 'col' with evolved state
schema = Schema(NestedField(1, "col", evolved_type, required=True))

# Historical manifest contains file_type bounds
data_file = DataFile.from_args(
file_path="file_1.parquet",
file_format=FileFormat.PARQUET,
partition={},
record_count=100,
file_size_in_bytes=1024,
lower_bounds={1: to_bytes(file_type, lower_bound)},
upper_bounds={1: to_bytes(file_type, upper_bound)},
null_value_counts={1: 0},
nan_value_counts={1: 0},
)

# Predicate refers to 'col'
evaluator = _StrictMetricsEvaluator(schema, op("col", lit))
assert evaluator.eval(data_file) == expected
Loading