Skip to content

Commit 7cd11d6

Browse files
committed
Fix bug in tensordot and optimise slightly
1 parent da27d58 commit 7cd11d6

2 files changed

Lines changed: 127 additions & 44 deletions

File tree

src/blosc2/ndarray.py

Lines changed: 95 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4651,6 +4651,7 @@ def matmul(x1: NDArray | np.ndarray, x2: NDArray, **kwargs: Any) -> NDArray | np
46514651
return result
46524652

46534653

4654+
# @profile
46544655
def tensordot(
46554656
x1: NDArray, x2: NDArray, axes: int | tuple[Sequence[int], Sequence[int]] = 2, **kwargs: Any
46564657
) -> NDArray:
@@ -4688,74 +4689,133 @@ def tensordot(
46884689
out: NDArray
46894690
An array containing the tensor contraction whose shape consists of the non-contracted axes (dimensions) of the first array x1, followed by the non-contracted axes (dimensions) of the second array x2. The returned array must have a data type determined by Type Promotion Rules.
46904691
"""
4692+
fast_path = kwargs.pop("fast_path", None) # for testing purposes
4693+
46914694
# Added this to pass array-api tests (which use internal getitem to check results)
46924695
if isinstance(x1, np.ndarray) and isinstance(x2, np.ndarray):
46934696
return np.tensordot(x1, x2, axes=axes)
46944697

4698+
x1, x2 = blosc2.asarray(x1), blosc2.asarray(x2)
4699+
46954700
if isinstance(axes, tuple):
46964701
a_axes, b_axes = axes
4702+
a_axes = list(a_axes)
4703+
b_axes = list(b_axes)
46974704
if len(a_axes) != len(b_axes):
46984705
raise ValueError("Lengths of reduction axes for x1 and x2 must be equal!")
4699-
order = np.argsort(a_axes)
4700-
a_red_axes = [(i - x1.ndim in a_axes) or (i in a_axes) for i in range(x1.ndim)]
4701-
b_red_axes = [(i - x2.ndim in b_axes) or (i in b_axes) for i in range(x2.ndim)]
4706+
# need to track order of b_axes; later we cycle through a_axes sorted for op_chunk
4707+
# a_sorted[inv_sort][b_sort] matches b_sorted since b_axes matches a_axes
4708+
inv_sort = np.argsort(np.argsort(a_axes))
4709+
b_sort = np.argsort(b_axes)
4710+
order = inv_sort[b_sort]
4711+
a_keep, b_keep = [True] * x1.ndim, [True] * x2.ndim
4712+
for i, j in zip(a_axes, b_axes, strict=False):
4713+
i = x1.ndim + i if i < 0 else i
4714+
j = x2.ndim + j if j < 0 else j
4715+
a_keep[i] = False
4716+
b_keep[j] = False
4717+
a_axes = [] if a_axes == () else a_axes # handle no reduction
4718+
b_axes = [] if b_axes == () else b_axes # handle no reduction
47024719
elif isinstance(axes, int):
47034720
if axes < 0:
47044721
raise ValueError("Integer axes argument must be nonnegative!")
4705-
order = np.arange(axes, dtype=int)
4706-
a_red_axes = [i + axes >= x1.ndim for i in range(x1.ndim)]
4707-
b_red_axes = [i < axes for i in range(x2.ndim)]
4722+
order = np.arange(axes, dtype=int) # no reordering required
4723+
a_axes = list(range(x1.ndim - axes, x1.ndim))
4724+
b_axes = list(range(0, axes))
4725+
a_keep = [i + axes < x1.ndim for i in range(x1.ndim)]
4726+
b_keep = [i >= axes for i in range(x2.ndim)]
47084727
else:
47094728
raise ValueError("Axes argument must be two element tuple of sequences or an integer.")
47104729
x1shape = np.array(x1.shape)
47114730
x2shape = np.array(x2.shape)
4712-
a_chunks_red = tuple(c for i, c in enumerate(x1.chunks) if a_red_axes[i])
4713-
if np.any(x1shape[a_red_axes] != x2shape[b_red_axes][order]):
4731+
a_chunks_red = tuple(c for i, c in enumerate(x1.chunks) if not a_keep[i])
4732+
a_shape_red = tuple(c for i, c in enumerate(x1.shape) if not a_keep[i])
4733+
4734+
if np.any(x1shape[a_axes] != x2shape[b_axes]):
47144735
raise ValueError("x1 and x2 must have same shapes along reduction dimensions")
47154736

4716-
a_axes = [not i for i in a_red_axes]
4717-
b_axes = [not i for i in b_red_axes]
4718-
result_shape = tuple(x1shape[a_axes]) + tuple(x2shape[b_axes])
4737+
result_shape = tuple(x1shape[a_keep]) + tuple(x2shape[b_keep])
47194738
result = blosc2.zeros(result_shape, dtype=np.result_type(x1, x2), **kwargs)
47204739

47214740
op_chunks = [
4722-
slice_to_chunktuple(slice(0, s, 1), c)
4723-
for s, c in zip(x1shape[a_red_axes], a_chunks_red, strict=True)
4741+
slice_to_chunktuple(slice(0, s, 1), c) for s, c in zip(x1shape[a_axes], a_chunks_red, strict=True)
47244742
]
47254743
res_chunks = [
47264744
slice_to_chunktuple(s, c)
47274745
for s, c in zip([slice(0, r, 1) for r in result.shape], result.chunks, strict=True)
47284746
]
47294747
a_selection = (slice(None, None, 1),) * x1.ndim
47304748
b_selection = (slice(None, None, 1),) * x2.ndim
4749+
4750+
chunk_memory = np.prod(result.chunks) * (
4751+
np.prod(x1shape[a_axes]) * x1.dtype.itemsize + np.prod(x2shape[b_axes]) * x2.dtype.itemsize
4752+
)
4753+
if chunk_memory < blosc2.MAX_FAST_PATH_SIZE:
4754+
fast_path = True if fast_path is None else fast_path
4755+
fast_path = False if fast_path is None else fast_path # fast_path set via kwargs for testing
4756+
4757+
# adapted from numpy.tensordot
4758+
a_keep_axes = [i for i, k in enumerate(a_keep) if k]
4759+
b_keep_axes = [i for i, k in enumerate(b_keep) if k]
4760+
newaxes_a = a_keep_axes + a_axes
4761+
newaxes_b = b_axes + b_keep_axes
4762+
47314763
for rchunk in product(*res_chunks):
47324764
res_chunk = tuple(
4733-
slice(rc * rcs, (rc + 1) * rcs, 1) for rc, rcs in zip(rchunk, result.chunks, strict=True)
4765+
slice(rc * rcs, builtins.min((rc + 1) * rcs, rshape), 1)
4766+
for rc, rcs, rshape in zip(rchunk, result.chunks, result.shape, strict=True)
47344767
)
47354768
rchunk_iter = iter(res_chunk)
4736-
a_selection = tuple(
4737-
next(rchunk_iter) if a else as_ for as_, a in zip(a_selection, a_axes, strict=True)
4738-
)
4739-
b_selection = tuple(
4740-
next(rchunk_iter) if b else bs_ for bs_, b in zip(b_selection, b_axes, strict=True)
4741-
)
4742-
for ochunk in product(*op_chunks):
4743-
op_chunk = tuple(
4744-
slice(rc * rcs, (rc + 1) * rcs, 1) for rc, rcs in zip(ochunk, a_chunks_red, strict=True)
4745-
) # use x1 chunk shape to iterate over reduction axes
4746-
ochunk_iter = iter(op_chunk)
4747-
a_selection = tuple(
4748-
next(ochunk_iter) if not a else as_ for as_, a in zip(a_selection, a_axes, strict=True)
4749-
)
4750-
# have to permute to match order of a_axes
4751-
order_iter = iter(order)
4752-
b_selection = tuple(
4753-
op_chunk[next(order_iter)] if not b else bs_
4754-
for bs_, b in zip(b_selection, b_axes, strict=True)
4755-
)
4769+
a_selection = tuple(next(rchunk_iter) if a else slice(None, None, 1) for a in a_keep)
4770+
b_selection = tuple(next(rchunk_iter) if b else slice(None, None, 1) for b in b_keep)
4771+
res_chunks = tuple(s.stop - s.start for s in res_chunk)
4772+
4773+
if fast_path: # just load everything
47564774
bx1 = x1[a_selection]
47574775
bx2 = x2[b_selection]
4758-
result[res_chunk] += np.tensordot(bx1, bx2, axes=axes)
4776+
newshape_a = (
4777+
math.prod([bx1.shape[i] for i in a_keep_axes]),
4778+
math.prod([bx1.shape[a] for a in a_axes]),
4779+
)
4780+
newshape_b = (
4781+
math.prod([bx2.shape[b] for b in b_axes]),
4782+
math.prod([bx2.shape[i] for i in b_keep_axes]),
4783+
)
4784+
at = bx1.transpose(newaxes_a).reshape(newshape_a)
4785+
bt = bx2.transpose(newaxes_b).reshape(newshape_b)
4786+
res = np.dot(at, bt)
4787+
result[res_chunk] += res.reshape(res_chunks)
4788+
else: # operands too big, have to go chunk-by-chunk
4789+
for ochunk in product(*op_chunks):
4790+
op_chunk = tuple(
4791+
slice(rc * rcs, builtins.min((rc + 1) * rcs, x1s), 1)
4792+
for rc, rcs, x1s in zip(ochunk, a_chunks_red, a_shape_red, strict=True)
4793+
) # use x1 chunk shape to iterate over reduction axes
4794+
ochunk_iter = iter(op_chunk)
4795+
a_selection = tuple(
4796+
next(ochunk_iter) if not a else as_ for as_, a in zip(a_selection, a_keep, strict=True)
4797+
)
4798+
# have to permute to match order of a_axes
4799+
order_iter = iter(order)
4800+
b_selection = tuple(
4801+
op_chunk[next(order_iter)] if not b else bs_
4802+
for bs_, b in zip(b_selection, b_keep, strict=True)
4803+
)
4804+
bx1 = x1[a_selection]
4805+
bx2 = x2[b_selection]
4806+
# adapted from numpy tensordot
4807+
newshape_a = (
4808+
math.prod([bx1.shape[i] for i in a_keep_axes]),
4809+
math.prod([bx1.shape[a] for a in a_axes]),
4810+
)
4811+
newshape_b = (
4812+
math.prod([bx2.shape[b] for b in b_axes]),
4813+
math.prod([bx2.shape[i] for i in b_keep_axes]),
4814+
)
4815+
at = bx1.transpose(newaxes_a).reshape(newshape_a)
4816+
bt = bx2.transpose(newaxes_b).reshape(newshape_b)
4817+
res = np.dot(at, bt)
4818+
result[res_chunk] += res.reshape(res_chunks)
47594819
return result
47604820

47614821

tests/ndarray/test_matmul.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,10 @@ def test_disk():
222222

223223

224224
@pytest.mark.parametrize(
225-
("shape1", "chunk1", "block1", "shape2", "chunk2", "block2", "axes"),
225+
("shape1", "chunk1", "block1", "shape2", "chunk2", "block2", "chunkres", "axes"),
226226
[
227227
# 1Dx1D->scalar (uneven chunks)
228-
((50,), (17,), (5,), (50,), (13,), (5,), 1),
228+
((50,), (17,), (5,), (50,), (13,), (5,), (), 1),
229229
# 2Dx2D->matrix multiplication
230230
(
231231
(30, 40),
@@ -234,6 +234,7 @@ def test_disk():
234234
(40, 20),
235235
(19, 20),
236236
(9, 10),
237+
(10, 5),
237238
([1], [0]),
238239
),
239240
# 3Dx3D->contraction along last/first
@@ -244,10 +245,20 @@ def test_disk():
244245
(30, 15, 5),
245246
(16, 15, 5),
246247
(8, 15, 5),
248+
(7, 6, 3, 1),
247249
([2], [0]),
248250
),
249251
# 4Dx3D->contraction along two axes
250-
((6, 7, 8, 9), (5, 6, 7, 8), (3, 3, 3, 3), (8, 9, 5), (7, 9, 5), (3, 5, 5), ([2, 3], [0, 1])),
252+
(
253+
(6, 7, 8, 9),
254+
(5, 6, 7, 8),
255+
(3, 3, 3, 3),
256+
(8, 9, 5),
257+
(7, 9, 5),
258+
(3, 5, 5),
259+
(4, 5, 2),
260+
([2, 3], [0, 1]),
261+
),
251262
# 2Dx1D->matrix-vector multiplication
252263
(
253264
(12, 7),
@@ -256,6 +267,7 @@ def test_disk():
256267
(7,),
257268
(5,),
258269
(5,),
270+
(5,),
259271
([1], [0]),
260272
),
261273
# 3Dx2D->like batched matmul
@@ -266,10 +278,11 @@ def test_disk():
266278
(7, 4),
267279
(6, 4),
268280
(3, 4),
281+
(2, 5, 3),
269282
([2], [0]),
270283
),
271284
# 1Dx3D->tensor contraction
272-
((20,), (9,), (4,), (20, 4, 5), (19, 3, 5), (10, 2, 5), ([0], [0])),
285+
((20,), (9,), (4,), (20, 4, 5), (19, 3, 5), (10, 2, 5), (3, 3), ([0], [0])),
273286
# 4Dx4D->reduce over 3 axes
274287
(
275288
(5, 6, 7, 8),
@@ -278,7 +291,8 @@ def test_disk():
278291
(7, 8, 6, 10),
279292
(6, 7, 5, 9),
280293
(3, 4, 3, 5),
281-
([1, 2, 3], [0, 1, 2]),
294+
(3, 7),
295+
([1, 2, 3], [2, 0, 1]),
282296
),
283297
# 5Dx5D->reduce over 4 axes
284298
(
@@ -288,7 +302,8 @@ def test_disk():
288302
(5, 6, 7, 4, 8),
289303
(4, 5, 6, 3, 7),
290304
(2, 3, 3, 2, 4),
291-
([1, 2, 3, 4], [0, 1, 2, 3]),
305+
(2, 5),
306+
([1, 2, 3, 4], [3, 0, 1, 2]),
292307
),
293308
],
294309
)
@@ -301,7 +316,7 @@ def test_disk():
301316
np.float64,
302317
],
303318
)
304-
def test_tensordot(shape1, chunk1, block1, shape2, chunk2, block2, axes, dtype):
319+
def test_tensordot(shape1, chunk1, block1, shape2, chunk2, block2, chunkres, axes, dtype):
305320
# Create operands with requested dtype
306321
a_b2 = blosc2.arange(0, np.prod(shape1), shape=shape1, chunks=chunk1, blocks=block1, dtype=dtype)
307322
a_np = a_b2[()] # decompress
@@ -318,11 +333,11 @@ def test_tensordot(shape1, chunk1, block1, shape2, chunk2, block2, axes, dtype):
318333
if np_raised is not None:
319334
# Expect Blosc2 to raise the same type
320335
with pytest.raises(np_raised):
321-
blosc2.tensordot(a_b2, b_b2, axes=axes)
336+
blosc2.tensordot(a_b2, b_b2, axes=axes, chunks=chunkres)
322337
else:
323338
# Both should succeed
324339
res_np = np.tensordot(a_np, b_np, axes=axes)
325-
res_b2 = blosc2.tensordot(a_b2, b_b2, axes=axes)
340+
res_b2 = blosc2.tensordot(a_b2, b_b2, axes=axes, chunks=chunkres, fast_path=False) # test slow path
326341
res_b2_np = res_b2[...]
327342

328343
# Assertions
@@ -331,3 +346,11 @@ def test_tensordot(shape1, chunk1, block1, shape2, chunk2, block2, axes, dtype):
331346
np.testing.assert_allclose(res_b2_np, res_np, rtol=1e-5, atol=1e-6)
332347
else:
333348
np.testing.assert_array_equal(res_b2_np, res_np)
349+
350+
res_b2 = blosc2.tensordot(a_b2, b_b2, axes=axes, chunks=chunkres, fast_path=True) # test fast path
351+
# Assertions
352+
assert res_b2_np.shape == res_np.shape
353+
if np.issubdtype(dtype, np.floating):
354+
np.testing.assert_allclose(res_b2_np, res_np, rtol=1e-5, atol=1e-6)
355+
else:
356+
np.testing.assert_array_equal(res_b2_np, res_np)

0 commit comments

Comments
 (0)