1515# specific language governing permissions and limitations
1616# under the License.
1717# pylint:disable=redefined-outer-name
18- from collections .abc import Generator
19- from datetime import datetime
18+ from collections .abc import Generator , Iterator
19+ from datetime import date , datetime
2020
2121import pyarrow as pa
22+ import pyarrow .parquet as pq
2223import pytest
2324from pyspark .sql import SparkSession
2425
26+ from pyiceberg .catalog import Catalog
2527from pyiceberg .catalog .rest import RestCatalog
2628from pyiceberg .exceptions import NoSuchTableError
2729from pyiceberg .expressions import AlwaysTrue , EqualTo , LessThanOrEqual
30+ from pyiceberg .io import FileIO
2831from pyiceberg .manifest import ManifestEntryStatus
29- from pyiceberg .partitioning import PartitionField , PartitionSpec
32+ from pyiceberg .partitioning import UNPARTITIONED_PARTITION_SPEC , PartitionField , PartitionSpec
3033from pyiceberg .schema import Schema
3134from pyiceberg .table import Table
3235from pyiceberg .table .snapshots import Operation , Summary
3336from pyiceberg .transforms import IdentityTransform
34- from pyiceberg .types import FloatType , IntegerType , LongType , NestedField , StringType , TimestampType
37+ from pyiceberg .types import BooleanType , DateType , FloatType , IntegerType , LongType , NestedField , StringType , TimestampType
38+
39+
40+ # Schema and data used by delete_files tests (moved from test_add_files)
41+ TABLE_SCHEMA_DELETE_FILES = Schema (
42+ NestedField (field_id = 1 , name = "foo" , field_type = BooleanType (), required = False ),
43+ NestedField (field_id = 2 , name = "bar" , field_type = StringType (), required = False ),
44+ NestedField (field_id = 4 , name = "baz" , field_type = IntegerType (), required = False ),
45+ NestedField (field_id = 10 , name = "qux" , field_type = DateType (), required = False ),
46+ )
47+
48+ ARROW_SCHEMA_DELETE_FILES = pa .schema (
49+ [
50+ ("foo" , pa .bool_ ()),
51+ ("bar" , pa .string ()),
52+ ("baz" , pa .int32 ()),
53+ ("qux" , pa .date32 ()),
54+ ]
55+ )
56+
57+ ARROW_TABLE_DELETE_FILES = pa .Table .from_pylist (
58+ [
59+ {
60+ "foo" : True ,
61+ "bar" : "bar_string" ,
62+ "baz" : 123 ,
63+ "qux" : date (2024 , 3 , 7 ),
64+ }
65+ ],
66+ schema = ARROW_SCHEMA_DELETE_FILES ,
67+ )
68+
69+
70+ def _write_parquet (io : FileIO , file_path : str , arrow_schema : pa .Schema , arrow_table : pa .Table ) -> None :
71+ fo = io .new_output (file_path )
72+ with fo .create (overwrite = True ) as fos :
73+ with pq .ParquetWriter (fos , schema = arrow_schema ) as writer :
74+ writer .write_table (arrow_table )
75+
76+
77+ def _create_table_for_delete_files (
78+ session_catalog : Catalog ,
79+ identifier : str ,
80+ format_version : int ,
81+ partition_spec : PartitionSpec = UNPARTITIONED_PARTITION_SPEC ,
82+ schema : Schema = TABLE_SCHEMA_DELETE_FILES ,
83+ ) -> Table :
84+ try :
85+ session_catalog .drop_table (identifier = identifier )
86+ except NoSuchTableError :
87+ pass
88+
89+ return session_catalog .create_table (
90+ identifier = identifier ,
91+ schema = schema ,
92+ properties = {"format-version" : str (format_version )},
93+ partition_spec = partition_spec ,
94+ )
3595
3696
3797def run_spark_commands (spark : SparkSession , sqls : list [str ]) -> None :
@@ -57,6 +117,12 @@ def test_table(session_catalog: RestCatalog) -> Generator[Table, None, None]:
57117 session_catalog .drop_table (identifier )
58118
59119
120+ @pytest .fixture (name = "format_version" , params = [pytest .param (1 , id = "format_version=1" ), pytest .param (2 , id = "format_version=2" )])
121+ def format_version_fixture (request : "pytest.FixtureRequest" ) -> Iterator [int ]:
122+ """Fixture to run tests with different table format versions (for delete_files tests)."""
123+ yield request .param
124+
125+
60126@pytest .mark .integration
61127@pytest .mark .parametrize ("format_version" , [1 , 2 ])
62128def test_partitioned_table_delete_full_file (spark : SparkSession , session_catalog : RestCatalog , format_version : int ) -> None :
@@ -975,3 +1041,93 @@ def assert_manifest_entry(expected_status: ManifestEntryStatus, expected_snapsho
9751041 assert after_delete_snapshot is not None
9761042
9771043 assert_manifest_entry (ManifestEntryStatus .DELETED , after_delete_snapshot .snapshot_id )
1044+
1045+
1046+ @pytest .mark .integration
1047+ def test_delete_files_from_unpartitioned_table (
1048+ spark : SparkSession , session_catalog : Catalog , format_version : int
1049+ ) -> None :
1050+ identifier = f"default.delete_files_unpartitioned_v{ format_version } "
1051+ tbl = _create_table_for_delete_files (session_catalog , identifier , format_version )
1052+
1053+ file_paths = [f"s3://warehouse/default/delete_unpartitioned/v{ format_version } /test-{ i } .parquet" for i in range (5 )]
1054+ for file_path in file_paths :
1055+ _write_parquet (tbl .io , file_path , ARROW_SCHEMA_DELETE_FILES , ARROW_TABLE_DELETE_FILES )
1056+
1057+ tbl .add_files (file_paths = file_paths )
1058+ assert len (tbl .scan ().to_arrow ()) == 5
1059+
1060+ tbl .delete_files (file_paths = file_paths [:2 ])
1061+
1062+ rows = spark .sql (
1063+ f"""
1064+ SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count
1065+ FROM { identifier } .all_manifests
1066+ """
1067+ ).collect ()
1068+
1069+ assert sum (row .deleted_data_files_count for row in rows ) == 2
1070+
1071+ df = spark .table (identifier )
1072+ assert df .count () == 3
1073+
1074+ assert len (tbl .scan ().to_arrow ()) == 3
1075+
1076+
1077+ @pytest .mark .integration
1078+ def test_delete_files_raises_on_nonexistent_file (session_catalog : Catalog , format_version : int ) -> None :
1079+ identifier = f"default.delete_files_nonexistent_v{ format_version } "
1080+ tbl = _create_table_for_delete_files (session_catalog , identifier , format_version )
1081+
1082+ file_paths = [f"s3://warehouse/default/delete_nonexistent/v{ format_version } /test-{ i } .parquet" for i in range (3 )]
1083+ for file_path in file_paths :
1084+ _write_parquet (tbl .io , file_path , ARROW_SCHEMA_DELETE_FILES , ARROW_TABLE_DELETE_FILES )
1085+
1086+ tbl .add_files (file_paths = file_paths )
1087+
1088+ with pytest .raises (ValueError , match = "Cannot delete files that are not referenced by table" ):
1089+ tbl .delete_files (file_paths = ["s3://warehouse/default/does-not-exist.parquet" ])
1090+
1091+
1092+ @pytest .mark .integration
1093+ def test_delete_files_raises_on_duplicate_paths (session_catalog : Catalog , format_version : int ) -> None :
1094+ identifier = f"default.delete_files_duplicate_v{ format_version } "
1095+ tbl = _create_table_for_delete_files (session_catalog , identifier , format_version )
1096+
1097+ file_path = f"s3://warehouse/default/delete_duplicate/v{ format_version } /test.parquet"
1098+ _write_parquet (tbl .io , file_path , ARROW_SCHEMA_DELETE_FILES , ARROW_TABLE_DELETE_FILES )
1099+
1100+ tbl .add_files (file_paths = [file_path ])
1101+
1102+ with pytest .raises (ValueError , match = "File paths must be unique" ):
1103+ tbl .delete_files (file_paths = [file_path , file_path ])
1104+
1105+
1106+ @pytest .mark .integration
1107+ def test_delete_files_from_branch (
1108+ spark : SparkSession , session_catalog : Catalog , format_version : int
1109+ ) -> None :
1110+ identifier = f"default.delete_files_branch_v{ format_version } "
1111+ branch = "branch1"
1112+
1113+ tbl = _create_table_for_delete_files (session_catalog , identifier , format_version )
1114+
1115+ file_paths = [f"s3://warehouse/default/delete_branch/v{ format_version } /test-{ i } .parquet" for i in range (5 )]
1116+ for file_path in file_paths :
1117+ _write_parquet (tbl .io , file_path , ARROW_SCHEMA_DELETE_FILES , ARROW_TABLE_DELETE_FILES )
1118+
1119+ tbl .append (ARROW_TABLE_DELETE_FILES )
1120+ assert tbl .metadata .current_snapshot_id is not None
1121+ tbl .manage_snapshots ().create_branch (snapshot_id = tbl .metadata .current_snapshot_id , branch_name = branch ).commit ()
1122+
1123+ tbl .add_files (file_paths = file_paths , branch = branch )
1124+ branch_df = spark .table (f"{ identifier } .branch_{ branch } " )
1125+ assert branch_df .count () == 6
1126+
1127+ tbl .delete_files (file_paths = file_paths [:3 ], branch = branch )
1128+
1129+ branch_df = spark .table (f"{ identifier } .branch_{ branch } " )
1130+ assert branch_df .count () == 3
1131+
1132+ main_df = spark .table (identifier )
1133+ assert main_df .count () == 1
0 commit comments