Skip to content

Commit ebda25f

Browse files
committed
Apply on top struct fields
1 parent c06e320 commit ebda25f

5 files changed

Lines changed: 43 additions & 5 deletions

File tree

dev/provision.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@
328328
CREATE TABLE {catalog_name}.default.test_table_empty_list_and_map (
329329
col_list array<int>,
330330
col_map map<int, int>,
331+
col_struct struct<test:int>,
331332
col_list_with_struct array<struct<test:int>>
332333
)
333334
USING iceberg
@@ -340,8 +341,8 @@
340341
spark.sql(
341342
f"""
342343
INSERT INTO {catalog_name}.default.test_table_empty_list_and_map
343-
VALUES (null, null, null),
344-
(array(), map(), array(struct(1)))
344+
VALUES (null, null, null, null),
345+
(array(), map(), struct(1), array(struct(1)))
345346
"""
346347
)
347348

pyiceberg/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,7 @@ class _BuildPositionAccessors(SchemaVisitor[Dict[Position, Accessor]]):
11991199
... 1: Accessor(position=1, inner=None),
12001200
... 5: Accessor(position=2, inner=Accessor(position=0, inner=None)),
12011201
... 6: Accessor(position=2, inner=Accessor(position=1, inner=None))
1202+
... 3: Accessor(position=2, inner=None),
12021203
... }
12031204
>>> result == expected
12041205
True
@@ -1214,8 +1215,7 @@ def struct(self, struct: StructType, field_results: List[Dict[Position, Accessor
12141215
if field_results[position]:
12151216
for inner_field_id, acc in field_results[position].items():
12161217
result[inner_field_id] = Accessor(position, inner=acc)
1217-
else:
1218-
result[field.field_id] = Accessor(position)
1218+
result[field.field_id] = Accessor(position)
12191219

12201220
return result
12211221

tests/expressions/test_expressions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,23 @@ def test_notnull_bind_required() -> None:
168168
assert NotNull(Reference("a")).bind(schema) == AlwaysTrue()
169169

170170

171+
def test_notnull_bind_top_struct() -> None:
172+
schema = Schema(
173+
NestedField(
174+
3,
175+
"struct_col",
176+
required=False,
177+
field_type=StructType(
178+
NestedField(1, "id", IntegerType(), required=True),
179+
NestedField(2, "cost", DecimalType(38, 18), required=False),
180+
),
181+
),
182+
schema_id=1,
183+
)
184+
bound = BoundNotNull(BoundReference(schema.find_field(3), schema.accessor_for_field(3)))
185+
assert NotNull(Reference("struct_col")).bind(schema) == bound
186+
187+
171188
def test_isnan_inverse() -> None:
172189
assert ~IsNaN(Reference("f")) == NotNaN(Reference("f"))
173190

tests/integration/test_reads.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
LessThan,
4242
NotEqualTo,
4343
NotNaN,
44+
NotNull,
4445
)
4546
from pyiceberg.io import PYARROW_USE_LARGE_TYPES_ON_READ
4647
from pyiceberg.io.pyarrow import (
@@ -668,6 +669,24 @@ def test_filter_case_insensitive(catalog: Catalog) -> None:
668669
assert arrow_table["b"].to_pylist() == ["2"]
669670

670671

672+
@pytest.mark.integration
673+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
674+
def test_filters_on_top_level_struct(catalog: Catalog) -> None:
675+
test_empty_struct = catalog.load_table("default.test_table_empty_list_and_map")
676+
677+
arrow_table = test_empty_struct.scan().to_arrow()
678+
assert None in arrow_table["col_struct"].to_pylist()
679+
680+
arrow_table = test_empty_struct.scan(row_filter=NotNull("col_struct")).to_arrow()
681+
assert arrow_table["col_struct"].to_pylist() == [{"test": 1}]
682+
683+
arrow_table = test_empty_struct.scan(row_filter="col_struct is not null", case_sensitive=False).to_arrow()
684+
assert arrow_table["col_struct"].to_pylist() == [{"test": 1}]
685+
686+
arrow_table = test_empty_struct.scan(row_filter="COL_STRUCT is null", case_sensitive=False).to_arrow()
687+
assert arrow_table["col_struct"].to_pylist() == [None]
688+
689+
671690
@pytest.mark.integration
672691
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
673692
def test_upgrade_table_version(catalog: Catalog) -> None:

tests/test_schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ def test_build_position_accessors(table_schema_nested: Schema) -> None:
398398
4: Accessor(position=3, inner=None),
399399
6: Accessor(position=4, inner=None),
400400
11: Accessor(position=5, inner=None),
401+
15: Accessor(position=6, inner=None),
401402
16: Accessor(position=6, inner=Accessor(position=0, inner=None)),
402403
17: Accessor(position=6, inner=Accessor(position=1, inner=None)),
403404
}
@@ -925,7 +926,7 @@ def primitive_fields() -> List[NestedField]:
925926
]
926927

927928

928-
def test_add_top_level_primitives(primitive_fields: NestedField) -> None:
929+
def test_add_top_level_primitives(primitive_fields: List[NestedField]) -> None:
929930
for primitive_field in primitive_fields:
930931
new_schema = Schema(primitive_field)
931932
applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore

0 commit comments

Comments
 (0)