Skip to content

Commit 7d4d2a0

Browse files
authored
Clarify axes keyword in tensordot function (#2733)
The PR extends docstrings to clarify behavior on repeated values passed in `axes` to `tensordot` functions. Also an explicit test was added to validate the exception raised on the repeated axes.
1 parent e9a79f4 commit 7d4d2a0

File tree

4 files changed

+30
-4
lines changed

4 files changed

+30
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
4444
* Compile indexing extension with `-fno-sycl-id-queries-fit-in-int` to support huge arrays [#2721](https://github.com/IntelPython/dpnp/pull/2721)
4545
* Updated `dpnp.fix` to reuse `dpnp.trunc` internally [#2722](https://github.com/IntelPython/dpnp/pull/2722)
4646
* Changed the build scripts and documentation due to `python setup.py develop` deprecation notice [#2716](https://github.com/IntelPython/dpnp/pull/2716)
47+
* Clarified behavior on repeated `axes` in `dpnp.tensordot` and `dpnp.linalg.tensordot` functions [#2733](https://github.com/IntelPython/dpnp/pull/2733)
4748

4849
### Deprecated
4950

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,7 +1121,7 @@ def outer(a, b, out=None):
11211121
return result
11221122

11231123

1124-
def tensordot(a, b, axes=2):
1124+
def tensordot(a, b, /, *, axes=2):
11251125
r"""
11261126
Compute tensor dot product along specified axes.
11271127
@@ -1148,7 +1148,10 @@ def tensordot(a, b, axes=2):
11481148
axes must match.
11491149
* (2,) array_like: A list of axes to be summed over, first sequence
11501150
applying to `a`, second to `b`. Both elements array_like must be of
1151-
the same length.
1151+
the same length. Each axis may appear at most once; repeated axes are
1152+
not allowed.
1153+
1154+
Default: ``2``.
11521155
11531156
Returns
11541157
-------
@@ -1178,6 +1181,13 @@ def tensordot(a, b, axes=2):
11781181
two sequences of the same length, with the first axis to sum over given
11791182
first in both sequences, the second axis second, and so forth.
11801183
1184+
For example, if ``a.shape == (2, 3, 4)`` and ``b.shape == (3, 4, 5)``, then
1185+
``axes=([1, 2], [0, 1])`` sums over the ``(3, 4)`` dimensions of both
1186+
arrays and produces an output of shape ``(2, 5)``.
1187+
1188+
Each summation axis corresponds to a distinct contraction index; repeating
1189+
an axis (for example ``axes=([1, 1], [0, 0])``) is invalid.
1190+
11811191
The shape of the result consists of the non-contracted axes of the
11821192
first tensor, followed by the non-contracted axes of the second.
11831193

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1975,9 +1975,10 @@ def tensordot(a, b, /, *, axes=2):
19751975
axes must match.
19761976
* (2,) array_like: A list of axes to be summed over, first sequence
19771977
applying to `a`, second to `b`. Both elements array_like must be of
1978-
the same length.
1978+
the same length. Each axis may appear at most once; repeated axes are
1979+
not allowed.
19791980
1980-
Default: ``2``.
1981+
Default: ``2``.
19811982
19821983
Returns
19831984
-------
@@ -2007,6 +2008,13 @@ def tensordot(a, b, /, *, axes=2):
20072008
two sequences of the same length, with the first axis to sum over given
20082009
first in both sequences, the second axis second, and so forth.
20092010
2011+
For example, if ``a.shape == (2, 3, 4)`` and ``b.shape == (3, 4, 5)``, then
2012+
``axes=([1, 2], [0, 1])`` sums over the ``(3, 4)`` dimensions of both
2013+
arrays and produces an output of shape ``(2, 5)``.
2014+
2015+
Each summation axis corresponds to a distinct contraction index; repeating
2016+
an axis (for example ``axes=([1, 1], [0, 0])``) is invalid.
2017+
20102018
The shape of the result consists of the non-contracted axes of the
20112019
first tensor, followed by the non-contracted axes of the second.
20122020

dpnp/tests/test_product.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1842,6 +1842,13 @@ def test_error(self):
18421842
with pytest.raises(ValueError):
18431843
dpnp.tensordot(dpnp.arange(4), dpnp.array(5), axes=-1)
18441844

1845+
@pytest.mark.parametrize("xp", [numpy, dpnp])
1846+
def test_repeated_axes(self, xp):
1847+
a = xp.ones((2, 3, 3))
1848+
b = xp.ones((3, 3, 4))
1849+
with pytest.raises(ValueError):
1850+
xp.tensordot(a, b, axes=([1, 1], [0, 0]))
1851+
18451852

18461853
class TestVdot:
18471854
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))

0 commit comments

Comments
 (0)