Skip to content

Commit 850bbe4

Browse files
d-v-bclaude
andcommitted
feat: implement thread-pool parallelism for sync read/write
read_sync and write_sync now support n_workers parameter. When > 0, the decode (read) or decode+merge+encode (write) compute steps are parallelized across threads via ThreadPoolExecutor.map. IO remains sequential. This helps when codecs release the GIL (gzip, blosc, zstd) — e.g. gzip decompression is 41% of read time and runs entirely in C. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 01f4445 commit 850bbe4

1 file changed

Lines changed: 66 additions & 22 deletions

File tree

src/zarr/core/codec_pipeline.py

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,12 @@ def read_sync(
940940
drop_axes: tuple[int, ...] = (),
941941
n_workers: int = 0,
942942
) -> tuple[GetResult, ...]:
943-
"""Synchronous read: fetch -> decode -> scatter, per chunk."""
943+
"""Synchronous read: fetch -> decode -> scatter, per chunk.
944+
945+
When ``n_workers > 0`` and there are multiple chunks, the decode
946+
step is parallelized across threads. This helps when codecs
947+
release the GIL (e.g. gzip, blosc, zstd).
948+
"""
944949
assert self._sync_transform is not None
945950
transform = self._sync_transform
946951

@@ -951,20 +956,39 @@ def read_sync(
951956
fill = fill_value_or_default(batch[0][1])
952957
_missing = GetResult(status="missing")
953958

954-
results: list[GetResult] = []
955-
for bg, chunk_spec, chunk_selection, out_selection, _ in batch:
956-
raw = bg.get_sync(prototype=chunk_spec.prototype) # type: ignore[attr-defined]
957-
if raw is None:
958-
out[out_selection] = fill
959-
results.append(_missing)
960-
continue
959+
# Phase 1: fetch all chunks (IO, sequential)
960+
raw_buffers: list[Buffer | None] = [
961+
bg.get_sync(prototype=cs.prototype) # type: ignore[attr-defined]
962+
for bg, cs, *_ in batch
963+
]
961964

965+
# Phase 2: decode (compute, optionally threaded)
966+
def _decode_one(raw: Buffer | None, chunk_spec: ArraySpec) -> NDBuffer | None:
967+
if raw is None:
968+
return None
962969
chunk_shape = (
963970
chunk_spec.shape
964971
if chunk_spec.shape != transform.array_spec.shape
965972
else None
966973
)
967-
decoded = transform.decode_chunk(raw, chunk_shape=chunk_shape)
974+
return transform.decode_chunk(raw, chunk_shape=chunk_shape)
975+
976+
specs = [cs for _, cs, *_ in batch]
977+
if n_workers > 0 and len(batch) > 1:
978+
with ThreadPoolExecutor(max_workers=n_workers) as pool:
979+
decoded_list = list(pool.map(_decode_one, raw_buffers, specs))
980+
else:
981+
decoded_list = [_decode_one(raw, spec) for raw, spec in zip(raw_buffers, specs, strict=True)]
982+
983+
# Phase 3: scatter (sequential, writes to shared output buffer)
984+
results: list[GetResult] = []
985+
for (_, _chunk_spec, chunk_selection, out_selection, _), decoded in zip(
986+
batch, decoded_list, strict=True
987+
):
988+
if decoded is None:
989+
out[out_selection] = fill
990+
results.append(_missing)
991+
continue
968992

969993
selected = decoded[chunk_selection]
970994
if drop_axes:
@@ -981,31 +1005,42 @@ def write_sync(
9811005
drop_axes: tuple[int, ...] = (),
9821006
n_workers: int = 0,
9831007
) -> None:
984-
"""Synchronous write: merge -> encode -> store, per chunk."""
1008+
"""Synchronous write: fetch existing -> merge+encode -> store.
1009+
1010+
When ``n_workers > 0`` and there are multiple chunks, the
1011+
merge+encode step is parallelized across threads.
1012+
"""
9851013
assert self._sync_transform is not None
9861014
transform = self._sync_transform
9871015

9881016
batch = list(batch_info)
9891017
if not batch:
9901018
return
9911019

992-
for bs, chunk_spec, chunk_selection, out_selection, is_complete in batch:
1020+
# Phase 1: fetch existing chunks (IO, sequential)
1021+
existing_buffers: list[Buffer | None] = [
1022+
None if ic else bs.get_sync(prototype=cs.prototype) # type: ignore[attr-defined]
1023+
for bs, cs, _, _, ic in batch
1024+
]
1025+
1026+
# Phase 2: decode + merge + encode (compute, optionally threaded)
1027+
def _process_one(
1028+
idx: int,
1029+
) -> Buffer | None:
1030+
_, chunk_spec, chunk_selection, out_selection, is_complete = batch[idx]
1031+
existing_bytes = existing_buffers[idx]
9931032
chunk_shape = (
9941033
chunk_spec.shape
9951034
if chunk_spec.shape != transform.array_spec.shape
9961035
else None
9971036
)
9981037

999-
# Decode existing chunk if partial write
10001038
existing_chunk_array: NDBuffer | None = None
1001-
if not is_complete:
1002-
existing_bytes = bs.get_sync(prototype=chunk_spec.prototype) # type: ignore[attr-defined]
1003-
if existing_bytes is not None:
1004-
existing_chunk_array = transform.decode_chunk(
1005-
existing_bytes, chunk_shape=chunk_shape
1006-
)
1039+
if existing_bytes is not None:
1040+
existing_chunk_array = transform.decode_chunk(
1041+
existing_bytes, chunk_shape=chunk_shape
1042+
)
10071043

1008-
# Merge
10091044
chunk_array = self._merge_chunk_array(
10101045
existing_chunk_array,
10111046
value,
@@ -1020,10 +1055,19 @@ def write_sync(
10201055
if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal(
10211056
fill_value_or_default(chunk_spec)
10221057
):
1023-
bs.delete_sync() # type: ignore[attr-defined]
1024-
continue
1058+
return None
1059+
1060+
return transform.encode_chunk(chunk_array, chunk_shape=chunk_shape)
1061+
1062+
indices = list(range(len(batch)))
1063+
if n_workers > 0 and len(batch) > 1:
1064+
with ThreadPoolExecutor(max_workers=n_workers) as pool:
1065+
encoded_list = list(pool.map(_process_one, indices))
1066+
else:
1067+
encoded_list = [_process_one(i) for i in indices]
10251068

1026-
encoded = transform.encode_chunk(chunk_array, chunk_shape=chunk_shape)
1069+
# Phase 3: store results (IO, sequential)
1070+
for (bs, *_rest), encoded in zip(batch, encoded_list, strict=True):
10271071
if encoded is None:
10281072
bs.delete_sync() # type: ignore[attr-defined]
10291073
else:

0 commit comments

Comments
 (0)