Skip to content

Commit 15929bd

Browse files
author
Test User
committed
feat: Support for masked arrays (fixes zarr-developers#3744)
1 parent 427df3b commit 15929bd

File tree

13 files changed

+242
-51
lines changed

13 files changed

+242
-51
lines changed

changes/3744.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add support for numpy masked arrays (numpy.ma.MaskedArray) in zarr.array(). When a masked array is provided, it is automatically converted to a filled array, with a warning that the mask is not preserved. Users who need to preserve mask information should use separate arrays or structured dtypes for storing both data and mask information.

src/zarr/abc/store.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,26 @@ def _check_writable(self) -> None:
187187
if self.read_only:
188188
raise ValueError("store was opened in read-only mode and does not support writing")
189189

190+
def _ensure_buffer(self, value: Buffer | bytes) -> Buffer:
191+
"""Convert bytes to Buffer if needed.
192+
193+
Parameters
194+
----------
195+
value : Buffer or bytes
196+
The value to ensure is a Buffer.
197+
198+
Returns
199+
-------
200+
Buffer
201+
The input value if it's already a Buffer, or a new Buffer created from bytes.
202+
"""
203+
# avoid circular import
204+
from zarr.core.buffer import Buffer
205+
206+
if isinstance(value, bytes):
207+
return Buffer.from_bytes(value)
208+
return value
209+
190210
@abstractmethod
191211
def __eq__(self, value: object) -> bool:
192212
"""Equality comparison."""
@@ -465,24 +485,28 @@ def supports_writes(self) -> bool:
465485
...
466486

467487
@abstractmethod
468-
async def set(self, key: str, value: Buffer) -> None:
488+
async def set(self, key: str, value: Buffer | bytes) -> None:
469489
"""Store a (key, value) pair.
470490
471491
Parameters
472492
----------
473493
key : str
474-
value : Buffer
494+
value : Buffer or bytes
495+
The value to store. If bytes are provided, they will be converted
496+
to a Buffer internally.
475497
"""
476498
...
477499

478-
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
500+
async def set_if_not_exists(self, key: str, value: Buffer | bytes) -> None:
479501
"""
480502
Store a key to ``value`` if the key is not already present.
481503
482504
Parameters
483505
----------
484506
key : str
485-
value : Buffer
507+
value : Buffer or bytes
508+
The value to store. If bytes are provided, they will be converted
509+
to a Buffer internally.
486510
"""
487511
# Note for implementers: the default implementation provided here
488512
# is not safe for concurrent writers. There's a race condition between
@@ -491,7 +515,7 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None:
491515
if not await self.exists(key):
492516
await self.set(key, value)
493517

494-
async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
518+
async def _set_many(self, values: Iterable[tuple[str, Buffer | bytes]]) -> None:
495519
"""
496520
Insert multiple (key, value) pairs into storage.
497521
"""

src/zarr/api/asynchronous.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,16 @@ async def array(data: npt.ArrayLike | AnyArray, **kwargs: Any) -> AnyAsyncArray:
621621
if isinstance(data, Array):
622622
return await from_array(data=data, **kwargs)
623623

624+
# Handle masked arrays by converting to filled array
625+
if isinstance(data, np.ma.MaskedArray):
626+
warnings.warn(
627+
"Masked arrays are not fully supported in Zarr. The mask will not be preserved. "
628+
"Consider using zarr's structured dtype or a separate array for the mask if you need to preserve it.",
629+
UserWarning,
630+
stacklevel=2,
631+
)
632+
data = cast(np.ndarray, data.filled())
633+
624634
# ensure data is array-like
625635
if not hasattr(data, "shape") or not hasattr(data, "dtype"):
626636
data = np.asanyarray(data)

src/zarr/experimental/cache_store.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,18 +357,21 @@ async def get(
357357
else:
358358
return await self._get_try_cache(key, prototype, byte_range)
359359

360-
async def set(self, key: str, value: Buffer) -> None:
360+
async def set(self, key: str, value: Buffer | bytes) -> None:
361361
"""
362362
Store data in the underlying store and optionally in cache.
363363
364364
Parameters
365365
----------
366366
key : str
367367
The key to store under
368-
value : Buffer
369-
The data to store
368+
value : Buffer or bytes
369+
The data to store. If bytes are provided, they will be converted
370+
to a Buffer internally.
370371
"""
371372
await super().set(key, value)
373+
# Ensure value is a Buffer for caching
374+
value = self._ensure_buffer(value)
372375
# Invalidate all cached byte-range entries (source data changed)
373376
async with self._state.lock:
374377
self._invalidate_range_entries(key)

src/zarr/storage/_common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,14 @@ async def get(
173173
prototype = default_buffer_prototype()
174174
return await self.store.get(self.path, prototype=prototype, byte_range=byte_range)
175175

176-
async def set(self, value: Buffer) -> None:
176+
async def set(self, value: Buffer | bytes) -> None:
177177
"""
178178
Write bytes to the store.
179179
180180
Parameters
181181
----------
182-
value : Buffer
183-
The buffer to write.
182+
value : Buffer or bytes
183+
The buffer or bytes to write.
184184
"""
185185
await self.store.set(self.path, value)
186186

@@ -201,14 +201,14 @@ async def delete_dir(self) -> None:
201201
"""
202202
await self.store.delete_dir(self.path)
203203

204-
async def set_if_not_exists(self, default: Buffer) -> None:
204+
async def set_if_not_exists(self, default: Buffer | bytes) -> None:
205205
"""
206206
Store a key to ``value`` if the key is not already present.
207207
208208
Parameters
209209
----------
210-
default : Buffer
211-
The buffer to store if the key is not already present.
210+
default : Buffer or bytes
211+
The buffer or bytes to store if the key is not already present.
212212
"""
213213
await self.store.set_if_not_exists(self.path, default)
214214

src/zarr/storage/_fsspec.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,17 +318,14 @@ async def get(
318318
async def set(
319319
self,
320320
key: str,
321-
value: Buffer,
321+
value: Buffer | bytes,
322322
byte_range: tuple[int, int] | None = None,
323323
) -> None:
324324
# docstring inherited
325325
if not self._is_open:
326326
await self._open()
327327
self._check_writable()
328-
if not isinstance(value, Buffer):
329-
raise TypeError(
330-
f"FsspecStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead."
331-
)
328+
value = self._ensure_buffer(value)
332329
path = _dereference_path(self.path, key)
333330
# write data
334331
if byte_range:

src/zarr/storage/_local.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,26 +269,23 @@ async def get_partial_values(
269269
args.append((_get, path, prototype, byte_range))
270270
return await concurrent_map(args, asyncio.to_thread, limit=None) # TODO: fix limit
271271

272-
async def set(self, key: str, value: Buffer) -> None:
272+
async def set(self, key: str, value: Buffer | bytes) -> None:
273273
# docstring inherited
274274
return await self._set(key, value)
275275

276-
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
276+
async def set_if_not_exists(self, key: str, value: Buffer | bytes) -> None:
277277
# docstring inherited
278278
try:
279279
return await self._set(key, value, exclusive=True)
280280
except FileExistsError:
281281
pass
282282

283-
async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None:
283+
async def _set(self, key: str, value: Buffer | bytes, exclusive: bool = False) -> None:
284284
if not self._is_open:
285285
await self._open()
286286
self._check_writable()
287287
assert isinstance(key, str)
288-
if not isinstance(value, Buffer):
289-
raise TypeError(
290-
f"LocalStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead."
291-
)
288+
value = self._ensure_buffer(value)
292289
path = self.root / key
293290
await asyncio.to_thread(_put, path, value, exclusive=exclusive)
294291

src/zarr/storage/_logging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,12 @@ async def exists(self, key: str) -> bool:
190190
with self.log(key):
191191
return await self._store.exists(key)
192192

193-
async def set(self, key: str, value: Buffer) -> None:
193+
async def set(self, key: str, value: Buffer | bytes) -> None:
194194
# docstring inherited
195195
with self.log(key):
196196
return await self._store.set(key=key, value=value)
197197

198-
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
198+
async def set_if_not_exists(self, key: str, value: Buffer | bytes) -> None:
199199
# docstring inherited
200200
with self.log(key):
201201
return await self._store.set_if_not_exists(key=key, value=value)

src/zarr/storage/_memory.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -156,26 +156,24 @@ async def exists(self, key: str) -> bool:
156156
# docstring inherited
157157
return key in self._store_dict
158158

159-
async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
159+
async def set(self, key: str, value: Buffer | bytes, byte_range: tuple[int, int] | None = None) -> None:
160160
# docstring inherited
161161
self._check_writable()
162162
await self._ensure_open()
163163
assert isinstance(key, str)
164-
if not isinstance(value, Buffer):
165-
raise TypeError(
166-
f"MemoryStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead."
167-
)
164+
value = self._ensure_buffer(value)
168165
if byte_range is not None:
169166
buf = self._store_dict[key]
170167
buf[byte_range[0] : byte_range[1]] = value
171168
self._store_dict[key] = buf
172169
else:
173170
self._store_dict[key] = value
174171

175-
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
172+
async def set_if_not_exists(self, key: str, value: Buffer | bytes) -> None:
176173
# docstring inherited
177174
self._check_writable()
178175
await self._ensure_open()
176+
value = self._ensure_buffer(value)
179177
self._store_dict.setdefault(key, value)
180178

181179
async def delete(self, key: str) -> None:
@@ -506,14 +504,11 @@ def from_dict(cls, store_dict: MutableMapping[str, Buffer]) -> Self:
506504
gpu_store_dict = {k: gpu.Buffer.from_buffer(v) for k, v in store_dict.items()}
507505
return cls(gpu_store_dict)
508506

509-
async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
507+
async def set(self, key: str, value: Buffer | bytes, byte_range: tuple[int, int] | None = None) -> None:
510508
# docstring inherited
511509
self._check_writable()
512510
assert isinstance(key, str)
513-
if not isinstance(value, Buffer):
514-
raise TypeError(
515-
f"GpuMemoryStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead."
516-
)
511+
value = self._ensure_buffer(value)
517512
# Convert to gpu.Buffer
518513
gpu_value = value if isinstance(value, gpu.Buffer) else gpu.Buffer.from_buffer(value)
519514
await super().set(key, gpu_value, byte_range=byte_range)

src/zarr/storage/_obstore.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,20 +166,21 @@ def supports_writes(self) -> bool:
166166
# docstring inherited
167167
return True
168168

169-
async def set(self, key: str, value: Buffer) -> None:
169+
async def set(self, key: str, value: Buffer | bytes) -> None:
170170
# docstring inherited
171171
import obstore as obs
172172

173173
self._check_writable()
174-
174+
value = self._ensure_buffer(value)
175175
buf = value.as_buffer_like()
176176
await obs.put_async(self.store, key, buf)
177177

178-
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
178+
async def set_if_not_exists(self, key: str, value: Buffer | bytes) -> None:
179179
# docstring inherited
180180
import obstore as obs
181181

182182
self._check_writable()
183+
value = self._ensure_buffer(value)
183184
buf = value.as_buffer_like()
184185
with contextlib.suppress(obs.exceptions.AlreadyExistsError):
185186
await obs.put_async(self.store, key, buf, mode="create")

0 commit comments

Comments
 (0)