Skip to content

Commit a2d9ea7

Browse files
committed
PR Comments
1 parent e64df3c commit a2d9ea7

5 files changed

Lines changed: 59 additions & 72 deletions

File tree

pyiceberg/io/fileformat.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pyiceberg.partitioning import PartitionField, PartitionSpec, partition_record_value
2929
from pyiceberg.schema import Schema
3030
from pyiceberg.typedef import Properties, Record
31+
from pyiceberg.types import NestedField
3132

3233
if TYPE_CHECKING:
3334
import pyarrow as pa
@@ -161,6 +162,10 @@ def create_writer(
161162
properties: Properties,
162163
) -> FileFormatWriter: ...
163164

165+
@abstractmethod
166+
def add_field_metadata(self, field: NestedField, metadata: dict[bytes, bytes], include_field_ids: bool) -> None:
167+
"""Add format-specific Arrow field metadata."""
168+
164169

165170
class FileFormatFactory:
166171
"""Registry of FileFormatModel implementations."""

pyiceberg/io/pyarrow.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1886,7 +1886,7 @@ def _to_requested_schema(
18861886
include_field_ids: bool = False,
18871887
projected_missing_fields: dict[int, Any] = EMPTY_DICT,
18881888
allow_timestamp_tz_mismatch: bool = False,
1889-
file_format: FileFormat = FileFormat.PARQUET,
1889+
format_model: FileFormatModel | None = None,
18901890
) -> pa.RecordBatch:
18911891
# We could reuse some of these visitors
18921892
struct_array = visit_with_partner(
@@ -1898,7 +1898,7 @@ def _to_requested_schema(
18981898
include_field_ids,
18991899
projected_missing_fields=projected_missing_fields,
19001900
allow_timestamp_tz_mismatch=allow_timestamp_tz_mismatch,
1901-
file_format=file_format,
1901+
format_model=format_model,
19021902
),
19031903
ArrowAccessor(file_schema),
19041904
)
@@ -1911,7 +1911,7 @@ class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, pa.Array | None]
19111911
_downcast_ns_timestamp_to_us: bool
19121912
_projected_missing_fields: dict[int, Any]
19131913
_allow_timestamp_tz_mismatch: bool
1914-
_file_format: FileFormat
1914+
_format_model: FileFormatModel | None
19151915

19161916
def __init__(
19171917
self,
@@ -1920,16 +1920,18 @@ def __init__(
19201920
include_field_ids: bool = False,
19211921
projected_missing_fields: dict[int, Any] = EMPTY_DICT,
19221922
allow_timestamp_tz_mismatch: bool = False,
1923-
file_format: FileFormat = FileFormat.PARQUET,
1923+
format_model: FileFormatModel | None = None,
19241924
) -> None:
1925+
if include_field_ids and format_model is None:
1926+
raise ValueError("format_model is required when include_field_ids=True")
19251927
self._file_schema = file_schema
19261928
self._include_field_ids = include_field_ids
19271929
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
19281930
self._projected_missing_fields = projected_missing_fields
19291931
# When True, allows projecting timestamptz (UTC) to timestamp (no tz).
19301932
# Allowed for reading (aligns with Spark); disallowed for writing to enforce Iceberg spec's strict typing.
19311933
self._allow_timestamp_tz_mismatch = allow_timestamp_tz_mismatch
1932-
self._file_format = file_format
1934+
self._format_model = format_model
19331935

19341936
def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
19351937
file_field = self._file_schema.find_field(field.field_id)
@@ -1984,16 +1986,11 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
19841986
return values
19851987

19861988
def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
1987-
metadata = {}
1989+
metadata: dict[bytes, bytes] = {}
19881990
if field.doc:
1989-
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
1990-
if self._include_field_ids:
1991-
if self._file_format == FileFormat.ORC:
1992-
metadata[ORC_FIELD_ID_KEY] = str(field.field_id)
1993-
else:
1994-
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)
1995-
if self._file_format == FileFormat.ORC:
1996-
metadata[ORC_FIELD_REQUIRED_KEY] = str(field.required).lower()
1991+
metadata[PYARROW_FIELD_DOC_KEY] = field.doc.encode()
1992+
if self._format_model is not None:
1993+
self._format_model.add_field_metadata(field, metadata, self._include_field_ids)
19971994

19981995
return pa.field(
19991996
name=field.name,
@@ -2675,6 +2672,10 @@ def create_writer(
26752672
) -> ParquetFormatWriter:
26762673
return ParquetFormatWriter(output_file, file_schema, properties)
26772674

2675+
def add_field_metadata(self, field: NestedField, metadata: dict[bytes, bytes], include_field_ids: bool) -> None:
2676+
if include_field_ids:
2677+
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id).encode()
2678+
26782679

26792680
FileFormatFactory.register(ParquetFormatModel())
26802681

@@ -2706,7 +2707,7 @@ def write_data_file(task: WriteTask) -> DataFile:
27062707
batch=batch,
27072708
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
27082709
include_field_ids=True,
2709-
file_format=file_format,
2710+
format_model=format_model,
27102711
)
27112712
for batch in task.record_batches
27122713
]

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ def table_schema_simple() -> Schema:
172172
)
173173

174174

175+
@pytest.fixture(scope="session")
176+
def arrow_table_simple() -> "pa.Table":
177+
"""Pyarrow table that pairs with `table_schema_simple` (3 rows, no nulls)."""
178+
import pyarrow as pa
179+
180+
return pa.table(
181+
{
182+
"foo": ["a", "b", "c"],
183+
"bar": pa.array([1, 2, 3], type=pa.int32()),
184+
"baz": [True, False, True],
185+
}
186+
)
187+
188+
175189
@pytest.fixture(scope="session")
176190
def table_schema_with_full_nested_fields() -> Schema:
177191
return Schema(

tests/io/test_fileformat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def file_extension(self) -> str:
5151
def create_writer(self, output_file: Any, file_schema: Any, properties: Any) -> Any:
5252
raise NotImplementedError
5353

54+
def add_field_metadata(self, field: Any, metadata: Any, include_field_ids: bool) -> None:
55+
pass
56+
5457
original = dict(FileFormatFactory._registry)
5558
try:
5659
model = _DummyModel()

tests/io/test_format_writers.py

Lines changed: 21 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -35,51 +35,27 @@ def format_model(request: pytest.FixtureRequest) -> FileFormatModel:
3535
return FileFormatFactory.get(request.param)
3636

3737

38-
@pytest.fixture
39-
def simple_table() -> pa.Table:
40-
return pa.table(
41-
{
42-
"foo": ["a", "b", "c"],
43-
"bar": pa.array([1, 2, 3], type=pa.int32()),
44-
"baz": [True, False, True],
45-
}
46-
)
47-
48-
4938
def test_parquet_registered() -> None:
5039
"""ParquetFormatModel is registered in the factory."""
5140
model = FileFormatFactory.get(FileFormat.PARQUET)
5241
assert model.format == FileFormat.PARQUET
5342
assert model.file_extension() == "parquet"
5443

5544

56-
def test_round_trip(format_model: FileFormatModel, table_schema_simple: Schema, simple_table: pa.Table, tmp_path: Path) -> None:
45+
def test_round_trip(
46+
format_model: FileFormatModel, table_schema_simple: Schema, arrow_table_simple: pa.Table, tmp_path: Path
47+
) -> None:
5748
"""Write a table and read it back, to verify equality and record count."""
5849
file_path = str(tmp_path / f"test.{format_model.file_extension()}")
5950
writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {})
60-
writer.write(simple_table)
51+
writer.write(arrow_table_simple)
6152
statistics = writer.close()
6253

6354
result = ds.dataset(file_path).to_table()
64-
assert result.equals(simple_table)
55+
assert result.equals(arrow_table_simple)
6556
assert statistics.record_count == 3
6657

6758

68-
def test_statistics_record_count(format_model: FileFormatModel, table_schema_simple: Schema, tmp_path: Path) -> None:
69-
"""close() returns DataFileStatistics with correct record count."""
70-
table = pa.table(
71-
{
72-
"foo": ["a", "b", "c", "d", "e"],
73-
"bar": pa.array([10, 20, 30, 40, 50], type=pa.int32()),
74-
"baz": [True] * 5,
75-
}
76-
)
77-
file_path = str(tmp_path / f"test.{format_model.file_extension()}")
78-
writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {})
79-
writer.write(table)
80-
assert writer.close().record_count == 5
81-
82-
8359
def test_null_handling(format_model: FileFormatModel, table_schema_simple: Schema, tmp_path: Path) -> None:
8460
"""Nullable columns produce correct null_value_counts in statistics."""
8561
table = pa.table(
@@ -98,23 +74,23 @@ def test_null_handling(format_model: FileFormatModel, table_schema_simple: Schem
9874

9975

10076
def test_context_manager_caches_result(
101-
format_model: FileFormatModel, table_schema_simple: Schema, simple_table: pa.Table, tmp_path: Path
77+
format_model: FileFormatModel, table_schema_simple: Schema, arrow_table_simple: pa.Table, tmp_path: Path
10278
) -> None:
10379
"""writer.result() returns cached statistics after context manager exit."""
10480
file_path = str(tmp_path / f"test.{format_model.file_extension()}")
10581
writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {})
10682
with writer:
107-
writer.write(simple_table)
83+
writer.write(arrow_table_simple)
10884
assert writer.result().record_count == 3
10985

11086

11187
def test_close_is_idempotent(
112-
format_model: FileFormatModel, table_schema_simple: Schema, simple_table: pa.Table, tmp_path: Path
88+
format_model: FileFormatModel, table_schema_simple: Schema, arrow_table_simple: pa.Table, tmp_path: Path
11389
) -> None:
11490
"""Calling close() twice returns the same cached statistics object."""
11591
file_path = str(tmp_path / f"test.{format_model.file_extension()}")
11692
writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {})
117-
writer.write(simple_table)
93+
writer.write(arrow_table_simple)
11894
stats1 = writer.close()
11995
stats2 = writer.close()
12096
assert stats1 is stats2
@@ -128,28 +104,16 @@ def test_close_without_write_raises(format_model: FileFormatModel, table_schema_
128104
writer.close()
129105

130106

131-
def test_construct_field_uses_orc_field_id_key() -> None:
132-
"""ArrowProjectionVisitor uses ORC field ID and required keys when file_format is ORC."""
133-
from pyiceberg.io.pyarrow import (
134-
ORC_FIELD_ID_KEY,
135-
ORC_FIELD_REQUIRED_KEY,
136-
PYARROW_PARQUET_FIELD_ID_KEY,
137-
ArrowProjectionVisitor,
138-
)
107+
def test_parquet_format_model_adds_field_id_metadata() -> None:
108+
"""ParquetFormatModel.add_field_metadata writes the Parquet field-id key when requested."""
109+
from pyiceberg.io.pyarrow import PYARROW_PARQUET_FIELD_ID_KEY, ParquetFormatModel
110+
111+
field = NestedField(field_id=1, name="x", field_type=LongType(), required=True)
112+
113+
metadata: dict[bytes, bytes] = {}
114+
ParquetFormatModel().add_field_metadata(field, metadata, include_field_ids=True)
115+
assert metadata == {PYARROW_PARQUET_FIELD_ID_KEY: b"1"}
139116

140-
schema = Schema(NestedField(field_id=1, name="x", field_type=LongType(), required=True))
141-
142-
visitor = ArrowProjectionVisitor(schema, include_field_ids=True, file_format=FileFormat.ORC)
143-
field = visitor._construct_field(schema.find_field(1), pa.int64())
144-
assert field.metadata is not None
145-
assert ORC_FIELD_ID_KEY in field.metadata
146-
assert ORC_FIELD_REQUIRED_KEY in field.metadata
147-
assert field.metadata[ORC_FIELD_REQUIRED_KEY] == b"true"
148-
assert PYARROW_PARQUET_FIELD_ID_KEY not in field.metadata
149-
150-
visitor_pq = ArrowProjectionVisitor(schema, include_field_ids=True, file_format=FileFormat.PARQUET)
151-
field_pq = visitor_pq._construct_field(schema.find_field(1), pa.int64())
152-
assert field_pq.metadata is not None
153-
assert PYARROW_PARQUET_FIELD_ID_KEY in field_pq.metadata
154-
assert ORC_FIELD_ID_KEY not in field_pq.metadata
155-
assert ORC_FIELD_REQUIRED_KEY not in field_pq.metadata
117+
metadata_no_ids: dict[bytes, bytes] = {}
118+
ParquetFormatModel().add_field_metadata(field, metadata_no_ids, include_field_ids=False)
119+
assert metadata_no_ids == {}

0 commit comments

Comments
 (0)