4040import warnings
4141from abc import ABC , abstractmethod
4242from collections .abc import Callable , Generator , Iterable , Iterator
43+ from concurrent .futures import ThreadPoolExecutor
4344from copy import copy
4445from dataclasses import dataclass
4546from enum import Enum
@@ -1695,9 +1696,9 @@ def _bounded_concurrent_batches(
16951696) -> Generator [pa .RecordBatch , None , None ]:
16961697 """Read batches from multiple files concurrently with bounded memory.
16971698
1698- Workers read from files in parallel (up to concurrent_files at a time) and push
1699- batches into a shared queue. The consumer yields batches from the queue.
1700- A sentinel value signals completion, avoiding timeout-based polling .
1699+ Uses a per-scan ThreadPoolExecutor(max_workers= concurrent_files) to naturally
1700+ bound concurrency. Workers push batches into a bounded queue which provides
1701+ backpressure when the consumer is slower than the producers .
17011702
17021703 Args:
17031704 tasks: The file scan tasks to process.
@@ -1709,60 +1710,49 @@ def _bounded_concurrent_batches(
17091710 return
17101711
17111712 batch_queue : queue .Queue [pa .RecordBatch | BaseException | object ] = queue .Queue (maxsize = max_buffered_batches )
1712- cancel_event = threading .Event ()
1713- pending_count = len (tasks )
1714- pending_lock = threading .Lock ()
1715- file_semaphore = threading .Semaphore (concurrent_files )
1713+ cancel = threading .Event ()
1714+ remaining = len (tasks )
1715+ remaining_lock = threading .Lock ()
17161716
17171717 def worker (task : FileScanTask ) -> None :
1718- nonlocal pending_count
1718+ nonlocal remaining
17191719 try :
1720- # Blocking acquire — on cancellation, extra permits are released to unblock.
1721- file_semaphore .acquire ()
1722- if cancel_event .is_set ():
1723- return
1724-
17251720 for batch in batch_fn (task ):
1726- if cancel_event .is_set ():
1721+ if cancel .is_set ():
17271722 return
17281723 batch_queue .put (batch )
17291724 except BaseException as e :
1730- if not cancel_event .is_set ():
1725+ if not cancel .is_set ():
17311726 batch_queue .put (e )
17321727 finally :
1733- file_semaphore .release ()
1734- with pending_lock :
1735- pending_count -= 1
1736- if pending_count == 0 :
1728+ with remaining_lock :
1729+ remaining -= 1
1730+ if remaining == 0 :
17371731 batch_queue .put (_QUEUE_SENTINEL )
17381732
1739- executor = ExecutorFactory .get_or_create ()
1740- futures = [executor .submit (worker , task ) for task in tasks ]
1733+ with ThreadPoolExecutor (max_workers = concurrent_files ) as executor :
1734+ for task in tasks :
1735+ executor .submit (worker , task )
17411736
1742- try :
1743- while True :
1744- item = batch_queue .get ()
1737+ try :
1738+ while True :
1739+ item = batch_queue .get ()
17451740
1746- if item is _QUEUE_SENTINEL :
1747- break
1741+ if item is _QUEUE_SENTINEL :
1742+ break
17481743
1749- if isinstance (item , BaseException ):
1750- raise item
1751-
1752- yield item
1753- finally :
1754- cancel_event .set ()
1755- # Release semaphore permits to unblock any workers waiting on acquire()
1756- for _ in range (len (tasks )):
1757- file_semaphore .release ()
1758- # Drain the queue to unblock any workers stuck on put()
1759- while not batch_queue .empty ():
1760- try :
1761- batch_queue .get_nowait ()
1762- except queue .Empty :
1763- break
1764- for future in futures :
1765- future .cancel ()
1744+ if isinstance (item , BaseException ):
1745+ raise item
1746+
1747+ yield item
1748+ finally :
1749+ cancel .set ()
1750+ # Drain the queue to unblock any workers stuck on put()
1751+ while not batch_queue .empty ():
1752+ try :
1753+ batch_queue .get_nowait ()
1754+ except queue .Empty :
1755+ break
17661756
17671757
17681758class ArrowScan :
@@ -1884,52 +1874,66 @@ def to_record_batches(
18841874 if concurrent_files < 1 :
18851875 raise ValueError (f"concurrent_files must be >= 1, got { concurrent_files } " )
18861876
1887- deletes_per_file = _read_all_delete_files ( self ._io , tasks )
1877+ task_list , deletes_per_file = self ._prepare_tasks_and_deletes ( tasks )
18881878
18891879 if streaming :
1890- # Streaming path: read files with bounded concurrency, yielding batches as produced.
1891- # When concurrent_files=1, this is sequential. When >1, batches may interleave across files.
1892- task_list = list (tasks )
1893-
1894- def batch_fn (task : FileScanTask ) -> Iterator [pa .RecordBatch ]:
1895- return self ._record_batches_from_scan_tasks_and_deletes ([task ], deletes_per_file , batch_size )
1896-
1897- total_row_count = 0
1898- for batch in _bounded_concurrent_batches (task_list , batch_fn , concurrent_files ):
1899- current_batch_size = len (batch )
1900- if self ._limit is not None and total_row_count + current_batch_size >= self ._limit :
1901- yield batch .slice (0 , self ._limit - total_row_count )
1902- return
1903- yield batch
1904- total_row_count += current_batch_size
1905- return
1880+ return self ._apply_limit (self ._iter_batches_streaming (task_list , deletes_per_file , batch_size , concurrent_files ))
19061881
1907- # Non-streaming path: existing behavior with executor.map + list()
1908- total_row_count = 0
1882+ return self ._apply_limit (self ._iter_batches_materialized (task_list , deletes_per_file , batch_size ))
1883+
1884+ def _prepare_tasks_and_deletes (
1885+ self , tasks : Iterable [FileScanTask ]
1886+ ) -> tuple [list [FileScanTask ], dict [str , list [ChunkedArray ]]]:
1887+ """Resolve delete files and return tasks as a list."""
1888+ task_list = list (tasks )
1889+ deletes_per_file = _read_all_delete_files (self ._io , task_list )
1890+ return task_list , deletes_per_file
1891+
1892+ def _iter_batches_streaming (
1893+ self ,
1894+ task_list : list [FileScanTask ],
1895+ deletes_per_file : dict [str , list [ChunkedArray ]],
1896+ batch_size : int | None ,
1897+ concurrent_files : int ,
1898+ ) -> Iterator [pa .RecordBatch ]:
1899+ """Yield batches using bounded concurrent streaming."""
1900+
1901+ def batch_fn (task : FileScanTask ) -> Iterator [pa .RecordBatch ]:
1902+ return self ._record_batches_from_scan_tasks_and_deletes ([task ], deletes_per_file , batch_size )
1903+
1904+ yield from _bounded_concurrent_batches (task_list , batch_fn , concurrent_files )
1905+
1906+ def _iter_batches_materialized (
1907+ self ,
1908+ task_list : list [FileScanTask ],
1909+ deletes_per_file : dict [str , list [ChunkedArray ]],
1910+ batch_size : int | None ,
1911+ ) -> Iterator [pa .RecordBatch ]:
1912+ """Yield batches using executor.map with full file materialization."""
19091913 executor = ExecutorFactory .get_or_create ()
19101914
19111915 def batches_for_task (task : FileScanTask ) -> list [pa .RecordBatch ]:
1912- # Materialize the iterator here to ensure execution happens within the executor.
1913- # Otherwise, the iterator would be lazily consumed later (in the main thread),
1914- # defeating the purpose of using executor.map.
19151916 return list (self ._record_batches_from_scan_tasks_and_deletes ([task ], deletes_per_file , batch_size ))
19161917
1917- limit_reached = False
1918- for batches in executor .map (batches_for_task , tasks ):
1919- for batch in batches :
1920- current_batch_size = len (batch )
1921- if self ._limit is not None and total_row_count + current_batch_size >= self ._limit :
1922- yield batch .slice (0 , self ._limit - total_row_count )
1918+ for batches in executor .map (batches_for_task , task_list ):
1919+ yield from batches
19231920
1924- limit_reached = True
1925- break
1926- else :
1927- yield batch
1928- total_row_count += current_batch_size
1921+ def _apply_limit ( self , batches : Iterator [ pa . RecordBatch ]) -> Iterator [ pa . RecordBatch ]:
1922+ """Apply row limit across batches."""
1923+ if self . _limit is None :
1924+ yield from batches
1925+ return
19291926
1930- if limit_reached :
1931- # This break will also cancel all running tasks in the executor
1932- break
1927+ total_row_count = 0
1928+ for batch in batches :
1929+ remaining = self ._limit - total_row_count
1930+ if remaining <= 0 :
1931+ return
1932+ if len (batch ) > remaining :
1933+ yield batch .slice (0 , remaining )
1934+ return
1935+ yield batch
1936+ total_row_count += len (batch )
19331937
19341938 def _record_batches_from_scan_tasks_and_deletes (
19351939 self , tasks : Iterable [FileScanTask ], deletes_per_file : dict [str , list [ChunkedArray ]], batch_size : int | None = None
0 commit comments