Skip to content

Commit 325729b

Browse files
Move ti.unstack() to dpctl_ext.tensor and reuse it
1 parent b5e3541 commit 325729b

3 files changed

Lines changed: 36 additions & 3 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
stack,
7777
swapaxes,
7878
tile,
79+
unstack,
7980
)
8081
from dpctl_ext.tensor._reshape import reshape
8182

@@ -128,6 +129,7 @@
128129
"to_numpy",
129130
"tril",
130131
"triu",
132+
"unstack",
131133
"where",
132134
"zeros",
133135
"zeros_like",

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,35 @@ def swapaxes(X, axis1, axis2):
974974
return dpt_ext.permute_dims(X, tuple(ind))
975975

976976

977+
def unstack(X, /, *, axis=0):
978+
"""unstack(x, axis=0)
979+
980+
Splits an array in a sequence of arrays along the given axis.
981+
982+
Args:
983+
x (usm_ndarray): input array
984+
985+
axis (int, optional): axis along which `x` is unstacked.
986+
If `x` has rank (i.e, number of dimensions) `N`,
987+
a valid `axis` must reside in the half-open interval `[-N, N)`.
988+
Default: `0`.
989+
990+
Returns:
991+
Tuple[usm_ndarray,...]:
992+
Output sequence of arrays which are views into the input array.
993+
994+
Raises:
995+
AxisError: if the `axis` value is invalid.
996+
"""
997+
if not isinstance(X, dpt.usm_ndarray):
998+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
999+
1000+
axis = normalize_axis_index(axis, X.ndim)
1001+
Y = dpt_ext.moveaxis(X, axis, 0)
1002+
1003+
return tuple(Y[i] for i in range(Y.shape[0]))
1004+
1005+
9771006
def tile(x, repetitions, /):
9781007
"""tile(x, repetitions)
9791008

dpnp/dpnp_iface_manipulation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,7 +1093,9 @@ def broadcast_arrays(*args, subok=False):
10931093
if len(args) == 0:
10941094
return []
10951095

1096-
usm_arrays = dpt.broadcast_arrays(*[dpnp.get_usm_ndarray(a) for a in args])
1096+
usm_arrays = dpt_ext.broadcast_arrays(
1097+
*[dpnp.get_usm_ndarray(a) for a in args]
1098+
)
10971099
return [dpnp_array._create_from_usm_ndarray(a) for a in usm_arrays]
10981100

10991101

@@ -1521,7 +1523,7 @@ def copyto(dst, src, casting="same_kind", where=True):
15211523
f"but got {where.dtype}"
15221524
)
15231525

1524-
dst_usm, src_usm, mask_usm = dpt.broadcast_arrays(
1526+
dst_usm, src_usm, mask_usm = dpt_ext.broadcast_arrays(
15251527
dpnp.get_usm_ndarray(dst),
15261528
dpnp.get_usm_ndarray(src),
15271529
dpnp.get_usm_ndarray(where),
@@ -4522,7 +4524,7 @@ def unstack(x, /, *, axis=0):
45224524
if usm_x.ndim == 0:
45234525
raise ValueError("Input array must be at least 1-d.")
45244526

4525-
res = dpt.unstack(usm_x, axis=axis)
4527+
res = dpt_ext.unstack(usm_x, axis=axis)
45264528
return tuple(dpnp_array._create_from_usm_ndarray(a) for a in res)
45274529

45284530

0 commit comments

Comments
 (0)