Skip to content

Commit 8f8a2d2

Browse files
sumedhsakdeoclaude
andcommitted
feat: forward batch_size parameter to PyArrow Scanner
Add batch_size parameter to _task_to_record_batches, _record_batches_from_scan_tasks_and_deletes, ArrowScan.to_record_batches, and DataScan.to_arrow_batch_reader so users can control the number of rows per RecordBatch returned by PyArrow's Scanner. Closes partially #3036 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7d4a8ef commit 8f8a2d2

File tree

4 files changed

+94
-11
lines changed

4 files changed

+94
-11
lines changed

mkdocs/docs/api.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,13 @@ for buf in tbl.scan().to_arrow_batch_reader():
355355
print(f"Buffer contains {len(buf)} rows")
356356
```
357357

358+
You can control the number of rows per batch using the `batch_size` parameter:
359+
360+
```python
361+
for buf in tbl.scan().to_arrow_batch_reader(batch_size=1000):
362+
print(f"Buffer contains {len(buf)} rows")
363+
```
364+
358365
To avoid any type inconsistencies during writing, you can convert the Iceberg table schema to Arrow:
359366

360367
```python
@@ -1619,6 +1626,15 @@ table.scan(
16191626
).to_arrow_batch_reader()
16201627
```
16211628

1629+
The `batch_size` parameter controls the maximum number of rows per RecordBatch (default is PyArrow's 131,072 rows):
1630+
1631+
```python
1632+
table.scan(
1633+
row_filter=GreaterThanOrEqual("trip_distance", 10.0),
1634+
selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"),
1635+
).to_arrow_batch_reader(batch_size=1000)
1636+
```
1637+
16221638
### Pandas
16231639

16241640
<!-- prettier-ignore-start -->

pyiceberg/io/pyarrow.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,7 @@ def _task_to_record_batches(
15811581
partition_spec: PartitionSpec | None = None,
15821582
format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION,
15831583
downcast_ns_timestamp_to_us: bool | None = None,
1584+
batch_size: int | None = None,
15841585
) -> Iterator[pa.RecordBatch]:
15851586
arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
15861587
with io.new_input(task.file.file_path).open() as fin:
@@ -1612,14 +1613,18 @@ def _task_to_record_batches(
16121613

16131614
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
16141615

1615-
fragment_scanner = ds.Scanner.from_fragment(
1616-
fragment=fragment,
1617-
schema=physical_schema,
1616+
scanner_kwargs: dict[str, Any] = {
1617+
"fragment": fragment,
1618+
"schema": physical_schema,
16181619
# This will push down the query to Arrow.
16191620
# But in case there are positional deletes, we have to apply them first
1620-
filter=pyarrow_filter if not positional_deletes else None,
1621-
columns=[col.name for col in file_project_schema.columns],
1622-
)
1621+
"filter": pyarrow_filter if not positional_deletes else None,
1622+
"columns": [col.name for col in file_project_schema.columns],
1623+
}
1624+
if batch_size is not None:
1625+
scanner_kwargs["batch_size"] = batch_size
1626+
1627+
fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs)
16231628

16241629
next_index = 0
16251630
batches = fragment_scanner.to_batches()
@@ -1756,7 +1761,7 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
17561761

17571762
return result
17581763

1759-
def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]:
1764+
def to_record_batches(self, tasks: Iterable[FileScanTask], batch_size: int | None = None) -> Iterator[pa.RecordBatch]:
17601765
"""Scan the Iceberg table and return an Iterator[pa.RecordBatch].
17611766
17621767
Returns an Iterator of pa.RecordBatch with data from the Iceberg table
@@ -1783,7 +1788,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
17831788
# Materialize the iterator here to ensure execution happens within the executor.
17841789
# Otherwise, the iterator would be lazily consumed later (in the main thread),
17851790
# defeating the purpose of using executor.map.
1786-
return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file))
1791+
return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size))
17871792

17881793
limit_reached = False
17891794
for batches in executor.map(batches_for_task, tasks):
@@ -1803,7 +1808,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
18031808
break
18041809

18051810
def _record_batches_from_scan_tasks_and_deletes(
1806-
self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]]
1811+
self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]], batch_size: int | None = None
18071812
) -> Iterator[pa.RecordBatch]:
18081813
total_row_count = 0
18091814
for task in tasks:
@@ -1822,6 +1827,7 @@ def _record_batches_from_scan_tasks_and_deletes(
18221827
self._table_metadata.specs().get(task.file.spec_id),
18231828
self._table_metadata.format_version,
18241829
self._downcast_ns_timestamp_to_us,
1830+
batch_size,
18251831
)
18261832
for batch in batches:
18271833
if self._limit is not None:

pyiceberg/table/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2157,13 +2157,16 @@ def to_arrow(self) -> pa.Table:
21572157
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
21582158
).to_table(self.plan_files())
21592159

2160-
def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
2160+
def to_arrow_batch_reader(self, batch_size: int | None = None) -> pa.RecordBatchReader:
21612161
"""Return an Arrow RecordBatchReader from this DataScan.
21622162
21632163
For large results, using a RecordBatchReader requires less memory than
21642164
loading an Arrow Table for the same DataScan, because a RecordBatch
21652165
is read one at a time.
21662166
2167+
Args:
2168+
batch_size: The number of rows per batch. If None, PyArrow's default is used.
2169+
21672170
Returns:
21682171
pa.RecordBatchReader: Arrow RecordBatchReader from the Iceberg table's DataScan
21692172
which can be used to read a stream of record batches one by one.
@@ -2175,7 +2178,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
21752178
target_schema = schema_to_pyarrow(self.projection())
21762179
batches = ArrowScan(
21772180
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
2178-
).to_record_batches(self.plan_files())
2181+
).to_record_batches(self.plan_files(), batch_size=batch_size)
21792182

21802183
return pa.RecordBatchReader.from_batches(
21812184
target_schema,

tests/io/test_pyarrow.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3048,6 +3048,64 @@ def _expected_batch(unit: str) -> pa.RecordBatch:
30483048
assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result)
30493049

30503050

3051+
def test_task_to_record_batches_with_batch_size(tmpdir: str) -> None:
3052+
"""Test that batch_size controls the number of rows per batch."""
3053+
num_rows = 1000
3054+
arrow_table = pa.table(
3055+
{"col": pa.array(range(num_rows))},
3056+
schema=pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]),
3057+
)
3058+
data_file = _write_table_to_data_file(f"{tmpdir}/test_batch_size.parquet", arrow_table.schema, arrow_table)
3059+
table_schema = Schema(NestedField(1, "col", LongType(), required=True))
3060+
3061+
batches = list(
3062+
_task_to_record_batches(
3063+
PyArrowFileIO(),
3064+
FileScanTask(data_file),
3065+
bound_row_filter=AlwaysTrue(),
3066+
projected_schema=table_schema,
3067+
table_schema=table_schema,
3068+
projected_field_ids={1},
3069+
positional_deletes=None,
3070+
case_sensitive=True,
3071+
batch_size=100,
3072+
)
3073+
)
3074+
3075+
assert len(batches) > 1
3076+
for batch in batches:
3077+
assert len(batch) <= 100
3078+
assert sum(len(b) for b in batches) == num_rows
3079+
3080+
3081+
def test_task_to_record_batches_default_batch_size(tmpdir: str) -> None:
3082+
"""Test that batch_size=None uses PyArrow default (single batch for small files)."""
3083+
num_rows = 100
3084+
arrow_table = pa.table(
3085+
{"col": pa.array(range(num_rows))},
3086+
schema=pa.schema([pa.field("col", pa.int64(), nullable=False, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]),
3087+
)
3088+
data_file = _write_table_to_data_file(f"{tmpdir}/test_default_batch_size.parquet", arrow_table.schema, arrow_table)
3089+
table_schema = Schema(NestedField(1, "col", LongType(), required=True))
3090+
3091+
batches = list(
3092+
_task_to_record_batches(
3093+
PyArrowFileIO(),
3094+
FileScanTask(data_file),
3095+
bound_row_filter=AlwaysTrue(),
3096+
projected_schema=table_schema,
3097+
table_schema=table_schema,
3098+
projected_field_ids={1},
3099+
positional_deletes=None,
3100+
case_sensitive=True,
3101+
)
3102+
)
3103+
3104+
# With default batch_size, a small file should produce a single batch
3105+
assert len(batches) == 1
3106+
assert len(batches[0]) == num_rows
3107+
3108+
30513109
def test_parse_location_defaults() -> None:
30523110
"""Test that parse_location uses defaults."""
30533111

0 commit comments

Comments
 (0)