Skip to content

Commit e64df3c

Browse files
committed
Implement ParquetFormatModel and wire write_file to use the format API
1 parent 43d1f1f commit e64df3c

3 files changed

Lines changed: 309 additions & 28 deletions

File tree

pyiceberg/io/pyarrow.py

Lines changed: 97 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from enum import Enum
4444
from functools import lru_cache, singledispatch
4545
from typing import (
46+
IO,
4647
TYPE_CHECKING,
4748
Any,
4849
Generic,
@@ -122,6 +123,7 @@
122123
OutputStream,
123124
)
124125
from pyiceberg.io.fileformat import DataFileStatistics as DataFileStatistics
126+
from pyiceberg.io.fileformat import FileFormatFactory, FileFormatModel, FileFormatWriter
125127
from pyiceberg.manifest import (
126128
DataFile,
127129
DataFileContent,
@@ -1884,6 +1886,7 @@ def _to_requested_schema(
18841886
include_field_ids: bool = False,
18851887
projected_missing_fields: dict[int, Any] = EMPTY_DICT,
18861888
allow_timestamp_tz_mismatch: bool = False,
1889+
file_format: FileFormat = FileFormat.PARQUET,
18871890
) -> pa.RecordBatch:
18881891
# We could reuse some of these visitors
18891892
struct_array = visit_with_partner(
@@ -1895,6 +1898,7 @@ def _to_requested_schema(
18951898
include_field_ids,
18961899
projected_missing_fields=projected_missing_fields,
18971900
allow_timestamp_tz_mismatch=allow_timestamp_tz_mismatch,
1901+
file_format=file_format,
18981902
),
18991903
ArrowAccessor(file_schema),
19001904
)
@@ -1907,6 +1911,7 @@ class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, pa.Array | None]
19071911
_downcast_ns_timestamp_to_us: bool
19081912
_projected_missing_fields: dict[int, Any]
19091913
_allow_timestamp_tz_mismatch: bool
1914+
_file_format: FileFormat
19101915

19111916
def __init__(
19121917
self,
@@ -1915,6 +1920,7 @@ def __init__(
19151920
include_field_ids: bool = False,
19161921
projected_missing_fields: dict[int, Any] = EMPTY_DICT,
19171922
allow_timestamp_tz_mismatch: bool = False,
1923+
file_format: FileFormat = FileFormat.PARQUET,
19181924
) -> None:
19191925
self._file_schema = file_schema
19201926
self._include_field_ids = include_field_ids
@@ -1923,6 +1929,7 @@ def __init__(
19231929
# When True, allows projecting timestamptz (UTC) to timestamp (no tz).
19241930
# Allowed for reading (aligns with Spark); disallowed for writing to enforce Iceberg spec's strict typing.
19251931
self._allow_timestamp_tz_mismatch = allow_timestamp_tz_mismatch
1932+
self._file_format = file_format
19261933

19271934
def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
19281935
file_field = self._file_schema.find_field(field.field_id)
@@ -1981,9 +1988,12 @@ def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Fi
19811988
if field.doc:
19821989
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
19831990
if self._include_field_ids:
1984-
# For projection visitor, we don't know the file format, so default to Parquet
1985-
# This is used for schema conversion during reads, not writes
1986-
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)
1991+
if self._file_format == FileFormat.ORC:
1992+
metadata[ORC_FIELD_ID_KEY] = str(field.field_id)
1993+
else:
1994+
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)
1995+
if self._file_format == FileFormat.ORC:
1996+
metadata[ORC_FIELD_REQUIRED_KEY] = str(field.required).lower()
19871997

19881998
return pa.field(
19891999
name=field.name,
@@ -2602,21 +2612,87 @@ def data_file_statistics_from_parquet_metadata(
26022612
)
26032613

26042614

2615+
class ParquetFormatWriter(FileFormatWriter):
2616+
"""Writes Arrow tables to a Parquet file."""
2617+
2618+
def __init__(self, output_file: OutputFile, file_schema: Schema, properties: Properties) -> None:
2619+
self._output_file = output_file
2620+
self._file_schema = file_schema
2621+
self._properties = properties
2622+
self._writer: pq.ParquetWriter | None = None
2623+
self._fos: OutputStream | None = None
2624+
self._parquet_writer_kwargs = _get_parquet_writer_kwargs(properties)
2625+
self._row_group_size = property_as_int(
2626+
properties=properties,
2627+
property_name=TableProperties.PARQUET_ROW_GROUP_LIMIT,
2628+
default=TableProperties.PARQUET_ROW_GROUP_LIMIT_DEFAULT,
2629+
)
2630+
2631+
def write(self, table: pa.Table) -> None:
2632+
if self._writer is None:
2633+
self._fos = self._output_file.create(overwrite=True)
2634+
self._writer = pq.ParquetWriter(
2635+
cast(IO[Any], self._fos),
2636+
schema=table.schema,
2637+
store_decimal_as_integer=True,
2638+
**self._parquet_writer_kwargs,
2639+
)
2640+
self._writer.write(table, row_group_size=self._row_group_size)
2641+
2642+
def close(self) -> DataFileStatistics:
2643+
if self._result is not None:
2644+
return self._result
2645+
try:
2646+
if self._writer is None:
2647+
raise ValueError("Cannot close a writer that was never written to")
2648+
self._writer.close()
2649+
self._result = data_file_statistics_from_parquet_metadata(
2650+
parquet_metadata=self._writer.writer.metadata,
2651+
stats_columns=compute_statistics_plan(self._file_schema, self._properties),
2652+
parquet_column_mapping=parquet_path_to_id_mapping(self._file_schema),
2653+
)
2654+
return self._result
2655+
finally:
2656+
if self._fos is not None:
2657+
self._fos.close()
2658+
2659+
2660+
class ParquetFormatModel(FileFormatModel):
2661+
"""Format model for Apache Parquet."""
2662+
2663+
@property
2664+
def format(self) -> FileFormat:
2665+
return FileFormat.PARQUET
2666+
2667+
def file_extension(self) -> str:
2668+
return "parquet"
2669+
2670+
def create_writer(
2671+
self,
2672+
output_file: OutputFile,
2673+
file_schema: Schema,
2674+
properties: Properties,
2675+
) -> ParquetFormatWriter:
2676+
return ParquetFormatWriter(output_file, file_schema, properties)
2677+
2678+
2679+
FileFormatFactory.register(ParquetFormatModel())
2680+
2681+
26052682
def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
26062683
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties
26072684

2608-
parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)
2609-
row_group_size = property_as_int(
2610-
properties=table_metadata.properties,
2611-
property_name=TableProperties.PARQUET_ROW_GROUP_LIMIT,
2612-
default=TableProperties.PARQUET_ROW_GROUP_LIMIT_DEFAULT,
2685+
file_format = FileFormat(
2686+
table_metadata.properties.get(
2687+
TableProperties.WRITE_FILE_FORMAT,
2688+
TableProperties.WRITE_FILE_FORMAT_DEFAULT,
2689+
)
26132690
)
2691+
format_model = FileFormatFactory.get(file_format)
26142692
location_provider = load_location_provider(table_location=table_metadata.location, table_properties=table_metadata.properties)
26152693

2616-
def write_parquet(task: WriteTask) -> DataFile:
2694+
def write_data_file(task: WriteTask) -> DataFile:
26172695
table_schema = table_metadata.schema()
2618-
# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
2619-
# otherwise use the original schema
26202696
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
26212697
file_schema = sanitized_schema
26222698
else:
@@ -2630,29 +2706,25 @@ def write_parquet(task: WriteTask) -> DataFile:
26302706
batch=batch,
26312707
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
26322708
include_field_ids=True,
2709+
file_format=file_format,
26332710
)
26342711
for batch in task.record_batches
26352712
]
26362713
arrow_table = pa.Table.from_batches(batches)
26372714
file_path = location_provider.new_data_location(
2638-
data_file_name=task.generate_data_file_filename("parquet"),
2715+
data_file_name=task.generate_data_file_filename(format_model.file_extension()),
26392716
partition_key=task.partition_key,
26402717
)
26412718
fo = io.new_output(file_path)
2642-
with fo.create(overwrite=True) as fos:
2643-
with pq.ParquetWriter(
2644-
fos, schema=arrow_table.schema, store_decimal_as_integer=True, **parquet_writer_kwargs
2645-
) as writer:
2646-
writer.write(arrow_table, row_group_size=row_group_size)
2647-
statistics = data_file_statistics_from_parquet_metadata(
2648-
parquet_metadata=writer.writer.metadata,
2649-
stats_columns=compute_statistics_plan(file_schema, table_metadata.properties),
2650-
parquet_column_mapping=parquet_path_to_id_mapping(file_schema),
2651-
)
2652-
data_file = DataFile.from_args(
2719+
writer = format_model.create_writer(fo, file_schema, table_metadata.properties)
2720+
with writer:
2721+
writer.write(arrow_table)
2722+
statistics = writer.result()
2723+
2724+
return DataFile.from_args(
26532725
content=DataFileContent.DATA,
26542726
file_path=file_path,
2655-
file_format=FileFormat.PARQUET,
2727+
file_format=file_format,
26562728
partition=task.partition_key.partition if task.partition_key else Record(),
26572729
file_size_in_bytes=len(fo),
26582730
# After this has been fixed:
@@ -2666,10 +2738,8 @@ def write_parquet(task: WriteTask) -> DataFile:
26662738
**statistics.to_serialized_dict(),
26672739
)
26682740

2669-
return data_file
2670-
26712741
executor = ExecutorFactory.get_or_create()
2672-
data_files = executor.map(write_parquet, tasks)
2742+
data_files = executor.map(write_data_file, tasks)
26732743

26742744
return iter(data_files)
26752745

tests/io/test_format_writers.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Parametrized format writer tests, modeled after Java's BaseFormatModelTests."""
19+
20+
from pathlib import Path
21+
22+
import pyarrow as pa
23+
import pyarrow.dataset as ds
24+
import pytest
25+
26+
from pyiceberg.io.fileformat import FileFormatFactory, FileFormatModel
27+
from pyiceberg.io.pyarrow import PyArrowFileIO
28+
from pyiceberg.manifest import FileFormat
29+
from pyiceberg.schema import Schema
30+
from pyiceberg.types import LongType, NestedField
31+
32+
33+
@pytest.fixture(params=FileFormatFactory.available_formats(), ids=lambda f: f.name.lower())
34+
def format_model(request: pytest.FixtureRequest) -> FileFormatModel:
35+
return FileFormatFactory.get(request.param)
36+
37+
38+
@pytest.fixture
39+
def simple_table() -> pa.Table:
40+
return pa.table(
41+
{
42+
"foo": ["a", "b", "c"],
43+
"bar": pa.array([1, 2, 3], type=pa.int32()),
44+
"baz": [True, False, True],
45+
}
46+
)
47+
48+
49+
def test_parquet_registered() -> None:
50+
"""ParquetFormatModel is registered in the factory."""
51+
model = FileFormatFactory.get(FileFormat.PARQUET)
52+
assert model.format == FileFormat.PARQUET
53+
assert model.file_extension() == "parquet"
54+
55+
56+
def test_round_trip(format_model: FileFormatModel, table_schema_simple: Schema, simple_table: pa.Table, tmp_path: Path) -> None:
57+
"""Write a table and read it back, to verify equality and record count."""
58+
file_path = str(tmp_path / f"test.{format_model.file_extension()}")
59+
writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {})
60+
writer.write(simple_table)
61+
statistics = writer.close()
62+
63+
result = ds.dataset(file_path).to_table()
64+
assert result.equals(simple_table)
65+
assert statistics.record_count == 3
66+
67+
68+
def test_statistics_record_count(format_model: FileFormatModel, table_schema_simple: Schema, tmp_path: Path) -> None:
69+
"""close() returns DataFileStatistics with correct record count."""
70+
table = pa.table(
71+
{
72+
"foo": ["a", "b", "c", "d", "e"],
73+
"bar": pa.array([10, 20, 30, 40, 50], type=pa.int32()),
74+
"baz": [True] * 5,
75+
}
76+
)
77+
file_path = str(tmp_path / f"test.{format_model.file_extension()}")
78+
writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {})
79+
writer.write(table)
80+
assert writer.close().record_count == 5
81+
82+
83+
def test_null_handling(format_model: FileFormatModel, table_schema_simple: Schema, tmp_path: Path) -> None:
84+
"""Nullable columns produce correct null_value_counts in statistics."""
85+
table = pa.table(
86+
{
87+
"foo": ["a", None, "c"], # field_id=1, optional
88+
"bar": pa.array([1, 2, 3], type=pa.int32()), # field_id=2, required
89+
"baz": [True, False, True], # field_id=3, optional
90+
}
91+
)
92+
file_path = str(tmp_path / f"test.{format_model.file_extension()}")
93+
writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {})
94+
writer.write(table)
95+
stats = writer.close()
96+
assert stats.record_count == 3
97+
assert stats.null_value_counts.get(1) == 1
98+
99+
100+
def test_context_manager_caches_result(
101+
format_model: FileFormatModel, table_schema_simple: Schema, simple_table: pa.Table, tmp_path: Path
102+
) -> None:
103+
"""writer.result() returns cached statistics after context manager exit."""
104+
file_path = str(tmp_path / f"test.{format_model.file_extension()}")
105+
writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {})
106+
with writer:
107+
writer.write(simple_table)
108+
assert writer.result().record_count == 3
109+
110+
111+
def test_close_is_idempotent(
112+
format_model: FileFormatModel, table_schema_simple: Schema, simple_table: pa.Table, tmp_path: Path
113+
) -> None:
114+
"""Calling close() twice returns the same cached statistics object."""
115+
file_path = str(tmp_path / f"test.{format_model.file_extension()}")
116+
writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {})
117+
writer.write(simple_table)
118+
stats1 = writer.close()
119+
stats2 = writer.close()
120+
assert stats1 is stats2
121+
122+
123+
def test_close_without_write_raises(format_model: FileFormatModel, table_schema_simple: Schema, tmp_path: Path) -> None:
124+
"""Closing a writer that was never written to raises ValueError."""
125+
file_path = str(tmp_path / f"test.{format_model.file_extension()}")
126+
writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {})
127+
with pytest.raises(ValueError, match="Cannot close a writer that was never written to"):
128+
writer.close()
129+
130+
131+
def test_construct_field_uses_orc_field_id_key() -> None:
132+
"""ArrowProjectionVisitor uses ORC field ID and required keys when file_format is ORC."""
133+
from pyiceberg.io.pyarrow import (
134+
ORC_FIELD_ID_KEY,
135+
ORC_FIELD_REQUIRED_KEY,
136+
PYARROW_PARQUET_FIELD_ID_KEY,
137+
ArrowProjectionVisitor,
138+
)
139+
140+
schema = Schema(NestedField(field_id=1, name="x", field_type=LongType(), required=True))
141+
142+
visitor = ArrowProjectionVisitor(schema, include_field_ids=True, file_format=FileFormat.ORC)
143+
field = visitor._construct_field(schema.find_field(1), pa.int64())
144+
assert field.metadata is not None
145+
assert ORC_FIELD_ID_KEY in field.metadata
146+
assert ORC_FIELD_REQUIRED_KEY in field.metadata
147+
assert field.metadata[ORC_FIELD_REQUIRED_KEY] == b"true"
148+
assert PYARROW_PARQUET_FIELD_ID_KEY not in field.metadata
149+
150+
visitor_pq = ArrowProjectionVisitor(schema, include_field_ids=True, file_format=FileFormat.PARQUET)
151+
field_pq = visitor_pq._construct_field(schema.find_field(1), pa.int64())
152+
assert field_pq.metadata is not None
153+
assert PYARROW_PARQUET_FIELD_ID_KEY in field_pq.metadata
154+
assert ORC_FIELD_ID_KEY not in field_pq.metadata
155+
assert ORC_FIELD_REQUIRED_KEY not in field_pq.metadata

0 commit comments

Comments
 (0)