Skip to content

Commit 78f8f54

Browse files
committed
Add matmul support for N-D arrays, modify squeeze to comply with array-api
1 parent 607ba6d commit 78f8f54

4 files changed

Lines changed: 123 additions & 64 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
import blosc2
4141
from blosc2 import compute_chunks_blocks
4242
from blosc2.info import InfoReporter
43-
from blosc2.ndarray import _check_allowed_dtypes, get_chunks_idx, is_inside_new_expr, process_key
43+
from blosc2.ndarray import (
44+
_check_allowed_dtypes,
45+
get_chunks_idx,
46+
get_intersecting_chunks,
47+
is_inside_new_expr,
48+
process_key,
49+
)
4450

4551
if not blosc2.IS_WASM:
4652
import numexpr
@@ -1575,7 +1581,7 @@ def slices_eval( # noqa: C901
15751581
elif isinstance(out, blosc2.NDArray):
15761582
# It *seems* better to choose an automatic chunks and blocks for the output array
15771583
# out = out.slice(_slice, chunks=out.chunks, blocks=out.blocks)
1578-
out = out.squeeze(mask_slice)
1584+
out = out.squeeze(np.where(mask_slice)[0])
15791585
else:
15801586
raise ValueError("The output array is not a NumPy array or a NDArray")
15811587

@@ -3629,16 +3635,6 @@ def evaluate(
36293635
return lexpr[()]
36303636

36313637

3632-
def get_intersecting_chunks(_slice, shape, chunks):
3633-
if 0 not in chunks:
3634-
chunk_size = ndindex.ChunkSize(chunks)
3635-
return chunk_size.as_subchunks(_slice, shape) # if _slice is (), returns all chunks
3636-
else:
3637-
return (
3638-
ndindex.ndindex(...).expand(shape),
3639-
) # chunk is whole array so just return full tuple to do loop once
3640-
3641-
36423638
if __name__ == "__main__":
36433639
from time import time
36443640

src/blosc2/ndarray.py

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2288,7 +2288,7 @@ def slice(self, key: int | slice | Sequence[slice], **kwargs: Any) -> NDArray:
22882288
for order, nchunk in enumerate(aligned_chunks):
22892289
chunk = self.schunk.get_chunk(nchunk)
22902290
newarr.schunk.update_chunk(order, chunk)
2291-
newarr.squeeze(mask=mask) # remove any dummy dims introduced
2291+
newarr.squeeze(axis=np.where(mask)[0]) # remove any dummy dims introduced
22922292
return newarr
22932293

22942294
key = (start, stop)
@@ -2307,11 +2307,11 @@ def slice(self, key: int | slice | Sequence[slice], **kwargs: Any) -> NDArray:
23072307

23082308
return ndslice
23092309

2310-
def squeeze(self, mask=None) -> NDArray:
2310+
def squeeze(self, axis=None) -> NDArray:
23112311
"""Remove single-dimensional entries from the shape of the array.
23122312
23132313
This method modifies the array in-place. If mask is None removes any dimensions with size 1.
2314-
If mask is provided, it should be a boolean array of the same shape as the array, and the corresponding
2314+
If axis is provided, it should be an int or tuple of ints and the corresponding
23152315
dimensions (of size 1) will be removed.
23162316
23172317
Returns
@@ -2331,7 +2331,18 @@ def squeeze(self, mask=None) -> NDArray:
23312331
>>> a.shape
23322332
(23, 11)
23332333
"""
2334-
super().squeeze(mask=mask)
2334+
if axis is None:
2335+
super().squeeze()
2336+
else:
2337+
axis = [axis] if isinstance(axis, int) else axis
2338+
mask = [False for i in range(self.ndim)]
2339+
for a in axis:
2340+
if a < 0:
2341+
a += self.ndim # Adjust axis to be within the array's dimensions
2342+
if mask[a]:
2343+
raise ValueError("Axis values must be unique.")
2344+
mask[a] = True
2345+
super().squeeze(mask=mask)
23352346
return self
23362347

23372348
def indices(self, order: str | list[str] | None = None, **kwargs: Any) -> NDArray:
@@ -4312,9 +4323,8 @@ def asarray(
43124323
else:
43134324
if not isinstance(array, NDArray):
43144325
raise ValueError("Must always do a copy for asarray unless NDArray provided.")
4315-
mask = [True] + [False for i in range(array.ndim)]
43164326
# TODO: make a direct view possible
4317-
return blosc2.expand_dims(array, axis=0).squeeze(mask) # way to get a view
4327+
return blosc2.expand_dims(array, axis=0).squeeze(axis=0) # way to get a view
43184328

43194329
return ndarr
43204330

@@ -4515,15 +4525,15 @@ def sort(array: NDArray, order: str | list[str] | None = None, **kwargs: Any) ->
45154525
return larr.sort(order).compute(**kwargs)
45164526

45174527

4518-
def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
4528+
def matmul(x1: NDArray | np.ndarray, x2: NDArray, **kwargs: Any) -> NDArray | np.ndarray: # noqa : C901
45194529
"""
45204530
Computes the matrix product between two Blosc2 NDArrays.
45214531
45224532
Parameters
45234533
----------
4524-
x1: :ref:`NDArray`
4534+
x1: :ref:`NDArray` | np.ndarray
45254535
The first input array.
4526-
x2: :ref:`NDArray`
4536+
x2: :ref:`NDArray` | np.ndarray
45274537
The second input array.
45284538
kwargs: Any, optional
45294539
Keyword arguments that are supported by the :func:`empty` constructor.
@@ -4575,51 +4585,70 @@ def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
45754585
array([1, 5])
45764586
45774587
"""
4588+
# Added this to pass array-api tests (which use internal getitem to check results)
4589+
if isinstance(x1, np.ndarray) and isinstance(x2, np.ndarray):
4590+
return np.matmul(x1, x2)
45784591

45794592
# Validate arguments are not scalars
45804593
if np.isscalar(x1) or np.isscalar(x2):
45814594
raise ValueError("Arguments can't be scalars.")
45824595

4583-
# Validate arguments are dimension 1 or 2
4584-
if x1.ndim > 2 or x2.ndim > 2:
4585-
raise ValueError("Multiplication of arrays with dimension greater than 2 is not supported yet.")
4596+
# Validate matrix multiplication compatibility
4597+
if x1.shape[-1] != x2.shape[builtins.max(-2, -len(x2.shape))]:
4598+
raise ValueError("Shapes are not aligned for matrix multiplication.")
45864599

45874600
# Promote 1D arrays to 2D if necessary
45884601
x1_is_vector = False
45894602
x2_is_vector = False
45904603
if x1.ndim == 1:
4591-
x1 = x1.reshape((1, x1.shape[0])) # (N,) -> (1, N)
4604+
x1 = blosc2.expand_dims(x1, axis=0) # (N,) -> (1, N)
45924605
x1_is_vector = True
45934606
if x2.ndim == 1:
4594-
x2 = x2.reshape((x2.shape[0], 1)) # (M,) -> (M, 1)
4607+
x2 = blosc2.expand_dims(x2, axis=1) # (M,) -> (M, 1)
45954608
x2_is_vector = True
45964609

4597-
# Validate matrix multiplication compatibility
4598-
if x1.shape[-1] != x2.shape[-2]:
4599-
raise ValueError("Shapes are not aligned for matrix multiplication.")
4600-
46014610
n, k = x1.shape[-2:]
46024611
m = x2.shape[-1]
4612+
result_shape = np.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) + (n, m)
4613+
result = blosc2.zeros(result_shape, dtype=np.result_type(x1, x2), **kwargs)
46034614

4604-
result = blosc2.zeros((n, m), dtype=np.result_type(x1, x2), **kwargs)
4615+
if 0 in result.shape + x1.shape + x2.shape: # if any array is empty, return array of 0s
4616+
if x1_is_vector:
4617+
result.squeeze(axis=-2)
4618+
if x2_is_vector:
4619+
result.squeeze(axis=-1)
4620+
return result
46054621

46064622
p, q = result.chunks[-2:]
46074623
r = x2.chunks[-1]
46084624

4609-
for row in range(0, n, p):
4610-
row_end = builtins.min(row + p, n)
4611-
for col in range(0, m, q):
4612-
col_end = builtins.min(col + q, m)
4613-
for aux in range(0, k, r):
4614-
aux_end = builtins.min(aux + r, k)
4615-
bx1 = x1[row:row_end, aux:aux_end]
4616-
bx2 = x2[aux:aux_end, col:col_end]
4617-
result[row:row_end, col:col_end] += np.matmul(bx1, bx2)
4625+
intersecting_chunks = get_intersecting_chunks((), result.shape[:-2], result.chunks[:-2])
4626+
for chunk in intersecting_chunks:
4627+
chunk = chunk.raw
4628+
for row in range(0, n, p):
4629+
row_end = builtins.min(row + p, n)
4630+
for col in range(0, m, q):
4631+
col_end = builtins.min(col + q, m)
4632+
for aux in range(0, k, r):
4633+
aux_end = builtins.min(aux + r, k)
4634+
bx1 = (
4635+
x1[chunk[-x1.ndim + 2 :] + (slice(row, row_end), slice(aux, aux_end))]
4636+
if x1.ndim > 2
4637+
else x1[row:row_end, aux:aux_end]
4638+
)
4639+
bx2 = (
4640+
x2[chunk[-x2.ndim + 2 :] + (slice(aux, aux_end), slice(col, col_end))]
4641+
if x2.ndim > 2
4642+
else x2[aux:aux_end, col:col_end]
4643+
)
4644+
result[chunk + (slice(row, row_end), slice(col, col_end))] += np.matmul(bx1, bx2)
46184645

4619-
if x1_is_vector and x2_is_vector:
4620-
return result[0][0]
4646+
if x1_is_vector:
4647+
result.squeeze(axis=-2)
4648+
if x2_is_vector:
4649+
result.squeeze(axis=-1)
46214650

4622-
return result.squeeze()
4651+
return result
46234652

46244653

46254654
def permute_dims(
@@ -5178,6 +5207,16 @@ def _get_local_slice(prior_selection, post_selection, chunk_bounds):
51785207
return locbegin, locend
51795208

51805209

5210+
def get_intersecting_chunks(_slice, shape, chunks):
5211+
if 0 not in chunks:
5212+
chunk_size = ndindex.ChunkSize(chunks)
5213+
return chunk_size.as_subchunks(_slice, shape) # if _slice is (), returns all chunks
5214+
else:
5215+
return (
5216+
ndindex.ndindex(...).expand(shape),
5217+
) # chunk is whole array so just return full tuple to do loop once
5218+
5219+
51815220
def broadcast_to(arr, shape):
51825221
"""
51835222
Broadcast an array to a new shape.

tests/ndarray/test_matmul.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
{
1010
((12, 10), (7, 5), (3, 3)),
1111
((10,), (9,), (7,)),
12+
((0,), (0,), (0,)),
13+
((40, 10, 10), (2, 3, 4), (1, 2, 2)),
1214
},
1315
)
1416
@pytest.mark.parametrize(
@@ -17,6 +19,9 @@
1719
((10,), (4,), (2,)),
1820
((10, 5), (3, 4), (1, 3)),
1921
((10, 12), (2, 4), (1, 2)),
22+
((200, 10, 22), (23, 2, 4), (4, 1, 2)),
23+
((0,), (0,), (0,)),
24+
((20, 40, 10, 10), (5, 2, 3, 4), (2, 1, 2, 2)),
2025
},
2126
)
2227
@pytest.mark.parametrize(
@@ -26,13 +31,21 @@
2631
def test_matmul(ashape, achunks, ablocks, bshape, bchunks, bblocks, dtype):
2732
a = blosc2.linspace(0, 1, dtype=dtype, shape=ashape, chunks=achunks, blocks=ablocks)
2833
b = blosc2.linspace(0, 1, dtype=dtype, shape=bshape, chunks=bchunks, blocks=bblocks)
29-
c = blosc2.matmul(a, b)
30-
31-
na = a[:]
32-
nb = b[:]
33-
nc = np.matmul(na, nb)
34-
35-
np.testing.assert_allclose(c, nc, rtol=1e-6)
34+
a_np = a[:]
35+
b_np = b[:]
36+
try:
37+
np_res = np.matmul(a_np, b_np)
38+
np_error = None
39+
except ValueError as e:
40+
np_res = None
41+
np_error = e
42+
43+
if np_error is not None:
44+
with pytest.raises(type(np_error)):
45+
blosc2.matmul(a, b)
46+
else:
47+
b2_res = blosc2.matmul(a, b)
48+
np.testing.assert_allclose(b2_res[()], np_res, rtol=1e-6)
3649

3750

3851
@pytest.mark.parametrize(
@@ -147,12 +160,22 @@ def test_scalars(scalar):
147160
def test_dims(ashape, bshape):
148161
a = blosc2.linspace(0, 10, shape=ashape)
149162
b = blosc2.linspace(0, 1, shape=bshape)
150-
151-
with pytest.raises(ValueError):
152-
blosc2.matmul(a, b)
153-
154-
with pytest.raises(ValueError):
155-
blosc2.matmul(b, a)
163+
a_np = a[:]
164+
b_np = b[:]
165+
166+
try:
167+
np_res = np.matmul(a_np, b_np)
168+
np_error = None
169+
except ValueError as e:
170+
np_res = None
171+
np_error = e
172+
173+
if np_error is not None:
174+
with pytest.raises(type(np_error)):
175+
blosc2.matmul(a, b)
176+
else:
177+
b2_res = blosc2.matmul(a, b)
178+
np.testing.assert_allclose(b2_res[:], np_res)
156179

157180

158181
@pytest.mark.parametrize(

tests/ndarray/test_squeeze.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,20 @@
1313

1414

1515
@pytest.mark.parametrize(
16-
("shape", "chunks", "blocks", "fill_value", "mask"),
16+
("shape", "chunks", "blocks", "fill_value", "axis"),
1717
[
18-
((1, 1230), (1, 100), (1, 3), b"0123", [True, False]),
19-
((23, 1, 1, 34), (20, 1, 1, 20), None, 1234, [False, False, True, False]),
20-
((80, 1, 51, 60, 1), None, (6, 1, 6, 26, 1), 3.333, [False] * 4 + [True]),
21-
((1, 1, 1), None, None, True, [False, True, True]),
18+
((1, 1230), (1, 100), (1, 3), b"0123", 0),
19+
((23, 1, 1, 34), (20, 1, 1, 20), None, 1234, 2),
20+
((80, 1, 51, 60, 1), None, (6, 1, 6, 26, 1), 3.333, 4),
21+
((1, 1, 1), None, None, True, (1, 2)),
22+
((1, 1, 1), None, None, True, None),
2223
],
2324
)
24-
def test_squeeze(shape, chunks, blocks, fill_value, mask):
25+
def test_squeeze(shape, chunks, blocks, fill_value, axis):
2526
a = blosc2.full(shape, fill_value=fill_value, chunks=chunks, blocks=blocks)
2627

27-
b = np.squeeze(a[...], tuple(i for i, m in enumerate(mask) if m))
28-
a_ = a.squeeze(mask)
28+
b = np.squeeze(a[...], axis)
29+
a_ = a.squeeze(axis)
2930

3031
assert a_.shape == b.shape
3132
# TODO: this would work if squeeze returns a view

0 commit comments

Comments
 (0)