Skip to content

Commit fbd3b06

Browse files
committed
Working on get/setitem masking
1 parent fdb9957 commit fbd3b06

1 file changed

Lines changed: 47 additions & 19 deletions

File tree

src/blosc2/ndarray.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,11 @@ def __rpow__(self, value: int | float | NDArray | NDField | blosc2.C2Array, /) -
963963
_check_allowed_dtypes(value)
964964
return blosc2.LazyExpr(new_op=(value, "**", self))
965965

966+
def __bool__(self) -> bool:
967+
if math.prod(self.shape) != 1:
968+
raise ValueError("The truth value of an array of shape {self.shape} is ambiguous.")
969+
return bool(self[()])
970+
966971
@is_documented_by(sum)
967972
def sum(self, axis=None, dtype=None, keepdims=False, **kwargs):
968973
expr = blosc2.LazyExpr(new_op=(self, None, None))
@@ -1579,19 +1584,39 @@ def get_fselection_numpy(self, key: list | np.ndarray) -> np.ndarray:
15791584
return out.reshape(out_shape) # should have filled in correct order, just need to reshape
15801585

15811586
# Default when there are booleans
1582-
out = np.empty(out_shape, dtype=self.dtype)
1583-
chunk_size = ndindex.ChunkSize(chunks)
1587+
return self._get_set_findex_default(_slice, out_shape)
1588+
1589+
def _get_set_findex_default(self, _slice, out_shape=None, updater=None):
1590+
_get = False
1591+
if not ((out_shape is None) or (updater is None)):
1592+
raise ValueError("Cannot provide both out_shape and updater.")
1593+
# we have a getitem
1594+
if out_shape is not None:
1595+
_get = True
1596+
out = np.empty(out_shape, dtype=self.dtype)
1597+
elif updater is None:
1598+
raise ValueError("Must provide one of out_shape or updater.")
1599+
else:
1600+
out = self # default return for no intersecting chunks
1601+
if 0 in self.shape:
1602+
return out
1603+
chunk_size = ndindex.ChunkSize(self.chunks)
15841604
# repeated indices are grouped together
1585-
intersecting_chunks = chunk_size.as_subchunks(_slice, shape) # if _slice is (), returns all chunks
1605+
intersecting_chunks = chunk_size.as_subchunks(
1606+
_slice, self.shape
1607+
) # if _slice is (), returns all chunks
15861608
for c in intersecting_chunks:
15871609
sub_idx = _slice.as_subindex(c).raw
15881610
sel_idx = c.as_subindex(_slice)
1589-
new_shape = sel_idx.newshape(out_shape)
1590-
start, stop, step = get_ndarray_start_stop(self.ndim, c.raw, shape)
1611+
start, stop, step, _ = get_ndarray_start_stop(self.ndim, c.raw, self.shape)
15911612
chunk = np.empty(tuple(sp - st for st, sp in zip(start, stop, strict=True)), dtype=self.dtype)
15921613
super().get_slice_numpy(chunk, (start, stop))
1593-
out[sel_idx.raw] = chunk[sub_idx].reshape(new_shape)
1594-
1614+
if _get:
1615+
new_shape = sel_idx.newshape(out_shape)
1616+
out[sel_idx.raw] = chunk[sub_idx].reshape(new_shape)
1617+
else:
1618+
chunk[sub_idx] = updater(sel_idx.raw)
1619+
out = super().set_slice((start, stop), chunk)
15951620
return out
15961621

15971622
def get_oselection_numpy(self, key: list | np.ndarray) -> np.ndarray:
@@ -1765,10 +1790,22 @@ def __setitem__( # noqa : C901
17651790
blosc2_ext.check_access_mode(self.schunk.urlpath, self.schunk.mode)
17661791

17671792
key_ = (key,) if not isinstance(key, tuple) else key
1768-
key_ = tuple(k[:] if isinstance(k, NDArray) else k for k in key_) # decompress NDArrays
1793+
key_ = tuple(k[()] if isinstance(k, NDArray) else k for k in key_) # decompress NDArrays
17691794
key_, mask = process_key(key_, self.shape) # internally handles key an integer
1770-
if builtins.any(isinstance(k, (list, np.ndarray)) for k in key_):
1771-
raise ValueError("Fancy indexing not supported for __setitem__.")
1795+
1796+
def updater(sel_idx):
1797+
return value[sel_idx]
1798+
1799+
if np.isscalar(value): # overwrite updater function for simple cases (faster)
1800+
1801+
def updater(sel_idx):
1802+
return value
1803+
1804+
if builtins.any(isinstance(k, (list, np.ndarray)) for k in key_): # fancy indexing
1805+
_slice = ndindex.ndindex(key_).expand(
1806+
self.shape
1807+
) # handles negative indices -> positive internally
1808+
return self._get_set_findex_default(_slice, updater=updater)
17721809

17731810
start, stop, step, none_mask = get_ndarray_start_stop(self.ndim, key_, self.shape)
17741811
if isinstance(value, NDArray):
@@ -1783,15 +1820,6 @@ def __setitem__( # noqa : C901
17831820
intersecting_chunks = [
17841821
slice_to_chunktuple(s, c) for s, c in zip(_slice, chunks, strict=True)
17851822
] # internally handles negative steps
1786-
if np.isscalar(value): # overwrite updater function for simple cases (faster)
1787-
1788-
def updater(sel_idx):
1789-
return value
1790-
else:
1791-
1792-
def updater(sel_idx):
1793-
return value[sel_idx]
1794-
17951823
out = self # for when shape has 0 (i.e. arr is empty, as then skip loop)
17961824
for c in product(*intersecting_chunks):
17971825
sel_idx, glob_selection, sub_idx = _get_selection(c, _slice, chunks)

0 commit comments

Comments
 (0)