Skip to content

Commit 6186d74

Browse files
committed
Add support for negative steps for setitem
1 parent 5b16a66 commit 6186d74

2 files changed

Lines changed: 132 additions & 38 deletions

File tree

src/blosc2/ndarray.py

Lines changed: 119 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,6 @@ def get_ndarray_start_stop(ndim, key, shape):
9191
stop[i] = shape[i] + stop[i]
9292
if stop[i] > shape[i]:
9393
stop[i] = shape[i]
94-
if step[i] < 0: # (start, stop, -1) => stop < start
95-
temp = start[i]
96-
start[i] = stop[i] + 1 # don't want to include stop
97-
stop[i] = temp + 1 # want to include start
9894

9995
return start, stop, tuple(step), none_mask
10096

@@ -234,7 +230,9 @@ def get_flat_slices(
234230
"""
235231
ndim = len(shape)
236232
if ndim == 0:
237-
return None #### something
233+
# this will likely cause failure since expected output is tuple of slices
234+
# however, the list conversion in the last line causes the process to be killed for some reason if shape = ()
235+
return ()
238236
start = [s[i].start if s[i].start is not None else 0 for i in range(ndim)]
239237
stop = [builtins.min(s[i].stop if s[i].stop is not None else shape[i], shape[i]) for i in range(ndim)]
240238
# Steps are not used in the computation, so raise an error if they are not None or 1
@@ -320,6 +318,10 @@ def reshape(
320318
# We already have the dtype and shape, so return immediately
321319
return dst
322320

321+
if shape == (): # get_flat_slices fails for this case so just return directly
322+
dst[()] = src[()] if src.shape == () else src[0]
323+
return dst
324+
323325
# Copy the data chunk by chunk
324326
for dst_chunk in dst.iterchunks_info():
325327
dst_slice = tuple(
@@ -1505,11 +1507,13 @@ def get_fselection_numpy(self, key: list | np.ndarray) -> np.ndarray:
15051507
chunks = self.chunks
15061508

15071509
# TODO: try to optimise and avoid this expand which seems to copy - maybe np.broadcast
1508-
_slice = ndindex.ndindex(key).expand(shape)
1510+
_slice = ndindex.ndindex(key).expand(shape) # handles negative indices -> positive internally
15091511
out_shape = _slice.newshape(shape)
15101512
_slice = _slice.raw
15111513
# now all indices are slices or arrays of integers (or booleans)
15121514
# moreover, all arrays are consecutive (otherwise an error is raised)
1515+
if builtins.any(k.step < 0 for k in _slice if isinstance(k, slice)):
1516+
raise ValueError("Fancy indexing not supported for slices with negative steps.")
15131517

15141518
if np.all([isinstance(s, (slice, np.ndarray)) for s in _slice]) and np.all(
15151519
[s.dtype is not bool for s in _slice if isinstance(s, np.ndarray)]
@@ -1720,6 +1724,11 @@ def __getitem__( # noqa: C901
17201724
return self.get_fselection_numpy(key_) # fancy index
17211725

17221726
start, stop, step, none_mask = get_ndarray_start_stop(self.ndim, key_, self.shape)
1727+
for i, s in enumerate(step): # (start, stop, -1) => stop < start
1728+
if s < 0:
1729+
temp = start[i]
1730+
start[i] = stop[i] + 1 # don't want to include stop
1731+
stop[i] = temp + 1 # want to include start
17231732
shape = np.array([sp - st for st, sp in zip(start, stop, strict=True)])
17241733
if mask is not None: # there are some dummy dims from ints
17251734
# only get mask for not Nones in key to have nm_ same length as shape
@@ -1746,7 +1755,7 @@ def __getitem__( # noqa: C901
17461755

17471756
return nparr
17481757

1749-
def __setitem__(
1758+
def __setitem__( # noqa : C901
17501759
self,
17511760
key: None | int | slice | Sequence[slice | int | np.bool_ | np.ndarray[int | np.bool_] | None],
17521761
value: object,
@@ -1785,15 +1794,54 @@ def __setitem__(
17851794
key_ = (key,) if not isinstance(key, tuple) else key
17861795
key_ = tuple(k[:] if isinstance(k, NDArray) else k for k in key_) # decompress NDArrays
17871796
key_, mask = process_key(key_, self.shape) # internally handles key an integer
1788-
start, stop, step, none_mask = get_ndarray_start_stop(self.ndim, key, self.shape)
1789-
if step != (1,) * self.ndim:
1790-
raise ValueError("Step parameter is not supported yet")
1791-
key = (start, stop)
1797+
if builtins.any(isinstance(k, (list, np.ndarray)) for k in key_):
1798+
raise ValueError("Fancy indexing not supported for __setitem__.")
1799+
1800+
start, stop, step, none_mask = get_ndarray_start_stop(self.ndim, key_, self.shape)
1801+
if isinstance(value, NDArray):
1802+
value = value[...] # convert to numpy
1803+
1804+
if step != (1,) * self.ndim: # handle non-unit or negative steps
1805+
if np.any(none_mask):
1806+
raise ValueError("Cannot mix non-unit steps and None indexing for __setitem__.")
1807+
chunks = self.chunks
1808+
shape = self.shape
1809+
pos_key = tuple(
1810+
slice(s, st, stp) if stp > 0 else slice(st + 1, s + 1, -stp)
1811+
for s, st, stp in zip(start, stop, step, strict=True)
1812+
) # get positive steps
1813+
_slice = tuple(slice(s, st, stp) for s, st, stp in zip(start, stop, step, strict=True))
1814+
# this will work only for positive steps
1815+
intersecting_chunks = [slice_to_chunktuple(s, c) for s, c in zip(pos_key, chunks, strict=True)]
1816+
if isinstance(value, int | float | bool): # overwrite updater function for simple cases (faster)
1817+
1818+
def updater(sel_idx):
1819+
return value
1820+
else:
1821+
1822+
def updater(sel_idx):
1823+
return value[sel_idx]
1824+
1825+
for c in product(*intersecting_chunks):
1826+
sel_idx, _, sub_idx = _get_selection(c, _slice, chunks, shape, load_full=True)
1827+
sel_idx = tuple(s for s, m in zip(sel_idx, mask, strict=True) if not m)
1828+
sub_idx = tuple(s if not m else k.start for s, m, k in zip(sub_idx, mask, key_, strict=True))
1829+
locstart, locstop = (
1830+
tuple(c_ * cs for c_, cs in zip(c, chunks, strict=True)),
1831+
tuple((c_ + 1) * cs for c_, cs in zip(c, chunks, strict=True)),
1832+
)
1833+
chunk = np.empty(
1834+
tuple(sp - st for st, sp in zip(locstart, locstop, strict=True)), dtype=self.dtype
1835+
)
1836+
super().get_slice_numpy(chunk, (locstart, locstop)) # copy whole chunk
1837+
chunk[sub_idx] = updater(sel_idx) # update relevant parts of chunk
1838+
out = super().set_slice((locstart, locstop), chunk) # load updated chunk into array
1839+
return out
17921840

17931841
shape = [sp - st for sp, st in zip(stop, start, strict=False)]
17941842
if isinstance(value, int | float | bool):
17951843
value = np.full(shape, value, dtype=self.dtype)
1796-
elif isinstance(value, np.ndarray):
1844+
elif isinstance(value, np.ndarray): # handles decompressed NDArray too
17971845
if value.dtype != self.dtype:
17981846
try:
17991847
value = value.astype(self.dtype)
@@ -1803,10 +1851,8 @@ def __setitem__(
18031851
value = value.real.astype(self.dtype)
18041852
if value.shape == ():
18051853
value = np.full(shape, value, dtype=self.dtype)
1806-
elif isinstance(value, NDArray):
1807-
value = value[...]
18081854

1809-
return super().set_slice(key, value)
1855+
return super().set_slice((start, stop), value)
18101856

18111857
def __iter__(self):
18121858
"""Iterate over the (outer) elements of the array.
@@ -4718,38 +4764,74 @@ def slice_to_chunktuple(s, n):
47184764
return tuple(range(start // n, ceiling(stop, n)))
47194765

47204766

4721-
def _get_selection(ctuple, ptuple, chunks):
4767+
def _get_selection(ctuple, ptuple, chunks, shape, load_full=False):
47224768
# we assume that at least one element of chunk intersects with the slice
47234769
# (as a consequence of only looping over intersecting chunks)
4770+
# ptuple is global slice, ctuple is chunk coords (in units of chunks)
47244771
pselection = ()
47254772
for i, s, csize in zip(ctuple, ptuple, chunks, strict=True):
47264773
# we need to advance to first element within chunk that intersects with slice, not
47274774
# necessarily the first element of chunk
47284775
# i * csize = s.start + n*step + k, already added n+1 elements, k in [1, step]
4729-
np1 = (i * csize - s.start + s.step - 1) // s.step # gives (n + 1)
4730-
# can have n = -1 if s.start > i * csize, but never < -1 since have to intersect with chunk
4731-
pselection += (
4776+
if s.step > 0:
4777+
np1 = (i * csize - s.start + s.step - 1) // s.step # gives (n + 1)
4778+
# can have n = -1 if s.start > i * csize, but never < -1 since have to intersect with chunk
4779+
pselection += (
4780+
slice(
4781+
builtins.max(
4782+
s.start, s.start + np1 * s.step
4783+
), # start+(n+1)*step gives i*csize if k=step
4784+
builtins.min(csize * (i + 1), s.stop),
4785+
s.step,
4786+
),
4787+
)
4788+
else:
4789+
# (i + 1) * csize = s.start + n*step + k, already added n+1 elements, k in [step+1, 0]
4790+
np1 = ((i + 1) * csize - s.start + s.step) // s.step # gives (n + 1)
4791+
# can have n = -1 if s.start < (i + 1) * csize, but never < -1 since have to intersect with chunk
4792+
pselection += (
4793+
slice(
4794+
builtins.min(s.start, s.start + np1 * s.step), # start+n*step gives (i+1)*csize if k=0
4795+
builtins.max(csize * i - 1, s.stop), # want to include csize * i
4796+
s.step,
4797+
),
4798+
)
4799+
4800+
# selection relative to coordinates of out (necessarily out_step = +-1)
4801+
# when added n + 1 elements
4802+
# ps.start = pt.start + step * n + k => n = (ps.start - pt.start - sign) // step
4803+
# hence, out_start = n + 1 or shape(out) - 1 - (n + 1) if step < 0
4804+
# ps.stop = pt.start + step * out_stop + k
4805+
# => out_stop = (ps.stop - pt.start - sign) // step
4806+
out_pselection = ()
4807+
i = 0
4808+
for ps, pt in zip(pselection, ptuple, strict=True):
4809+
sign_ = pt.step // builtins.abs(pt.step)
4810+
n = (ps.start - pt.start - sign_) // pt.step
4811+
out_start = n + 1 if sign_ > 0 else shape[i] - (n + 1) - 1
4812+
# ps.stop always positive except for case where get full array (it is then -1 since desire 0th element)
4813+
out_stop = None if ps.stop == -1 else (ps.stop - pt.start - sign_) // pt.step
4814+
out_pselection += (
47324815
slice(
4733-
builtins.max(s.start, s.start + np1 * s.step), # start+(n+1)*step gives i*csize if k=step
4734-
builtins.min(csize * (i + 1), s.stop),
4735-
s.step,
4816+
out_start,
4817+
out_stop,
4818+
sign_,
47364819
),
47374820
)
4738-
4739-
# selection relative to coordinates of out (necessarily step = 1)
4740-
# stop = start + step * n + k => n = (stop - start - 1) // step
4741-
# hence, out_stop = out_start + n + 1
4742-
# ps.start = pt.start + out_start * step
4743-
out_pselection = tuple(
4744-
slice(
4745-
(ps.start - pt.start + pt.step - 1) // pt.step,
4746-
(ps.start - pt.start + pt.step - 1) // pt.step + (ps.stop - ps.start + ps.step - 1) // ps.step,
4747-
1,
4748-
)
4749-
for ps, pt in zip(pselection, ptuple, strict=True)
4750-
)
4751-
4752-
loc_selection = tuple(slice(0, s.stop - s.start, s.step) for s in pselection)
4821+
i += 1
4822+
if load_full:
4823+
for i, s, csize in zip(ctuple, pselection, chunks, strict=True):
4824+
loc_selection = tuple(
4825+
slice(s.start - i * csize, s.stop - i * csize, s.step)
4826+
if s.step > 0
4827+
else slice(s.stop - i * csize + 1, s.start - i * csize + 1, -s.step)
4828+
for s in pselection
4829+
) # local coords of full chunk (note not reversed for negative steps!)
4830+
else:
4831+
loc_selection = tuple(
4832+
slice(0, s.stop - s.start, s.step) if s.step > 0 else slice(s.start - s.stop, None, s.step)
4833+
for s in pselection
4834+
) # local coords of loaded part of chunk
47534835

47544836
return out_pselection, pselection, loc_selection
47554837

tests/ndarray/test_ndarray.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def test_ndarray_cframe(contiguous, urlpath, cparams, dparams, nchunks, copy):
5757
((200, 10), 2),
5858
((200, 10, 10), 2),
5959
((200, 10, 10), 40),
60+
((200, 10, 10), -1),
61+
((200, 10, 10), -3),
6062
((200, 10, 10, 10), 9),
6163
],
6264
)
@@ -321,7 +323,7 @@ def test_fancy_index(c):
321323
chunks = (c,) * ndim if c is not None else None
322324
arr = blosc2.linspace(0, 100, num=np.prod(shape), shape=shape, dtype=dtype, chunks=chunks)
323325
rng = np.random.default_rng()
324-
idx = rng.integers(low=0, high=d, size=(100,))
326+
idx = rng.integers(low=-d, high=d, size=(100,)) # mix of +ve and -ve indices
325327

326328
row = idx
327329
col = rng.permutation(idx)
@@ -358,6 +360,16 @@ def test_fancy_index(c):
358360
b = arr[row[:, None], mask]
359361
n = nparr[row[:, None], mask]
360362
np.testing.assert_allclose(b, n)
363+
364+
# indices and negative slice steps
365+
# TODO: these currently fail
366+
# b = arr[row, d//2::-1]
367+
# n = nparr[row, d//2::-1]
368+
# np.testing.assert_allclose(b, n)
369+
# b = arr[row, d//2::-3]
370+
# n = nparr[row, d//2::-3]
371+
# np.testing.assert_allclose(b, n)
372+
361373
# Transposition test (3rd example is transposed)
362374
b1 = arr[:, [0, 1], 0]
363375
b2 = arr[[0, 1], 0, :]

0 commit comments

Comments
 (0)