Skip to content

Commit 4186713

Browse files
sumedhsakdeoclaude
andcommitted
refactor: simplify _bounded_concurrent_batches with per-scan executor
Replace shared ExecutorFactory + Semaphore with per-scan ThreadPoolExecutor(max_workers=concurrent_files) for deterministic shutdown and simpler concurrency control. Refactor to_record_batches into helpers: - _prepare_tasks_and_deletes: resolve delete files - _iter_batches_streaming: bounded concurrent streaming path - _iter_batches_materialized: executor.map materialization path - _apply_limit: unified row limit logic (was duplicated) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b360ae8 commit 4186713

File tree

2 files changed

+86
-82
lines changed

2 files changed

+86
-82
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 83 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import warnings
4141
from abc import ABC, abstractmethod
4242
from collections.abc import Callable, Generator, Iterable, Iterator
43+
from concurrent.futures import ThreadPoolExecutor
4344
from copy import copy
4445
from dataclasses import dataclass
4546
from enum import Enum
@@ -1695,9 +1696,9 @@ def _bounded_concurrent_batches(
16951696
) -> Generator[pa.RecordBatch, None, None]:
16961697
"""Read batches from multiple files concurrently with bounded memory.
16971698
1698-
Workers read from files in parallel (up to concurrent_files at a time) and push
1699-
batches into a shared queue. The consumer yields batches from the queue.
1700-
A sentinel value signals completion, avoiding timeout-based polling.
1699+
Uses a per-scan ThreadPoolExecutor(max_workers=concurrent_files) to naturally
1700+
bound concurrency. Workers push batches into a bounded queue which provides
1701+
backpressure when the consumer is slower than the producers.
17011702
17021703
Args:
17031704
tasks: The file scan tasks to process.
@@ -1709,60 +1710,49 @@ def _bounded_concurrent_batches(
17091710
return
17101711

17111712
batch_queue: queue.Queue[pa.RecordBatch | BaseException | object] = queue.Queue(maxsize=max_buffered_batches)
1712-
cancel_event = threading.Event()
1713-
pending_count = len(tasks)
1714-
pending_lock = threading.Lock()
1715-
file_semaphore = threading.Semaphore(concurrent_files)
1713+
cancel = threading.Event()
1714+
remaining = len(tasks)
1715+
remaining_lock = threading.Lock()
17161716

17171717
def worker(task: FileScanTask) -> None:
1718-
nonlocal pending_count
1718+
nonlocal remaining
17191719
try:
1720-
# Blocking acquire — on cancellation, extra permits are released to unblock.
1721-
file_semaphore.acquire()
1722-
if cancel_event.is_set():
1723-
return
1724-
17251720
for batch in batch_fn(task):
1726-
if cancel_event.is_set():
1721+
if cancel.is_set():
17271722
return
17281723
batch_queue.put(batch)
17291724
except BaseException as e:
1730-
if not cancel_event.is_set():
1725+
if not cancel.is_set():
17311726
batch_queue.put(e)
17321727
finally:
1733-
file_semaphore.release()
1734-
with pending_lock:
1735-
pending_count -= 1
1736-
if pending_count == 0:
1728+
with remaining_lock:
1729+
remaining -= 1
1730+
if remaining == 0:
17371731
batch_queue.put(_QUEUE_SENTINEL)
17381732

1739-
executor = ExecutorFactory.get_or_create()
1740-
futures = [executor.submit(worker, task) for task in tasks]
1733+
with ThreadPoolExecutor(max_workers=concurrent_files) as executor:
1734+
for task in tasks:
1735+
executor.submit(worker, task)
17411736

1742-
try:
1743-
while True:
1744-
item = batch_queue.get()
1737+
try:
1738+
while True:
1739+
item = batch_queue.get()
17451740

1746-
if item is _QUEUE_SENTINEL:
1747-
break
1741+
if item is _QUEUE_SENTINEL:
1742+
break
17481743

1749-
if isinstance(item, BaseException):
1750-
raise item
1751-
1752-
yield item
1753-
finally:
1754-
cancel_event.set()
1755-
# Release semaphore permits to unblock any workers waiting on acquire()
1756-
for _ in range(len(tasks)):
1757-
file_semaphore.release()
1758-
# Drain the queue to unblock any workers stuck on put()
1759-
while not batch_queue.empty():
1760-
try:
1761-
batch_queue.get_nowait()
1762-
except queue.Empty:
1763-
break
1764-
for future in futures:
1765-
future.cancel()
1744+
if isinstance(item, BaseException):
1745+
raise item
1746+
1747+
yield item
1748+
finally:
1749+
cancel.set()
1750+
# Drain the queue to unblock any workers stuck on put()
1751+
while not batch_queue.empty():
1752+
try:
1753+
batch_queue.get_nowait()
1754+
except queue.Empty:
1755+
break
17661756

17671757

17681758
class ArrowScan:
@@ -1889,52 +1879,66 @@ def to_record_batches(
18891879
if concurrent_files < 1:
18901880
raise ValueError(f"concurrent_files must be >= 1, got {concurrent_files}")
18911881

1892-
deletes_per_file = _read_all_delete_files(self._io, tasks)
1882+
task_list, deletes_per_file = self._prepare_tasks_and_deletes(tasks)
18931883

18941884
if order == ScanOrder.ARRIVAL:
1895-
# Arrival order: read files with bounded concurrency, yielding batches as produced.
1896-
# When concurrent_files=1, this is sequential. When >1, batches may interleave across files.
1897-
task_list = list(tasks)
1898-
1899-
def batch_fn(task: FileScanTask) -> Iterator[pa.RecordBatch]:
1900-
return self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size)
1901-
1902-
total_row_count = 0
1903-
for batch in _bounded_concurrent_batches(task_list, batch_fn, concurrent_files):
1904-
current_batch_size = len(batch)
1905-
if self._limit is not None and total_row_count + current_batch_size >= self._limit:
1906-
yield batch.slice(0, self._limit - total_row_count)
1907-
return
1908-
yield batch
1909-
total_row_count += current_batch_size
1910-
return
1885+
return self._apply_limit(self._iter_batches_arrival(task_list, deletes_per_file, batch_size, concurrent_files))
19111886

1912-
# Task order: existing behavior with executor.map + list()
1913-
total_row_count = 0
1887+
return self._apply_limit(self._iter_batches_materialized(task_list, deletes_per_file, batch_size))
1888+
1889+
def _prepare_tasks_and_deletes(
1890+
self, tasks: Iterable[FileScanTask]
1891+
) -> tuple[list[FileScanTask], dict[str, list[ChunkedArray]]]:
1892+
"""Resolve delete files and return tasks as a list."""
1893+
task_list = list(tasks)
1894+
deletes_per_file = _read_all_delete_files(self._io, task_list)
1895+
return task_list, deletes_per_file
1896+
1897+
def _iter_batches_arrival(
1898+
self,
1899+
task_list: list[FileScanTask],
1900+
deletes_per_file: dict[str, list[ChunkedArray]],
1901+
batch_size: int | None,
1902+
concurrent_files: int,
1903+
) -> Iterator[pa.RecordBatch]:
1904+
"""Yield batches using bounded concurrent streaming in arrival order."""
1905+
1906+
def batch_fn(task: FileScanTask) -> Iterator[pa.RecordBatch]:
1907+
return self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size)
1908+
1909+
yield from _bounded_concurrent_batches(task_list, batch_fn, concurrent_files)
1910+
1911+
def _iter_batches_materialized(
1912+
self,
1913+
task_list: list[FileScanTask],
1914+
deletes_per_file: dict[str, list[ChunkedArray]],
1915+
batch_size: int | None,
1916+
) -> Iterator[pa.RecordBatch]:
1917+
"""Yield batches using executor.map with full file materialization."""
19141918
executor = ExecutorFactory.get_or_create()
19151919

19161920
def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
1917-
# Materialize the iterator here to ensure execution happens within the executor.
1918-
# Otherwise, the iterator would be lazily consumed later (in the main thread),
1919-
# defeating the purpose of using executor.map.
19201921
return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, batch_size))
19211922

1922-
limit_reached = False
1923-
for batches in executor.map(batches_for_task, tasks):
1924-
for batch in batches:
1925-
current_batch_size = len(batch)
1926-
if self._limit is not None and total_row_count + current_batch_size >= self._limit:
1927-
yield batch.slice(0, self._limit - total_row_count)
1923+
for batches in executor.map(batches_for_task, task_list):
1924+
yield from batches
19281925

1929-
limit_reached = True
1930-
break
1931-
else:
1932-
yield batch
1933-
total_row_count += current_batch_size
1926+
def _apply_limit(self, batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
1927+
"""Apply row limit across batches."""
1928+
if self._limit is None:
1929+
yield from batches
1930+
return
19341931

1935-
if limit_reached:
1936-
# This break will also cancel all running tasks in the executor
1937-
break
1932+
total_row_count = 0
1933+
for batch in batches:
1934+
remaining = self._limit - total_row_count
1935+
if remaining <= 0:
1936+
return
1937+
if len(batch) > remaining:
1938+
yield batch.slice(0, remaining)
1939+
return
1940+
yield batch
1941+
total_row_count += len(batch)
19381942

19391943
def _record_batches_from_scan_tasks_and_deletes(
19401944
self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]], batch_size: int | None = None

tests/io/test_bounded_concurrent_batches.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pytest
2626

2727
from pyiceberg.io.pyarrow import _bounded_concurrent_batches
28-
from pyiceberg.table import FileScanTask
28+
from pyiceberg.table import FileScanTask, ScanOrder
2929

3030

3131
def _make_task() -> FileScanTask:
@@ -72,7 +72,7 @@ def batch_fn(t: FileScanTask) -> Iterator[pa.RecordBatch]:
7272
assert total_rows == batches_per_file * len(tasks) * 10 # 3 batches * 4 files * 10 rows
7373

7474

75-
def test_streaming_yields_incrementally() -> None:
75+
def test_arrival_order_yields_incrementally() -> None:
7676
"""Test that batches are yielded incrementally, not all at once."""
7777
barrier = threading.Event()
7878
tasks = [_make_task(), _make_task()]
@@ -253,6 +253,6 @@ def test_concurrent_with_limit_via_arrowscan(tmpdir: str) -> None:
253253
limit=150,
254254
)
255255

256-
batches = list(scan.to_record_batches(tasks, streaming=True, concurrent_files=2))
256+
batches = list(scan.to_record_batches(tasks, order=ScanOrder.ARRIVAL, concurrent_files=2))
257257
total_rows = sum(len(b) for b in batches)
258258
assert total_rows == 150

0 commit comments

Comments
 (0)