4343from enum import Enum
4444from functools import lru_cache , singledispatch
4545from typing import (
46+ IO ,
4647 TYPE_CHECKING ,
4748 Any ,
4849 Generic ,
122123 OutputStream ,
123124)
124125from pyiceberg .io .fileformat import DataFileStatistics as DataFileStatistics
126+ from pyiceberg .io .fileformat import FileFormatFactory , FileFormatModel , FileFormatWriter
125127from pyiceberg .manifest import (
126128 DataFile ,
127129 DataFileContent ,
@@ -1884,6 +1886,7 @@ def _to_requested_schema(
18841886 include_field_ids : bool = False ,
18851887 projected_missing_fields : dict [int , Any ] = EMPTY_DICT ,
18861888 allow_timestamp_tz_mismatch : bool = False ,
1889+ file_format : FileFormat = FileFormat .PARQUET ,
18871890) -> pa .RecordBatch :
18881891 # We could reuse some of these visitors
18891892 struct_array = visit_with_partner (
@@ -1895,6 +1898,7 @@ def _to_requested_schema(
18951898 include_field_ids ,
18961899 projected_missing_fields = projected_missing_fields ,
18971900 allow_timestamp_tz_mismatch = allow_timestamp_tz_mismatch ,
1901+ file_format = file_format ,
18981902 ),
18991903 ArrowAccessor (file_schema ),
19001904 )
@@ -1907,6 +1911,7 @@ class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, pa.Array | None]
19071911 _downcast_ns_timestamp_to_us : bool
19081912 _projected_missing_fields : dict [int , Any ]
19091913 _allow_timestamp_tz_mismatch : bool
1914+ _file_format : FileFormat
19101915
19111916 def __init__ (
19121917 self ,
@@ -1915,6 +1920,7 @@ def __init__(
19151920 include_field_ids : bool = False ,
19161921 projected_missing_fields : dict [int , Any ] = EMPTY_DICT ,
19171922 allow_timestamp_tz_mismatch : bool = False ,
1923+ file_format : FileFormat = FileFormat .PARQUET ,
19181924 ) -> None :
19191925 self ._file_schema = file_schema
19201926 self ._include_field_ids = include_field_ids
@@ -1923,6 +1929,7 @@ def __init__(
19231929 # When True, allows projecting timestamptz (UTC) to timestamp (no tz).
19241930 # Allowed for reading (aligns with Spark); disallowed for writing to enforce Iceberg spec's strict typing.
19251931 self ._allow_timestamp_tz_mismatch = allow_timestamp_tz_mismatch
1932+ self ._file_format = file_format
19261933
19271934 def _cast_if_needed (self , field : NestedField , values : pa .Array ) -> pa .Array :
19281935 file_field = self ._file_schema .find_field (field .field_id )
@@ -1981,9 +1988,12 @@ def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Fi
19811988 if field .doc :
19821989 metadata [PYARROW_FIELD_DOC_KEY ] = field .doc
19831990 if self ._include_field_ids :
1984- # For projection visitor, we don't know the file format, so default to Parquet
1985- # This is used for schema conversion during reads, not writes
1986- metadata [PYARROW_PARQUET_FIELD_ID_KEY ] = str (field .field_id )
1991+ if self ._file_format == FileFormat .ORC :
1992+ metadata [ORC_FIELD_ID_KEY ] = str (field .field_id )
1993+ else :
1994+ metadata [PYARROW_PARQUET_FIELD_ID_KEY ] = str (field .field_id )
1995+ if self ._file_format == FileFormat .ORC :
1996+ metadata [ORC_FIELD_REQUIRED_KEY ] = str (field .required ).lower ()
19871997
19881998 return pa .field (
19891999 name = field .name ,
@@ -2602,21 +2612,87 @@ def data_file_statistics_from_parquet_metadata(
26022612 )
26032613
26042614
2615+ class ParquetFormatWriter (FileFormatWriter ):
2616+ """Writes Arrow tables to a Parquet file."""
2617+
2618+ def __init__ (self , output_file : OutputFile , file_schema : Schema , properties : Properties ) -> None :
2619+ self ._output_file = output_file
2620+ self ._file_schema = file_schema
2621+ self ._properties = properties
2622+ self ._writer : pq .ParquetWriter | None = None
2623+ self ._fos : OutputStream | None = None
2624+ self ._parquet_writer_kwargs = _get_parquet_writer_kwargs (properties )
2625+ self ._row_group_size = property_as_int (
2626+ properties = properties ,
2627+ property_name = TableProperties .PARQUET_ROW_GROUP_LIMIT ,
2628+ default = TableProperties .PARQUET_ROW_GROUP_LIMIT_DEFAULT ,
2629+ )
2630+
2631+ def write (self , table : pa .Table ) -> None :
2632+ if self ._writer is None :
2633+ self ._fos = self ._output_file .create (overwrite = True )
2634+ self ._writer = pq .ParquetWriter (
2635+ cast (IO [Any ], self ._fos ),
2636+ schema = table .schema ,
2637+ store_decimal_as_integer = True ,
2638+ ** self ._parquet_writer_kwargs ,
2639+ )
2640+ self ._writer .write (table , row_group_size = self ._row_group_size )
2641+
2642+ def close (self ) -> DataFileStatistics :
2643+ if self ._result is not None :
2644+ return self ._result
2645+ try :
2646+ if self ._writer is None :
2647+ raise ValueError ("Cannot close a writer that was never written to" )
2648+ self ._writer .close ()
2649+ self ._result = data_file_statistics_from_parquet_metadata (
2650+ parquet_metadata = self ._writer .writer .metadata ,
2651+ stats_columns = compute_statistics_plan (self ._file_schema , self ._properties ),
2652+ parquet_column_mapping = parquet_path_to_id_mapping (self ._file_schema ),
2653+ )
2654+ return self ._result
2655+ finally :
2656+ if self ._fos is not None :
2657+ self ._fos .close ()
2658+
2659+
2660+ class ParquetFormatModel (FileFormatModel ):
2661+ """Format model for Apache Parquet."""
2662+
2663+ @property
2664+ def format (self ) -> FileFormat :
2665+ return FileFormat .PARQUET
2666+
2667+ def file_extension (self ) -> str :
2668+ return "parquet"
2669+
2670+ def create_writer (
2671+ self ,
2672+ output_file : OutputFile ,
2673+ file_schema : Schema ,
2674+ properties : Properties ,
2675+ ) -> ParquetFormatWriter :
2676+ return ParquetFormatWriter (output_file , file_schema , properties )
2677+
2678+
2679+ FileFormatFactory .register (ParquetFormatModel ())
2680+
2681+
26052682def write_file (io : FileIO , table_metadata : TableMetadata , tasks : Iterator [WriteTask ]) -> Iterator [DataFile ]:
26062683 from pyiceberg .table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE , TableProperties
26072684
2608- parquet_writer_kwargs = _get_parquet_writer_kwargs ( table_metadata . properties )
2609- row_group_size = property_as_int (
2610- properties = table_metadata . properties ,
2611- property_name = TableProperties .PARQUET_ROW_GROUP_LIMIT ,
2612- default = TableProperties . PARQUET_ROW_GROUP_LIMIT_DEFAULT ,
2685+ file_format = FileFormat (
2686+ table_metadata . properties . get (
2687+ TableProperties . WRITE_FILE_FORMAT ,
2688+ TableProperties .WRITE_FILE_FORMAT_DEFAULT ,
2689+ )
26132690 )
2691+ format_model = FileFormatFactory .get (file_format )
26142692 location_provider = load_location_provider (table_location = table_metadata .location , table_properties = table_metadata .properties )
26152693
2616- def write_parquet (task : WriteTask ) -> DataFile :
2694+ def write_data_file (task : WriteTask ) -> DataFile :
26172695 table_schema = table_metadata .schema ()
2618- # if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
2619- # otherwise use the original schema
26202696 if (sanitized_schema := sanitize_column_names (table_schema )) != table_schema :
26212697 file_schema = sanitized_schema
26222698 else :
@@ -2630,29 +2706,25 @@ def write_parquet(task: WriteTask) -> DataFile:
26302706 batch = batch ,
26312707 downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us ,
26322708 include_field_ids = True ,
2709+ file_format = file_format ,
26332710 )
26342711 for batch in task .record_batches
26352712 ]
26362713 arrow_table = pa .Table .from_batches (batches )
26372714 file_path = location_provider .new_data_location (
2638- data_file_name = task .generate_data_file_filename ("parquet" ),
2715+ data_file_name = task .generate_data_file_filename (format_model . file_extension () ),
26392716 partition_key = task .partition_key ,
26402717 )
26412718 fo = io .new_output (file_path )
2642- with fo .create (overwrite = True ) as fos :
2643- with pq .ParquetWriter (
2644- fos , schema = arrow_table .schema , store_decimal_as_integer = True , ** parquet_writer_kwargs
2645- ) as writer :
2646- writer .write (arrow_table , row_group_size = row_group_size )
2647- statistics = data_file_statistics_from_parquet_metadata (
2648- parquet_metadata = writer .writer .metadata ,
2649- stats_columns = compute_statistics_plan (file_schema , table_metadata .properties ),
2650- parquet_column_mapping = parquet_path_to_id_mapping (file_schema ),
2651- )
2652- data_file = DataFile .from_args (
2719+ writer = format_model .create_writer (fo , file_schema , table_metadata .properties )
2720+ with writer :
2721+ writer .write (arrow_table )
2722+ statistics = writer .result ()
2723+
2724+ return DataFile .from_args (
26532725 content = DataFileContent .DATA ,
26542726 file_path = file_path ,
2655- file_format = FileFormat . PARQUET ,
2727+ file_format = file_format ,
26562728 partition = task .partition_key .partition if task .partition_key else Record (),
26572729 file_size_in_bytes = len (fo ),
26582730 # After this has been fixed:
@@ -2666,10 +2738,8 @@ def write_parquet(task: WriteTask) -> DataFile:
26662738 ** statistics .to_serialized_dict (),
26672739 )
26682740
2669- return data_file
2670-
26712741 executor = ExecutorFactory .get_or_create ()
2672- data_files = executor .map (write_parquet , tasks )
2742+ data_files = executor .map (write_data_file , tasks )
26732743
26742744 return iter (data_files )
26752745
0 commit comments