4040import warnings
4141from abc import ABC , abstractmethod
4242from collections .abc import Callable , Generator , Iterable , Iterator
43+ from concurrent .futures import ThreadPoolExecutor
4344from copy import copy
4445from dataclasses import dataclass
4546from enum import Enum
@@ -1695,9 +1696,9 @@ def _bounded_concurrent_batches(
16951696) -> Generator [pa .RecordBatch , None , None ]:
16961697 """Read batches from multiple files concurrently with bounded memory.
16971698
1698- Workers read from files in parallel (up to concurrent_files at a time) and push
1699- batches into a shared queue. The consumer yields batches from the queue.
1700- A sentinel value signals completion, avoiding timeout-based polling .
1699+ Uses a per-scan ThreadPoolExecutor(max_workers= concurrent_files) to naturally
1700+ bound concurrency. Workers push batches into a bounded queue which provides
1701+ backpressure when the consumer is slower than the producers .
17011702
17021703 Args:
17031704 tasks: The file scan tasks to process.
@@ -1709,60 +1710,49 @@ def _bounded_concurrent_batches(
17091710 return
17101711
17111712 batch_queue : queue .Queue [pa .RecordBatch | BaseException | object ] = queue .Queue (maxsize = max_buffered_batches )
1712- cancel_event = threading .Event ()
1713- pending_count = len (tasks )
1714- pending_lock = threading .Lock ()
1715- file_semaphore = threading .Semaphore (concurrent_files )
1713+ cancel = threading .Event ()
1714+ remaining = len (tasks )
1715+ remaining_lock = threading .Lock ()
17161716
17171717 def worker (task : FileScanTask ) -> None :
1718- nonlocal pending_count
1718+ nonlocal remaining
17191719 try :
1720- # Blocking acquire — on cancellation, extra permits are released to unblock.
1721- file_semaphore .acquire ()
1722- if cancel_event .is_set ():
1723- return
1724-
17251720 for batch in batch_fn (task ):
1726- if cancel_event .is_set ():
1721+ if cancel .is_set ():
17271722 return
17281723 batch_queue .put (batch )
17291724 except BaseException as e :
1730- if not cancel_event .is_set ():
1725+ if not cancel .is_set ():
17311726 batch_queue .put (e )
17321727 finally :
1733- file_semaphore .release ()
1734- with pending_lock :
1735- pending_count -= 1
1736- if pending_count == 0 :
1728+ with remaining_lock :
1729+ remaining -= 1
1730+ if remaining == 0 :
17371731 batch_queue .put (_QUEUE_SENTINEL )
17381732
1739- executor = ExecutorFactory .get_or_create ()
1740- futures = [executor .submit (worker , task ) for task in tasks ]
1733+ with ThreadPoolExecutor (max_workers = concurrent_files ) as executor :
1734+ for task in tasks :
1735+ executor .submit (worker , task )
17411736
1742- try :
1743- while True :
1744- item = batch_queue .get ()
1737+ try :
1738+ while True :
1739+ item = batch_queue .get ()
17451740
1746- if item is _QUEUE_SENTINEL :
1747- break
1741+ if item is _QUEUE_SENTINEL :
1742+ break
17481743
1749- if isinstance (item , BaseException ):
1750- raise item
1751-
1752- yield item
1753- finally :
1754- cancel_event .set ()
1755- # Release semaphore permits to unblock any workers waiting on acquire()
1756- for _ in range (len (tasks )):
1757- file_semaphore .release ()
1758- # Drain the queue to unblock any workers stuck on put()
1759- while not batch_queue .empty ():
1760- try :
1761- batch_queue .get_nowait ()
1762- except queue .Empty :
1763- break
1764- for future in futures :
1765- future .cancel ()
1744+ if isinstance (item , BaseException ):
1745+ raise item
1746+
1747+ yield item
1748+ finally :
1749+ cancel .set ()
1750+ # Drain the queue to unblock any workers stuck on put()
1751+ while not batch_queue .empty ():
1752+ try :
1753+ batch_queue .get_nowait ()
1754+ except queue .Empty :
1755+ break
17661756
17671757
17681758class ArrowScan :
@@ -1889,52 +1879,66 @@ def to_record_batches(
18891879 if concurrent_files < 1 :
18901880 raise ValueError (f"concurrent_files must be >= 1, got { concurrent_files } " )
18911881
1892- deletes_per_file = _read_all_delete_files ( self ._io , tasks )
1882+ task_list , deletes_per_file = self ._prepare_tasks_and_deletes ( tasks )
18931883
18941884 if order == ScanOrder .ARRIVAL :
1895- # Arrival order: read files with bounded concurrency, yielding batches as produced.
1896- # When concurrent_files=1, this is sequential. When >1, batches may interleave across files.
1897- task_list = list (tasks )
1898-
1899- def batch_fn (task : FileScanTask ) -> Iterator [pa .RecordBatch ]:
1900- return self ._record_batches_from_scan_tasks_and_deletes ([task ], deletes_per_file , batch_size )
1901-
1902- total_row_count = 0
1903- for batch in _bounded_concurrent_batches (task_list , batch_fn , concurrent_files ):
1904- current_batch_size = len (batch )
1905- if self ._limit is not None and total_row_count + current_batch_size >= self ._limit :
1906- yield batch .slice (0 , self ._limit - total_row_count )
1907- return
1908- yield batch
1909- total_row_count += current_batch_size
1910- return
1885+ return self ._apply_limit (self ._iter_batches_arrival (task_list , deletes_per_file , batch_size , concurrent_files ))
19111886
1912- # Task order: existing behavior with executor.map + list()
1913- total_row_count = 0
1887+ return self ._apply_limit (self ._iter_batches_materialized (task_list , deletes_per_file , batch_size ))
1888+
1889+ def _prepare_tasks_and_deletes (
1890+ self , tasks : Iterable [FileScanTask ]
1891+ ) -> tuple [list [FileScanTask ], dict [str , list [ChunkedArray ]]]:
1892+ """Resolve delete files and return tasks as a list."""
1893+ task_list = list (tasks )
1894+ deletes_per_file = _read_all_delete_files (self ._io , task_list )
1895+ return task_list , deletes_per_file
1896+
1897+ def _iter_batches_arrival (
1898+ self ,
1899+ task_list : list [FileScanTask ],
1900+ deletes_per_file : dict [str , list [ChunkedArray ]],
1901+ batch_size : int | None ,
1902+ concurrent_files : int ,
1903+ ) -> Iterator [pa .RecordBatch ]:
1904+ """Yield batches using bounded concurrent streaming in arrival order."""
1905+
1906+ def batch_fn (task : FileScanTask ) -> Iterator [pa .RecordBatch ]:
1907+ return self ._record_batches_from_scan_tasks_and_deletes ([task ], deletes_per_file , batch_size )
1908+
1909+ yield from _bounded_concurrent_batches (task_list , batch_fn , concurrent_files )
1910+
1911+ def _iter_batches_materialized (
1912+ self ,
1913+ task_list : list [FileScanTask ],
1914+ deletes_per_file : dict [str , list [ChunkedArray ]],
1915+ batch_size : int | None ,
1916+ ) -> Iterator [pa .RecordBatch ]:
1917+ """Yield batches using executor.map with full file materialization."""
19141918 executor = ExecutorFactory .get_or_create ()
19151919
19161920 def batches_for_task (task : FileScanTask ) -> list [pa .RecordBatch ]:
1917- # Materialize the iterator here to ensure execution happens within the executor.
1918- # Otherwise, the iterator would be lazily consumed later (in the main thread),
1919- # defeating the purpose of using executor.map.
19201921 return list (self ._record_batches_from_scan_tasks_and_deletes ([task ], deletes_per_file , batch_size ))
19211922
1922- limit_reached = False
1923- for batches in executor .map (batches_for_task , tasks ):
1924- for batch in batches :
1925- current_batch_size = len (batch )
1926- if self ._limit is not None and total_row_count + current_batch_size >= self ._limit :
1927- yield batch .slice (0 , self ._limit - total_row_count )
1923+ for batches in executor .map (batches_for_task , task_list ):
1924+ yield from batches
19281925
1929- limit_reached = True
1930- break
1931- else :
1932- yield batch
1933- total_row_count += current_batch_size
1926+ def _apply_limit ( self , batches : Iterator [ pa . RecordBatch ]) -> Iterator [ pa . RecordBatch ]:
1927+ """Apply row limit across batches."""
1928+ if self . _limit is None :
1929+ yield from batches
1930+ return
19341931
1935- if limit_reached :
1936- # This break will also cancel all running tasks in the executor
1937- break
1932+ total_row_count = 0
1933+ for batch in batches :
1934+ remaining = self ._limit - total_row_count
1935+ if remaining <= 0 :
1936+ return
1937+ if len (batch ) > remaining :
1938+ yield batch .slice (0 , remaining )
1939+ return
1940+ yield batch
1941+ total_row_count += len (batch )
19381942
19391943 def _record_batches_from_scan_tasks_and_deletes (
19401944 self , tasks : Iterable [FileScanTask ], deletes_per_file : dict [str , list [ChunkedArray ]], batch_size : int | None = None
0 commit comments