1515 Store ,
1616 SuffixByteRequest ,
1717)
18- from zarr .storage ._utils import _relativize_path , with_concurrency_limit
18+ from zarr .storage ._utils import ConcurrencyLimiter , _relativize_path , with_concurrency_limit
1919
2020if TYPE_CHECKING :
2121 from collections .abc import AsyncGenerator , Coroutine , Iterable , Sequence
3838T_Store = TypeVar ("T_Store" , bound = "_UpstreamObjectStore" )
3939
4040
41- class ObjectStore (Store , Generic [T_Store ]):
41+ class ObjectStore (Store , ConcurrencyLimiter , Generic [T_Store ]):
4242 """
4343 Store that uses obstore for fast read/write from AWS, GCP, Azure.
4444
@@ -60,7 +60,6 @@ class ObjectStore(Store, Generic[T_Store]):
6060
6161 store : T_Store
6262 """The underlying obstore instance."""
63- _semaphore : asyncio .Semaphore | None
6463
6564 def __eq__ (self , value : object ) -> bool :
6665 if not isinstance (value , ObjectStore ):
@@ -80,23 +79,16 @@ def __init__(
8079 ) -> None :
8180 if not store .__class__ .__module__ .startswith ("obstore" ):
8281 raise TypeError (f"expected ObjectStore class, got { store !r} " )
83- super ().__init__ (read_only = read_only )
82+ Store .__init__ (self , read_only = read_only )
83+ ConcurrencyLimiter .__init__ (self , concurrency_limit )
8484 self .store = store
85- self ._semaphore = (
86- asyncio .Semaphore (concurrency_limit ) if concurrency_limit is not None else None
87- )
88-
89- def get_semaphore (self ) -> asyncio .Semaphore | None :
90- return self ._semaphore
9185
9286 def with_read_only (self , read_only : bool = False ) -> Self :
9387 # docstring inherited
94- sem = self .get_semaphore ()
95- concurrency_limit = sem ._value if sem else None
9688 return type (self )(
9789 store = self .store ,
9890 read_only = read_only ,
99- concurrency_limit = concurrency_limit ,
91+ concurrency_limit = self . concurrency_limit ,
10092 )
10193
10294 def __str__ (self ) -> str :
@@ -114,7 +106,7 @@ def __setstate__(self, state: dict[Any, Any]) -> None:
114106 state ["store" ] = pickle .loads (state ["store" ])
115107 self .__dict__ .update (state )
116108
117- @with_concurrency_limit ()
109+ @with_concurrency_limit
118110 async def get (
119111 self , key : str , prototype : BufferPrototype , byte_range : ByteRequest | None = None
120112 ) -> Buffer | None :
@@ -138,7 +130,6 @@ async def get_partial_values(
138130 import obstore as obs
139131
140132 key_ranges = list (key_ranges )
141- semaphore = self .get_semaphore ()
142133 # Group bounded range requests by path for batched fetching
143134 per_file_bounded : dict [str , list [tuple [int , RangeByteRequest ]]] = defaultdict (list )
144135 other_requests : list [tuple [int , str , ByteRequest | None ]] = []
@@ -155,12 +146,7 @@ async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]])
155146 """Batch multiple range requests for the same file using get_ranges_async."""
156147 starts = [r .start for _ , r in requests ]
157148 ends = [r .end for _ , r in requests ]
158- if semaphore :
159- async with semaphore :
160- responses = await obs .get_ranges_async (
161- self .store , path = path , starts = starts , ends = ends
162- )
163- else :
149+ async with self ._limit ():
164150 responses = await obs .get_ranges_async (
165151 self .store , path = path , starts = starts , ends = ends
166152 )
@@ -170,10 +156,7 @@ async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]])
170156 async def _fetch_one (idx : int , path : str , byte_range : ByteRequest | None ) -> None :
171157 """Fetch a single non-range request with semaphore limiting."""
172158 try :
173- if semaphore :
174- async with semaphore :
175- buffers [idx ] = await self ._get_impl (path , prototype , byte_range , obs )
176- else :
159+ async with self ._limit ():
177160 buffers [idx ] = await self ._get_impl (path , prototype , byte_range , obs )
178161 except _ALLOWED_EXCEPTIONS :
179162 pass # buffers[idx] stays None
@@ -240,7 +223,7 @@ def supports_writes(self) -> bool:
240223 # docstring inherited
241224 return True
242225
243- @with_concurrency_limit ()
226+ @with_concurrency_limit
244227 async def set (self , key : str , value : Buffer ) -> None :
245228 # docstring inherited
246229 import obstore as obs
@@ -255,31 +238,22 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
255238 import obstore as obs
256239
257240 self ._check_writable ()
258- semaphore = self .get_semaphore ()
259241
260242 async def _set_with_limit (key : str , value : Buffer ) -> None :
261243 buf = value .as_buffer_like ()
262- if semaphore :
263- async with semaphore :
264- await obs .put_async (self .store , key , buf )
265- else :
244+ async with self ._limit ():
266245 await obs .put_async (self .store , key , buf )
267246
268247 await asyncio .gather (* [_set_with_limit (key , value ) for key , value in values ])
269248
270249 async def set_if_not_exists (self , key : str , value : Buffer ) -> None :
271250 # docstring inherited
272- # Note: Not decorated to avoid deadlock when called in batch via gather()
251+ # Not decorated to avoid deadlock when called in batch via gather()
273252 import obstore as obs
274253
275254 self ._check_writable ()
276255 buf = value .as_buffer_like ()
277- semaphore = self .get_semaphore ()
278- if semaphore :
279- async with semaphore :
280- with contextlib .suppress (obs .exceptions .AlreadyExistsError ):
281- await obs .put_async (self .store , key , buf , mode = "create" )
282- else :
256+ async with self ._limit ():
283257 with contextlib .suppress (obs .exceptions .AlreadyExistsError ):
284258 await obs .put_async (self .store , key , buf , mode = "create" )
285259
@@ -288,7 +262,7 @@ def supports_deletes(self) -> bool:
288262 # docstring inherited
289263 return True
290264
291- @with_concurrency_limit ()
265+ @with_concurrency_limit
292266 async def delete (self , key : str ) -> None :
293267 # docstring inherited
294268 import obstore as obs
@@ -311,15 +285,9 @@ async def delete_dir(self, prefix: str) -> None:
311285 prefix += "/"
312286
313287 metas = await obs .list (self .store , prefix ).collect_async ()
314- semaphore = self .get_semaphore ()
315288
316- # Delete with semaphore limiting to avoid deadlock
317289 async def _delete_with_limit (path : str ) -> None :
318- if semaphore :
319- async with semaphore :
320- with contextlib .suppress (FileNotFoundError ):
321- await obs .delete_async (self .store , path )
322- else :
290+ async with self ._limit ():
323291 with contextlib .suppress (FileNotFoundError ):
324292 await obs .delete_async (self .store , path )
325293
0 commit comments