Skip to content

Commit ac8add8

Browse files
sumedhsakdeoclaude
andcommitted
Fix mypy errors: change concurrent_files to concurrent_streams
- Update all test function calls to use concurrent_streams parameter - Fix parameter name mismatch with _bounded_concurrent_batches function signature - Update variable names and comments to match new parameter name 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent de9f3c2 commit ac8add8

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

tests/io/test_bounded_concurrent_batches.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pytest
2626

2727
from pyiceberg.io.pyarrow import _bounded_concurrent_batches
28-
from pyiceberg.table import FileScanTask, ScanOrder, TaskOrder, ArrivalOrder
28+
from pyiceberg.table import ArrivalOrder, FileScanTask
2929

3030

3131
def _make_task() -> FileScanTask:
@@ -50,7 +50,7 @@ def test_correctness_single_file() -> None:
5050
def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]:
5151
yield from expected_batches
5252

53-
result = list(_bounded_concurrent_batches([task], batch_fn, concurrent_files=1, max_buffered_batches=16))
53+
result = list(_bounded_concurrent_batches([task], batch_fn, concurrent_streams=1, max_buffered_batches=16))
5454

5555
assert len(result) == 3
5656
total_rows = sum(len(b) for b in result)
@@ -66,7 +66,7 @@ def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]:
6666
idx = tasks.index(t)
6767
yield from _make_batches(batches_per_file, start=idx * 100)
6868

69-
result = list(_bounded_concurrent_batches(tasks, batch_fn, concurrent_files=2, max_buffered_batches=16))
69+
result = list(_bounded_concurrent_batches(tasks, batch_fn, concurrent_streams=2, max_buffered_batches=16))
7070

7171
total_rows = sum(len(b) for b in result)
7272
assert total_rows == batches_per_file * len(tasks) * 10 # 3 batches * 4 files * 10 rows
@@ -82,7 +82,7 @@ def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]:
8282
barrier.wait(timeout=5.0)
8383
yield pa.record_batch({"col": [4, 5, 6]})
8484

85-
gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=2, max_buffered_batches=16)
85+
gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_streams=2, max_buffered_batches=16)
8686

8787
# Should get at least one batch before all are done
8888
first = next(gen)
@@ -110,7 +110,7 @@ def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]:
110110
produced_count += 1
111111
yield pa.record_batch({"col": [i]})
112112

113-
gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=1, max_buffered_batches=max_buffered)
113+
gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_streams=1, max_buffered_batches=max_buffered)
114114

115115
# Consume slowly and check that not all batches are produced immediately
116116
first = next(gen)
@@ -131,7 +131,7 @@ def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]:
131131
yield pa.record_batch({"col": [1]})
132132
raise ValueError("test error")
133133

134-
gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=1, max_buffered_batches=16)
134+
gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_streams=1, max_buffered_batches=16)
135135

136136
# Should get the first batch
137137
first = next(gen)
@@ -153,7 +153,7 @@ def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]:
153153
yield pa.record_batch({"col": [i]})
154154
time.sleep(0.01)
155155

156-
gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_files=3, max_buffered_batches=4)
156+
gen = _bounded_concurrent_batches(tasks, batch_fn, concurrent_streams=3, max_buffered_batches=4)
157157

158158
# Consume a few batches then stop
159159
worker_started.wait(timeout=5.0)
@@ -168,8 +168,8 @@ def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]:
168168

169169

170170
def test_concurrency_limit() -> None:
171-
"""Test that at most concurrent_files files are read concurrently."""
172-
concurrent_files = 2
171+
"""Test that at most concurrent_streams files are read concurrently."""
172+
concurrent_streams = 2
173173
tasks = [_make_task() for _ in range(6)]
174174
active_count = 0
175175
max_active = 0
@@ -187,10 +187,10 @@ def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]:
187187
with active_lock:
188188
active_count -= 1
189189

190-
result = list(_bounded_concurrent_batches(tasks, batch_fn, concurrent_files=concurrent_files, max_buffered_batches=16))
190+
result = list(_bounded_concurrent_batches(tasks, batch_fn, concurrent_streams=concurrent_streams, max_buffered_batches=16))
191191

192192
assert len(result) == 6
193-
assert max_active <= concurrent_files
193+
assert max_active <= concurrent_streams
194194

195195

196196
def test_empty_tasks() -> None:
@@ -199,12 +199,12 @@ def test_empty_tasks() -> None:
199199
def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]:
200200
yield from []
201201

202-
result = list(_bounded_concurrent_batches([], batch_fn, concurrent_files=2, max_buffered_batches=16))
202+
result = list(_bounded_concurrent_batches([], batch_fn, concurrent_streams=2, max_buffered_batches=16))
203203
assert result == []
204204

205205

206206
def test_concurrent_with_limit_via_arrowscan(tmpdir: str) -> None:
207-
"""Test concurrent_files with limit through ArrowScan integration."""
207+
"""Test concurrent_streams with limit through ArrowScan integration."""
208208
from pyiceberg.expressions import AlwaysTrue
209209
from pyiceberg.io.pyarrow import ArrowScan, PyArrowFileIO
210210
from pyiceberg.manifest import DataFileContent, FileFormat

0 commit comments

Comments
 (0)