Skip to content

Commit b4a2add

Browse files
committed
Further fixes for setitem
1 parent 5e00ba5 commit b4a2add

4 files changed

Lines changed: 15 additions & 8 deletions

File tree

src/blosc2/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1467,7 +1467,7 @@ def compute_chunks_blocks( # noqa: C901
14671467

14681468
# Return an arbitrary value for chunks and blocks when shape has any 0 dim
14691469
if 0 in shape:
1470-
return (1,) * len(shape), (1,) * len(shape)
1470+
return shape, shape
14711471

14721472
if blocks:
14731473
if not isinstance(blocks, tuple | list):

src/blosc2/lazyexpr.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2882,7 +2882,7 @@ def sort(self, order: str | list[str] | None = None) -> blosc2.LazyArray:
28822882
lazy_expr._order = order
28832883
return lazy_expr
28842884

2885-
def compute(self, item=(), **kwargs) -> blosc2.NDArray:
2885+
def compute(self, item=(), **kwargs) -> blosc2.NDArray: # noqa : C901
28862886
# When NumPy ufuncs are called, the user may add an `out` parameter to kwargs
28872887
if "out" in kwargs:
28882888
kwargs["_output"] = kwargs.pop("out")
@@ -2899,7 +2899,15 @@ def compute(self, item=(), **kwargs) -> blosc2.NDArray:
28992899
kwargs["_indices"] = self._indices
29002900
if hasattr(self, "_order"):
29012901
kwargs["_order"] = self._order
2902-
result = self._compute_expr(item, kwargs)
2902+
# handle empty arrays
2903+
if 0 in self.shape:
2904+
result = (
2905+
np.empty(self.shape, dtype=self.dtype)
2906+
if "_getitem" in kwargs
2907+
else blosc2.empty(self.shape, dtype=self.dtype)
2908+
)
2909+
else:
2910+
result = self._compute_expr(item, kwargs)
29032911
if "_order" in kwargs and "_indices" not in kwargs:
29042912
# We still need to apply the index in result
29052913
x = self._where_args["_where_x"]

src/blosc2/ndarray.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,10 +1783,7 @@ def __setitem__( # noqa : C901
17831783
intersecting_chunks = [
17841784
slice_to_chunktuple(s, c) for s, c in zip(_slice, chunks, strict=True)
17851785
] # internally handles negative steps
1786-
intersecting_chunks = [
1787-
(0,) if i == () else i for i in intersecting_chunks
1788-
] # special case of dims with 0 length
1789-
if isinstance(value, int | float | bool): # overwrite updater function for simple cases (faster)
1786+
if np.isscalar(value): # overwrite updater function for simple cases (faster)
17901787

17911788
def updater(sel_idx):
17921789
return value
@@ -1795,6 +1792,7 @@ def updater(sel_idx):
17951792
def updater(sel_idx):
17961793
return value[sel_idx]
17971794

1795+
out = self # for when shape has 0 (i.e. arr is empty, as then skip loop)
17981796
for c in product(*intersecting_chunks):
17991797
sel_idx, glob_selection, sub_idx = _get_selection(c, _slice, chunks)
18001798
sel_idx = tuple(s for s, m in zip(sel_idx, mask, strict=True) if not m)
@@ -1813,7 +1811,7 @@ def updater(sel_idx):
18131811
return out
18141812

18151813
shape = [sp - st for sp, st in zip(stop, start, strict=False)]
1816-
if isinstance(value, int | float | bool):
1814+
if np.isscalar(value):
18171815
value = np.full(shape, value, dtype=self.dtype)
18181816
elif isinstance(value, np.ndarray): # handles decompressed NDArray too
18191817
if value.dtype != self.dtype:

tests/ndarray/test_setitem.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
([12, 13], [5, 5], [2, 2], (slice(11, 2, -1), slice(6, 2, -1)), np.float32),
2121
([25, 13, 22], [5, 5, 3], [2, 2, 1], (slice(17, 2, -3), 0, slice(6, 2, -1)), np.float32),
2222
([25, 13, 22], [5, 5, 3], [2, 2, 1], (np.s_[-5:-15:-1], np.s_[-3:-11:-2], slice(6, 2, -1)), np.float32),
23+
([0, 13, 22], [0, 5, 3], [0, 2, 1], (np.s_[:], np.s_[-5:-15:-1], slice(6, 2, -1)), np.float32),
2324
]
2425

2526

0 commit comments

Comments
 (0)