Skip to content

Commit af0d5e3

Browse files
feat: add support for streaming the batches using ArrowScan
1 parent b98de51 commit af0d5e3

File tree

4 files changed

+315
-7
lines changed

4 files changed

+315
-7
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,7 @@ def _task_to_record_batches(
15811581
partition_spec: PartitionSpec | None = None,
15821582
format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION,
15831583
downcast_ns_timestamp_to_us: bool | None = None,
1584+
batch_size: int | None = None,
15841585
) -> Iterator[pa.RecordBatch]:
15851586
arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
15861587
with io.new_input(task.file.file_path).open() as fin:
@@ -1612,14 +1613,17 @@ def _task_to_record_batches(
16121613

16131614
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
16141615

1615-
fragment_scanner = ds.Scanner.from_fragment(
1616-
fragment=fragment,
1617-
schema=physical_schema,
1616+
scanner_kwargs: dict[str, Any] = {
1617+
"fragment": fragment,
1618+
"schema": physical_schema,
16181619
# This will push down the query to Arrow.
16191620
# But in case there are positional deletes, we have to apply them first
1620-
filter=pyarrow_filter if not positional_deletes else None,
1621-
columns=[col.name for col in file_project_schema.columns],
1622-
)
1621+
"filter": pyarrow_filter if not positional_deletes else None,
1622+
"columns": [col.name for col in file_project_schema.columns],
1623+
}
1624+
if batch_size is not None:
1625+
scanner_kwargs["batch_size"] = batch_size
1626+
fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs)
16231627

16241628
next_index = 0
16251629
batches = fragment_scanner.to_batches()
@@ -1802,8 +1806,32 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
18021806
# This break will also cancel all running tasks in the executor
18031807
break
18041808

1809+
def to_record_batch_stream(
1810+
self, tasks: Iterable[FileScanTask], batch_size: int | None = None
1811+
) -> Iterator[pa.RecordBatch]:
1812+
"""Scan the Iceberg table and return an Iterator[pa.RecordBatch] in a streaming fashion.
1813+
1814+
Files are read sequentially and batches are yielded one at a time
1815+
without materializing all batches in memory. Use this when memory
1816+
efficiency is more important than throughput.
1817+
1818+
Args:
1819+
tasks: FileScanTasks representing the data files and delete files to read from.
1820+
batch_size: Maximum number of rows per RecordBatch. If None,
1821+
uses PyArrow's default (131,072 rows).
1822+
1823+
Yields:
1824+
pa.RecordBatch: Record batches from the scan, one at a time.
1825+
"""
1826+
tasks = list(tasks) if not isinstance(tasks, list) else tasks
1827+
deletes_per_file = _read_all_delete_files(self._io, tasks)
1828+
yield from self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file, batch_size)
1829+
18051830
def _record_batches_from_scan_tasks_and_deletes(
1806-
self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]]
1831+
self,
1832+
tasks: Iterable[FileScanTask],
1833+
deletes_per_file: dict[str, list[ChunkedArray]],
1834+
batch_size: int | None = None,
18071835
) -> Iterator[pa.RecordBatch]:
18081836
total_row_count = 0
18091837
for task in tasks:
@@ -1822,6 +1850,7 @@ def _record_batches_from_scan_tasks_and_deletes(
18221850
self._table_metadata.specs().get(task.file.spec_id),
18231851
self._table_metadata.format_version,
18241852
self._downcast_ns_timestamp_to_us,
1853+
batch_size,
18251854
)
18261855
for batch in batches:
18271856
if self._limit is not None:

pyiceberg/table/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2182,6 +2182,26 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
21822182
batches,
21832183
).cast(target_schema)
21842184

2185+
def to_record_batches(self, batch_size: int | None = None) -> Iterator["pa.RecordBatch"]:
2186+
"""Read record batches in a streaming fashion from this DataScan.
2187+
2188+
Files are read sequentially and batches are yielded one at a time
2189+
without materializing all batches in memory. Use this when memory
2190+
efficiency is more important than throughput.
2191+
2192+
Args:
2193+
batch_size: Maximum number of rows per RecordBatch. If None,
2194+
uses PyArrow's default (131,072 rows).
2195+
2196+
Yields:
2197+
pa.RecordBatch: Record batches from the scan, one at a time.
2198+
"""
2199+
from pyiceberg.io.pyarrow import ArrowScan
2200+
2201+
yield from ArrowScan(
2202+
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
2203+
).to_record_batch_stream(self.plan_files(), batch_size)
2204+
21852205
def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
21862206
"""Read a Pandas DataFrame eagerly from this Iceberg table.
21872207

tests/integration/test_reads.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,3 +1272,53 @@ def test_scan_source_field_missing_in_spec(catalog: Catalog, spark: SparkSession
12721272

12731273
table = catalog.load_table(identifier)
12741274
assert len(list(table.scan().plan_files())) == 3
1275+
1276+
1277+
@pytest.mark.integration
1278+
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
1279+
def test_datascan_to_record_batches(catalog: Catalog) -> None:
1280+
table = create_table(catalog)
1281+
1282+
arrow_table = pa.Table.from_pydict(
1283+
{
1284+
"str": ["a", "b", "c"],
1285+
"int": [1, 2, 3],
1286+
},
1287+
schema=pa.schema([pa.field("str", pa.large_string()), pa.field("int", pa.int32())]),
1288+
)
1289+
table.append(arrow_table)
1290+
1291+
scan = table.scan()
1292+
streaming_batches = list(scan.to_record_batches())
1293+
streaming_result = pa.concat_tables(
1294+
[pa.Table.from_batches([b]) for b in streaming_batches], promote_options="permissive"
1295+
)
1296+
1297+
eager_result = scan.to_arrow()
1298+
1299+
assert streaming_result.num_rows == eager_result.num_rows
1300+
assert streaming_result.column_names == eager_result.column_names
1301+
assert streaming_result.sort_by("int").equals(eager_result.sort_by("int"))
1302+
1303+
1304+
@pytest.mark.integration
1305+
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
1306+
def test_datascan_to_record_batches_with_batch_size(catalog: Catalog) -> None:
1307+
table = create_table(catalog)
1308+
1309+
arrow_table = pa.Table.from_pydict(
1310+
{
1311+
"str": [f"val_{i}" for i in range(100)],
1312+
"int": list(range(100)),
1313+
},
1314+
schema=pa.schema([pa.field("str", pa.large_string()), pa.field("int", pa.int32())]),
1315+
)
1316+
table.append(arrow_table)
1317+
1318+
scan = table.scan()
1319+
batches = list(scan.to_record_batches(batch_size=10))
1320+
1321+
total_rows = sum(len(b) for b in batches)
1322+
assert total_rows == 100
1323+
for batch in batches:
1324+
assert len(batch) <= 10

tests/io/test_pyarrow.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4884,3 +4884,212 @@ def test_partition_column_projection_with_schema_evolution(catalog: InMemoryCata
48844884
result_sorted = result.sort_by("name")
48854885
assert result_sorted["name"].to_pylist() == ["Alice", "Bob", "Charlie", "David"]
48864886
assert result_sorted["new_column"].to_pylist() == [None, None, "new1", "new2"]
4887+
4888+
4889+
def test_task_to_record_batches_with_batch_size(tmpdir: str) -> None:
4890+
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
4891+
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})
4892+
4893+
# Create a parquet file with 1000 rows
4894+
table = pa.Table.from_arrays([pa.array(list(range(1000)))], schema=pyarrow_schema)
4895+
data_file = _write_table_to_data_file(f"{tmpdir}/batch_size_test.parquet", pyarrow_schema, table)
4896+
data_file.spec_id = 0
4897+
4898+
task = FileScanTask(data_file=data_file)
4899+
4900+
batches = list(
4901+
_task_to_record_batches(
4902+
PyArrowFileIO(),
4903+
task,
4904+
bound_row_filter=AlwaysTrue(),
4905+
projected_schema=schema,
4906+
table_schema=schema,
4907+
projected_field_ids={1},
4908+
positional_deletes=None,
4909+
case_sensitive=True,
4910+
batch_size=100,
4911+
)
4912+
)
4913+
4914+
total_rows = sum(len(b) for b in batches)
4915+
assert total_rows == 1000
4916+
for batch in batches:
4917+
assert len(batch) <= 100
4918+
4919+
4920+
def test_to_record_batches_streaming_basic(tmpdir: str) -> None:
4921+
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
4922+
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})
4923+
4924+
table = pa.Table.from_arrays([pa.array(list(range(100)))], schema=pyarrow_schema)
4925+
data_file = _write_table_to_data_file(f"{tmpdir}/streaming_basic.parquet", pyarrow_schema, table)
4926+
data_file.spec_id = 0
4927+
4928+
task = FileScanTask(data_file=data_file)
4929+
4930+
scan = ArrowScan(
4931+
table_metadata=TableMetadataV2(
4932+
location="file://a/b/",
4933+
last_column_id=1,
4934+
format_version=2,
4935+
schemas=[schema],
4936+
partition_specs=[PartitionSpec()],
4937+
),
4938+
io=PyArrowFileIO(),
4939+
projected_schema=schema,
4940+
row_filter=AlwaysTrue(),
4941+
case_sensitive=True,
4942+
)
4943+
4944+
result = scan.to_record_batches_streaming([task])
4945+
# Should be a generator/iterator, not a list
4946+
import types
4947+
4948+
assert isinstance(result, types.GeneratorType)
4949+
4950+
batches = list(result)
4951+
total_rows = sum(len(b) for b in batches)
4952+
assert total_rows == 100
4953+
4954+
4955+
def test_to_record_batches_streaming_with_batch_size(tmpdir: str) -> None:
4956+
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
4957+
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})
4958+
4959+
table = pa.Table.from_arrays([pa.array(list(range(500)))], schema=pyarrow_schema)
4960+
data_file = _write_table_to_data_file(f"{tmpdir}/streaming_batch_size.parquet", pyarrow_schema, table)
4961+
data_file.spec_id = 0
4962+
4963+
task = FileScanTask(data_file=data_file)
4964+
4965+
scan = ArrowScan(
4966+
table_metadata=TableMetadataV2(
4967+
location="file://a/b/",
4968+
last_column_id=1,
4969+
format_version=2,
4970+
schemas=[schema],
4971+
partition_specs=[PartitionSpec()],
4972+
),
4973+
io=PyArrowFileIO(),
4974+
projected_schema=schema,
4975+
row_filter=AlwaysTrue(),
4976+
case_sensitive=True,
4977+
)
4978+
4979+
batches = list(scan.to_record_batches_streaming([task], batch_size=50))
4980+
4981+
total_rows = sum(len(b) for b in batches)
4982+
assert total_rows == 500
4983+
for batch in batches:
4984+
assert len(batch) <= 50
4985+
4986+
4987+
def test_to_record_batches_streaming_with_limit(tmpdir: str) -> None:
4988+
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
4989+
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})
4990+
4991+
table = pa.Table.from_arrays([pa.array(list(range(500)))], schema=pyarrow_schema)
4992+
data_file = _write_table_to_data_file(f"{tmpdir}/streaming_limit.parquet", pyarrow_schema, table)
4993+
data_file.spec_id = 0
4994+
4995+
task = FileScanTask(data_file=data_file)
4996+
4997+
scan = ArrowScan(
4998+
table_metadata=TableMetadataV2(
4999+
location="file://a/b/",
5000+
last_column_id=1,
5001+
format_version=2,
5002+
schemas=[schema],
5003+
partition_specs=[PartitionSpec()],
5004+
),
5005+
io=PyArrowFileIO(),
5006+
projected_schema=schema,
5007+
row_filter=AlwaysTrue(),
5008+
case_sensitive=True,
5009+
limit=100,
5010+
)
5011+
5012+
batches = list(scan.to_record_batches_streaming([task]))
5013+
5014+
total_rows = sum(len(b) for b in batches)
5015+
assert total_rows == 100
5016+
5017+
5018+
def test_to_record_batches_streaming_with_deletes(
5019+
deletes_file: str, request: pytest.FixtureRequest, table_schema_simple: Schema
5020+
) -> None:
5021+
file_format = FileFormat.PARQUET if deletes_file.endswith(".parquet") else FileFormat.ORC
5022+
5023+
if file_format == FileFormat.PARQUET:
5024+
example_task = request.getfixturevalue("example_task")
5025+
else:
5026+
example_task = request.getfixturevalue("example_task_orc")
5027+
5028+
example_task_with_delete = FileScanTask(
5029+
data_file=example_task.file,
5030+
delete_files={
5031+
DataFile.from_args(
5032+
content=DataFileContent.POSITION_DELETES,
5033+
file_path=deletes_file,
5034+
file_format=file_format,
5035+
)
5036+
},
5037+
)
5038+
5039+
metadata_location = "file://a/b/c.json"
5040+
scan = ArrowScan(
5041+
table_metadata=TableMetadataV2(
5042+
location=metadata_location,
5043+
last_column_id=1,
5044+
format_version=2,
5045+
current_schema_id=1,
5046+
schemas=[table_schema_simple],
5047+
partition_specs=[PartitionSpec()],
5048+
),
5049+
io=load_file_io(),
5050+
projected_schema=table_schema_simple,
5051+
row_filter=AlwaysTrue(),
5052+
)
5053+
5054+
# Compare streaming path to table path
5055+
streaming_batches = list(scan.to_record_batches_streaming([example_task_with_delete]))
5056+
streaming_table = pa.concat_tables(
5057+
[pa.Table.from_batches([b]) for b in streaming_batches], promote_options="permissive"
5058+
)
5059+
eager_table = scan.to_table(tasks=[example_task_with_delete])
5060+
5061+
assert streaming_table.num_rows == eager_table.num_rows
5062+
assert streaming_table.column_names == eager_table.column_names
5063+
5064+
5065+
def test_to_record_batches_streaming_multiple_files(tmpdir: str) -> None:
5066+
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
5067+
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})
5068+
5069+
tasks = []
5070+
total_expected = 0
5071+
for i in range(3):
5072+
num_rows = (i + 1) * 100 # 100, 200, 300
5073+
total_expected += num_rows
5074+
table = pa.Table.from_arrays([pa.array(list(range(num_rows)))], schema=pyarrow_schema)
5075+
data_file = _write_table_to_data_file(f"{tmpdir}/multi_{i}.parquet", pyarrow_schema, table)
5076+
data_file.spec_id = 0
5077+
tasks.append(FileScanTask(data_file=data_file))
5078+
5079+
scan = ArrowScan(
5080+
table_metadata=TableMetadataV2(
5081+
location="file://a/b/",
5082+
last_column_id=1,
5083+
format_version=2,
5084+
schemas=[schema],
5085+
partition_specs=[PartitionSpec()],
5086+
),
5087+
io=PyArrowFileIO(),
5088+
projected_schema=schema,
5089+
row_filter=AlwaysTrue(),
5090+
case_sensitive=True,
5091+
)
5092+
5093+
batches = list(scan.to_record_batches_streaming(tasks))
5094+
total_rows = sum(len(b) for b in batches)
5095+
assert total_rows == total_expected # 600 rows total

0 commit comments

Comments
 (0)