22
33The standard zarr codec pipeline (``BatchedCodecPipeline``) wraps fundamentally
44synchronous operations (e.g. gzip compress/decompress) in ``asyncio.to_thread``.
5- The ``SyncCodecPipeline`` in this module eliminates that overhead by dispatching
6- the full codec chain for each chunk via ``ThreadPoolExecutor.map``, achieving
7- 2-11x throughput improvements.
5+ The ``SyncCodecPipeline`` in this module eliminates that overhead by running
6+ per-chunk codec chains synchronously, achieving 2-11x throughput improvements.
87
98Usage::
109
1110 import zarr
12- from zarr.experimental.sync_codecs import SyncCodecPipeline
13-
14- arr = zarr.create_array(
15- store,
16- shape=(100, 100),
17- chunks=(32, 32),
18- dtype="float64",
19- codec_pipeline_class=SyncCodecPipeline,
20- )
11+
12+ zarr.config.set({"codec_pipeline.path": "zarr.experimental.sync_codecs.SyncCodecPipeline"})
2113"""
2214
2315from __future__ import annotations
2416
25- import asyncio
26- import os
27- from concurrent .futures import ThreadPoolExecutor
2817from dataclasses import dataclass
2918from itertools import islice
3019from typing import TYPE_CHECKING , TypeVar
6453# Pipeline helpers
6554# ---------------------------------------------------------------------------
6655
56+
6757def _batched (iterable : Iterable [T ], n : int ) -> Iterable [tuple [T , ...]]:
6858 if n < 1 :
6959 raise ValueError ("n must be at least one" )
@@ -79,33 +69,22 @@ def _fill_value_or_default(chunk_spec: ArraySpec) -> Any:
7969 return fill_value
8070
8171
82- def _get_pool () -> ThreadPoolExecutor :
83- """Lazily get or create the module-level thread pool."""
84- global _POOL
85- if _POOL is None :
86- _POOL = ThreadPoolExecutor (max_workers = os .cpu_count ())
87- return _POOL
88-
89-
90- _POOL : ThreadPoolExecutor | None = None
91-
92-
9372# ---------------------------------------------------------------------------
9473# SyncCodecPipeline
9574# ---------------------------------------------------------------------------
9675
76+
9777@dataclass (frozen = True )
9878class SyncCodecPipeline (CodecPipeline ):
99- """A codec pipeline that runs full per-chunk codec chains in a thread pool .
79+ """A codec pipeline that runs per-chunk codec chains synchronously .
10080
10181 When all codecs implement ``_decode_sync`` / ``_encode_sync`` (i.e.
102- ``supports_sync`` is ``True``), the entire per-chunk codec chain is
103- dispatched as a single work item via ``ThreadPoolExecutor.map`` .
82+ ``supports_sync`` is ``True``), the per-chunk codec chain runs synchronously
83+ without any ``asyncio.to_thread`` overhead .
10484
10585 When a codec does *not* support sync (e.g. ``ShardingCodec``), the pipeline
106- falls back to the standard async ``decode`` / ``encode`` path from the base
107- class for that batch, preserving correctness while still benefiting from
108- sync dispatch for the inner pipeline.
86+ falls back to the standard async ``decode`` / ``encode`` path, preserving
87+ correctness while still benefiting from sync dispatch for the inner pipeline.
10988 """
11089
11190 array_array_codecs : tuple [ArrayArrayCodec , ...]
@@ -165,10 +144,12 @@ def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
165144 return byte_length
166145
167146 # -------------------------------------------------------------------
168- # Per-chunk codec chain (for pool.map dispatch)
147+ # Per-chunk sync codec chain
169148 # -------------------------------------------------------------------
170149
171- def _resolve_metadata_chain (self , chunk_spec : ArraySpec ) -> tuple [
150+ def _resolve_metadata_chain (
151+ self , chunk_spec : ArraySpec
152+ ) -> tuple [
172153 list [tuple [ArrayArrayCodec , ArraySpec ]],
173154 tuple [ArrayBytesCodec , ArraySpec ],
174155 list [tuple [BytesBytesCodec , ArraySpec ]],
@@ -244,7 +225,7 @@ def _encode_one(
244225 return chunk_bytes
245226
246227 # -------------------------------------------------------------------
247- # Top-level decode / encode (pool.map over full chain per chunk )
228+ # Async fallback for codecs that don't support sync (e.g. sharding )
248229 # -------------------------------------------------------------------
249230
250231 async def _decode_async (
@@ -255,18 +236,18 @@ async def _decode_async(
255236 chunk_bytes_batch , chunk_specs = _unzip2 (chunk_bytes_and_specs )
256237
257238 for bb_codec in self .bytes_bytes_codecs [::- 1 ]:
258- chunk_bytes_batch = list (await bb_codec . decode (
259- zip (chunk_bytes_batch , chunk_specs , strict = False )
260- ))
239+ chunk_bytes_batch = list (
240+ await bb_codec . decode ( zip (chunk_bytes_batch , chunk_specs , strict = False ) )
241+ )
261242
262- chunk_array_batch : list [NDBuffer | None ] = list (await self . array_bytes_codec . decode (
263- zip (chunk_bytes_batch , chunk_specs , strict = False )
264- ))
243+ chunk_array_batch : list [NDBuffer | None ] = list (
244+ await self . array_bytes_codec . decode ( zip (chunk_bytes_batch , chunk_specs , strict = False ) )
245+ )
265246
266247 for aa_codec in self .array_array_codecs [::- 1 ]:
267- chunk_array_batch = list (await aa_codec . decode (
268- zip (chunk_array_batch , chunk_specs , strict = False )
269- ))
248+ chunk_array_batch = list (
249+ await aa_codec . decode ( zip (chunk_array_batch , chunk_specs , strict = False ) )
250+ )
270251
271252 return chunk_array_batch
272253
@@ -278,24 +259,28 @@ async def _encode_async(
278259 chunk_array_batch , chunk_specs = _unzip2 (chunk_arrays_and_specs )
279260
280261 for aa_codec in self .array_array_codecs :
281- chunk_array_batch = list (await aa_codec . encode (
282- zip (chunk_array_batch , chunk_specs , strict = False )
283- ))
262+ chunk_array_batch = list (
263+ await aa_codec . encode ( zip (chunk_array_batch , chunk_specs , strict = False ) )
264+ )
284265 chunk_specs = list (resolve_batched (aa_codec , chunk_specs ))
285266
286- chunk_bytes_batch : list [Buffer | None ] = list (await self . array_bytes_codec . encode (
287- zip (chunk_array_batch , chunk_specs , strict = False )
288- ))
267+ chunk_bytes_batch : list [Buffer | None ] = list (
268+ await self . array_bytes_codec . encode ( zip (chunk_array_batch , chunk_specs , strict = False ) )
269+ )
289270 chunk_specs = list (resolve_batched (self .array_bytes_codec , chunk_specs ))
290271
291272 for bb_codec in self .bytes_bytes_codecs :
292- chunk_bytes_batch = list (await bb_codec . encode (
293- zip (chunk_bytes_batch , chunk_specs , strict = False )
294- ))
273+ chunk_bytes_batch = list (
274+ await bb_codec . encode ( zip (chunk_bytes_batch , chunk_specs , strict = False ) )
275+ )
295276 chunk_specs = list (resolve_batched (bb_codec , chunk_specs ))
296277
297278 return chunk_bytes_batch
298279
280+ # -------------------------------------------------------------------
281+ # Top-level decode / encode
282+ # -------------------------------------------------------------------
283+
299284 async def decode (
300285 self ,
301286 chunk_bytes_and_specs : Iterable [tuple [Buffer | None , ArraySpec ]],
@@ -307,22 +292,14 @@ async def decode(
307292 if not self ._all_sync :
308293 return await self ._decode_async (items )
309294
310- # Precompute the metadata chain once (same for all chunks in a batch)
295+ # All codecs support sync -- run the full chain inline (no threading).
311296 _ , first_spec = items [0 ]
312297 aa_chain , ab_pair , bb_chain = self ._resolve_metadata_chain (first_spec )
313298
314- pool = _get_pool ()
315- loop = asyncio .get_running_loop ()
316-
317- # Submit each chunk to the pool and wrap each Future for asyncio.
318- async_futures = [
319- asyncio .wrap_future (
320- pool .submit (self ._decode_one , item [0 ], item [1 ], aa_chain , ab_pair , bb_chain ),
321- loop = loop ,
322- )
323- for item in items
299+ return [
300+ self ._decode_one (chunk_bytes , chunk_spec , aa_chain , ab_pair , bb_chain )
301+ for chunk_bytes , chunk_spec in items
324302 ]
325- return await asyncio .gather (* async_futures )
326303
327304 async def encode (
328305 self ,
@@ -335,21 +312,11 @@ async def encode(
335312 if not self ._all_sync :
336313 return await self ._encode_async (items )
337314
338- pool = _get_pool ()
339- loop = asyncio .get_running_loop ()
340-
341- # Submit each chunk to the pool and wrap each Future for asyncio.
342- async_futures = [
343- asyncio .wrap_future (
344- pool .submit (self ._encode_one , item [0 ], item [1 ]),
345- loop = loop ,
346- )
347- for item in items
348- ]
349- return await asyncio .gather (* async_futures )
315+ # All codecs support sync -- run the full chain inline (no threading).
316+ return [self ._encode_one (chunk_array , chunk_spec ) for chunk_array , chunk_spec in items ]
350317
351318 # -------------------------------------------------------------------
352- # read / write (IO stays async, compute goes through pool.map )
319+ # read / write (IO stays async, compute runs inline )
353320 # -------------------------------------------------------------------
354321
355322 async def read (
@@ -381,16 +348,22 @@ async def _read_batch(
381348 config .get ("async.concurrency" ),
382349 )
383350
384- # Phase 2: Compute -- decode via pool.map
351+ # Phase 2: Compute -- decode + scatter
385352 decode_items = [
386353 (chunk_bytes , chunk_spec )
387- for chunk_bytes , (_ , chunk_spec , * _ ) in zip (
388- chunk_bytes_batch , batch_info , strict = False
389- )
354+ for chunk_bytes , (_ , chunk_spec , * _ ) in zip (chunk_bytes_batch , batch_info , strict = False )
390355 ]
356+
391357 chunk_array_batch : Iterable [NDBuffer | None ] = await self .decode (decode_items )
358+ self ._scatter (chunk_array_batch , batch_info , out , drop_axes )
392359
393- # Phase 3: Scatter into output buffer
360+ @staticmethod
361+ def _scatter (
362+ chunk_array_batch : Iterable [NDBuffer | None ],
363+ batch_info : list [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
364+ out : NDBuffer ,
365+ drop_axes : tuple [int , ...],
366+ ) -> None :
394367 for chunk_array , (_ , chunk_spec , chunk_selection , out_selection , _ ) in zip (
395368 chunk_array_batch , batch_info , strict = False
396369 ):
@@ -450,8 +423,7 @@ def _merge_chunk_array(
450423 chunk_value = value [out_selection ]
451424 if drop_axes != ():
452425 item = tuple (
453- None if idx in drop_axes else slice (None )
454- for idx in range (chunk_spec .ndim )
426+ None if idx in drop_axes else slice (None ) for idx in range (chunk_spec .ndim )
455427 )
456428 chunk_value = chunk_value [item ]
457429 chunk_array [chunk_selection ] = chunk_value
@@ -473,7 +445,7 @@ async def _read_key(
473445 return None
474446 return await byte_setter .get (prototype = prototype )
475447
476- chunk_bytes_batch : Iterable [Buffer | None ]
448+ chunk_bytes_batch : list [Buffer | None ]
477449 chunk_bytes_batch = await concurrent_map (
478450 [
479451 (
@@ -486,16 +458,58 @@ async def _read_key(
486458 config .get ("async.concurrency" ),
487459 )
488460
489- # Phase 2: Compute -- decode existing chunks via pool.map
461+ # Phase 2: Compute -- decode, merge, encode
490462 decode_items = [
491463 (chunk_bytes , chunk_spec )
492- for chunk_bytes , (_ , chunk_spec , * _ ) in zip (
493- chunk_bytes_batch , batch_info , strict = False
494- )
464+ for chunk_bytes , (_ , chunk_spec , * _ ) in zip (chunk_bytes_batch , batch_info , strict = False )
495465 ]
466+
467+ encoded_batch = await self ._write_batch_compute (decode_items , batch_info , value , drop_axes )
468+
469+ # Phase 3: IO -- write to store
470+ async def _write_key (byte_setter : ByteSetter , chunk_bytes : Buffer | None ) -> None :
471+ if chunk_bytes is None :
472+ await byte_setter .delete ()
473+ else :
474+ await byte_setter .set (chunk_bytes )
475+
476+ await concurrent_map (
477+ [
478+ (byte_setter , chunk_bytes )
479+ for chunk_bytes , (byte_setter , * _ ) in zip (encoded_batch , batch_info , strict = False )
480+ ],
481+ _write_key ,
482+ config .get ("async.concurrency" ),
483+ )
484+
485+ async def _write_batch_compute (
486+ self ,
487+ decode_items : list [tuple [Buffer | None , ArraySpec ]],
488+ batch_info : list [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
489+ value : NDBuffer ,
490+ drop_axes : tuple [int , ...],
491+ ) -> list [Buffer | None ]:
492+ """Async fallback for compute phase of _write_batch."""
496493 chunk_array_decoded : Iterable [NDBuffer | None ] = await self .decode (decode_items )
497494
498- # Phase 3: Merge (pure compute, single-threaded -- touches shared `value` buffer)
495+ chunk_array_batch = self ._merge_and_filter (
496+ chunk_array_decoded , batch_info , value , drop_axes
497+ )
498+
499+ encode_items = [
500+ (chunk_array , chunk_spec )
501+ for chunk_array , (_ , chunk_spec , * _ ) in zip (chunk_array_batch , batch_info , strict = False )
502+ ]
503+ return list (await self .encode (encode_items ))
504+
505+ def _merge_and_filter (
506+ self ,
507+ chunk_array_decoded : Iterable [NDBuffer | None ],
508+ batch_info : list [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
509+ value : NDBuffer ,
510+ drop_axes : tuple [int , ...],
511+ ) -> list [NDBuffer | None ]:
512+ """Merge decoded chunks with new data and filter empty chunks."""
499513 chunk_array_merged = [
500514 self ._merge_chunk_array (
501515 chunk_array ,
@@ -515,44 +529,18 @@ async def _read_key(
515529 ) in zip (chunk_array_decoded , batch_info , strict = False )
516530 ]
517531
518- chunk_array_batch : list [NDBuffer | None ] = []
532+ result : list [NDBuffer | None ] = []
519533 for chunk_array , (_ , chunk_spec , * _ ) in zip (chunk_array_merged , batch_info , strict = False ):
520534 if chunk_array is None :
521- chunk_array_batch .append (None )
535+ result .append (None )
522536 else :
523537 if not chunk_spec .config .write_empty_chunks and chunk_array .all_equal (
524538 _fill_value_or_default (chunk_spec )
525539 ):
526- chunk_array_batch .append (None )
540+ result .append (None )
527541 else :
528- chunk_array_batch .append (chunk_array )
529-
530- # Phase 4: Compute -- encode via pool.map
531- encode_items = [
532- (chunk_array , chunk_spec )
533- for chunk_array , (_ , chunk_spec , * _ ) in zip (
534- chunk_array_batch , batch_info , strict = False
535- )
536- ]
537- chunk_bytes_batch = await self .encode (encode_items )
538-
539- # Phase 5: IO -- write to store
540- async def _write_key (byte_setter : ByteSetter , chunk_bytes : Buffer | None ) -> None :
541- if chunk_bytes is None :
542- await byte_setter .delete ()
543- else :
544- await byte_setter .set (chunk_bytes )
545-
546- await concurrent_map (
547- [
548- (byte_setter , chunk_bytes )
549- for chunk_bytes , (byte_setter , * _ ) in zip (
550- chunk_bytes_batch , batch_info , strict = False
551- )
552- ],
553- _write_key ,
554- config .get ("async.concurrency" ),
555- )
542+ result .append (chunk_array )
543+ return result
556544
557545
558546register_pipeline (SyncCodecPipeline )
0 commit comments