Skip to content

Commit 84b975c

Browse files
committed
Last fixes for ndarray_object array_api tests
1 parent 9282a91 commit 84b975c

4 files changed

Lines changed: 97 additions & 20 deletions

File tree

src/blosc2/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def __array_namespace_info__() -> Info:
308308
are_partitions_aligned,
309309
are_partitions_behaved,
310310
arange,
311+
broadcast_to,
311312
linspace,
312313
eye,
313314
asarray,
@@ -482,7 +483,7 @@ def __array_namespace_info__() -> Info:
482483
"are_partitions_behaved",
483484
"asarray",
484485
"astypeclib_info",
485-
"compress",
486+
"broadcast_tocompress",
486487
"compress2",
487488
"compressor_list",
488489
"compute_chunks_blocks",

src/blosc2/lazyexpr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,8 +1506,11 @@ def slices_eval( # noqa: C901
15061506
if out is None:
15071507
shape_ = shape_slice if shape_slice is not None else shape
15081508
if where is not None and len(where) < 2:
1509-
# The result is a linear array
1510-
shape_ = math.prod(shape_)
1509+
# The result is a linear array in the first ndims
1510+
try: # call from NDArray.__getitem__
1511+
shape_ = (math.prod(operands["key"].shape),) + shape[operands["key"].ndim :]
1512+
except KeyError:
1513+
shape_ = math.prod(shape_)
15111514
if getitem or _order:
15121515
out = np.empty(shape_, dtype=dtype_)
15131516
if _order:
@@ -2899,7 +2902,6 @@ def compute(self, item=(), **kwargs) -> blosc2.NDArray:
28992902
kwargs["_indices"] = self._indices
29002903
if hasattr(self, "_order"):
29012904
kwargs["_order"] = self._order
2902-
# handle empty arrays
29032905
result = self._compute_expr(item, kwargs)
29042906
if "_order" in kwargs and "_indices" not in kwargs:
29052907
# We still need to apply the index in result

src/blosc2/ndarray.py

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,7 +1517,7 @@ def get_fselection_numpy(self, key: list | np.ndarray) -> np.ndarray:
15171517
flat_shape = tuple(
15181518
(i.stop - i.start - i.step // builtins.abs(i.step)) // i.step + 1 for i in prior_tuple
15191519
)
1520-
idx_dim = np.prod(_slice[begin].shape)
1520+
idx_dim = np.prod(_slice[begin].shape, dtype=np.int32)
15211521

15221522
# TODO: find a nicer way to do the copy maybe
15231523
arr = np.empty((idx_dim, end - begin), dtype=_slice[begin].dtype)
@@ -1718,15 +1718,28 @@ def __getitem__( # noqa: C901
17181718
expr = blosc2.LazyExpr._new_expr(key, self.fields, guess=False)
17191719
return expr.where(self)
17201720

1721-
key_ = (key,) if not isinstance(key, tuple) else key
1722-
key_ = tuple(k[()] if isinstance(k, NDArray) else k for k in key_) # decompress NDArrays
1723-
key_, mask = process_key(key_, self.shape) # internally handles key an integer
1721+
# key not iterable
1722+
key = key[()] if isinstance(key, NDArray) else key
1723+
key = tuple(k[()] if isinstance(k, NDArray) else k for k in key) if isinstance(key, tuple) else key
17241724

1725-
if builtins.any(isinstance(k, (list, np.ndarray)) for k in key_):
1726-
if hasattr(key, "dtype") and np.issubdtype(key.dtype, np.bool_): # check ORIGINAL key
1725+
# decompress NDArrays
1726+
key_, mask = process_key(key, self.shape) # internally handles key an integer
1727+
key = key[()] if hasattr(key, "shape") and key.shape == () else key # convert to scalar
1728+
# fancy indexing
1729+
if isinstance(key_, (list, np.ndarray)) or builtins.any(
1730+
isinstance(k, (list, np.ndarray)) for k in key_
1731+
):
1732+
# check scalar booleans, which add 1 dim to beginning but which cause problems for ndindex.as_subindex
1733+
if np.issubdtype(type(key), bool) and np.isscalar(key):
1734+
if key:
1735+
_slice = ndindex.ndindex(()).expand(self.shape) # just get whole array
1736+
out_shape = _slice.newshape(self.shape)
1737+
return np.expand_dims(self._get_set_findex_default(_slice, out_shape=out_shape), 0)
1738+
else: # do nothing
1739+
return np.empty((0,) + self.shape, dtype=self.dtype)
1740+
elif hasattr(key, "dtype") and np.issubdtype(key.dtype, np.bool_): # check ORIGINAL key
17271741
# This can be interpreted as a boolean expression
1728-
if key.shape != self.shape:
1729-
raise ValueError("The shape of the boolean expression should match the array shape")
1742+
# elif key.shape == self.shape: # This is faster than the fancy indexing path
17301743
# expr = blosc2.lazyexpr(f"(key)")
17311744
# The next should be a bit faster
17321745
expr = blosc2.LazyExpr._new_expr("key", {"key": key}, guess=False)
@@ -1766,7 +1779,7 @@ def __getitem__( # noqa: C901
17661779
inmutable_key = make_key_hashable(key)
17671780
self._last_read[inmutable_key] = nparr
17681781

1769-
return nparr
1782+
return nparr[()] # [()] does nothing except for 0-dim arrays, for which returns a scalar
17701783

17711784
def __setitem__( # noqa : C901
17721785
self,
@@ -1804,9 +1817,13 @@ def __setitem__( # noqa : C901
18041817
"""
18051818
blosc2_ext.check_access_mode(self.schunk.urlpath, self.schunk.mode)
18061819

1807-
key_ = (key,) if not isinstance(key, tuple) else key
1808-
key_ = tuple(k[()] if isinstance(k, NDArray) else k for k in key_) # decompress NDArrays
1809-
key_, mask = process_key(key_, self.shape) # internally handles key an integer
1820+
# key not iterable
1821+
key = key[()] if isinstance(key, NDArray) else key
1822+
key = tuple(k[()] if isinstance(k, NDArray) else k for k in key) if isinstance(key, tuple) else key
1823+
1824+
key_, mask = process_key(key, self.shape) # internally handles key an integer
1825+
if hasattr(value, "shape") and value.shape == ():
1826+
value = value[()]
18101827

18111828
def updater(sel_idx):
18121829
return value[sel_idx]
@@ -1820,11 +1837,17 @@ def updater(sel_idx):
18201837
_slice = ndindex.ndindex(key_).expand(
18211838
self.shape
18221839
) # handles negative indices -> positive internally
1840+
# check scalar booleans, which add 1 dim to beginning but which cause problems for ndindex.as_subindex
1841+
if (
1842+
key.shape == () and hasattr(key, "dtype") and np.issubdtype(key.dtype, np.bool_)
1843+
): # check ORIGINAL key after decompression
1844+
if key:
1845+
_slice = ndindex.ndindex(()).expand(self.shape) # just get whole array
1846+
else: # do nothing
1847+
return self
18231848
return self._get_set_findex_default(_slice, updater=updater)
18241849

18251850
start, stop, step, none_mask = get_ndarray_start_stop(self.ndim, key_, self.shape)
1826-
if isinstance(value, NDArray):
1827-
value = value[...] # convert to numpy
18281851

18291852
if step != (1,) * self.ndim: # handle non-unit or negative steps
18301853
if np.any(none_mask):
@@ -1848,12 +1871,14 @@ def updater(sel_idx):
18481871
chunk = np.empty(
18491872
tuple(sp - st for st, sp in zip(locstart, locstop, strict=True)), dtype=self.dtype
18501873
)
1851-
super().get_slice_numpy(chunk, (locstart, locstop)) # copy whole chunk
1874+
super().get_slice_numpy(chunk, (locstart, locstop)) # copy relevant slice of chunk
18521875
chunk[sub_idx] = updater(sel_idx) # update relevant parts of chunk
1853-
out = super().set_slice((locstart, locstop), chunk) # load updated chunk into array
1876+
out = super().set_slice((locstart, locstop), chunk) # load updated partial chunk into array
18541877
return out
18551878

18561879
shape = [sp - st for sp, st in zip(stop, start, strict=False)]
1880+
if isinstance(value, NDArray):
1881+
value = value[...] # convert to numpy
18571882
if np.isscalar(value):
18581883
value = np.full(shape, value, dtype=self.dtype)
18591884
elif isinstance(value, np.ndarray): # handles decompressed NDArray too
@@ -4174,6 +4199,30 @@ def astype(
41744199
copy: bool = True,
41754200
**kwargs: Any,
41764201
) -> NDArray:
4202+
"""
4203+
Copy of the array, cast to a specified type. Does not support copy = False.
4204+
4205+
Parameters
4206+
----------
4207+
array: Sequence | np.ndarray | NDArray | blosc2.C2Array
4208+
The array to be cast to a different type.
4209+
dtype: DType-like
4210+
The desired data type to cast to.
4211+
casting: str = 'unsafe'
4212+
Controls what kind of data casting may occur. Defaults to 'unsafe' for backwards compatibility.
4213+
* 'no' means the data types should not be cast at all.
4214+
* 'equiv' means only byte-order changes are allowed.
4215+
* 'safe' means only casts which can preserve values are allowed.
4216+
* 'same_kind' means only safe casts or casts within a kind, like float64 to float32, are allowed.
4217+
* 'unsafe' means any data conversions may be done.
4218+
copy: bool = True
4219+
Must always be True as copy is made by default. Will be changed in a future version
4220+
4221+
Returns
4222+
-------
4223+
out: NDArray
4224+
New array with specified data type.
4225+
"""
41774226
return asarray(array, dtype=dtype, casting=casting, copy=copy, **kwargs)
41784227

41794228

@@ -4884,3 +4933,24 @@ def _get_local_slice(prior_selection, post_selection, chunk_bounds):
48844933
dtype="int64",
48854934
)
48864935
return locbegin, locend
4936+
4937+
4938+
def broadcast_to(arr, shape):
4939+
"""
4940+
Broadcast an array to a new shape.
4941+
Warning: Computes a lazyexpr, so probably a bit suboptimal
4942+
4943+
Parameters
4944+
----------
4945+
array: NDArray
4946+
The array to broadcast.
4947+
4948+
shape: tuple
4949+
The shape of the desired array.
4950+
4951+
Returns
4952+
-------
4953+
broadcast: NDArray
4954+
A new array with the given shape.
4955+
"""
4956+
return (arr + blosc2.zeros(shape, dtype=arr.dtype)).compute() # return lazyexpr quickly

tests/ndarray/test_getitem.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ def bool_array(shape):
115115
((10, 10), (5, 5), (2, 2), bool_array((10, 10))),
116116
((8, 8, 8), (4, 4, 4), (2, 2, 2), bool_array((8, 8, 8))),
117117
((6, 5, 4, 3), (3, 2, 2, 1), (1, 1, 1, 1), bool_array((6, 5, 4, 3))),
118+
((6, 5, 4, 3), (3, 2, 2, 1), (1, 1, 1, 1), bool_array((6, 5))),
119+
((6, 5, 4, 3), (3, 2, 2, 1), (1, 1, 1, 1), bool_array((6, 0, 4))),
120+
((6, 5, 4, 3), (3, 2, 2, 1), (1, 1, 1, 1), True),
121+
((6, 5, 4, 3), (3, 2, 2, 1), (1, 1, 1, 1), False),
118122
],
119123
)
120124
def test_bool_values(shape, chunks, blocks, idx):

0 commit comments

Comments
 (0)