Skip to content

Commit 01e1f73

Browse files
committed
fix type hints, prevent thread pool leakage, make codec pipeline introspection more efficient
1 parent e9db616 commit 01e1f73

3 files changed

Lines changed: 21 additions & 13 deletions

File tree

changes/3715.misc.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ different codecs. This algorithm is aware of the latency required for setting up
99
single-chunk workloads we skip the thread pool entirely.
1010

1111
Use of the thread pool can be disabled in the global configuration. The minimum number of threads
12-
and the maximum number of threads can be set via the configuration as well.
12+
and the maximum number of threads can be set via the configuration as well.

src/zarr/codecs/sharding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,12 @@ async def get(
100100
return value[start:stop]
101101

102102
def get_sync(
103-
self, prototype: BufferPrototype | None = None, byte_range: ByteRequest | None = None
103+
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
104104
) -> Buffer | None:
105105
# Sync equivalent of get() — just a dict lookup, no IO.
106+
assert prototype == default_buffer_prototype(), (
107+
f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}"
108+
)
106109
value = self.shard_dict.get(self.chunk_coords)
107110
if value is None:
108111
return None

src/zarr/core/codec_pipeline.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
from concurrent.futures import ThreadPoolExecutor
5-
from dataclasses import dataclass
5+
from dataclasses import dataclass, field
66
from itertools import pairwise
77
from typing import TYPE_CHECKING, Any, TypeVar, cast
88
from warnings import warn
@@ -182,7 +182,10 @@ def _get_pool(max_workers: int) -> ThreadPoolExecutor:
182182
"""Get a thread pool with at most *max_workers* threads."""
183183
global _pool
184184
if _pool is None or _pool._max_workers < max_workers:
185+
old = _pool
185186
_pool = ThreadPoolExecutor(max_workers=max_workers)
187+
if old is not None:
188+
old.shutdown(wait=False)
186189
return _pool
187190

188191

@@ -214,6 +217,8 @@ class BatchedCodecPipeline(CodecPipeline):
214217
bytes_bytes_codecs: tuple[BytesBytesCodec, ...]
215218
batch_size: int | None = None
216219

220+
_all_sync: bool = field(default=False, init=False, repr=False, compare=False)
221+
217222
def __post_init__(self) -> None:
218223
if self.batch_size is not None:
219224
warn(
@@ -222,11 +227,12 @@ def __post_init__(self) -> None:
222227
FutureWarning,
223228
stacklevel=2,
224229
)
225-
226-
@property
227-
def _all_sync(self) -> bool:
228-
"""True when every codec in the chain implements SupportsSyncCodec."""
229-
return all(isinstance(c, SupportsSyncCodec) for c in self)
230+
# Compute once; frozen dataclass requires object.__setattr__.
231+
object.__setattr__(
232+
self,
233+
"_all_sync",
234+
all(isinstance(c, SupportsSyncCodec) for c in self),
235+
)
230236

231237
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
232238
return type(self).from_codecs(c.evolve_from_array_spec(array_spec=array_spec) for c in self)
@@ -710,7 +716,7 @@ async def _write_chunk(
710716
)
711717

712718
# 3) Write result
713-
if chunk_bytes is _DELETED or chunk_bytes is None:
719+
if chunk_bytes is _DELETED:
714720
await byte_setter.delete()
715721
else:
716722
await byte_setter.set(chunk_bytes) # type: ignore[arg-type]
@@ -1020,22 +1026,21 @@ def write_sync(
10201026
for encoded, (byte_setter, *_) in zip(encoded_list, batch_info_list, strict=False):
10211027
if encoded is _DELETED:
10221028
byte_setter.delete_sync()
1023-
elif encoded is not None:
1024-
byte_setter.set_sync(encoded)
10251029
else:
1026-
byte_setter.delete_sync()
1030+
byte_setter.set_sync(encoded)
10271031

10281032

10291033
def codecs_from_list(
10301034
codecs: Iterable[Codec],
10311035
) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]:
10321036
from zarr.codecs.sharding import ShardingCodec
10331037

1038+
codecs = list(codecs)
10341039
array_array: tuple[ArrayArrayCodec, ...] = ()
10351040
array_bytes_maybe: ArrayBytesCodec | None = None
10361041
bytes_bytes: tuple[BytesBytesCodec, ...] = ()
10371042

1038-
if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(tuple(codecs)) > 1:
1043+
if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(codecs) > 1:
10391044
warn(
10401045
"Combining a `sharding_indexed` codec disables partial reads and "
10411046
"writes, which may lead to inefficient performance.",

0 commit comments

Comments
 (0)