Skip to content

Commit ccad5f0

Browse files
Move ti.swapaxes() to dpctl_ext.tensor and reuse it
1 parent bb16c19 commit ccad5f0

3 files changed

Lines changed: 41 additions & 1 deletion

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
roll,
7575
squeeze,
7676
stack,
77+
swapaxes,
7778
)
7879
from dpctl_ext.tensor._reshape import reshape
7980

@@ -119,6 +120,7 @@
119120
"roll",
120121
"squeeze",
121122
"stack",
123+
"swapaxes",
122124
"take",
123125
"take_along_axis",
124126
"to_numpy",

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,3 +934,41 @@ def stack(arrays, /, *, axis=0):
934934
_manager.add_event_pair(hev, cpy_ev)
935935

936936
return res
937+
938+
939+
def swapaxes(X, axis1, axis2):
940+
"""swapaxes(x, axis1, axis2)
941+
942+
Interchanges two axes of an array.
943+
944+
Args:
945+
x (usm_ndarray): input array
946+
947+
axis1 (int): First axis.
948+
If `x` has rank (i.e., number of dimensions) `N`,
949+
a valid `axis` must be in the half-open interval `[-N, N)`.
950+
951+
axis2 (int): Second axis.
952+
If `x` has rank (i.e., number of dimensions) `N`,
953+
a valid `axis` must be in the half-open interval `[-N, N)`.
954+
955+
Returns:
956+
usm_ndarray:
957+
Array with swapped axes.
958+
The returned array must has the same data type as `x`,
959+
is created on the same device as `x` and has the same USM
960+
allocation type as `x`.
961+
962+
Raises:
963+
AxisError: if `axis` value is invalid.
964+
"""
965+
if not isinstance(X, dpt.usm_ndarray):
966+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
967+
968+
axis1 = normalize_axis_index(axis1, X.ndim, "axis1")
969+
axis2 = normalize_axis_index(axis2, X.ndim, "axis2")
970+
971+
ind = list(range(0, X.ndim))
972+
ind[axis1] = axis2
973+
ind[axis2] = axis1
974+
return dpt_ext.permute_dims(X, tuple(ind))

dpnp/dpnp_iface_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3812,7 +3812,7 @@ def swapaxes(a, axis1, axis2):
38123812
"""
38133813

38143814
usm_a = dpnp.get_usm_ndarray(a)
3815-
usm_res = dpt.swapaxes(usm_a, axis1=axis1, axis2=axis2)
3815+
usm_res = dpt_ext.swapaxes(usm_a, axis1=axis1, axis2=axis2)
38163816
return dpnp_array._create_from_usm_ndarray(usm_res)
38173817

38183818

0 commit comments

Comments
 (0)