Skip to content

Commit 127a024

Browse files
committed
use base class instead of protocol
1 parent f3d92f6 commit 127a024

6 files changed

Lines changed: 100 additions & 164 deletions

File tree

src/zarr/storage/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
from zarr.storage._logging import LoggingStore
1111
from zarr.storage._memory import GpuMemoryStore, MemoryStore
1212
from zarr.storage._obstore import ObjectStore
13+
from zarr.storage._utils import ConcurrencyLimiter
1314
from zarr.storage._wrapper import WrapperStore
1415
from zarr.storage._zip import ZipStore
1516

1617
__all__ = [
18+
"ConcurrencyLimiter",
1719
"FsspecStore",
1820
"GpuMemoryStore",
1921
"LocalStore",

src/zarr/storage/_fsspec.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from zarr.core.buffer import Buffer
1919
from zarr.errors import ZarrUserWarning
2020
from zarr.storage._common import _dereference_path
21-
from zarr.storage._utils import with_concurrency_limit
21+
from zarr.storage._utils import ConcurrencyLimiter, with_concurrency_limit
2222

2323
if TYPE_CHECKING:
2424
from collections.abc import AsyncIterator, Iterable
@@ -69,7 +69,7 @@ def _make_async(fs: AbstractFileSystem) -> AsyncFileSystem:
6969
return AsyncFileSystemWrapper(fs, asynchronous=True)
7070

7171

72-
class FsspecStore(Store):
72+
class FsspecStore(Store, ConcurrencyLimiter):
7373
"""
7474
Store for remote data based on FSSpec.
7575
@@ -122,7 +122,6 @@ class FsspecStore(Store):
122122
fs: AsyncFileSystem
123123
allowed_exceptions: tuple[type[Exception], ...]
124124
path: str
125-
_semaphore: asyncio.Semaphore | None
126125

127126
def __init__(
128127
self,
@@ -133,13 +132,11 @@ def __init__(
133132
allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS,
134133
concurrency_limit: int | None = 50,
135134
) -> None:
136-
super().__init__(read_only=read_only)
135+
Store.__init__(self, read_only=read_only)
136+
ConcurrencyLimiter.__init__(self, concurrency_limit)
137137
self.fs = fs
138138
self.path = path
139139
self.allowed_exceptions = allowed_exceptions
140-
self._semaphore = (
141-
asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None
142-
)
143140

144141
if not self.fs.async_impl:
145142
raise TypeError("Filesystem needs to support async operations.")
@@ -255,19 +252,14 @@ def from_url(
255252

256253
return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)
257254

258-
def get_semaphore(self) -> asyncio.Semaphore | None:
259-
return self._semaphore
260-
261255
def with_read_only(self, read_only: bool = False) -> FsspecStore:
262256
# docstring inherited
263-
sem = self.get_semaphore()
264-
concurrency_limit = sem._value if sem else None
265257
return type(self)(
266258
fs=self.fs,
267259
path=self.path,
268260
allowed_exceptions=self.allowed_exceptions,
269261
read_only=read_only,
270-
concurrency_limit=concurrency_limit,
262+
concurrency_limit=self.concurrency_limit,
271263
)
272264

273265
async def clear(self) -> None:
@@ -290,7 +282,7 @@ def __eq__(self, other: object) -> bool:
290282
and self.fs == other.fs
291283
)
292284

293-
@with_concurrency_limit()
285+
@with_concurrency_limit
294286
async def get(
295287
self,
296288
key: str,
@@ -333,7 +325,7 @@ async def get(
333325
else:
334326
return value
335327

336-
@with_concurrency_limit()
328+
@with_concurrency_limit
337329
async def set(
338330
self,
339331
key: str,
@@ -359,23 +351,19 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
359351
if not self._is_open:
360352
await self._open()
361353
self._check_writable()
362-
semaphore = self.get_semaphore()
363354

364355
async def _set_with_limit(key: str, value: Buffer) -> None:
365356
if not isinstance(value, Buffer):
366357
raise TypeError(
367358
f"FsspecStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead."
368359
)
369360
path = _dereference_path(self.path, key)
370-
if semaphore:
371-
async with semaphore:
372-
await self.fs._pipe_file(path, value.to_bytes())
373-
else:
361+
async with self._limit():
374362
await self.fs._pipe_file(path, value.to_bytes())
375363

376364
await asyncio.gather(*[_set_with_limit(key, value) for key, value in values])
377365

378-
@with_concurrency_limit()
366+
@with_concurrency_limit
379367
async def delete(self, key: str) -> None:
380368
# docstring inherited
381369
self._check_writable()

src/zarr/storage/_local.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from zarr.core.buffer import Buffer
2121
from zarr.core.buffer.core import default_buffer_prototype
22-
from zarr.storage._utils import with_concurrency_limit
22+
from zarr.storage._utils import ConcurrencyLimiter, with_concurrency_limit
2323

2424
if TYPE_CHECKING:
2525
from collections.abc import AsyncIterator, Iterable, Iterator
@@ -86,7 +86,7 @@ def _put(path: Path, value: Buffer, exclusive: bool = False) -> int:
8686
return f.write(view)
8787

8888

89-
class LocalStore(Store):
89+
class LocalStore(Store, ConcurrencyLimiter):
9090
"""
9191
Store for the local file system.
9292
@@ -113,7 +113,6 @@ class LocalStore(Store):
113113
supports_listing: bool = True
114114

115115
root: Path
116-
_semaphore: asyncio.Semaphore | None
117116

118117
def __init__(
119118
self,
@@ -122,29 +121,22 @@ def __init__(
122121
read_only: bool = False,
123122
concurrency_limit: int | None = 100,
124123
) -> None:
125-
super().__init__(read_only=read_only)
126124
if isinstance(root, str):
127125
root = Path(root)
128126
if not isinstance(root, Path):
129127
raise TypeError(
130128
f"'root' must be a string or Path instance. Got an instance of {type(root)} instead."
131129
)
130+
Store.__init__(self, read_only=read_only)
131+
ConcurrencyLimiter.__init__(self, concurrency_limit)
132132
self.root = root
133-
self._semaphore = (
134-
asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None
135-
)
136-
137-
def get_semaphore(self) -> asyncio.Semaphore | None:
138-
return self._semaphore
139133

140134
def with_read_only(self, read_only: bool = False) -> Self:
141135
# docstring inherited
142-
sem = self.get_semaphore()
143-
concurrency_limit = sem._value if sem else None
144136
return type(self)(
145137
root=self.root,
146138
read_only=read_only,
147-
concurrency_limit=concurrency_limit,
139+
concurrency_limit=self.concurrency_limit,
148140
)
149141

150142
@classmethod
@@ -207,7 +199,7 @@ def __repr__(self) -> str:
207199
def __eq__(self, other: object) -> bool:
208200
return isinstance(other, type(self)) and self.root == other.root
209201

210-
@with_concurrency_limit()
202+
@with_concurrency_limit
211203
async def get(
212204
self,
213205
key: str,
@@ -233,18 +225,13 @@ async def get_partial_values(
233225
key_ranges: Iterable[tuple[str, ByteRequest | None]],
234226
) -> list[Buffer | None]:
235227
# docstring inherited
236-
# Note: We directly call the I/O functions here, wrapped with semaphore
237-
# to avoid deadlock from calling the decorated get() method
238-
239-
semaphore = self.get_semaphore()
228+
# We directly call the I/O functions here, wrapped with the semaphore,
229+
# to avoid deadlock from calling the decorated get() method.
240230

241231
async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None:
242232
path = self.root / key
243233
try:
244-
if semaphore:
245-
async with semaphore:
246-
return await asyncio.to_thread(_get, path, prototype, byte_range)
247-
else:
234+
async with self._limit():
248235
return await asyncio.to_thread(_get, path, prototype, byte_range)
249236
except (FileNotFoundError, IsADirectoryError, NotADirectoryError):
250237
return None
@@ -264,7 +251,7 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None:
264251
except FileExistsError:
265252
pass
266253

267-
@with_concurrency_limit()
254+
@with_concurrency_limit
268255
async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None:
269256
if not self._is_open:
270257
await self._open()
@@ -277,7 +264,7 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None:
277264
path = self.root / key
278265
await asyncio.to_thread(_put, path, value, exclusive=exclusive)
279266

280-
@with_concurrency_limit()
267+
@with_concurrency_limit
281268
async def delete(self, key: str) -> None:
282269
"""
283270
Remove a key from the store.

src/zarr/storage/_obstore.py

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
Store,
1616
SuffixByteRequest,
1717
)
18-
from zarr.storage._utils import _relativize_path, with_concurrency_limit
18+
from zarr.storage._utils import ConcurrencyLimiter, _relativize_path, with_concurrency_limit
1919

2020
if TYPE_CHECKING:
2121
from collections.abc import AsyncGenerator, Coroutine, Iterable, Sequence
@@ -38,7 +38,7 @@
3838
T_Store = TypeVar("T_Store", bound="_UpstreamObjectStore")
3939

4040

41-
class ObjectStore(Store, Generic[T_Store]):
41+
class ObjectStore(Store, ConcurrencyLimiter, Generic[T_Store]):
4242
"""
4343
Store that uses obstore for fast read/write from AWS, GCP, Azure.
4444
@@ -60,7 +60,6 @@ class ObjectStore(Store, Generic[T_Store]):
6060

6161
store: T_Store
6262
"""The underlying obstore instance."""
63-
_semaphore: asyncio.Semaphore | None
6463

6564
def __eq__(self, value: object) -> bool:
6665
if not isinstance(value, ObjectStore):
@@ -80,23 +79,16 @@ def __init__(
8079
) -> None:
8180
if not store.__class__.__module__.startswith("obstore"):
8281
raise TypeError(f"expected ObjectStore class, got {store!r}")
83-
super().__init__(read_only=read_only)
82+
Store.__init__(self, read_only=read_only)
83+
ConcurrencyLimiter.__init__(self, concurrency_limit)
8484
self.store = store
85-
self._semaphore = (
86-
asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None
87-
)
88-
89-
def get_semaphore(self) -> asyncio.Semaphore | None:
90-
return self._semaphore
9185

9286
def with_read_only(self, read_only: bool = False) -> Self:
9387
# docstring inherited
94-
sem = self.get_semaphore()
95-
concurrency_limit = sem._value if sem else None
9688
return type(self)(
9789
store=self.store,
9890
read_only=read_only,
99-
concurrency_limit=concurrency_limit,
91+
concurrency_limit=self.concurrency_limit,
10092
)
10193

10294
def __str__(self) -> str:
@@ -114,7 +106,7 @@ def __setstate__(self, state: dict[Any, Any]) -> None:
114106
state["store"] = pickle.loads(state["store"])
115107
self.__dict__.update(state)
116108

117-
@with_concurrency_limit()
109+
@with_concurrency_limit
118110
async def get(
119111
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
120112
) -> Buffer | None:
@@ -138,7 +130,6 @@ async def get_partial_values(
138130
import obstore as obs
139131

140132
key_ranges = list(key_ranges)
141-
semaphore = self.get_semaphore()
142133
# Group bounded range requests by path for batched fetching
143134
per_file_bounded: dict[str, list[tuple[int, RangeByteRequest]]] = defaultdict(list)
144135
other_requests: list[tuple[int, str, ByteRequest | None]] = []
@@ -155,12 +146,7 @@ async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]])
155146
"""Batch multiple range requests for the same file using get_ranges_async."""
156147
starts = [r.start for _, r in requests]
157148
ends = [r.end for _, r in requests]
158-
if semaphore:
159-
async with semaphore:
160-
responses = await obs.get_ranges_async(
161-
self.store, path=path, starts=starts, ends=ends
162-
)
163-
else:
149+
async with self._limit():
164150
responses = await obs.get_ranges_async(
165151
self.store, path=path, starts=starts, ends=ends
166152
)
@@ -170,10 +156,7 @@ async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]])
170156
async def _fetch_one(idx: int, path: str, byte_range: ByteRequest | None) -> None:
171157
"""Fetch a single non-range request with semaphore limiting."""
172158
try:
173-
if semaphore:
174-
async with semaphore:
175-
buffers[idx] = await self._get_impl(path, prototype, byte_range, obs)
176-
else:
159+
async with self._limit():
177160
buffers[idx] = await self._get_impl(path, prototype, byte_range, obs)
178161
except _ALLOWED_EXCEPTIONS:
179162
pass # buffers[idx] stays None
@@ -240,7 +223,7 @@ def supports_writes(self) -> bool:
240223
# docstring inherited
241224
return True
242225

243-
@with_concurrency_limit()
226+
@with_concurrency_limit
244227
async def set(self, key: str, value: Buffer) -> None:
245228
# docstring inherited
246229
import obstore as obs
@@ -255,31 +238,22 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
255238
import obstore as obs
256239

257240
self._check_writable()
258-
semaphore = self.get_semaphore()
259241

260242
async def _set_with_limit(key: str, value: Buffer) -> None:
261243
buf = value.as_buffer_like()
262-
if semaphore:
263-
async with semaphore:
264-
await obs.put_async(self.store, key, buf)
265-
else:
244+
async with self._limit():
266245
await obs.put_async(self.store, key, buf)
267246

268247
await asyncio.gather(*[_set_with_limit(key, value) for key, value in values])
269248

270249
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
271250
# docstring inherited
272-
# Note: Not decorated to avoid deadlock when called in batch via gather()
251+
# Not decorated to avoid deadlock when called in batch via gather()
273252
import obstore as obs
274253

275254
self._check_writable()
276255
buf = value.as_buffer_like()
277-
semaphore = self.get_semaphore()
278-
if semaphore:
279-
async with semaphore:
280-
with contextlib.suppress(obs.exceptions.AlreadyExistsError):
281-
await obs.put_async(self.store, key, buf, mode="create")
282-
else:
256+
async with self._limit():
283257
with contextlib.suppress(obs.exceptions.AlreadyExistsError):
284258
await obs.put_async(self.store, key, buf, mode="create")
285259

@@ -288,7 +262,7 @@ def supports_deletes(self) -> bool:
288262
# docstring inherited
289263
return True
290264

291-
@with_concurrency_limit()
265+
@with_concurrency_limit
292266
async def delete(self, key: str) -> None:
293267
# docstring inherited
294268
import obstore as obs
@@ -311,15 +285,9 @@ async def delete_dir(self, prefix: str) -> None:
311285
prefix += "/"
312286

313287
metas = await obs.list(self.store, prefix).collect_async()
314-
semaphore = self.get_semaphore()
315288

316-
# Delete with semaphore limiting to avoid deadlock
317289
async def _delete_with_limit(path: str) -> None:
318-
if semaphore:
319-
async with semaphore:
320-
with contextlib.suppress(FileNotFoundError):
321-
await obs.delete_async(self.store, path)
322-
else:
290+
async with self._limit():
323291
with contextlib.suppress(FileNotFoundError):
324292
await obs.delete_async(self.store, path)
325293

0 commit comments

Comments
 (0)