Skip to content

Commit 720708a

Browse files
authored
Fix walrus truthiness on metrics bounds and identity-partition projection (#3412)
Inspired by the walrus issue in #3353 ## Summary Several `if x := dict.get(k):` checks treated legitimate falsy values as missing: - `lower_bounds.get(field_id)` / `upper_bounds.get(field_id)` return `bytes`. `b""` is a valid serialization of an empty string and was being skipped, causing metrics-based row filtering to return `ROWS_MIGHT_MATCH` when it should have been `ROWS_CANNOT_MATCH` (and vice versa for strict eval). - `accessors[...].get(file.partition)` can return `0` or `""` for an `IdentityTransform` partition. The walrus dropped it, so projected partition columns were filled with `null` instead of the actual value. - `inspect.py` `lower_bound` / `upper_bound` rendering had the same `b""` issue, showing `None` instead of `""` in `readable_metrics`. All conditions are switched to explicit `is not None` checks. A small `_readable_bound` helper deduplicates the inspect rendering. ## Changes - `pyiceberg/expressions/visitors.py` — `_InclusiveMetricsEvaluator` and `_StrictMetricsEvaluator` bound lookups (21 sites). - `pyiceberg/io/pyarrow.py` — `_get_column_projection_values` and `ArrowProjectionVisitor` missing-field handling. - `pyiceberg/table/inspect.py` — extract `_readable_bound`, use it in both the `entries` and `_get_files_from_manifest` rendering paths. ## Tests - `tests/expressions/test_evaluator.py` — inclusive and strict evaluators with empty-string bounds, covering lower-only, upper-only, and both-bounds branches. - `tests/io/test_pyarrow.py` — parametrized identity-transform projection with falsy partition values (`0`, `""`, `None`). - `tests/table/test_inspect.py` — `_readable_bound` helper plus integration tests via `tbl.inspect.entries()` and `tbl.inspect.files()` for empty-string and null bounds.
1 parent 55887b4 commit 720708a

6 files changed

Lines changed: 252 additions & 35 deletions

File tree

pyiceberg/expressions/visitors.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,7 +1241,8 @@ def visit_less_than(self, term: BoundTerm, literal: LiteralValue) -> bool:
12411241
if not isinstance(field.field_type, PrimitiveType):
12421242
raise ValueError(f"Expected PrimitiveType: {field.field_type}")
12431243

1244-
if lower_bound_bytes := self.lower_bounds.get(field_id):
1244+
lower_bound_bytes = self.lower_bounds.get(field_id)
1245+
if lower_bound_bytes is not None:
12451246
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
12461247

12471248
if self._is_nan(lower_bound):
@@ -1263,7 +1264,8 @@ def visit_less_than_or_equal(self, term: BoundTerm, literal: LiteralValue) -> bo
12631264
if not isinstance(field.field_type, PrimitiveType):
12641265
raise ValueError(f"Expected PrimitiveType: {field.field_type}")
12651266

1266-
if lower_bound_bytes := self.lower_bounds.get(field_id):
1267+
lower_bound_bytes = self.lower_bounds.get(field_id)
1268+
if lower_bound_bytes is not None:
12671269
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
12681270
if self._is_nan(lower_bound):
12691271
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
@@ -1284,7 +1286,8 @@ def visit_greater_than(self, term: BoundTerm, literal: LiteralValue) -> bool:
12841286
if not isinstance(field.field_type, PrimitiveType):
12851287
raise ValueError(f"Expected PrimitiveType: {field.field_type}")
12861288

1287-
if upper_bound_bytes := self.upper_bounds.get(field_id):
1289+
upper_bound_bytes = self.upper_bounds.get(field_id)
1290+
if upper_bound_bytes is not None:
12881291
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
12891292
if upper_bound <= literal.value:
12901293
if self._is_nan(upper_bound):
@@ -1305,7 +1308,8 @@ def visit_greater_than_or_equal(self, term: BoundTerm, literal: LiteralValue) ->
13051308
if not isinstance(field.field_type, PrimitiveType):
13061309
raise ValueError(f"Expected PrimitiveType: {field.field_type}")
13071310

1308-
if upper_bound_bytes := self.upper_bounds.get(field_id):
1311+
upper_bound_bytes = self.upper_bounds.get(field_id)
1312+
if upper_bound_bytes is not None:
13091313
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
13101314
if upper_bound < literal.value:
13111315
if self._is_nan(upper_bound):
@@ -1326,7 +1330,8 @@ def visit_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
13261330
if not isinstance(field.field_type, PrimitiveType):
13271331
raise ValueError(f"Expected PrimitiveType: {field.field_type}")
13281332

1329-
if lower_bound_bytes := self.lower_bounds.get(field_id):
1333+
lower_bound_bytes = self.lower_bounds.get(field_id)
1334+
if lower_bound_bytes is not None:
13301335
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
13311336
if self._is_nan(lower_bound):
13321337
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
@@ -1335,7 +1340,8 @@ def visit_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
13351340
if lower_bound > literal.value:
13361341
return ROWS_CANNOT_MATCH
13371342

1338-
if upper_bound_bytes := self.upper_bounds.get(field_id):
1343+
upper_bound_bytes = self.upper_bounds.get(field_id)
1344+
if upper_bound_bytes is not None:
13391345
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
13401346
if self._is_nan(upper_bound):
13411347
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
@@ -1363,7 +1369,8 @@ def visit_in(self, term: BoundTerm, literals: set[L]) -> bool:
13631369
if not isinstance(field.field_type, PrimitiveType):
13641370
raise ValueError(f"Expected PrimitiveType: {field.field_type}")
13651371

1366-
if lower_bound_bytes := self.lower_bounds.get(field_id):
1372+
lower_bound_bytes = self.lower_bounds.get(field_id)
1373+
if lower_bound_bytes is not None:
13671374
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
13681375
if self._is_nan(lower_bound):
13691376
# NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more.
@@ -1373,7 +1380,8 @@ def visit_in(self, term: BoundTerm, literals: set[L]) -> bool:
13731380
if len(literals) == 0:
13741381
return ROWS_CANNOT_MATCH
13751382

1376-
if upper_bound_bytes := self.upper_bounds.get(field_id):
1383+
upper_bound_bytes = self.upper_bounds.get(field_id)
1384+
if upper_bound_bytes is not None:
13771385
upper_bound = from_bytes(field.field_type, upper_bound_bytes)
13781386
# this is different from Java, here NaN is always larger
13791387
if self._is_nan(upper_bound):
@@ -1403,14 +1411,16 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
14031411
prefix = str(literal.value)
14041412
len_prefix = len(prefix)
14051413

1406-
if lower_bound_bytes := self.lower_bounds.get(field_id):
1414+
lower_bound_bytes = self.lower_bounds.get(field_id)
1415+
if lower_bound_bytes is not None:
14071416
lower_bound = str(from_bytes(field.field_type, lower_bound_bytes))
14081417

14091418
# truncate lower bound so that its length is not greater than the length of prefix
14101419
if lower_bound and lower_bound[:len_prefix] > prefix:
14111420
return ROWS_CANNOT_MATCH
14121421

1413-
if upper_bound_bytes := self.upper_bounds.get(field_id):
1422+
upper_bound_bytes = self.upper_bounds.get(field_id)
1423+
if upper_bound_bytes is not None:
14141424
upper_bound = str(from_bytes(field.field_type, upper_bound_bytes))
14151425

14161426
# truncate upper bound so that its length is not greater than the length of prefix
@@ -1434,7 +1444,9 @@ def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
14341444

14351445
# not_starts_with will match unless all values must start with the prefix. This happens when
14361446
# the lower and upper bounds both start with the prefix.
1437-
if (lower_bound_bytes := self.lower_bounds.get(field_id)) and (upper_bound_bytes := self.upper_bounds.get(field_id)):
1447+
lower_bound_bytes = self.lower_bounds.get(field_id)
1448+
upper_bound_bytes = self.upper_bounds.get(field_id)
1449+
if lower_bound_bytes is not None and upper_bound_bytes is not None:
14381450
lower_bound = str(from_bytes(field.field_type, lower_bound_bytes))
14391451
upper_bound = str(from_bytes(field.field_type, upper_bound_bytes))
14401452

@@ -1558,7 +1570,8 @@ def visit_less_than(self, term: BoundTerm, literal: LiteralValue) -> bool:
15581570
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
15591571
return ROWS_MIGHT_NOT_MATCH
15601572

1561-
if upper_bytes := self.upper_bounds.get(field_id):
1573+
upper_bytes = self.upper_bounds.get(field_id)
1574+
if upper_bytes is not None:
15621575
field = self._get_field(field_id)
15631576
upper = _from_byte_buffer(field.field_type, upper_bytes)
15641577

@@ -1575,7 +1588,8 @@ def visit_less_than_or_equal(self, term: BoundTerm, literal: LiteralValue) -> bo
15751588
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
15761589
return ROWS_MIGHT_NOT_MATCH
15771590

1578-
if upper_bytes := self.upper_bounds.get(field_id):
1591+
upper_bytes = self.upper_bounds.get(field_id)
1592+
if upper_bytes is not None:
15791593
field = self._get_field(field_id)
15801594
upper = _from_byte_buffer(field.field_type, upper_bytes)
15811595

@@ -1592,7 +1606,8 @@ def visit_greater_than(self, term: BoundTerm, literal: LiteralValue) -> bool:
15921606
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
15931607
return ROWS_MIGHT_NOT_MATCH
15941608

1595-
if lower_bytes := self.lower_bounds.get(field_id):
1609+
lower_bytes = self.lower_bounds.get(field_id)
1610+
if lower_bytes is not None:
15961611
field = self._get_field(field_id)
15971612
lower = _from_byte_buffer(field.field_type, lower_bytes)
15981613

@@ -1613,7 +1628,8 @@ def visit_greater_than_or_equal(self, term: BoundTerm, literal: LiteralValue) ->
16131628
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
16141629
return ROWS_MIGHT_NOT_MATCH
16151630

1616-
if lower_bytes := self.lower_bounds.get(field_id):
1631+
lower_bytes = self.lower_bounds.get(field_id)
1632+
if lower_bytes is not None:
16171633
field = self._get_field(field_id)
16181634
lower = _from_byte_buffer(field.field_type, lower_bytes)
16191635

@@ -1634,7 +1650,9 @@ def visit_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
16341650
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
16351651
return ROWS_MIGHT_NOT_MATCH
16361652

1637-
if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)):
1653+
lower_bytes = self.lower_bounds.get(field_id)
1654+
upper_bytes = self.upper_bounds.get(field_id)
1655+
if lower_bytes is not None and upper_bytes is not None:
16381656
field = self._get_field(field_id)
16391657
lower = _from_byte_buffer(field.field_type, lower_bytes)
16401658
upper = _from_byte_buffer(field.field_type, upper_bytes)
@@ -1655,7 +1673,8 @@ def visit_not_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
16551673

16561674
field = self._get_field(field_id)
16571675

1658-
if lower_bytes := self.lower_bounds.get(field_id):
1676+
lower_bytes = self.lower_bounds.get(field_id)
1677+
if lower_bytes is not None:
16591678
lower = _from_byte_buffer(field.field_type, lower_bytes)
16601679

16611680
if self._is_nan(lower):
@@ -1666,7 +1685,8 @@ def visit_not_equal(self, term: BoundTerm, literal: LiteralValue) -> bool:
16661685
if lower > literal.value:
16671686
return ROWS_MUST_MATCH
16681687

1669-
if upper_bytes := self.upper_bounds.get(field_id):
1688+
upper_bytes = self.upper_bounds.get(field_id)
1689+
if upper_bytes is not None:
16701690
upper = _from_byte_buffer(field.field_type, upper_bytes)
16711691

16721692
if upper < literal.value:
@@ -1682,7 +1702,9 @@ def visit_in(self, term: BoundTerm, literals: set[L]) -> bool:
16821702

16831703
field = self._get_field(field_id)
16841704

1685-
if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)):
1705+
lower_bytes = self.lower_bounds.get(field_id)
1706+
upper_bytes = self.upper_bounds.get(field_id)
1707+
if lower_bytes is not None and upper_bytes is not None:
16861708
# similar to the implementation in eq, first check if the lower bound is in the set
16871709
lower = _from_byte_buffer(field.field_type, lower_bytes)
16881710
if lower not in literals:
@@ -1711,7 +1733,8 @@ def visit_not_in(self, term: BoundTerm, literals: set[L]) -> bool:
17111733

17121734
field = self._get_field(field_id)
17131735

1714-
if lower_bytes := self.lower_bounds.get(field_id):
1736+
lower_bytes = self.lower_bounds.get(field_id)
1737+
if lower_bytes is not None:
17151738
lower = _from_byte_buffer(field.field_type, lower_bytes)
17161739

17171740
if self._is_nan(lower):
@@ -1723,7 +1746,8 @@ def visit_not_in(self, term: BoundTerm, literals: set[L]) -> bool:
17231746
if len(literals) == 0:
17241747
return ROWS_MUST_MATCH
17251748

1726-
if upper_bytes := self.upper_bounds.get(field_id):
1749+
upper_bytes = self.upper_bounds.get(field_id)
1750+
if upper_bytes is not None:
17271751
upper = _from_byte_buffer(field.field_type, upper_bytes)
17281752

17291753
literals = {val for val in literals if upper >= val}

pyiceberg/io/pyarrow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,7 +1605,8 @@ def _get_column_projection_values(
16051605
for field_id in project_schema_diff:
16061606
for partition_field in partition_spec.fields_by_source_id(field_id):
16071607
if isinstance(partition_field.transform, IdentityTransform):
1608-
if partition_value := accessors[partition_field.field_id].get(file.partition):
1608+
partition_value = accessors[partition_field.field_id].get(file.partition)
1609+
if partition_value is not None:
16091610
projected_missing_fields[field_id] = partition_value
16101611

16111612
return projected_missing_fields
@@ -2010,7 +2011,8 @@ def struct(
20102011
elif field.optional or field.initial_default is not None:
20112012
# When an optional field is added, or when a required field with a non-null initial default is added
20122013
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)
2013-
if projected_value := self._projected_missing_fields.get(field.field_id):
2014+
projected_value = self._projected_missing_fields.get(field.field_id)
2015+
if projected_value is not None:
20142016
field_arrays.append(pa.repeat(pa.scalar(projected_value, type=arrow_type), len(struct_array)))
20152017
elif field.initial_default is None:
20162018
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))

pyiceberg/table/inspect.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
ALWAYS_TRUE = AlwaysTrue()
3939

4040

41+
def _readable_bound(field_type: PrimitiveType, bound: bytes | None) -> Any | None:
42+
return from_bytes(field_type, bound) if bound is not None else None
43+
44+
4145
class InspectTable:
4246
tbl: Table
4347

@@ -180,12 +184,8 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
180184
"null_value_count": null_value_counts.get(field.field_id),
181185
"nan_value_count": nan_value_counts.get(field.field_id),
182186
# Makes them readable
183-
"lower_bound": from_bytes(field.field_type, lower_bound)
184-
if (lower_bound := lower_bounds.get(field.field_id))
185-
else None,
186-
"upper_bound": from_bytes(field.field_type, upper_bound)
187-
if (upper_bound := upper_bounds.get(field.field_id))
188-
else None,
187+
"lower_bound": _readable_bound(field.field_type, lower_bounds.get(field.field_id)),
188+
"upper_bound": _readable_bound(field.field_type, upper_bounds.get(field.field_id)),
189189
}
190190
for field in self.tbl.metadata.schema().fields
191191
}
@@ -570,12 +570,8 @@ def _get_files_from_manifest(
570570
"value_count": value_counts.get(field.field_id),
571571
"null_value_count": null_value_counts.get(field.field_id),
572572
"nan_value_count": nan_value_counts.get(field.field_id),
573-
"lower_bound": from_bytes(field.field_type, lower_bound)
574-
if (lower_bound := lower_bounds.get(field.field_id))
575-
else None,
576-
"upper_bound": from_bytes(field.field_type, upper_bound)
577-
if (upper_bound := upper_bounds.get(field.field_id))
578-
else None,
573+
"lower_bound": _readable_bound(field.field_type, lower_bounds.get(field.field_id)),
574+
"upper_bound": _readable_bound(field.field_type, upper_bounds.get(field.field_id)),
579575
}
580576
for field in self.tbl.metadata.schema().fields
581577
}

tests/expressions/test_evaluator.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,62 @@ def test_string_starts_with(
900900
# assert not should_read, "Should not read: range doesn't match"
901901

902902

903+
def test_inclusive_metrics_evaluator_uses_empty_byte_lower_bound() -> None:
904+
schema = Schema(NestedField(1, "empty_string", StringType(), required=True))
905+
data_file = DataFile.from_args(
906+
file_path="file.parquet",
907+
file_format=FileFormat.PARQUET,
908+
partition={},
909+
record_count=10,
910+
file_size_in_bytes=1,
911+
value_counts={1: 10},
912+
null_value_counts={1: 0},
913+
nan_value_counts=None,
914+
lower_bounds={1: to_bytes(StringType(), "")},
915+
upper_bounds={1: to_bytes(StringType(), "")},
916+
)
917+
918+
# Lower-bound branch: LessThan reads lower_bound only.
919+
should_read = _InclusiveMetricsEvaluator(schema, LessThan("empty_string", "")).eval(data_file)
920+
assert not should_read, "Should not read: lower bound is present and equal to the literal"
921+
922+
# Upper-bound branch: GreaterThan reads upper_bound only.
923+
should_read = _InclusiveMetricsEvaluator(schema, GreaterThan("empty_string", "abc")).eval(data_file)
924+
assert not should_read, "Should not read: upper bound '' is not greater than 'abc'"
925+
926+
# Both-bounds branch: EqualTo reads lower_bound and upper_bound.
927+
should_read = _InclusiveMetricsEvaluator(schema, EqualTo("empty_string", "abc")).eval(data_file)
928+
assert not should_read, "Should not read: 'abc' falls outside ['', '']"
929+
930+
931+
def test_strict_metrics_evaluator_uses_empty_byte_bounds() -> None:
932+
schema = Schema(NestedField(1, "empty_string", StringType(), required=True))
933+
data_file = DataFile.from_args(
934+
file_path="file.parquet",
935+
file_format=FileFormat.PARQUET,
936+
partition={},
937+
record_count=10,
938+
file_size_in_bytes=1,
939+
value_counts={1: 10},
940+
null_value_counts={1: 0},
941+
nan_value_counts=None,
942+
lower_bounds={1: to_bytes(StringType(), "")},
943+
upper_bounds={1: to_bytes(StringType(), "")},
944+
)
945+
946+
# Both-bounds branch: EqualTo reads lower_bound and upper_bound.
947+
should_read = _StrictMetricsEvaluator(schema, EqualTo("empty_string", "")).eval(data_file)
948+
assert should_read, "Should match: lower and upper bounds are present and equal to the literal"
949+
950+
# Upper-bound branch: LessThan reads upper_bound only.
951+
should_read = _StrictMetricsEvaluator(schema, LessThan("empty_string", "a")).eval(data_file)
952+
assert should_read, "Should match: upper bound '' is strictly less than 'a'"
953+
954+
# Both-bounds branch: NotEqualTo reads lower_bound and upper_bound.
955+
should_read = _StrictMetricsEvaluator(schema, NotEqualTo("empty_string", "abc")).eval(data_file)
956+
assert should_read, "Should match: 'abc' falls outside ['', '']"
957+
958+
903959
def test_string_not_starts_with(
904960
schema_data_file: Schema, data_file: DataFile, data_file_2: DataFile, data_file_3: DataFile, data_file_4: DataFile
905961
) -> None:

0 commit comments

Comments
 (0)