|
86 | 86 | from pyiceberg.manifest import DataFile, DataFileContent, FileFormat |
87 | 87 | from pyiceberg.partitioning import PartitionField, PartitionSpec |
88 | 88 | from pyiceberg.schema import Schema, make_compatible_name, visit |
89 | | -from pyiceberg.table import FileScanTask, TableProperties |
| 89 | +from pyiceberg.table import FileScanTask, ScanOrder, TableProperties |
90 | 90 | from pyiceberg.table.metadata import TableMetadataV2 |
91 | 91 | from pyiceberg.table.name_mapping import create_mapping_from_schema |
92 | 92 | from pyiceberg.transforms import HourTransform, IdentityTransform |
@@ -3106,6 +3106,176 @@ def test_task_to_record_batches_default_batch_size(tmpdir: str) -> None: |
3106 | 3106 | assert len(batches[0]) == num_rows |
3107 | 3107 |
|
3108 | 3108 |
|
| 3109 | +def _create_scan_and_tasks( |
| 3110 | + tmpdir: str, |
| 3111 | + num_files: int = 1, |
| 3112 | + rows_per_file: int = 100, |
| 3113 | + limit: int | None = None, |
| 3114 | + delete_rows_per_file: list[list[int]] | None = None, |
| 3115 | +) -> tuple[ArrowScan, list[FileScanTask]]: |
| 3116 | + """Helper to create an ArrowScan and FileScanTasks for testing. |
| 3117 | +
|
| 3118 | + Args: |
| 3119 | + delete_rows_per_file: If provided, a list of lists of row positions to delete |
| 3120 | + per file. Length must match num_files. Each inner list contains 0-based |
| 3121 | + row positions within that file to mark as positionally deleted. |
| 3122 | + """ |
| 3123 | + table_schema = Schema(NestedField(1, "col", LongType(), required=True)) |
| 3124 | + pa_schema = pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]) |
| 3125 | + tasks = [] |
| 3126 | + for i in range(num_files): |
| 3127 | + start = i * rows_per_file |
| 3128 | + arrow_table = pa.table({"col": pa.array(range(start, start + rows_per_file))}, schema=pa_schema) |
| 3129 | + data_file = _write_table_to_data_file(f"{tmpdir}/file_{i}.parquet", pa_schema, arrow_table) |
| 3130 | + data_file.spec_id = 0 |
| 3131 | + |
| 3132 | + delete_files = set() |
| 3133 | + if delete_rows_per_file and delete_rows_per_file[i]: |
| 3134 | + delete_table = pa.table( |
| 3135 | + { |
| 3136 | + "file_path": [data_file.file_path] * len(delete_rows_per_file[i]), |
| 3137 | + "pos": delete_rows_per_file[i], |
| 3138 | + } |
| 3139 | + ) |
| 3140 | + delete_path = f"{tmpdir}/deletes_{i}.parquet" |
| 3141 | + pq.write_table(delete_table, delete_path) |
| 3142 | + delete_files.add( |
| 3143 | + DataFile.from_args( |
| 3144 | + content=DataFileContent.POSITION_DELETES, |
| 3145 | + file_path=delete_path, |
| 3146 | + file_format=FileFormat.PARQUET, |
| 3147 | + partition={}, |
| 3148 | + record_count=len(delete_rows_per_file[i]), |
| 3149 | + file_size_in_bytes=22, |
| 3150 | + ) |
| 3151 | + ) |
| 3152 | + |
| 3153 | + tasks.append(FileScanTask(data_file=data_file, delete_files=delete_files)) |
| 3154 | + |
| 3155 | + scan = ArrowScan( |
| 3156 | + table_metadata=TableMetadataV2( |
| 3157 | + location="file://a/b/", |
| 3158 | + last_column_id=1, |
| 3159 | + format_version=2, |
| 3160 | + schemas=[table_schema], |
| 3161 | + partition_specs=[PartitionSpec()], |
| 3162 | + ), |
| 3163 | + io=PyArrowFileIO(), |
| 3164 | + projected_schema=table_schema, |
| 3165 | + row_filter=AlwaysTrue(), |
| 3166 | + case_sensitive=True, |
| 3167 | + limit=limit, |
| 3168 | + ) |
| 3169 | + return scan, tasks |
| 3170 | + |
| 3171 | + |
| 3172 | +def test_task_order_produces_same_results(tmpdir: str) -> None: |
| 3173 | + """Test that order=ScanOrder.TASK produces the same results as the default behavior.""" |
| 3174 | + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) |
| 3175 | + |
| 3176 | + batches_default = list(scan.to_record_batches(tasks, order=ScanOrder.TASK)) |
| 3177 | + # Re-create tasks since iterators are consumed |
| 3178 | + _, tasks2 = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) |
| 3179 | + batches_task_order = list(scan.to_record_batches(tasks2, order=ScanOrder.TASK)) |
| 3180 | + |
| 3181 | + total_default = sum(len(b) for b in batches_default) |
| 3182 | + total_task_order = sum(len(b) for b in batches_task_order) |
| 3183 | + assert total_default == 300 |
| 3184 | + assert total_task_order == 300 |
| 3185 | + |
| 3186 | + |
| 3187 | +def test_arrival_order_yields_all_batches(tmpdir: str) -> None: |
| 3188 | + """Test that order=ScanOrder.ARRIVAL yields all batches correctly.""" |
| 3189 | + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) |
| 3190 | + |
| 3191 | + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL)) |
| 3192 | + |
| 3193 | + total_rows = sum(len(b) for b in batches) |
| 3194 | + assert total_rows == 300 |
| 3195 | + # Verify all values are present |
| 3196 | + all_values = sorted([v for b in batches for v in b.column("col").to_pylist()]) |
| 3197 | + assert all_values == list(range(300)) |
| 3198 | + |
| 3199 | + |
| 3200 | +def test_arrival_order_with_limit(tmpdir: str) -> None: |
| 3201 | + """Test that order=ScanOrder.ARRIVAL respects the row limit.""" |
| 3202 | + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100, limit=150) |
| 3203 | + |
| 3204 | + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL)) |
| 3205 | + |
| 3206 | + total_rows = sum(len(b) for b in batches) |
| 3207 | + assert total_rows == 150 |
| 3208 | + |
| 3209 | + |
| 3210 | +def test_arrival_order_file_ordering_preserved(tmpdir: str) -> None: |
| 3211 | + """Test that file ordering is preserved in arrival order mode.""" |
| 3212 | + scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100) |
| 3213 | + |
| 3214 | + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL)) |
| 3215 | + all_values = [v for b in batches for v in b.column("col").to_pylist()] |
| 3216 | + |
| 3217 | + # Values should be in file order: 0-99 from file 0, 100-199 from file 1, 200-299 from file 2 |
| 3218 | + assert all_values == list(range(300)) |
| 3219 | + |
| 3220 | + |
| 3221 | +def test_arrival_order_with_positional_deletes(tmpdir: str) -> None: |
| 3222 | + """Test that order=ScanOrder.ARRIVAL correctly applies positional deletes.""" |
| 3223 | + # 3 files, 10 rows each; delete rows 0,5 from file 0, row 3 from file 1, nothing from file 2 |
| 3224 | + scan, tasks = _create_scan_and_tasks( |
| 3225 | + tmpdir, |
| 3226 | + num_files=3, |
| 3227 | + rows_per_file=10, |
| 3228 | + delete_rows_per_file=[[0, 5], [3], []], |
| 3229 | + ) |
| 3230 | + |
| 3231 | + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL)) |
| 3232 | + |
| 3233 | + total_rows = sum(len(b) for b in batches) |
| 3234 | + assert total_rows == 27 # 30 - 3 deletes |
| 3235 | + all_values = sorted([v for b in batches for v in b.column("col").to_pylist()]) |
| 3236 | + # File 0: 0-9, delete rows 0,5 → values 1,2,3,4,6,7,8,9 |
| 3237 | + # File 1: 10-19, delete row 3 → values 10,11,12,14,15,16,17,18,19 |
| 3238 | + # File 2: 20-29, no deletes → values 20-29 |
| 3239 | + expected = [1, 2, 3, 4, 6, 7, 8, 9] + [10, 11, 12, 14, 15, 16, 17, 18, 19] + list(range(20, 30)) |
| 3240 | + assert all_values == sorted(expected) |
| 3241 | + |
| 3242 | + |
| 3243 | +def test_arrival_order_with_positional_deletes_and_limit(tmpdir: str) -> None: |
| 3244 | + """Test that order=ScanOrder.ARRIVAL with positional deletes respects the row limit.""" |
| 3245 | + # 3 files, 10 rows each; delete row 0 from each file |
| 3246 | + scan, tasks = _create_scan_and_tasks( |
| 3247 | + tmpdir, |
| 3248 | + num_files=3, |
| 3249 | + rows_per_file=10, |
| 3250 | + limit=15, |
| 3251 | + delete_rows_per_file=[[0], [0], [0]], |
| 3252 | + ) |
| 3253 | + |
| 3254 | + batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL)) |
| 3255 | + |
| 3256 | + total_rows = sum(len(b) for b in batches) |
| 3257 | + assert total_rows == 15 |
| 3258 | + |
| 3259 | + |
| 3260 | +def test_task_order_with_positional_deletes(tmpdir: str) -> None: |
| 3261 | + """Test that the default task order mode correctly applies positional deletes.""" |
| 3262 | + # 3 files, 10 rows each; delete rows from each file |
| 3263 | + scan, tasks = _create_scan_and_tasks( |
| 3264 | + tmpdir, |
| 3265 | + num_files=3, |
| 3266 | + rows_per_file=10, |
| 3267 | + delete_rows_per_file=[[0, 5], [3], []], |
| 3268 | + ) |
| 3269 | + |
| 3270 | + batches = list(scan.to_record_batches(tasks, order=ScanOrder.TASK)) |
| 3271 | + |
| 3272 | + total_rows = sum(len(b) for b in batches) |
| 3273 | + assert total_rows == 27 # 30 - 3 deletes |
| 3274 | + all_values = sorted([v for b in batches for v in b.column("col").to_pylist()]) |
| 3275 | + expected = [1, 2, 3, 4, 6, 7, 8, 9] + [10, 11, 12, 14, 15, 16, 17, 18, 19] + list(range(20, 30)) |
| 3276 | + assert all_values == sorted(expected) |
| 3277 | + |
| 3278 | + |
3109 | 3279 | def test_parse_location_defaults() -> None: |
3110 | 3280 | """Test that parse_location uses defaults.""" |
3111 | 3281 |
|
|
0 commit comments