Skip to content

Commit 4acf26b

Browse files
committed
Add support for partitioning by nested columns
1 parent dc43940 commit 4acf26b

File tree

2 files changed

+91
-4
lines changed

2 files changed

+91
-4
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2728,9 +2728,11 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
27282728

27292729
for partition, name in zip(spec.fields, partition_fields):
27302730
source_field = schema.find_field(partition.source_id)
2731-
arrow_table = arrow_table.append_column(
2732-
name, partition.transform.pyarrow_transform(source_field.field_type)(arrow_table[source_field.name])
2733-
)
2731+
full_field_name = schema.find_column_name(partition.source_id)
2732+
if full_field_name is None:
2733+
raise ValueError(f"Could not find column name for field ID: {partition.source_id}")
2734+
field_array = _get_field_from_arrow_table(arrow_table, full_field_name)
2735+
arrow_table = arrow_table.append_column(name, partition.transform.pyarrow_transform(source_field.field_type)(field_array))
27342736

27352737
unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])
27362738

@@ -2765,3 +2767,22 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
27652767
)
27662768

27672769
return table_partitions
2770+
2771+
2772+
def _get_field_from_arrow_table(arrow_table: pa.Table, field_path: str) -> pa.Array:
2773+
"""Get a nested field from an Arrow table struct type field using dot notation.
2774+
2775+
Args:
2776+
arrow_table: The Arrow table containing the field
2777+
field_path: Dot-separated field path (e.g., "name" or "bar.baz.timestamp")
2778+
2779+
Returns:
2780+
The unnested field as a PyArrow Array
2781+
"""
2782+
if "." not in field_path:
2783+
return arrow_table[field_path]
2784+
2785+
path_parts = field_path.split(".")
2786+
field_array = arrow_table[path_parts[0]]
2787+
field_array = pc.struct_field(field_array, path_parts[1:]) # type: ignore
2788+
return field_array

tests/io/test_pyarrow.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
from pyiceberg.table import FileScanTask, TableProperties
8585
from pyiceberg.table.metadata import TableMetadataV2
8686
from pyiceberg.table.name_mapping import create_mapping_from_schema
87-
from pyiceberg.transforms import IdentityTransform
87+
from pyiceberg.transforms import HourTransform, IdentityTransform
8888
from pyiceberg.typedef import UTF8, Properties, Record
8989
from pyiceberg.types import (
9090
BinaryType,
@@ -2350,6 +2350,72 @@ def test_partition_for_demo() -> None:
23502350
)
23512351

23522352

2353+
def test_partition_for_nested_field() -> None:
2354+
schema = Schema(
2355+
NestedField(id=1, name="foo", field_type=StringType(), required=True),
2356+
NestedField(
2357+
id=2,
2358+
name="bar",
2359+
field_type=StructType(
2360+
NestedField(id=3, name="baz", field_type=TimestampType(), required=False),
2361+
NestedField(id=4, name="qux", field_type=IntegerType(), required=False),
2362+
),
2363+
required=True,
2364+
),
2365+
)
2366+
2367+
spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=HourTransform(), name="ts"))
2368+
2369+
from datetime import datetime
2370+
2371+
t1 = datetime(2025, 7, 11, 9, 30, 0)
2372+
t2 = datetime(2025, 7, 11, 10, 30, 0)
2373+
2374+
test_data = [
2375+
{"foo": "a", "bar": {"baz": t1, "qux": 1}},
2376+
{"foo": "b", "bar": {"baz": t2, "qux": 2}},
2377+
]
2378+
2379+
arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
2380+
partitions = _determine_partitions(spec, schema, arrow_table)
2381+
partition_values = {p.partition_key.partition[0] for p in partitions}
2382+
2383+
assert partition_values == {486729, 486730}
2384+
2385+
2386+
def test_partition_for_deep_nested_field() -> None:
2387+
schema = Schema(
2388+
NestedField(
2389+
id=1,
2390+
name="foo",
2391+
field_type=StructType(
2392+
NestedField(
2393+
id=2,
2394+
name="bar",
2395+
field_type=StructType(NestedField(id=3, name="baz", field_type=StringType(), required=False)),
2396+
required=True,
2397+
)
2398+
),
2399+
required=True,
2400+
)
2401+
)
2402+
2403+
spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="qux"))
2404+
2405+
test_data = [
2406+
{"foo": {"bar": {"baz": "data-1"}}},
2407+
{"foo": {"bar": {"baz": "data-2"}}},
2408+
{"foo": {"bar": {"baz": "data-1"}}},
2409+
]
2410+
2411+
arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
2412+
partitions = _determine_partitions(spec, schema, arrow_table)
2413+
2414+
assert len(partitions) == 2 # 2 unique partitions
2415+
partition_values = {p.partition_key.partition[0] for p in partitions}
2416+
assert partition_values == {"data-1", "data-2"}
2417+
2418+
23532419
def test_identity_partition_on_multi_columns() -> None:
23542420
test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
23552421
test_schema = Schema(

0 commit comments

Comments
 (0)