Skip to content

Commit 0e38757

Browse files
authored
ENH: add delegation for kron (#516)
1 parent 1a5ad8c commit 0e38757

3 files changed

Lines changed: 102 additions & 80 deletions

File tree

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
expand_dims,
1010
isclose,
1111
isin,
12+
kron,
1213
nan_to_num,
1314
one_hot,
1415
pad,
@@ -23,7 +24,6 @@
2324
angle,
2425
apply_where,
2526
default_dtype,
26-
kron,
2727
nunique,
2828
)
2929
from ._lib._lazy import lazy_apply

src/array_api_extra/_delegation.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"create_diagonal",
2626
"expand_dims",
2727
"isclose",
28+
"kron",
2829
"nan_to_num",
2930
"one_hot",
3031
"pad",
@@ -479,6 +480,101 @@ def isclose(
479480
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
480481

481482

483+
def kron(
484+
a: Array | complex,
485+
b: Array | complex,
486+
/,
487+
*,
488+
xp: ModuleType | None = None,
489+
) -> Array:
490+
"""
491+
Kronecker product of two arrays.
492+
493+
Computes the Kronecker product, a composite array made of blocks of the
494+
second array scaled by the first.
495+
496+
Equivalent to ``numpy.kron`` for NumPy arrays.
497+
498+
Parameters
499+
----------
500+
a, b : Array | int | float | complex
501+
Input arrays or scalars. At least one must be an array.
502+
xp : array_namespace, optional
503+
The standard-compatible namespace for `a` and `b`. Default: infer.
504+
505+
Returns
506+
-------
507+
array
508+
The Kronecker product of `a` and `b`.
509+
510+
Notes
511+
-----
512+
The function assumes that the number of dimensions of `a` and `b`
513+
are the same, if necessary prepending the smallest with ones.
514+
If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``,
515+
the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``.
516+
The elements are products of elements from `a` and `b`, organized
517+
explicitly by::
518+
519+
kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]
520+
521+
where::
522+
523+
kt = it * st + jt, t = 0,...,N
524+
525+
In the common 2-D case (N=1), the block structure can be visualized::
526+
527+
[[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
528+
[ ... ... ],
529+
[ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]
530+
531+
Examples
532+
--------
533+
>>> import array_api_strict as xp
534+
>>> import array_api_extra as xpx
535+
>>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp)
536+
Array([ 5, 6, 7, 50, 60, 70, 500,
537+
600, 700], dtype=array_api_strict.int64)
538+
539+
>>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp)
540+
Array([ 5, 50, 500, 6, 60, 600, 7,
541+
70, 700], dtype=array_api_strict.int64)
542+
543+
>>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp)
544+
Array([[1., 1., 0., 0.],
545+
[1., 1., 0., 0.],
546+
[0., 0., 1., 1.],
547+
[0., 0., 1., 1.]], dtype=array_api_strict.float64)
548+
549+
>>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
550+
>>> b = xp.reshape(xp.arange(24), (2, 3, 4))
551+
>>> c = xpx.kron(a, b, xp=xp)
552+
>>> c.shape
553+
(2, 10, 6, 20)
554+
>>> I = (1, 3, 0, 2)
555+
>>> J = (0, 2, 1)
556+
>>> J1 = (0,) + J # extend to ndim=4
557+
>>> S1 = (1,) + b.shape
558+
>>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
559+
>>> c[K] == a[I]*b[J]
560+
Array(True, dtype=array_api_strict.bool)
561+
"""
562+
if xp is None:
563+
xp = array_namespace(a, b)
564+
565+
a, b = asarrays(a, b, xp=xp)
566+
567+
if (
568+
is_cupy_namespace(xp)
569+
or is_jax_namespace(xp)
570+
or is_numpy_namespace(xp)
571+
or is_torch_namespace(xp)
572+
):
573+
return xp.kron(a, b)
574+
575+
return _funcs.kron(a, b, xp=xp)
576+
577+
482578
def nan_to_num(
483579
x: Array | float | complex,
484580
/,

src/array_api_extra/_lib/_funcs.py

Lines changed: 5 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -405,87 +405,13 @@ def isclose(
405405

406406

407407
def kron(
408-
a: Array | complex,
409-
b: Array | complex,
408+
a: Array,
409+
b: Array,
410410
/,
411411
*,
412-
xp: ModuleType | None = None,
413-
) -> Array:
414-
"""
415-
Kronecker product of two arrays.
416-
417-
Computes the Kronecker product, a composite array made of blocks of the
418-
second array scaled by the first.
419-
420-
Equivalent to ``numpy.kron`` for NumPy arrays.
421-
422-
Parameters
423-
----------
424-
a, b : Array | int | float | complex
425-
Input arrays or scalars. At least one must be an array.
426-
xp : array_namespace, optional
427-
The standard-compatible namespace for `a` and `b`. Default: infer.
428-
429-
Returns
430-
-------
431-
array
432-
The Kronecker product of `a` and `b`.
433-
434-
Notes
435-
-----
436-
The function assumes that the number of dimensions of `a` and `b`
437-
are the same, if necessary prepending the smallest with ones.
438-
If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``,
439-
the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``.
440-
The elements are products of elements from `a` and `b`, organized
441-
explicitly by::
442-
443-
kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]
444-
445-
where::
446-
447-
kt = it * st + jt, t = 0,...,N
448-
449-
In the common 2-D case (N=1), the block structure can be visualized::
450-
451-
[[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
452-
[ ... ... ],
453-
[ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]
454-
455-
Examples
456-
--------
457-
>>> import array_api_strict as xp
458-
>>> import array_api_extra as xpx
459-
>>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp)
460-
Array([ 5, 6, 7, 50, 60, 70, 500,
461-
600, 700], dtype=array_api_strict.int64)
462-
463-
>>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp)
464-
Array([ 5, 50, 500, 6, 60, 600, 7,
465-
70, 700], dtype=array_api_strict.int64)
466-
467-
>>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp)
468-
Array([[1., 1., 0., 0.],
469-
[1., 1., 0., 0.],
470-
[0., 0., 1., 1.],
471-
[0., 0., 1., 1.]], dtype=array_api_strict.float64)
472-
473-
>>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
474-
>>> b = xp.reshape(xp.arange(24), (2, 3, 4))
475-
>>> c = xpx.kron(a, b, xp=xp)
476-
>>> c.shape
477-
(2, 10, 6, 20)
478-
>>> I = (1, 3, 0, 2)
479-
>>> J = (0, 2, 1)
480-
>>> J1 = (0,) + J # extend to ndim=4
481-
>>> S1 = (1,) + b.shape
482-
>>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
483-
>>> c[K] == a[I]*b[J]
484-
Array(True, dtype=array_api_strict.bool)
485-
"""
486-
if xp is None:
487-
xp = array_namespace(a, b)
488-
a, b = asarrays(a, b, xp=xp)
412+
xp: ModuleType,
413+
) -> Array: # numpydoc ignore=PR01,RT01
414+
"""See docstring in array_api_extra._delegation."""
489415

490416
singletons = (1,) * (b.ndim - a.ndim)
491417
a = cast(Array, xp.broadcast_to(a, singletons + a.shape))

0 commit comments

Comments
 (0)