Skip to content

Commit b72b7ba

Browse files
sumedhsakdeoclaude
andcommitted
feat: add streaming flag to ArrowScan.to_record_batches
When streaming=True, batches are yielded as they are produced by PyArrow without materializing entire files into memory. Files are still processed sequentially, preserving file ordering. The inner method handles the global limit correctly when called with all tasks, avoiding double-counting. This addresses the OOM issue in apache#3036 for single-file-at-a-time streaming. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8f8a2d2 commit b72b7ba

File tree

4 files changed

+136
-26
lines changed

4 files changed

+136
-26
lines changed

mkdocs/docs/api.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,13 @@ for buf in tbl.scan().to_arrow_batch_reader(batch_size=1000):
362362
print(f"Buffer contains {len(buf)} rows")
363363
```
364364

365+
By default, each file's batches are materialized in memory before being yielded. For large files that may exceed available memory, use `streaming=True` to yield batches as they are produced without materializing entire files:
366+
367+
```python
368+
for buf in tbl.scan().to_arrow_batch_reader(streaming=True, batch_size=1000):
369+
print(f"Buffer contains {len(buf)} rows")
370+
```
371+
365372
To avoid any type inconsistencies during writing, you can convert the Iceberg table schema to Arrow:
366373

367374
```python
@@ -1635,6 +1642,15 @@ table.scan(
16351642
).to_arrow_batch_reader(batch_size=1000)
16361643
```
16371644

1645+
Use `streaming=True` to avoid materializing entire files in memory. This yields batches as they are produced by PyArrow, one file at a time:
1646+
1647+
```python
1648+
table.scan(
1649+
row_filter=GreaterThanOrEqual("trip_distance", 10.0),
1650+
selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"),
1651+
).to_arrow_batch_reader(streaming=True)
1652+
```
1653+
16381654
### Pandas
16391655

16401656
<!-- prettier-ignore-start -->

pyiceberg/io/pyarrow.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,7 +1761,9 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
17611761

17621762
return result
17631763

1764-
def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | None = None) -> Iterator[pa.RecordBatch]:
1764+
def to_record_batches(
1765+
self, tasks: Iterable[FileScanTask], batch_size: int | None = None, streaming: bool = False
1766+
) -> Iterator[pa.RecordBatch]:
17651767
"""Scan the Iceberg table and return an Iterator[pa.RecordBatch].
17661768
17671769
Returns an Iterator of pa.RecordBatch with data from the Iceberg table
@@ -1770,6 +1772,9 @@ def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | Non
17701772
17711773
Args:
17721774
tasks: FileScanTasks representing the data files and delete files to read from.
1775+
batch_size: The number of rows per batch. If None, PyArrow's default is used.
1776+
streaming: If True, yield batches as they are produced without materializing
1777+
entire files into memory. Files are still processed sequentially.
17731778
17741779
Returns:
17751780
An Iterator of PyArrow RecordBatches.
@@ -1781,31 +1786,38 @@ def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | Non
17811786
"""
17821787
deletes_per_file = _read_all_delete_files(self._io, tasks)
17831788

1784-
total_row_count = 0
1785-
executor = ExecutorFactory.get_or_create()
1786-
1787-
def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
1788-
# Materialize the iterator here to ensure execution happens within the executor.
1789-
# Otherwise, the iterator would be lazily consumed later (in the main thread),
1790-
# defeating the purpose of using executor.map.
1791-
return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size))
1792-
1793-
limit_reached = False
1794-
for batches in executor.map(batches_for_task, tasks):
1795-
for batch in batches:
1796-
current_batch_size = len(batch)
1797-
if self._limit is not None and total_row_count + current_batch_size >= self._limit:
1798-
yield batch.slice(0, self._limit - total_row_count)
1789+
if streaming:
1790+
# Streaming path: process all tasks sequentially, yielding batches as produced.
1791+
# _record_batches_from_scan_tasks_and_deletes handles the limit internally
1792+
# when called with all tasks, so no outer limit check is needed.
1793+
yield from self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file, batch_size)
1794+
else:
1795+
# Non-streaming path: existing behavior with executor.map + list()
1796+
total_row_count = 0
1797+
executor = ExecutorFactory.get_or_create()
1798+
1799+
def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
1800+
# Materialize the iterator here to ensure execution happens within the executor.
1801+
# Otherwise, the iterator would be lazily consumed later (in the main thread),
1802+
# defeating the purpose of using executor.map.
1803+
return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size))
1804+
1805+
limit_reached = False
1806+
for batches in executor.map(batches_for_task, tasks):
1807+
for batch in batches:
1808+
current_batch_size = len(batch)
1809+
if self._limit is not None and total_row_count + current_batch_size >= self._limit:
1810+
yield batch.slice(0, self._limit - total_row_count)
1811+
1812+
limit_reached = True
1813+
break
1814+
else:
1815+
yield batch
1816+
total_row_count += current_batch_size
17991817

1800-
limit_reached = True
1818+
if limit_reached:
1819+
# This break will also cancel all running tasks in the executor
18011820
break
1802-
else:
1803-
yield batch
1804-
total_row_count += current_batch_size
1805-
1806-
if limit_reached:
1807-
# This break will also cancel all running tasks in the executor
1808-
break
18091821

18101822
def _record_batches_from_scan_tasks_and_deletes(
18111823
self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]], batch_size: int | None = None

pyiceberg/table/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2157,7 +2157,7 @@ def to_arrow(self) -> pa.Table:
21572157
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
21582158
).to_table(self.plan_files())
21592159

2160-
def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatchReader:
2160+
def to_arrow_batch_reader(self, batch_size: int | None = None, streaming: bool = False) -> pa.RecordBatchReader:
21612161
"""Return an Arrow RecordBatchReader from this DataScan.
21622162
21632163
For large results, using a RecordBatchReader requires less memory than
@@ -2166,6 +2166,8 @@ def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatch
21662166
21672167
Args:
21682168
batch_size: The number of rows per batch. If None, PyArrow's default is used.
2169+
streaming: If True, yield batches as they are produced without materializing
2170+
entire files into memory. Files are still processed sequentially.
21692171
21702172
Returns:
21712173
pa.RecordBatchReader: Arrow RecordBatchReader from the Iceberg table's DataScan
@@ -2178,7 +2180,7 @@ def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatch
21782180
target_schema = schema_to_pyarrow(self.projection())
21792181
batches = ArrowScan(
21802182
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
2181-
).to_record_batches(self.plan_files(), batch_size=batch_size)
2183+
).to_record_batches(self.plan_files(), batch_size=batch_size, streaming=streaming)
21822184

21832185
return pa.RecordBatchReader.from_batches(
21842186
target_schema,

tests/io/test_pyarrow.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3106,6 +3106,86 @@ def test_task_to_record_batches_default_batch_size(tmpdir: str) -> None:
31063106
assert len(batches[0]) == num_rows
31073107

31083108

3109+
def _create_scan_and_tasks(
3110+
tmpdir: str, num_files: int = 1, rows_per_file: int = 100, limit: int | None = None
3111+
) -> tuple[ArrowScan, list[FileScanTask]]:
3112+
"""Helper to create an ArrowScan and FileScanTasks for testing."""
3113+
table_schema = Schema(NestedField(1, "col", LongType(), required=True))
3114+
pa_schema = pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})])
3115+
tasks = []
3116+
for i in range(num_files):
3117+
start = i * rows_per_file
3118+
arrow_table = pa.table({"col": pa.array(range(start, start + rows_per_file))}, schema=pa_schema)
3119+
data_file = _write_table_to_data_file(f"{tmpdir}/file_{i}.parquet", pa_schema, arrow_table)
3120+
data_file.spec_id = 0
3121+
tasks.append(FileScanTask(data_file))
3122+
3123+
scan = ArrowScan(
3124+
table_metadata=TableMetadataV2(
3125+
location="file://a/b/",
3126+
last_column_id=1,
3127+
format_version=2,
3128+
schemas=[table_schema],
3129+
partition_specs=[PartitionSpec()],
3130+
),
3131+
io=PyArrowFileIO(),
3132+
projected_schema=table_schema,
3133+
row_filter=AlwaysTrue(),
3134+
case_sensitive=True,
3135+
limit=limit,
3136+
)
3137+
return scan, tasks
3138+
3139+
3140+
def test_streaming_false_produces_same_results(tmpdir: str) -> None:
3141+
"""Test that streaming=False produces the same results as the default behavior."""
3142+
scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100)
3143+
3144+
batches_default = list(scan.to_record_batches(tasks, streaming=False))
3145+
# Re-create tasks since iterators are consumed
3146+
_, tasks2 = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100)
3147+
batches_streaming = list(scan.to_record_batches(tasks2, streaming=False))
3148+
3149+
total_default = sum(len(b) for b in batches_default)
3150+
total_streaming = sum(len(b) for b in batches_streaming)
3151+
assert total_default == 300
3152+
assert total_streaming == 300
3153+
3154+
3155+
def test_streaming_true_yields_all_batches(tmpdir: str) -> None:
3156+
"""Test that streaming=True yields all batches correctly."""
3157+
scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100)
3158+
3159+
batches = list(scan.to_record_batches(tasks, streaming=True))
3160+
3161+
total_rows = sum(len(b) for b in batches)
3162+
assert total_rows == 300
3163+
# Verify all values are present
3164+
all_values = sorted([v for b in batches for v in b.column("col").to_pylist()])
3165+
assert all_values == list(range(300))
3166+
3167+
3168+
def test_streaming_true_with_limit(tmpdir: str) -> None:
3169+
"""Test that streaming=True respects the row limit."""
3170+
scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100, limit=150)
3171+
3172+
batches = list(scan.to_record_batches(tasks, streaming=True))
3173+
3174+
total_rows = sum(len(b) for b in batches)
3175+
assert total_rows == 150
3176+
3177+
3178+
def test_streaming_file_ordering_preserved(tmpdir: str) -> None:
3179+
"""Test that file ordering is preserved in both streaming modes."""
3180+
scan, tasks = _create_scan_and_tasks(tmpdir, num_files=3, rows_per_file=100)
3181+
3182+
batches = list(scan.to_record_batches(tasks, streaming=True))
3183+
all_values = [v for b in batches for v in b.column("col").to_pylist()]
3184+
3185+
# Values should be in file order: 0-99 from file 0, 100-199 from file 1, 200-299 from file 2
3186+
assert all_values == list(range(300))
3187+
3188+
31093189
def test_parse_location_defaults() -> None:
31103190
"""Test that parse_location uses defaults."""
31113191

0 commit comments

Comments
 (0)