|
21 | 21 | import uuid |
22 | 22 | import warnings |
23 | 23 | from datetime import date, datetime, timezone |
| 24 | +from pathlib import Path |
24 | 25 | from typing import Any, List, Optional |
25 | 26 | from unittest.mock import MagicMock, patch |
26 | 27 | from uuid import uuid4 |
|
71 | 72 | _determine_partitions, |
72 | 73 | _primitive_to_physical, |
73 | 74 | _read_deletes, |
| 75 | + _resolve_row_group_size, |
74 | 76 | _task_to_record_batches, |
75 | 77 | _to_requested_schema, |
76 | 78 | bin_pack_arrow_table, |
|
79 | 81 | expression_to_pyarrow, |
80 | 82 | parquet_path_to_id_mapping, |
81 | 83 | schema_to_pyarrow, |
| 84 | + write_file, |
82 | 85 | ) |
83 | 86 | from pyiceberg.manifest import DataFile, DataFileContent, FileFormat |
84 | 87 | from pyiceberg.partitioning import PartitionField, PartitionSpec |
@@ -2825,3 +2828,86 @@ def test_parse_location_defaults() -> None: |
2825 | 2828 | assert scheme == "hdfs" |
2826 | 2829 | assert netloc == "netloc:8000" |
2827 | 2830 | assert path == "/foo/bar" |
| 2831 | + |
| 2832 | + |
| 2833 | +@pytest.mark.parametrize( |
| 2834 | + "arrow_table,row_group_limit,row_group_size_bytes,expected", |
| 2835 | + [ |
| 2836 | + # Byte limit tighter than row limit — 2 int64 cols => 16 bytes/row, |
| 2837 | + # 1024-byte budget => 64 rows/group. |
| 2838 | + (pa.table({"a": list(range(1000)), "b": list(range(1000))}), 10_000, 1024, 64), |
| 2839 | + # Row limit tighter than byte limit. |
| 2840 | + (pa.table({"a": list(range(1000))}), 10, 10**9, 10), |
| 2841 | + # Byte limit disabled (0) falls back to the row limit. |
| 2842 | + (pa.table({"a": list(range(1000))}), 500, 0, 500), |
| 2843 | + # Empty input falls back to the row limit. |
| 2844 | + (pa.table({"a": pa.array([], type=pa.int64())}), 500, 1024, 500), |
| 2845 | + ], |
| 2846 | +) |
| 2847 | +def test__resolve_row_group_size(arrow_table: pa.Table, row_group_limit: int, row_group_size_bytes: int, expected: int) -> None: |
| 2848 | + """Pick min(row_group_limit, bytes/(bytes_per_row)) when byte limit is set.""" |
| 2849 | + assert _resolve_row_group_size(arrow_table, row_group_limit, row_group_size_bytes) == expected |
| 2850 | + |
| 2851 | + |
| 2852 | +def test_write_file_byte_limit_produces_more_row_groups_than_row_limit_alone(tmp_path: Path) -> None: |
| 2853 | + """A tight byte limit splits a single arrow table across multiple row groups.""" |
| 2854 | + from pyiceberg.table import WriteTask |
| 2855 | + |
| 2856 | + table_schema = Schema( |
| 2857 | + NestedField(1, "a", LongType(), required=False), |
| 2858 | + NestedField(2, "b", LongType(), required=False), |
| 2859 | + ) |
| 2860 | + arrow_data = pa.table({"a": list(range(10_000)), "b": list(range(10_000))}) |
| 2861 | + |
| 2862 | + def _write(properties: dict[str, str], subdir: str) -> Path: |
| 2863 | + table_metadata = TableMetadataV2( |
| 2864 | + location=f"file://{tmp_path}/{subdir}", |
| 2865 | + last_column_id=2, |
| 2866 | + format_version=2, |
| 2867 | + schemas=[table_schema], |
| 2868 | + partition_specs=[PartitionSpec()], |
| 2869 | + properties=properties, |
| 2870 | + ) |
| 2871 | + task = WriteTask( |
| 2872 | + write_uuid=uuid.uuid4(), |
| 2873 | + task_id=0, |
| 2874 | + record_batches=arrow_data.to_batches(), |
| 2875 | + schema=table_schema, |
| 2876 | + ) |
| 2877 | + data_files = list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task]))) |
| 2878 | + return Path(data_files[0].file_path.removeprefix("file://")) |
| 2879 | + |
| 2880 | + default_groups = pq.ParquetFile(_write({}, "default")).num_row_groups |
| 2881 | + constrained_groups = pq.ParquetFile( |
| 2882 | + _write({TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES: "1024"}, "constrained") |
| 2883 | + ).num_row_groups |
| 2884 | + assert default_groups == 1 |
| 2885 | + assert constrained_groups > 1 |
| 2886 | + |
| 2887 | + |
| 2888 | +def test_write_file_byte_limit_respects_row_limit_upper_bound(tmp_path: Path) -> None: |
| 2889 | + """With an effectively infinite byte target, the row limit caps row groups.""" |
| 2890 | + from pyiceberg.table import WriteTask |
| 2891 | + |
| 2892 | + table_schema = Schema(NestedField(1, "a", LongType(), required=False)) |
| 2893 | + arrow_data = pa.table({"a": list(range(10_000))}) |
| 2894 | + table_metadata = TableMetadataV2( |
| 2895 | + location=f"file://{tmp_path}", |
| 2896 | + last_column_id=1, |
| 2897 | + format_version=2, |
| 2898 | + schemas=[table_schema], |
| 2899 | + partition_specs=[PartitionSpec()], |
| 2900 | + properties={ |
| 2901 | + TableProperties.PARQUET_ROW_GROUP_LIMIT: "1000", |
| 2902 | + TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES: str(10**12), |
| 2903 | + }, |
| 2904 | + ) |
| 2905 | + task = WriteTask( |
| 2906 | + write_uuid=uuid.uuid4(), |
| 2907 | + task_id=0, |
| 2908 | + record_batches=arrow_data.to_batches(), |
| 2909 | + schema=table_schema, |
| 2910 | + ) |
| 2911 | + data_files = list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task]))) |
| 2912 | + pf = pq.ParquetFile(data_files[0].file_path.removeprefix("file://")) |
| 2913 | + assert pf.num_row_groups == 10 |
0 commit comments