Skip to content

Commit 7ee5924

Browse files
committed
feat: add to_record_batch_stream function and update DataFrame iteration methods
1 parent 6e85080 commit 7ee5924

File tree

5 files changed

+52
-11
lines changed

5 files changed

+52
-11
lines changed

python/datafusion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
)
5454
from .io import read_avro, read_csv, read_json, read_parquet
5555
from .plan import ExecutionPlan, LogicalPlan
56-
from .record_batch import RecordBatch, RecordBatchStream
56+
from .record_batch import RecordBatch, RecordBatchStream, to_record_batch_stream
5757
from .user_defined import (
5858
Accumulator,
5959
AggregateUDF,
@@ -107,6 +107,7 @@
107107
"read_json",
108108
"read_parquet",
109109
"substrait",
110+
"to_record_batch_stream",
110111
"udaf",
111112
"udf",
112113
"udtf",

python/datafusion/dataframe.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from typing import (
2626
TYPE_CHECKING,
2727
Any,
28+
AsyncIterator,
2829
Iterable,
2930
Iterator,
3031
Literal,
@@ -43,7 +44,11 @@
4344
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
4445
from datafusion.expr import Expr, SortExpr, sort_or_default
4546
from datafusion.plan import ExecutionPlan, LogicalPlan
46-
from datafusion.record_batch import RecordBatchStream
47+
from datafusion.record_batch import (
48+
RecordBatch,
49+
RecordBatchStream,
50+
to_record_batch_stream,
51+
)
4752

4853
if TYPE_CHECKING:
4954
import pathlib
@@ -1123,15 +1128,20 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
11231128
return self.df.__arrow_c_stream__(requested_schema)
11241129

11251130
def __iter__(self) -> Iterator[RecordBatch]:
1126-
"""Yield DataFusion record batches without materializing results.
1131+
"""Yield record batches from the DataFrame without materializing results.
1132+
1133+
This implementation delegates to :func:`to_record_batch_stream`, which
1134+
executes the DataFrame and returns a :class:`RecordBatchStream`.
1135+
"""
1136+
return to_record_batch_stream(self).__iter__()
1137+
1138+
def __aiter__(self) -> AsyncIterator[RecordBatch]:
1139+
"""Asynchronously yield record batches from the DataFrame.
11271140
1128-
Batches are produced lazily using DataFusion's partitioned streaming
1129-
APIs so ``collect`` is never invoked. Each returned batch exposes the
1130-
Arrow C data interface and can be consumed by downstream libraries that
1131-
support ``__arrow_c_array__``.
1141+
This delegates to :func:`to_record_batch_stream` to obtain a
1142+
:class:`RecordBatchStream` and returns its asynchronous iterator.
11321143
"""
1133-
for stream in self.execute_stream_partitioned():
1134-
yield from stream
1144+
return to_record_batch_stream(self).__aiter__()
11351145

11361146
def transform(self, func: Callable[..., DataFrame], *args: Any) -> DataFrame:
11371147
"""Apply a function to the current DataFrame which returns another DataFrame.

python/datafusion/record_batch.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@
2525

2626
from typing import TYPE_CHECKING
2727

28+
import datafusion._internal as df_internal
29+
2830
if TYPE_CHECKING:
2931
import pyarrow as pa
3032
import typing_extensions
3133

32-
import datafusion._internal as df_internal
34+
from datafusion.dataframe import DataFrame
3335

3436

3537
class RecordBatch:
@@ -79,3 +81,15 @@ def __aiter__(self) -> typing_extensions.Self:
7981
def __iter__(self) -> typing_extensions.Self:
8082
"""Iterator function."""
8183
return self
84+
85+
86+
def to_record_batch_stream(df: DataFrame) -> RecordBatchStream:
87+
"""Convert a DataFrame into a RecordBatchStream.
88+
89+
Args:
90+
df: DataFrame to convert.
91+
92+
Returns:
93+
A RecordBatchStream representing the DataFrame.
94+
"""
95+
return df.execute_stream()

python/tests/test_dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1314,7 +1314,7 @@ def test_execution_plan(aggregate_df):
13141314
@pytest.mark.asyncio
13151315
async def test_async_iteration_of_df(aggregate_df):
13161316
rows_returned = 0
1317-
async for batch in aggregate_df.execute_stream():
1317+
async for batch in aggregate_df:
13181318
assert batch is not None
13191319
rows_returned += len(batch.to_pyarrow()[0])
13201320

python/tests/test_dataframe_iter_stream.py renamed to python/tests/test_dataframe_iter.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,26 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import pytest
1819
import datafusion
1920

2021

22+
def test_iter_dataframe(ctx):
23+
df = ctx.from_pydict({"a": [1, 2]})
24+
batches = [batch.to_pyarrow() for batch in df]
25+
assert len(batches) == 1
26+
assert batches[0].column(0).to_pylist() == [1, 2]
27+
28+
2129
def test_iter_returns_record_batch(ctx):
2230
df = ctx.from_pydict({"a": [1, 2]})
2331
batch = next(iter(df))
2432
assert isinstance(batch, datafusion.RecordBatch)
33+
34+
35+
@pytest.mark.asyncio
36+
async def test_async_iter_dataframe(ctx):
37+
df = ctx.from_pydict({"a": [1, 2]})
38+
batches = [batch async for batch in df]
39+
assert len(batches) == 1
40+
assert batches[0].to_pyarrow().column(0).to_pylist() == [1, 2]

0 commit comments

Comments
 (0)