Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyiceberg/catalog/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ def create_table(
iceberg_schema = self._convert_schema_if_needed(schema)
iceberg_schema = assign_fresh_schema_ids(iceberg_schema)

properties = properties.copy()
Comment thread
jonashaag marked this conversation as resolved.
Outdated
for copy_key in ["write.parquet.compression-codec", "write.parquet.compression-level"]:
if copy_key in self.properties:
properties[copy_key] = self.properties[copy_key]
Comment thread
jonashaag marked this conversation as resolved.
Outdated
namespace_and_table = self._split_identifier_for_path(identifier)
request = CreateTableRequest(
name=namespace_and_table["table"],
Expand Down
20 changes: 13 additions & 7 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,13 +1720,23 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
except StopIteration:
pass

compression_codec = table.properties.get("write.parquet.compression-codec")
compression_level = table.properties.get("write.parquet.compression-level")
compression_options: Dict[str, Any]
if compression_codec == "uncompressed":
Comment thread
jonashaag marked this conversation as resolved.
Outdated
compression_options = {"compression": "none"}
else:
Comment thread
jonashaag marked this conversation as resolved.
Outdated
compression_options = {
"compression": compression_codec,
"compression_level": None if compression_level is None else int(compression_level),
}

file_path = f'{table.location()}/data/{task.generate_data_file_filename("parquet")}'
file_schema = schema_to_pyarrow(table.schema())

collected_metrics: List[pq.FileMetaData] = []
fo = table.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=file_schema, version="1.0", metadata_collector=collected_metrics) as writer:
with pq.ParquetWriter(fos, schema=file_schema, version="1.0", **compression_options) as writer:
writer.write_table(task.df)

data_file = DataFile(
Expand All @@ -1745,13 +1755,9 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
key_metadata=None,
)

if len(collected_metrics) != 1:
# One file has been written
raise ValueError(f"Expected 1 entry, got: {collected_metrics}")

fill_parquet_file_metadata(
data_file=data_file,
parquet_metadata=collected_metrics[0],
parquet_metadata=writer.writer.metadata,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked this through the debugger, and this looks good. Nice change @jonashaag 👍

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also tell from the PyArrow code that it's identical :)

stats_columns=compute_statistics_plan(table.schema(), table.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
)
Expand Down
6 changes: 5 additions & 1 deletion tests/integration/test_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,11 @@ def test_ray_all_types(catalog: Catalog) -> None:
@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')])
def test_pyarrow_to_iceberg_all_types(catalog: Catalog) -> None:
table_test_all_types = catalog.load_table("default.test_all_types")
fs = S3FileSystem(endpoint_override="http://localhost:9000", access_key="admin", secret_key="password")
fs = S3FileSystem(
endpoint_override=catalog.properties["s3.endpoint"],
access_key=catalog.properties["s3.access-key-id"],
secret_key=catalog.properties["s3.secret-access-key"],
)
data_file_paths = [task.file.file_path for task in table_test_all_types.scan().plan_files()]
for data_file_path in data_file_paths:
uri = urlparse(data_file_path)
Expand Down
55 changes: 55 additions & 0 deletions tests/integration/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
# pylint:disable=redefined-outer-name
import uuid
from datetime import date, datetime
from urllib.parse import urlparse

import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from pyarrow.fs import S3FileSystem
from pyspark.sql import SparkSession

from pyiceberg.catalog import Catalog, load_catalog
Expand Down Expand Up @@ -489,6 +492,58 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
assert [row.deleted_data_files_count for row in rows] == [0, 0, 1, 0, 0]


@pytest.mark.integration
@pytest.mark.parametrize(
"compression",
# List of (compression_properties, expected_compression_name)
[
# REST catalog uses Zstandard by default: https://github.com/apache/iceberg/pull/8593
({}, "ZSTD"),
({"write.parquet.compression-codec": "uncompressed"}, "UNCOMPRESSED"),
({"write.parquet.compression-codec": "gzip", "write.parquet.compression-level": "1"}, "GZIP"),
({"write.parquet.compression-codec": "zstd", "write.parquet.compression-level": "1"}, "ZSTD"),
({"write.parquet.compression-codec": "snappy"}, "SNAPPY"),
],
)
def test_parquet_compression(spark: SparkSession, arrow_table_with_null: pa.Table, compression) -> None:
compression_properties, expected_compression_name = compression

catalog = load_catalog(
"local",
**{
"type": "rest",
"uri": "http://localhost:8181",
"s3.endpoint": "http://localhost:9000",
"s3.access-key-id": "admin",
"s3.secret-access-key": "password",
**compression_properties,
},
Comment thread
jonashaag marked this conversation as resolved.
Outdated
)
identifier = "default.arrow_data_files"

try:
catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass
tbl = catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '1'})
Comment thread
jonashaag marked this conversation as resolved.
Outdated

tbl.overwrite(arrow_table_with_null)

data_file_paths = [task.file.file_path for task in tbl.scan().plan_files()]

fs = S3FileSystem(
endpoint_override=catalog.properties["s3.endpoint"],
access_key=catalog.properties["s3.access-key-id"],
secret_key=catalog.properties["s3.secret-access-key"],
)
uri = urlparse(data_file_paths[0])
with fs.open_input_file(f"{uri.netloc}{uri.path}") as f:
parquet_metadata = pq.read_metadata(f)
compression = parquet_metadata.row_group(0).column(0).compression

assert compression == expected_compression_name


@pytest.mark.integration
def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.arrow_data_files"
Expand Down