Skip to content

Commit babeba2

Browse files
fix tests
1 parent f43130c commit babeba2

File tree

4 files changed

+14
-20
lines changed

4 files changed

+14
-20
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,9 +1804,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
18041804
# This break will also cancel all running tasks in the executor
18051805
break
18061806

1807-
def to_record_batch_stream(
1808-
self, tasks: Iterable[FileScanTask], batch_size: int | None = None
1809-
) -> Iterator[pa.RecordBatch]:
1807+
def to_record_batch_stream(self, tasks: Iterable[FileScanTask], batch_size: int | None = None) -> Iterator[pa.RecordBatch]:
18101808
"""Scan the Iceberg table and return an Iterator[pa.RecordBatch] in a streaming fashion.
18111809
18121810
Files are read sequentially and batches are yielded one at a time

pyiceberg/table/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2182,7 +2182,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
21822182
batches,
21832183
).cast(target_schema)
21842184

2185-
def to_record_batches(self, batch_size: int | None = None) -> Iterator["pa.RecordBatch"]:
2185+
def to_record_batches(self, batch_size: int | None = None) -> Iterator[pa.RecordBatch]:
21862186
"""Read record batches in a streaming fashion from this DataScan.
21872187
21882188
Files are read sequentially and batches are yielded one at a time

tests/integration/test_reads.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,9 +1290,7 @@ def test_datascan_to_record_batches(catalog: Catalog) -> None:
12901290

12911291
scan = table.scan()
12921292
streaming_batches = list(scan.to_record_batches())
1293-
streaming_result = pa.concat_tables(
1294-
[pa.Table.from_batches([b]) for b in streaming_batches], promote_options="permissive"
1295-
)
1293+
streaming_result = pa.concat_tables([pa.Table.from_batches([b]) for b in streaming_batches], promote_options="permissive")
12961294

12971295
eager_result = scan.to_arrow()
12981296

tests/io/test_pyarrow.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4917,7 +4917,7 @@ def test_task_to_record_batches_with_batch_size(tmpdir: str) -> None:
49174917
assert len(batch) <= 100
49184918

49194919

4920-
def test_to_record_batches_streaming_basic(tmpdir: str) -> None:
4920+
def test_to_record_batch_stream_basic(tmpdir: str) -> None:
49214921
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
49224922
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})
49234923

@@ -4941,7 +4941,7 @@ def test_to_record_batches_streaming_basic(tmpdir: str) -> None:
49414941
case_sensitive=True,
49424942
)
49434943

4944-
result = scan.to_record_batches_streaming([task])
4944+
result = scan.to_record_batch_stream([task])
49454945
# Should be a generator/iterator, not a list
49464946
import types
49474947

@@ -4952,7 +4952,7 @@ def test_to_record_batches_streaming_basic(tmpdir: str) -> None:
49524952
assert total_rows == 100
49534953

49544954

4955-
def test_to_record_batches_streaming_with_batch_size(tmpdir: str) -> None:
4955+
def test_to_record_batch_stream_with_batch_size(tmpdir: str) -> None:
49564956
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
49574957
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})
49584958

@@ -4976,15 +4976,15 @@ def test_to_record_batches_streaming_with_batch_size(tmpdir: str) -> None:
49764976
case_sensitive=True,
49774977
)
49784978

4979-
batches = list(scan.to_record_batches_streaming([task], batch_size=50))
4979+
batches = list(scan.to_record_batch_stream([task], batch_size=50))
49804980

49814981
total_rows = sum(len(b) for b in batches)
49824982
assert total_rows == 500
49834983
for batch in batches:
49844984
assert len(batch) <= 50
49854985

49864986

4987-
def test_to_record_batches_streaming_with_limit(tmpdir: str) -> None:
4987+
def test_to_record_batch_stream_with_limit(tmpdir: str) -> None:
49884988
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
49894989
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})
49904990

@@ -5009,13 +5009,13 @@ def test_to_record_batches_streaming_with_limit(tmpdir: str) -> None:
50095009
limit=100,
50105010
)
50115011

5012-
batches = list(scan.to_record_batches_streaming([task]))
5012+
batches = list(scan.to_record_batch_stream([task]))
50135013

50145014
total_rows = sum(len(b) for b in batches)
50155015
assert total_rows == 100
50165016

50175017

5018-
def test_to_record_batches_streaming_with_deletes(
5018+
def test_to_record_batch_stream_with_deletes(
50195019
deletes_file: str, request: pytest.FixtureRequest, table_schema_simple: Schema
50205020
) -> None:
50215021
file_format = FileFormat.PARQUET if deletes_file.endswith(".parquet") else FileFormat.ORC
@@ -5052,17 +5052,15 @@ def test_to_record_batches_streaming_with_deletes(
50525052
)
50535053

50545054
# Compare streaming path to table path
5055-
streaming_batches = list(scan.to_record_batches_streaming([example_task_with_delete]))
5056-
streaming_table = pa.concat_tables(
5057-
[pa.Table.from_batches([b]) for b in streaming_batches], promote_options="permissive"
5058-
)
5055+
streaming_batches = list(scan.to_record_batch_stream([example_task_with_delete]))
5056+
streaming_table = pa.concat_tables([pa.Table.from_batches([b]) for b in streaming_batches], promote_options="permissive")
50595057
eager_table = scan.to_table(tasks=[example_task_with_delete])
50605058

50615059
assert streaming_table.num_rows == eager_table.num_rows
50625060
assert streaming_table.column_names == eager_table.column_names
50635061

50645062

5065-
def test_to_record_batches_streaming_multiple_files(tmpdir: str) -> None:
5063+
def test_to_record_batch_stream_multiple_files(tmpdir: str) -> None:
50665064
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
50675065
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})
50685066

@@ -5090,6 +5088,6 @@ def test_to_record_batches_streaming_multiple_files(tmpdir: str) -> None:
50905088
case_sensitive=True,
50915089
)
50925090

5093-
batches = list(scan.to_record_batches_streaming(tasks))
5091+
batches = list(scan.to_record_batch_stream(tasks))
50945092
total_rows = sum(len(b) for b in batches)
50955093
assert total_rows == total_expected # 600 rows total

0 commit comments

Comments
 (0)