Skip to content

Commit 5c652fd

Browse files
committed
refactor: add to_record_batch_stream method and improve iteration support in DataFrame
1 parent 2a839d9 commit 5c652fd

3 files changed

Lines changed: 34 additions & 65 deletions

File tree

python/datafusion/dataframe.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from typing import (
2626
TYPE_CHECKING,
2727
Any,
28+
AsyncIterator,
2829
Iterable,
2930
Iterator,
3031
Literal,
@@ -1043,6 +1044,15 @@ def execute_stream_partitioned(self) -> list[RecordBatchStream]:
10431044
streams = self.df.execute_stream_partitioned()
10441045
return [RecordBatchStream(rbs) for rbs in streams]
10451046

1047+
def to_record_batch_stream(self) -> RecordBatchStream:
1048+
"""Return a :py:class:`RecordBatchStream` over this DataFrame's results.
1049+
1050+
Returns:
1051+
A ``RecordBatchStream`` representing the lazily generated record
1052+
batches for this DataFrame.
1053+
"""
1054+
return self.execute_stream()
1055+
10461056
def to_pandas(self) -> pd.DataFrame:
10471057
"""Execute the :py:class:`DataFrame` and convert it into a Pandas DataFrame.
10481058
@@ -1126,12 +1136,12 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
11261136
return self.df.__arrow_c_stream__(requested_schema)
11271137

11281138
def __iter__(self) -> Iterator[RecordBatch]:
1129-
"""Yield :class:`RecordBatch` objects by streaming execution."""
1130-
yield from self.to_record_batch_stream()
1139+
"""Return an iterator over this DataFrame's record batches."""
1140+
return iter(self.to_record_batch_stream())
11311141

1132-
async def __aiter__(self) -> RecordBatchStream:
1133-
"""Return an asynchronous iterator over streamed ``RecordBatch`` objects."""
1134-
return await self.to_record_batch_stream().__aiter__()
1142+
def __aiter__(self) -> AsyncIterator[RecordBatch]:
1143+
"""Return an async iterator over this DataFrame's record batches."""
1144+
return self.to_record_batch_stream().__aiter__()
11351145

11361146
def transform(self, func: Callable[..., DataFrame], *args: Any) -> DataFrame:
11371147
"""Apply a function to the current DataFrame which returns another DataFrame.

python/tests/test_dataframe.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DataFrame,
3030
ParquetColumnOptions,
3131
ParquetWriterOptions,
32+
RecordBatch,
3233
SessionContext,
3334
WindowFrame,
3435
column,
@@ -390,10 +391,23 @@ def test_iter_batches(df):
390391
assert len(batches) == 1
391392

392393
batch = batches[0]
393-
assert isinstance(batch, pa.RecordBatch)
394-
assert batch.column(0).to_pylist() == [1, 2, 3]
395-
assert batch.column(1).to_pylist() == [4, 5, 6]
396-
assert batch.column(2).to_pylist() == [8, 5, 8]
394+
assert isinstance(batch, RecordBatch)
395+
pa_batch = batch.to_pyarrow()
396+
assert pa_batch.column(0).to_pylist() == [1, 2, 3]
397+
assert pa_batch.column(1).to_pylist() == [4, 5, 6]
398+
assert pa_batch.column(2).to_pylist() == [8, 5, 8]
399+
400+
401+
def test_to_record_batch_stream(df):
402+
stream = df.to_record_batch_stream()
403+
batches = list(stream)
404+
405+
assert len(batches) == 1
406+
assert isinstance(batches[0], RecordBatch)
407+
pa_batch = batches[0].to_pyarrow()
408+
assert pa_batch.column(0).to_pylist() == [1, 2, 3]
409+
assert pa_batch.column(1).to_pylist() == [4, 5, 6]
410+
assert pa_batch.column(2).to_pylist() == [8, 5, 8]
397411

398412

399413
def test_with_column_renamed(df):
@@ -1331,7 +1345,7 @@ def test_execution_plan(aggregate_df):
13311345
@pytest.mark.asyncio
13321346
async def test_async_iteration_of_df(aggregate_df):
13331347
rows_returned = 0
1334-
async for batch in aggregate_df.execute_stream():
1348+
async for batch in aggregate_df:
13351349
assert batch is not None
13361350
rows_returned += len(batch.to_pyarrow()[0])
13371351

python/tests/test_dataframe_iter_stream.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

0 commit comments

Comments
 (0)