Skip to content

Commit 6c84cee

Browse files
committed
Fix vecdot bug, add vecdot tests, rename test_matmul file
1 parent 515d597 commit 6c84cee

2 files changed

Lines changed: 159 additions & 8 deletions

File tree

src/blosc2/ndarray.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,15 @@ def reshape(
367367

368368

369369
def _check_allowed_dtypes(
370-
value: bool | int | float | str | blosc2.NDArray | blosc2.NDField | blosc2.C2Array | blosc2.Proxy,
370+
value: bool
371+
| int
372+
| float
373+
| str
374+
| blosc2.NDArray
375+
| blosc2.NDField
376+
| blosc2.C2Array
377+
| blosc2.Proxy
378+
| blosc2.LazyExpr,
371379
):
372380
if not (
373381
isinstance(
@@ -4925,23 +4933,31 @@ def vecdot(x1: NDArray, x2: NDArray, axis: int = -1, **kwargs) -> NDArray:
49254933
slice(rc * rcs, builtins.min((rc + 1) * rcs, rshape), 1)
49264934
for rc, rcs, rshape in zip(rchunk, result.chunks, result.shape, strict=True)
49274935
)
4928-
rchunk_iter = iter(res_chunk)
4936+
# handle broadcasting - if x1, x2 different ndim, could have to prepend 1s
4937+
rchunk_iter = (
4938+
slice(0, 1, 1) if s == 1 else r
4939+
for r, s in zip(res_chunk[-x1.ndim + 1 :], x1shape[a_keep], strict=True)
4940+
)
49294941
a_selection = tuple(next(rchunk_iter) if a else slice(None, None, 1) for a in a_keep)
4930-
rchunk_iter = iter(res_chunk)
4942+
rchunk_iter = (
4943+
slice(0, 1, 1) if s == 1 else r
4944+
for r, s in zip(res_chunk[-x2.ndim + 1 :], x2shape[b_keep], strict=True)
4945+
)
49314946
b_selection = tuple(next(rchunk_iter) if b else slice(None, None, 1) for b in b_keep)
49324947

4933-
if fast_path: # just load everything
4948+
if fast_path: # just load everything, also handles case of 0 in shapes
49344949
bx1 = x1[a_selection]
49354950
bx2 = x2[b_selection]
49364951
result[res_chunk] += np.vecdot(bx1, bx2, axis=axis)
49374952
else: # operands too big, have to go chunk-by-chunk
49384953
for ochunk in range(0, a_shape_red, a_chunks_red):
4939-
op_chunk = slice(ochunk, builtins.min(ochunk + a_chunks_red, x1.shape[a_axes]), 1)
4940-
a_selection[a_axes] = op_chunk
4941-
b_selection[b_axes] = op_chunk
4954+
op_chunk = (slice(ochunk, builtins.min(ochunk + a_chunks_red, x1.shape[a_axes]), 1),)
4955+
a_selection = a_selection[:a_axes] + op_chunk + a_selection[a_axes + 1 :]
4956+
b_selection = b_selection[:b_axes] + op_chunk + b_selection[b_axes + 1 :]
49424957
bx1 = x1[a_selection]
49434958
bx2 = x2[b_selection]
4944-
result[res_chunk] += np.vecdot(bx1, bx2, axis=axis)
4959+
res = np.vecdot(bx1, bx2, axis=axis)
4960+
result[res_chunk] += res
49454961
return result
49464962

49474963

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,138 @@ def test_tensordot(shape1, chunk1, block1, shape2, chunk2, block2, chunkres, axe
354354
np.testing.assert_allclose(res_b2_np, res_np, rtol=1e-5, atol=1e-6)
355355
else:
356356
np.testing.assert_array_equal(res_b2_np, res_np)
357+
358+
359+
@pytest.mark.parametrize(
360+
("shape1", "chunk1", "block1", "shape2", "chunk2", "block2", "chunkres", "axis"),
361+
[
362+
# 1Dx1D->scalar
363+
((50,), (17,), (5,), (50,), (13,), (5,), (), -1),
364+
# 2Dx2D
365+
(
366+
(30, 40),
367+
(17, 21),
368+
(8, 10),
369+
(30, 40),
370+
(19, 20),
371+
(9, 10),
372+
(10,),
373+
-1,
374+
),
375+
# 3Dx3D
376+
(
377+
(10, 1, 5),
378+
(9, 1, 1),
379+
(5, 1, 1),
380+
(10, 1, 1),
381+
(4, 1, 1),
382+
(3, 1, 1),
383+
(3, 3),
384+
-2,
385+
),
386+
# 4Dx3D
387+
(
388+
(6, 7, 8, 9),
389+
(5, 6, 7, 8),
390+
(3, 3, 3, 3),
391+
(1, 7, 8, 1),
392+
(1, 7, 3, 1),
393+
(1, 3, 2, 1),
394+
(4, 5, 2),
395+
-2,
396+
),
397+
# 2Dx1D->broadcastable to (12, 7)
398+
(
399+
(12, 7),
400+
(11, 7),
401+
(5, 7),
402+
(7,),
403+
(5,),
404+
(2,),
405+
(5,),
406+
-1,
407+
),
408+
# 3Dx2D->broadcastable to (1, 6, 7)
409+
(
410+
(5, 6, 7),
411+
(4, 5, 6),
412+
(2, 3, 3),
413+
(6, 7),
414+
(6, 4),
415+
(3, 4),
416+
(3, 2),
417+
-2,
418+
),
419+
# 1Dx3D -> broadcastable to (1, 1, 20)
420+
((20,), (9,), (4,), (20, 4, 20), (19, 3, 5), (10, 2, 5), (10, 2), -1),
421+
# 4Dx4D
422+
(
423+
(5, 8, 1, 8),
424+
(4, 5, 1, 7),
425+
(2, 3, 1, 4),
426+
(1, 8, 6, 8),
427+
(1, 7, 5, 5),
428+
(1, 4, 3, 5),
429+
(2, 2, 2),
430+
-3,
431+
),
432+
# 5Dx5D
433+
(
434+
(3, 4, 5, 6, 7),
435+
(2, 3, 4, 5, 6),
436+
(1, 2, 2, 3, 3),
437+
(3, 1, 1, 6, 7),
438+
(2, 1, 1, 3, 5),
439+
(2, 1, 1, 2, 4),
440+
(2, 2, 2, 5),
441+
-2,
442+
),
443+
],
444+
)
445+
@pytest.mark.parametrize(
446+
"dtype",
447+
[
448+
np.int32,
449+
np.int64,
450+
np.float32,
451+
np.float64,
452+
],
453+
)
454+
def test_vecdot(shape1, chunk1, block1, shape2, chunk2, block2, chunkres, axis, dtype):
455+
# Create operands with requested dtype
456+
a_b2 = blosc2.arange(0, np.prod(shape1), shape=shape1, chunks=chunk1, blocks=block1, dtype=dtype)
457+
a_np = a_b2[()] # decompress
458+
b_b2 = blosc2.arange(0, np.prod(shape2), shape=shape2, chunks=chunk2, blocks=block2, dtype=dtype)
459+
b_np = b_b2[()] # decompress
460+
461+
# NumPy reference and Blosc2 comparison
462+
np_raised = None
463+
try:
464+
res_np = np.vecdot(a_np, b_np, axis=axis)
465+
except Exception as e:
466+
np_raised = type(e)
467+
468+
if np_raised is not None:
469+
# Expect Blosc2 to raise the same type
470+
with pytest.raises(np_raised):
471+
blosc2.vecdot(a_b2, b_b2, axis=axis, chunks=chunkres)
472+
else:
473+
# Both should succeed
474+
res_np = np.vecdot(a_np, b_np, axis=axis)
475+
res_b2 = blosc2.vecdot(a_b2, b_b2, axis=axis, chunks=chunkres, fast_path=False) # test slow path
476+
res_b2_np = res_b2[...]
477+
478+
# Assertions
479+
assert res_b2_np.shape == res_np.shape
480+
if np.issubdtype(dtype, np.floating):
481+
np.testing.assert_allclose(res_b2_np, res_np, rtol=1e-5, atol=1e-6)
482+
else:
483+
np.testing.assert_array_equal(res_b2_np, res_np)
484+
485+
res_b2 = blosc2.vecdot(a_b2, b_b2, axis=axis, chunks=chunkres, fast_path=True) # test fast path
486+
# Assertions
487+
assert res_b2_np.shape == res_np.shape
488+
if np.issubdtype(dtype, np.floating):
489+
np.testing.assert_allclose(res_b2_np, res_np, rtol=1e-5, atol=1e-6)
490+
else:
491+
np.testing.assert_array_equal(res_b2_np, res_np)

0 commit comments

Comments
 (0)