|
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | 17 | # pylint: disable=redefined-outer-name,arguments-renamed,fixme |
| 18 | +from collections.abc import Callable |
18 | 19 | from tempfile import TemporaryDirectory |
19 | 20 |
|
20 | 21 | import fastavro |
|
31 | 32 | ManifestEntry, |
32 | 33 | ManifestEntryStatus, |
33 | 34 | ManifestFile, |
| 35 | + ManifestWriter, |
34 | 36 | PartitionFieldSummary, |
| 37 | + RollingManifestWriter, |
35 | 38 | _manifest_cache, |
36 | 39 | _manifests, |
37 | 40 | read_manifest_list, |
@@ -932,3 +935,98 @@ def test_manifest_writer_tell(format_version: TableVersion) -> None: |
932 | 935 | after_entry_bytes = writer.tell() |
933 | 936 |
|
934 | 937 | assert after_entry_bytes > initial_bytes, "Bytes should increase after adding entry" |
| 938 | + |
| 939 | + |
| 940 | +@pytest.mark.parametrize("format_version", [1, 2]) |
| 941 | +def test_rolling_manifest_writer_stays_in_one_file_under_target(format_version: TableVersion) -> None: |
| 942 | + with TemporaryDirectory() as tmpdir: |
| 943 | + supplier = _create_manifest_writer_supplier( |
| 944 | + tmpdir, format_version, Schema(NestedField(1, "id", IntegerType(), required=True)) |
| 945 | + ) |
| 946 | + entries = [_create_simple_entry(i) for i in range(100)] |
| 947 | + |
| 948 | + with RollingManifestWriter(supplier=supplier, target_file_size_in_bytes=10000) as writer: |
| 949 | + for entry in entries: |
| 950 | + writer.add_entry(entry) |
| 951 | + |
| 952 | + assert len(writer.to_manifest_files()) == 1 |
| 953 | + |
| 954 | + |
| 955 | +@pytest.mark.parametrize("format_version", [1, 2]) |
| 956 | +def test_rolling_manifest_writer_splits_when_over_target(format_version: TableVersion) -> None: |
| 957 | + with TemporaryDirectory() as tmpdir: |
| 958 | + supplier = _create_manifest_writer_supplier( |
| 959 | + tmpdir, format_version, Schema(NestedField(1, "id", IntegerType(), required=True)) |
| 960 | + ) |
| 961 | + entries = [_create_simple_entry(i) for i in range(500)] |
| 962 | + |
| 963 | + with RollingManifestWriter(supplier=supplier, target_file_size_in_bytes=1) as writer: |
| 964 | + for entry in entries: |
| 965 | + writer.add_entry(entry) |
| 966 | + |
| 967 | + manifest_files = writer.to_manifest_files() |
| 968 | + # writer will check size every 250 entries. Target=1 forces splits at 250 and 500. |
| 969 | + assert len(manifest_files) == 2 |
| 970 | + |
| 971 | + with pytest.raises(RuntimeError, match="Cannot add entry to closed"): |
| 972 | + writer.add_entry(entries[0]) |
| 973 | + |
| 974 | + |
| 975 | +@pytest.mark.parametrize("format_version", [1, 2]) |
| 976 | +def test_rolling_manifest_writer_empty(format_version: TableVersion) -> None: |
| 977 | + with TemporaryDirectory() as tmpdir: |
| 978 | + supplier = _create_manifest_writer_supplier( |
| 979 | + tmpdir, format_version, Schema(NestedField(1, "id", IntegerType(), required=True)) |
| 980 | + ) |
| 981 | + |
| 982 | + with RollingManifestWriter(supplier=supplier, target_file_size_in_bytes=42) as writer: |
| 983 | + pass |
| 984 | + |
| 985 | + assert writer.to_manifest_files() == [] |
| 986 | + |
| 987 | + |
| 988 | +def _create_manifest_writer_supplier( |
| 989 | + tmpdir: str, |
| 990 | + format_version: TableVersion, |
| 991 | + schema: Schema, |
| 992 | + snapshot_id: int = 1, |
| 993 | +) -> Callable[[], ManifestWriter]: |
| 994 | + counter = [0] |
| 995 | + io = PyArrowFileIO() |
| 996 | + |
| 997 | + def _supplier() -> ManifestWriter: |
| 998 | + output_file = io.new_output(f"{tmpdir}/manifest-{counter[0]}.avro") |
| 999 | + counter[0] += 1 |
| 1000 | + return write_manifest( |
| 1001 | + format_version=format_version, |
| 1002 | + spec=UNPARTITIONED_PARTITION_SPEC, |
| 1003 | + schema=schema, |
| 1004 | + output_file=output_file, |
| 1005 | + snapshot_id=snapshot_id, |
| 1006 | + avro_compression="null", |
| 1007 | + ) |
| 1008 | + |
| 1009 | + return _supplier |
| 1010 | + |
| 1011 | + |
| 1012 | +def _create_simple_entry( |
| 1013 | + i: int, |
| 1014 | + status: ManifestEntryStatus = ManifestEntryStatus.ADDED, |
| 1015 | + sequence_number: int | None = 1, |
| 1016 | +) -> ManifestEntry: |
| 1017 | + data_file = DataFile.from_args( |
| 1018 | + content=DataFileContent.DATA, |
| 1019 | + file_path=f"data-{i}.parquet", |
| 1020 | + file_format=FileFormat.PARQUET, |
| 1021 | + partition=Record(), |
| 1022 | + record_count=1, |
| 1023 | + file_size_in_bytes=1000, |
| 1024 | + ) |
| 1025 | + return ManifestEntry.from_args( |
| 1026 | + status=status, |
| 1027 | + snapshot_id=1, |
| 1028 | + sequence_number=sequence_number, |
| 1029 | + data_sequence_number=1, |
| 1030 | + file_sequence_number=1, |
| 1031 | + data_file=data_file, |
| 1032 | + ) |
0 commit comments