Skip to content

Commit b5e3541

Browse files
Move ti.tile() to dpctl_ext.tensor and reuse it
1 parent ccad5f0 commit b5e3541

3 files changed

Lines changed: 100 additions & 1 deletion

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
squeeze,
7676
stack,
7777
swapaxes,
78+
tile,
7879
)
7980
from dpctl_ext.tensor._reshape import reshape
8081

@@ -123,6 +124,7 @@
123124
"swapaxes",
124125
"take",
125126
"take_along_axis",
127+
"tile",
126128
"to_numpy",
127129
"tril",
128130
"triu",

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,3 +972,100 @@ def swapaxes(X, axis1, axis2):
972972
ind[axis1] = axis2
973973
ind[axis2] = axis1
974974
return dpt_ext.permute_dims(X, tuple(ind))
975+
976+
977+
def tile(x, repetitions, /):
978+
"""tile(x, repetitions)
979+
980+
Repeat an input array `x` along each axis a number of times given by
981+
`repetitions`.
982+
983+
For `N` = len(`repetitions`) and `M` = len(`x.shape`):
984+
985+
* If `M < N`, `x` will have `N - M` new axes prepended to its shape
986+
* If `M > N`, `repetitions` will have `M - N` ones prepended to it
987+
988+
Args:
989+
x (usm_ndarray): input array
990+
991+
repetitions (Union[int, Tuple[int, ...]]):
992+
The number of repetitions along each dimension of `x`.
993+
994+
Returns:
995+
usm_ndarray:
996+
tiled output array.
997+
998+
The returned array will have rank `max(M, N)`. If `S` is the
999+
shape of `x` after prepending dimensions and `R` is
1000+
`repetitions` after prepending ones, then the shape of the
1001+
result will be `S[i] * R[i]` for each dimension `i`.
1002+
1003+
The returned array will have the same data type as `x`.
1004+
The returned array will be located on the same device as `x` and
1005+
have the same USM allocation type as `x`.
1006+
"""
1007+
if not isinstance(x, dpt.usm_ndarray):
1008+
raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
1009+
1010+
if not isinstance(repetitions, tuple):
1011+
if isinstance(repetitions, int):
1012+
repetitions = (repetitions,)
1013+
else:
1014+
raise TypeError(
1015+
f"Expected tuple or integer type, got {type(repetitions)}."
1016+
)
1017+
1018+
rep_dims = len(repetitions)
1019+
x_dims = x.ndim
1020+
if rep_dims < x_dims:
1021+
repetitions = (x_dims - rep_dims) * (1,) + repetitions
1022+
elif x_dims < rep_dims:
1023+
x = dpt_ext.reshape(x, (rep_dims - x_dims) * (1,) + x.shape)
1024+
res_shape = tuple(map(lambda sh, rep: sh * rep, x.shape, repetitions))
1025+
# case of empty input
1026+
if x.size == 0:
1027+
return dpt_ext.empty(
1028+
res_shape,
1029+
dtype=x.dtype,
1030+
usm_type=x.usm_type,
1031+
sycl_queue=x.sycl_queue,
1032+
)
1033+
in_sh = x.shape
1034+
if res_shape == in_sh:
1035+
return dpt_ext.copy(x)
1036+
expanded_sh = []
1037+
broadcast_sh = []
1038+
out_sz = 1
1039+
for i in range(len(res_shape)):
1040+
out_sz *= res_shape[i]
1041+
reps, sh = repetitions[i], in_sh[i]
1042+
if reps == 1:
1043+
# dimension will be unchanged
1044+
broadcast_sh.append(sh)
1045+
expanded_sh.append(sh)
1046+
elif sh == 1:
1047+
# dimension will be broadcast
1048+
broadcast_sh.append(reps)
1049+
expanded_sh.append(sh)
1050+
else:
1051+
broadcast_sh.extend([reps, sh])
1052+
expanded_sh.extend([1, sh])
1053+
exec_q = x.sycl_queue
1054+
xdt = x.dtype
1055+
xut = x.usm_type
1056+
res = dpt_ext.empty((out_sz,), dtype=xdt, usm_type=xut, sycl_queue=exec_q)
1057+
# no need to copy data for empty output
1058+
if out_sz > 0:
1059+
x = dpt_ext.broadcast_to(
1060+
# this reshape should never copy
1061+
dpt_ext.reshape(x, expanded_sh),
1062+
broadcast_sh,
1063+
)
1064+
# copy broadcast input into flat array
1065+
_manager = dputils.SequentialOrderManager[exec_q]
1066+
dep_evs = _manager.submitted_events
1067+
hev, cp_ev = ti._copy_usm_ndarray_for_reshape(
1068+
src=x, dst=res, sycl_queue=exec_q, depends=dep_evs
1069+
)
1070+
_manager.add_event_pair(hev, cp_ev)
1071+
return dpt_ext.reshape(res, res_shape)

dpnp/dpnp_iface_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3892,7 +3892,7 @@ def tile(A, reps):
38923892
"""
38933893

38943894
usm_a = dpnp.get_usm_ndarray(A)
3895-
usm_res = dpt.tile(usm_a, reps)
3895+
usm_res = dpt_ext.tile(usm_a, reps)
38963896
return dpnp_array._create_from_usm_ndarray(usm_res)
38973897

38983898

0 commit comments

Comments
 (0)