Skip to content

Commit 65d1230

Browse files
committed
fix perf regressions
1 parent f427898 commit 65d1230

1 file changed

Lines changed: 110 additions & 122 deletions

File tree

src/zarr/experimental/sync_codecs.py

Lines changed: 110 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,18 @@
22
33
The standard zarr codec pipeline (``BatchedCodecPipeline``) wraps fundamentally
44
synchronous operations (e.g. gzip compress/decompress) in ``asyncio.to_thread``.
5-
The ``SyncCodecPipeline`` in this module eliminates that overhead by dispatching
6-
the full codec chain for each chunk via ``ThreadPoolExecutor.map``, achieving
7-
2-11x throughput improvements.
5+
The ``SyncCodecPipeline`` in this module eliminates that overhead by running
6+
per-chunk codec chains synchronously, achieving 2-11x throughput improvements.
87
98
Usage::
109
1110
import zarr
12-
from zarr.experimental.sync_codecs import SyncCodecPipeline
13-
14-
arr = zarr.create_array(
15-
store,
16-
shape=(100, 100),
17-
chunks=(32, 32),
18-
dtype="float64",
19-
codec_pipeline_class=SyncCodecPipeline,
20-
)
11+
12+
zarr.config.set({"codec_pipeline.path": "zarr.experimental.sync_codecs.SyncCodecPipeline"})
2113
"""
2214

2315
from __future__ import annotations
2416

25-
import asyncio
26-
import os
27-
from concurrent.futures import ThreadPoolExecutor
2817
from dataclasses import dataclass
2918
from itertools import islice
3019
from typing import TYPE_CHECKING, TypeVar
@@ -64,6 +53,7 @@
6453
# Pipeline helpers
6554
# ---------------------------------------------------------------------------
6655

56+
6757
def _batched(iterable: Iterable[T], n: int) -> Iterable[tuple[T, ...]]:
6858
if n < 1:
6959
raise ValueError("n must be at least one")
@@ -79,33 +69,22 @@ def _fill_value_or_default(chunk_spec: ArraySpec) -> Any:
7969
return fill_value
8070

8171

82-
def _get_pool() -> ThreadPoolExecutor:
83-
"""Lazily get or create the module-level thread pool."""
84-
global _POOL
85-
if _POOL is None:
86-
_POOL = ThreadPoolExecutor(max_workers=os.cpu_count())
87-
return _POOL
88-
89-
90-
_POOL: ThreadPoolExecutor | None = None
91-
92-
9372
# ---------------------------------------------------------------------------
9473
# SyncCodecPipeline
9574
# ---------------------------------------------------------------------------
9675

76+
9777
@dataclass(frozen=True)
9878
class SyncCodecPipeline(CodecPipeline):
99-
"""A codec pipeline that runs full per-chunk codec chains in a thread pool.
79+
"""A codec pipeline that runs per-chunk codec chains synchronously.
10080
10181
When all codecs implement ``_decode_sync`` / ``_encode_sync`` (i.e.
102-
``supports_sync`` is ``True``), the entire per-chunk codec chain is
103-
dispatched as a single work item via ``ThreadPoolExecutor.map``.
82+
``supports_sync`` is ``True``), the per-chunk codec chain runs synchronously
83+
without any ``asyncio.to_thread`` overhead.
10484
10585
When a codec does *not* support sync (e.g. ``ShardingCodec``), the pipeline
106-
falls back to the standard async ``decode`` / ``encode`` path from the base
107-
class for that batch, preserving correctness while still benefiting from
108-
sync dispatch for the inner pipeline.
86+
falls back to the standard async ``decode`` / ``encode`` path, preserving
87+
correctness while still benefiting from sync dispatch for the inner pipeline.
10988
"""
11089

11190
array_array_codecs: tuple[ArrayArrayCodec, ...]
@@ -165,10 +144,12 @@ def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
165144
return byte_length
166145

167146
# -------------------------------------------------------------------
168-
# Per-chunk codec chain (for pool.map dispatch)
147+
# Per-chunk sync codec chain
169148
# -------------------------------------------------------------------
170149

171-
def _resolve_metadata_chain(self, chunk_spec: ArraySpec) -> tuple[
150+
def _resolve_metadata_chain(
151+
self, chunk_spec: ArraySpec
152+
) -> tuple[
172153
list[tuple[ArrayArrayCodec, ArraySpec]],
173154
tuple[ArrayBytesCodec, ArraySpec],
174155
list[tuple[BytesBytesCodec, ArraySpec]],
@@ -244,7 +225,7 @@ def _encode_one(
244225
return chunk_bytes
245226

246227
# -------------------------------------------------------------------
247-
# Top-level decode / encode (pool.map over full chain per chunk)
228+
# Async fallback for codecs that don't support sync (e.g. sharding)
248229
# -------------------------------------------------------------------
249230

250231
async def _decode_async(
@@ -255,18 +236,18 @@ async def _decode_async(
255236
chunk_bytes_batch, chunk_specs = _unzip2(chunk_bytes_and_specs)
256237

257238
for bb_codec in self.bytes_bytes_codecs[::-1]:
258-
chunk_bytes_batch = list(await bb_codec.decode(
259-
zip(chunk_bytes_batch, chunk_specs, strict=False)
260-
))
239+
chunk_bytes_batch = list(
240+
await bb_codec.decode(zip(chunk_bytes_batch, chunk_specs, strict=False))
241+
)
261242

262-
chunk_array_batch: list[NDBuffer | None] = list(await self.array_bytes_codec.decode(
263-
zip(chunk_bytes_batch, chunk_specs, strict=False)
264-
))
243+
chunk_array_batch: list[NDBuffer | None] = list(
244+
await self.array_bytes_codec.decode(zip(chunk_bytes_batch, chunk_specs, strict=False))
245+
)
265246

266247
for aa_codec in self.array_array_codecs[::-1]:
267-
chunk_array_batch = list(await aa_codec.decode(
268-
zip(chunk_array_batch, chunk_specs, strict=False)
269-
))
248+
chunk_array_batch = list(
249+
await aa_codec.decode(zip(chunk_array_batch, chunk_specs, strict=False))
250+
)
270251

271252
return chunk_array_batch
272253

@@ -278,24 +259,28 @@ async def _encode_async(
278259
chunk_array_batch, chunk_specs = _unzip2(chunk_arrays_and_specs)
279260

280261
for aa_codec in self.array_array_codecs:
281-
chunk_array_batch = list(await aa_codec.encode(
282-
zip(chunk_array_batch, chunk_specs, strict=False)
283-
))
262+
chunk_array_batch = list(
263+
await aa_codec.encode(zip(chunk_array_batch, chunk_specs, strict=False))
264+
)
284265
chunk_specs = list(resolve_batched(aa_codec, chunk_specs))
285266

286-
chunk_bytes_batch: list[Buffer | None] = list(await self.array_bytes_codec.encode(
287-
zip(chunk_array_batch, chunk_specs, strict=False)
288-
))
267+
chunk_bytes_batch: list[Buffer | None] = list(
268+
await self.array_bytes_codec.encode(zip(chunk_array_batch, chunk_specs, strict=False))
269+
)
289270
chunk_specs = list(resolve_batched(self.array_bytes_codec, chunk_specs))
290271

291272
for bb_codec in self.bytes_bytes_codecs:
292-
chunk_bytes_batch = list(await bb_codec.encode(
293-
zip(chunk_bytes_batch, chunk_specs, strict=False)
294-
))
273+
chunk_bytes_batch = list(
274+
await bb_codec.encode(zip(chunk_bytes_batch, chunk_specs, strict=False))
275+
)
295276
chunk_specs = list(resolve_batched(bb_codec, chunk_specs))
296277

297278
return chunk_bytes_batch
298279

280+
# -------------------------------------------------------------------
281+
# Top-level decode / encode
282+
# -------------------------------------------------------------------
283+
299284
async def decode(
300285
self,
301286
chunk_bytes_and_specs: Iterable[tuple[Buffer | None, ArraySpec]],
@@ -307,22 +292,14 @@ async def decode(
307292
if not self._all_sync:
308293
return await self._decode_async(items)
309294

310-
# Precompute the metadata chain once (same for all chunks in a batch)
295+
# All codecs support sync -- run the full chain inline (no threading).
311296
_, first_spec = items[0]
312297
aa_chain, ab_pair, bb_chain = self._resolve_metadata_chain(first_spec)
313298

314-
pool = _get_pool()
315-
loop = asyncio.get_running_loop()
316-
317-
# Submit each chunk to the pool and wrap each Future for asyncio.
318-
async_futures = [
319-
asyncio.wrap_future(
320-
pool.submit(self._decode_one, item[0], item[1], aa_chain, ab_pair, bb_chain),
321-
loop=loop,
322-
)
323-
for item in items
299+
return [
300+
self._decode_one(chunk_bytes, chunk_spec, aa_chain, ab_pair, bb_chain)
301+
for chunk_bytes, chunk_spec in items
324302
]
325-
return await asyncio.gather(*async_futures)
326303

327304
async def encode(
328305
self,
@@ -335,21 +312,11 @@ async def encode(
335312
if not self._all_sync:
336313
return await self._encode_async(items)
337314

338-
pool = _get_pool()
339-
loop = asyncio.get_running_loop()
340-
341-
# Submit each chunk to the pool and wrap each Future for asyncio.
342-
async_futures = [
343-
asyncio.wrap_future(
344-
pool.submit(self._encode_one, item[0], item[1]),
345-
loop=loop,
346-
)
347-
for item in items
348-
]
349-
return await asyncio.gather(*async_futures)
315+
# All codecs support sync -- run the full chain inline (no threading).
316+
return [self._encode_one(chunk_array, chunk_spec) for chunk_array, chunk_spec in items]
350317

351318
# -------------------------------------------------------------------
352-
# read / write (IO stays async, compute goes through pool.map)
319+
# read / write (IO stays async, compute runs inline)
353320
# -------------------------------------------------------------------
354321

355322
async def read(
@@ -381,16 +348,22 @@ async def _read_batch(
381348
config.get("async.concurrency"),
382349
)
383350

384-
# Phase 2: Compute -- decode via pool.map
351+
# Phase 2: Compute -- decode + scatter
385352
decode_items = [
386353
(chunk_bytes, chunk_spec)
387-
for chunk_bytes, (_, chunk_spec, *_) in zip(
388-
chunk_bytes_batch, batch_info, strict=False
389-
)
354+
for chunk_bytes, (_, chunk_spec, *_) in zip(chunk_bytes_batch, batch_info, strict=False)
390355
]
356+
391357
chunk_array_batch: Iterable[NDBuffer | None] = await self.decode(decode_items)
358+
self._scatter(chunk_array_batch, batch_info, out, drop_axes)
392359

393-
# Phase 3: Scatter into output buffer
360+
@staticmethod
361+
def _scatter(
362+
chunk_array_batch: Iterable[NDBuffer | None],
363+
batch_info: list[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
364+
out: NDBuffer,
365+
drop_axes: tuple[int, ...],
366+
) -> None:
394367
for chunk_array, (_, chunk_spec, chunk_selection, out_selection, _) in zip(
395368
chunk_array_batch, batch_info, strict=False
396369
):
@@ -450,8 +423,7 @@ def _merge_chunk_array(
450423
chunk_value = value[out_selection]
451424
if drop_axes != ():
452425
item = tuple(
453-
None if idx in drop_axes else slice(None)
454-
for idx in range(chunk_spec.ndim)
426+
None if idx in drop_axes else slice(None) for idx in range(chunk_spec.ndim)
455427
)
456428
chunk_value = chunk_value[item]
457429
chunk_array[chunk_selection] = chunk_value
@@ -473,7 +445,7 @@ async def _read_key(
473445
return None
474446
return await byte_setter.get(prototype=prototype)
475447

476-
chunk_bytes_batch: Iterable[Buffer | None]
448+
chunk_bytes_batch: list[Buffer | None]
477449
chunk_bytes_batch = await concurrent_map(
478450
[
479451
(
@@ -486,16 +458,58 @@ async def _read_key(
486458
config.get("async.concurrency"),
487459
)
488460

489-
# Phase 2: Compute -- decode existing chunks via pool.map
461+
# Phase 2: Compute -- decode, merge, encode
490462
decode_items = [
491463
(chunk_bytes, chunk_spec)
492-
for chunk_bytes, (_, chunk_spec, *_) in zip(
493-
chunk_bytes_batch, batch_info, strict=False
494-
)
464+
for chunk_bytes, (_, chunk_spec, *_) in zip(chunk_bytes_batch, batch_info, strict=False)
495465
]
466+
467+
encoded_batch = await self._write_batch_compute(decode_items, batch_info, value, drop_axes)
468+
469+
# Phase 3: IO -- write to store
470+
async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> None:
471+
if chunk_bytes is None:
472+
await byte_setter.delete()
473+
else:
474+
await byte_setter.set(chunk_bytes)
475+
476+
await concurrent_map(
477+
[
478+
(byte_setter, chunk_bytes)
479+
for chunk_bytes, (byte_setter, *_) in zip(encoded_batch, batch_info, strict=False)
480+
],
481+
_write_key,
482+
config.get("async.concurrency"),
483+
)
484+
485+
async def _write_batch_compute(
486+
self,
487+
decode_items: list[tuple[Buffer | None, ArraySpec]],
488+
batch_info: list[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
489+
value: NDBuffer,
490+
drop_axes: tuple[int, ...],
491+
) -> list[Buffer | None]:
492+
"""Async fallback for compute phase of _write_batch."""
496493
chunk_array_decoded: Iterable[NDBuffer | None] = await self.decode(decode_items)
497494

498-
# Phase 3: Merge (pure compute, single-threaded -- touches shared `value` buffer)
495+
chunk_array_batch = self._merge_and_filter(
496+
chunk_array_decoded, batch_info, value, drop_axes
497+
)
498+
499+
encode_items = [
500+
(chunk_array, chunk_spec)
501+
for chunk_array, (_, chunk_spec, *_) in zip(chunk_array_batch, batch_info, strict=False)
502+
]
503+
return list(await self.encode(encode_items))
504+
505+
def _merge_and_filter(
506+
self,
507+
chunk_array_decoded: Iterable[NDBuffer | None],
508+
batch_info: list[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
509+
value: NDBuffer,
510+
drop_axes: tuple[int, ...],
511+
) -> list[NDBuffer | None]:
512+
"""Merge decoded chunks with new data and filter empty chunks."""
499513
chunk_array_merged = [
500514
self._merge_chunk_array(
501515
chunk_array,
@@ -515,44 +529,18 @@ async def _read_key(
515529
) in zip(chunk_array_decoded, batch_info, strict=False)
516530
]
517531

518-
chunk_array_batch: list[NDBuffer | None] = []
532+
result: list[NDBuffer | None] = []
519533
for chunk_array, (_, chunk_spec, *_) in zip(chunk_array_merged, batch_info, strict=False):
520534
if chunk_array is None:
521-
chunk_array_batch.append(None)
535+
result.append(None)
522536
else:
523537
if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal(
524538
_fill_value_or_default(chunk_spec)
525539
):
526-
chunk_array_batch.append(None)
540+
result.append(None)
527541
else:
528-
chunk_array_batch.append(chunk_array)
529-
530-
# Phase 4: Compute -- encode via pool.map
531-
encode_items = [
532-
(chunk_array, chunk_spec)
533-
for chunk_array, (_, chunk_spec, *_) in zip(
534-
chunk_array_batch, batch_info, strict=False
535-
)
536-
]
537-
chunk_bytes_batch = await self.encode(encode_items)
538-
539-
# Phase 5: IO -- write to store
540-
async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> None:
541-
if chunk_bytes is None:
542-
await byte_setter.delete()
543-
else:
544-
await byte_setter.set(chunk_bytes)
545-
546-
await concurrent_map(
547-
[
548-
(byte_setter, chunk_bytes)
549-
for chunk_bytes, (byte_setter, *_) in zip(
550-
chunk_bytes_batch, batch_info, strict=False
551-
)
552-
],
553-
_write_key,
554-
config.get("async.concurrency"),
555-
)
542+
result.append(chunk_array)
543+
return result
556544

557545

558546
register_pipeline(SyncCodecPipeline)

0 commit comments

Comments
 (0)