Skip to content

Commit c35ee60

Browse files
committed
fix: update DataFrame iteration to use num_rows and correct RecordBatch type
1 parent 74c95e8 commit c35ee60

2 files changed

Lines changed: 7 additions & 6 deletions

File tree

python/tests/test_dataframe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,7 @@ def test_execution_plan(aggregate_df):
13021302
try:
13031303
batch = stream.next()
13041304
assert batch is not None
1305-
rows_returned += len(batch.to_pyarrow()[0])
1305+
rows_returned += batch.num_rows
13061306
except StopIteration:
13071307
# This is one of the partitions with no values
13081308
pass
@@ -1317,7 +1317,7 @@ async def test_async_iteration_of_df(aggregate_df):
13171317
rows_returned = 0
13181318
async for batch in aggregate_df:
13191319
assert batch is not None
1320-
rows_returned += len(batch.to_pyarrow()[0])
1320+
rows_returned += batch.num_rows
13211321

13221322
assert rows_returned == 5
13231323

@@ -1728,6 +1728,7 @@ def test_arrow_c_stream_capsule_released(ctx):
17281728
get_ptr(capsule, b"arrow_array_stream")
17291729
pyerr_clear()
17301730

1731+
17311732
def test_to_pylist(df):
17321733
# Convert datafusion dataframe to Python list
17331734
pylist = df.to_pylist()

python/tests/test_dataframe_iter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,26 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import pyarrow as pa
1819
import pytest
19-
import datafusion
2020

2121

2222
def test_iter_dataframe(ctx):
2323
df = ctx.from_pydict({"a": [1, 2]})
24-
batches = [batch.to_pyarrow() for batch in df]
24+
batches = list(df)
2525
assert len(batches) == 1
2626
assert batches[0].column(0).to_pylist() == [1, 2]
2727

2828

2929
def test_iter_returns_record_batch(ctx):
3030
df = ctx.from_pydict({"a": [1, 2]})
3131
batch = next(iter(df))
32-
assert isinstance(batch, datafusion.RecordBatch)
32+
assert isinstance(batch, pa.RecordBatch)
3333

3434

3535
@pytest.mark.asyncio
3636
async def test_async_iter_dataframe(ctx):
3737
df = ctx.from_pydict({"a": [1, 2]})
3838
batches = [batch async for batch in df]
3939
assert len(batches) == 1
40-
assert batches[0].to_pyarrow().column(0).to_pylist() == [1, 2]
40+
assert batches[0].column(0).to_pylist() == [1, 2]

0 commit comments

Comments
 (0)