Skip to content

Commit 444549f

Browse files
sumedhsakdeoclaude
andcommitted
feat: add streaming flag to ArrowScan.to_record_batches
When streaming=True, batches are yielded as they are produced by PyArrow without materializing entire files into memory. Files are still processed sequentially, preserving file ordering. The inner method handles the global limit correctly when called with all tasks, avoiding double-counting. This addresses the OOM issue in #3036 for single-file-at-a-time streaming. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8f8a2d2 commit 444549f

File tree

4 files changed

+202
-3
lines changed

4 files changed

+202
-3
lines changed

mkdocs/docs/api.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,13 @@ for buf in tbl.scan().to_arrow_batch_reader(batch_size=1000):
362362
print(f"Buffer contains {len(buf)} rows")
363363
```
364364

365+
By default, each file's batches are materialized in memory before being yielded. For large files that may exceed available memory, use `streaming=True` to yield batches as they are produced without materializing entire files:
366+
367+
```python
368+
for buf in tbl.scan().to_arrow_batch_reader(streaming=True, batch_size=1000):
369+
print(f"Buffer contains {len(buf)} rows")
370+
```
371+
365372
To avoid any type inconsistencies during writing, you can convert the Iceberg table schema to Arrow:
366373

367374
```python
@@ -1635,6 +1642,15 @@ table.scan(
16351642
).to_arrow_batch_reader(batch_size=1000)
16361643
```
16371644

1645+
Use `streaming=True` to avoid materializing entire files in memory. This yields batches as they are produced by PyArrow, one file at a time:
1646+
1647+
```python
1648+
table.scan(
1649+
row_filter=GreaterThanOrEqual("trip_distance", 10.0),
1650+
selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"),
1651+
).to_arrow_batch_reader(streaming=True)
1652+
```
1653+
16381654
### Pandas
16391655

16401656
<!-- prettier-ignore-start -->

pyiceberg/io/pyarrow.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1761,7 +1761,9 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
17611761

17621762
return result
17631763

1764-
def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | None = None) -> Iterator[pa.RecordBatch]:
1764+
def to_record_batches(
1765+
self, tasks: Iterable[FileScanTask], batch_size: int | None = None, streaming: bool = False
1766+
) -> Iterator[pa.RecordBatch]:
17651767
"""Scan the Iceberg table and return an Iterator[pa.RecordBatch].
17661768
17671769
Returns an Iterator of pa.RecordBatch with data from the Iceberg table
@@ -1770,6 +1772,9 @@ def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | Non
17701772
17711773
Args:
17721774
tasks: FileScanTasks representing the data files and delete files to read from.
1775+
batch_size: The number of rows per batch. If None, PyArrow's default is used.
1776+
streaming: If True, yield batches as they are produced without materializing
1777+
entire files into memory. Files are still processed sequentially.
17731778
17741779
Returns:
17751780
An Iterator of PyArrow RecordBatches.
@@ -1781,6 +1786,14 @@ def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | Non
17811786
"""
17821787
deletes_per_file = _read_all_delete_files(self._io, tasks)
17831788

1789+
if streaming:
1790+
# Streaming path: process all tasks sequentially, yielding batches as produced.
1791+
# _record_batches_from_scan_tasks_and_deletes handles the limit internally
1792+
# when called with all tasks, so no outer limit check is needed.
1793+
yield from self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file, batch_size)
1794+
return
1795+
1796+
# Non-streaming path: existing behavior with executor.map + list()
17841797
total_row_count = 0
17851798
executor = ExecutorFactory.get_or_create()
17861799

pyiceberg/table/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2157,7 +2157,7 @@ def to_arrow(self) -> pa.Table:
21572157
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
21582158
).to_table(self.plan_files())
21592159

2160-
def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatchReader:
2160+
def to_arrow_batch_reader(self, batch_size: int | None = None, streaming: bool = False) -> pa.RecordBatchReader:
21612161
"""Return an Arrow RecordBatchReader from this DataScan.
21622162
21632163
For large results, using a RecordBatchReader requires less memory than
@@ -2166,6 +2166,8 @@ def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatch
21662166
21672167
Args:
21682168
batch_size: The number of rows per batch. If None, PyArrow's default is used.
2169+
streaming: If True, yield batches as they are produced without materializing
2170+
entire files into memory. Files are still processed sequentially.
21692171
21702172
Returns:
21712173
pa.RecordBatchReader: Arrow RecordBatchReader from the Iceberg table's DataScan
@@ -2178,7 +2180,7 @@ def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatch
21782180
target_schema = schema_to_pyarrow(self.projection())
21792181
batches = ArrowScan(
21802182
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
2181-
).to_record_batches(self.plan_files(), batch_size=batch_size)
2183+
).to_record_batches(self.plan_files(), batch_size=batch_size, streaming=streaming)
21822184

21832185
return pa.RecordBatchReader.from_batches(
21842186
target_schema,

tests/io/test_pyarrow.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3106,6 +3106,174 @@ def test_task_to_record_batches_default_batch_size(tmpdir: str) -> None:
31063106
assert len(batches[0]) == num_rows
31073107

31083108

3109+
def _create_scan_and_tasks(
3110+
tmpdir: str,
3111+
num_files: int = 1,
3112+
rows_per_file: int = 100,
3113+
limit: int | None = None,
3114+
delete_rows_per_file: list[list[int]] | None = None,
3115+
) -> tuple[ArrowScan, list[FileScanTask]]:
3116+
"""Helper to create an ArrowScan and FileScanTasks for testing.
3117+
3118+
Args:
3119+
delete_rows_per_file: If provided, a list of lists of row positions to delete
3120+
per file. Length must match num_files. Each inner list contains 0-based
3121+
row positions within that file to mark as positionally deleted.
3122+
"""
3123+
table_schema = Schema(NestedField(1, "col", LongType(), required=True))
3124+
pa_schema = pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})])
3125+
tasks = []
3126+
for i in range(num_files):
3127+
start = i * rows_per_file
3128+
arrow_table = pa.table({"col": pa.array(range(start, start + rows_per_file))}, schema=pa_schema)
3129+
data_file = _write_table_to_data_file(f"{tmpdir}/file_{i}.parquet", pa_schema, arrow_table)
3130+
data_file.spec_id = 0
3131+
3132+
delete_files = set()
3133+
if delete_rows_per_file and delete_rows_per_file[i]:
3134+
delete_table = pa.table({
3135+
"file_path": [data_file.file_path] * len(delete_rows_per_file[i]),
3136+
"pos": delete_rows_per_file[i],
3137+
})
3138+
delete_path = f"{tmpdir}/deletes_{i}.parquet"
3139+
pq.write_table(delete_table, delete_path)
3140+
delete_files.add(
3141+
DataFile.from_args(
3142+
content=DataFileContent.POSITION_DELETES,
3143+
file_path=delete_path,
3144+
file_format=FileFormat.PARQUET,
3145+
partition={},
3146+
record_count=len(delete_rows_per_file[i]),
3147+
file_size_in_bytes=22,
3148+
)
3149+
)
3150+
3151+
tasks.append(FileScanTask(data_file=data_file, delete_files=delete_files))
3152+
3153+
scan = ArrowScan(
3154+
table_metadata=TableMetadataV2(
3155+
location="file://a/b/",
3156+
last_column_id=1,
3157+
format_version=2,
3158+
schemas=[table_schema],
3159+
partition_specs=[PartitionSpec()],
3160+
),
3161+
io=PyArrowFileIO(),
3162+
projected_schema=table_schema,
3163+
row_filter=AlwaysTrue(),
3164+
case_sensitive=True,
3165+
limit=limit,
3166+
)
3167+
return scan, tasks
3168+
3169+
3170+
def test_streaming_false_produces_same_results(tmpdir: str) -> None:
3171+
"""Test that streaming=False produces the same results as the default behavior."""
3172+
scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100)
3173+
3174+
batches_default = list(scan.to_record_batches(tasks, streaming=False))
3175+
# Re-create tasks since iterators are consumed
3176+
_, tasks2 = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100)
3177+
batches_streaming = list(scan.to_record_batches(tasks2, streaming=False))
3178+
3179+
total_default = sum(len(b) for b in batches_default)
3180+
total_streaming = sum(len(b) for b in batches_streaming)
3181+
assert total_default == 300
3182+
assert total_streaming == 300
3183+
3184+
3185+
def test_streaming_true_yields_all_batches(tmpdir: str) -> None:
3186+
"""Test that streaming=True yields all batches correctly."""
3187+
scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100)
3188+
3189+
batches = list(scan.to_record_batches(tasks, streaming=True))
3190+
3191+
total_rows = sum(len(b) for b in batches)
3192+
assert total_rows == 300
3193+
# Verify all values are present
3194+
all_values = sorted([v for b in batches for v in b.column("col").to_pylist()])
3195+
assert all_values == list(range(300))
3196+
3197+
3198+
def test_streaming_true_with_limit(tmpdir: str) -> None:
3199+
"""Test that streaming=True respects the row limit."""
3200+
scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100, limit=150)
3201+
3202+
batches = list(scan.to_record_batches(tasks, streaming=True))
3203+
3204+
total_rows = sum(len(b) for b in batches)
3205+
assert total_rows == 150
3206+
3207+
3208+
def test_streaming_file_ordering_preserved(tmpdir: str) -> None:
3209+
"""Test that file ordering is preserved in both streaming modes."""
3210+
scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100)
3211+
3212+
batches = list(scan.to_record_batches(tasks, streaming=True))
3213+
all_values = [v for b in batches for v in b.column("col").to_pylist()]
3214+
3215+
# Values should be in file order: 0-99 from file 0, 100-199 from file 1, 200-299 from file 2
3216+
assert all_values == list(range(300))
3217+
3218+
3219+
def test_streaming_with_positional_deletes(tmpdir: str) -> None:
3220+
"""Test that streaming=True correctly applies positional deletes."""
3221+
# 3 files, 10 rows each; delete rows 0,5 from file 0, row 3 from file 1, nothing from file 2
3222+
scan, tasks = _create_scan_and_tasks(
3223+
tmpdir,
3224+
num_files=3,
3225+
rows_per_file=10,
3226+
delete_rows_per_file=[[0, 5], [3], []],
3227+
)
3228+
3229+
batches = list(scan.to_record_batches(tasks, streaming=True))
3230+
3231+
total_rows = sum(len(b) for b in batches)
3232+
assert total_rows == 27 # 30 - 3 deletes
3233+
all_values = sorted([v for b in batches for v in b.column("col").to_pylist()])
3234+
# File 0: 0-9, delete rows 0,5 → values 1,2,3,4,6,7,8,9
3235+
# File 1: 10-19, delete row 3 → values 10,11,12,14,15,16,17,18,19
3236+
# File 2: 20-29, no deletes → values 20-29
3237+
expected = [1, 2, 3, 4, 6, 7, 8, 9] + [10, 11, 12, 14, 15, 16, 17, 18, 19] + list(range(20, 30))
3238+
assert all_values == sorted(expected)
3239+
3240+
3241+
def test_streaming_with_positional_deletes_and_limit(tmpdir: str) -> None:
3242+
"""Test that streaming=True with positional deletes respects the row limit."""
3243+
# 3 files, 10 rows each; delete row 0 from each file
3244+
scan, tasks = _create_scan_and_tasks(
3245+
tmpdir,
3246+
num_files=3,
3247+
rows_per_file=10,
3248+
limit=15,
3249+
delete_rows_per_file=[[0], [0], [0]],
3250+
)
3251+
3252+
batches = list(scan.to_record_batches(tasks, streaming=True))
3253+
3254+
total_rows = sum(len(b) for b in batches)
3255+
assert total_rows == 15
3256+
3257+
3258+
def test_default_mode_with_positional_deletes(tmpdir: str) -> None:
3259+
"""Test that the default (non-streaming) mode correctly applies positional deletes."""
3260+
# 3 files, 10 rows each; delete rows from each file
3261+
scan, tasks = _create_scan_and_tasks(
3262+
tmpdir,
3263+
num_files=3,
3264+
rows_per_file=10,
3265+
delete_rows_per_file=[[0, 5], [3], []],
3266+
)
3267+
3268+
batches = list(scan.to_record_batches(tasks, streaming=False))
3269+
3270+
total_rows = sum(len(b) for b in batches)
3271+
assert total_rows == 27 # 30 - 3 deletes
3272+
all_values = sorted([v for b in batches for v in b.column("col").to_pylist()])
3273+
expected = [1, 2, 3, 4, 6, 7, 8, 9] + [10, 11, 12, 14, 15, 16, 17, 18, 19] + list(range(20, 30))
3274+
assert all_values == sorted(expected)
3275+
3276+
31093277
def test_parse_location_defaults() -> None:
31103278
"""Test that parse_location uses defaults."""
31113279

0 commit comments

Comments
 (0)