33import os
44from concurrent .futures import ThreadPoolExecutor
55from dataclasses import dataclass
6- from itertools import islice , pairwise
6+ from itertools import pairwise
77from typing import TYPE_CHECKING , Any , TypeVar , cast
88from warnings import warn
99
@@ -46,14 +46,6 @@ def _unzip2(iterable: Iterable[tuple[T, U]]) -> tuple[list[T], list[U]]:
4646 return (out0 , out1 )
4747
4848
49- def batched (iterable : Iterable [T ], n : int ) -> Iterable [tuple [T , ...]]:
50- if n < 1 :
51- raise ValueError ("n must be at least one" )
52- it = iter (iterable )
53- while batch := tuple (islice (it , n )):
54- yield batch
55-
56-
5749def resolve_batched (codec : Codec , chunk_specs : Iterable [ArraySpec ]) -> Iterable [ArraySpec ]:
5850 return [codec .resolve_metadata (chunk_spec ) for chunk_spec in chunk_specs ]
5951
@@ -153,25 +145,37 @@ def _choose_workers(
153145 * ,
154146 is_encode : bool = False ,
155147) -> int :
156- """Decide how many thread pool workers to use (0 = don't use pool)."""
157- if n_chunks < 2 :
148+ """Decide how many thread pool workers to use (0 = don't use pool).
149+
150+ Respects ``threading.codec_workers`` config:
151+ - ``enabled``: if False, always returns 0.
152+ - ``min``: floor for the number of workers.
153+ - ``max``: ceiling for the number of workers (default: ``os.cpu_count()``).
154+ """
155+ codec_workers = config .get ("threading.codec_workers" )
156+ if not codec_workers .get ("enabled" , True ):
158157 return 0
159158
159+ min_workers : int = codec_workers .get ("min" , 0 )
160+ max_workers : int = codec_workers .get ("max" ) or os .cpu_count () or 4
161+
162+ if n_chunks < 2 :
163+ return min_workers
164+
160165 per_chunk_ns = _estimate_chunk_work_ns (chunk_nbytes , codecs , is_encode = is_encode )
161166
162- if per_chunk_ns < _POOL_OVERHEAD_NS :
167+ if per_chunk_ns < _POOL_OVERHEAD_NS and min_workers == 0 :
163168 return 0
164169
165170 total_work_ns = per_chunk_ns * n_chunks
166171 total_dispatch_ns = n_chunks * 50_000 # ~50us per task
167- if total_work_ns < total_dispatch_ns * 3 :
172+ if total_work_ns < total_dispatch_ns * 3 and min_workers == 0 :
168173 return 0
169174
170175 target_per_worker_ns = 1_000_000 # 1ms
171176 workers = max (1 , int (total_work_ns / target_per_worker_ns ))
172177
173- cpu_count = os .cpu_count () or 4
174- return min (workers , n_chunks , cpu_count )
178+ return max (min_workers , min (workers , n_chunks , max_workers ))
175179
176180
177181def _get_pool (max_workers : int ) -> ThreadPoolExecutor :
@@ -208,7 +212,6 @@ class BatchedCodecPipeline(CodecPipeline):
208212 array_array_codecs : tuple [ArrayArrayCodec , ...]
209213 array_bytes_codec : ArrayBytesCodec
210214 bytes_bytes_codecs : tuple [BytesBytesCodec , ...]
211- batch_size : int
212215
213216 @property
214217 def _all_sync (self ) -> bool :
@@ -219,14 +222,13 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
219222 return type (self ).from_codecs (c .evolve_from_array_spec (array_spec = array_spec ) for c in self )
220223
221224 @classmethod
222- def from_codecs (cls , codecs : Iterable [Codec ], * , batch_size : int | None = None ) -> Self :
225+ def from_codecs (cls , codecs : Iterable [Codec ]) -> Self :
223226 array_array_codecs , array_bytes_codec , bytes_bytes_codecs = codecs_from_list (list (codecs ))
224227
225228 return cls (
226229 array_array_codecs = array_array_codecs ,
227230 array_bytes_codec = array_bytes_codec ,
228231 bytes_bytes_codecs = bytes_bytes_codecs ,
229- batch_size = batch_size or config .get ("codec_pipeline.batch_size" ),
230232 )
231233
232234 @property
@@ -478,10 +480,7 @@ async def decode(
478480 ]
479481
480482 # Async fallback: layer-by-layer across all chunks.
481- output : list [NDBuffer | None ] = []
482- for batch_info in batched (items , self .batch_size ):
483- output .extend (await self .decode_batch (batch_info ))
484- return output
483+ return list (await self .decode_batch (items ))
485484
486485 async def encode (
487486 self ,
@@ -496,10 +495,7 @@ async def encode(
496495 return [self ._encode_one (chunk_array , chunk_spec ) for chunk_array , chunk_spec in items ]
497496
498497 # Async fallback: layer-by-layer across all chunks.
499- output : list [Buffer | None ] = []
500- for single_batch_info in batched (items , self .batch_size ):
501- output .extend (await self .encode_batch (single_batch_info ))
502- return output
498+ return list (await self .encode_batch (items ))
503499
504500 # -------------------------------------------------------------------
505501 # Async read / write (IO overlap via concurrent_map)
@@ -610,14 +606,7 @@ async def read(
610606 out : NDBuffer ,
611607 drop_axes : tuple [int , ...] = (),
612608 ) -> None :
613- await concurrent_map (
614- [
615- (single_batch_info , out , drop_axes )
616- for single_batch_info in batched (batch_info , self .batch_size )
617- ],
618- self .read_batch ,
619- config .get ("async.concurrency" ),
620- )
609+ await self .read_batch (batch_info , out , drop_axes )
621610
622611 def _merge_chunk_array (
623612 self ,
@@ -840,14 +829,7 @@ async def write(
840829 value : NDBuffer ,
841830 drop_axes : tuple [int , ...] = (),
842831 ) -> None :
843- await concurrent_map (
844- [
845- (single_batch_info , value , drop_axes )
846- for single_batch_info in batched (batch_info , self .batch_size )
847- ],
848- self .write_batch ,
849- config .get ("async.concurrency" ),
850- )
832+ await self .write_batch (batch_info , value , drop_axes )
851833
852834 # -------------------------------------------------------------------
853835 # Fully synchronous read / write (no event loop)
0 commit comments