@@ -972,3 +972,100 @@ def swapaxes(X, axis1, axis2):
972972 ind [axis1 ] = axis2
973973 ind [axis2 ] = axis1
974974 return dpt_ext .permute_dims (X , tuple (ind ))
975+
976+
977+ def tile (x , repetitions , / ):
978+ """tile(x, repetitions)
979+
980+ Repeat an input array `x` along each axis a number of times given by
981+ `repetitions`.
982+
983+ For `N` = len(`repetitions`) and `M` = len(`x.shape`):
984+
985+ * If `M < N`, `x` will have `N - M` new axes prepended to its shape
986+ * If `M > N`, `repetitions` will have `M - N` ones prepended to it
987+
988+ Args:
989+ x (usm_ndarray): input array
990+
991+ repetitions (Union[int, Tuple[int, ...]]):
992+ The number of repetitions along each dimension of `x`.
993+
994+ Returns:
995+ usm_ndarray:
996+ tiled output array.
997+
998+ The returned array will have rank `max(M, N)`. If `S` is the
999+ shape of `x` after prepending dimensions and `R` is
1000+ `repetitions` after prepending ones, then the shape of the
1001+ result will be `S[i] * R[i]` for each dimension `i`.
1002+
1003+ The returned array will have the same data type as `x`.
1004+ The returned array will be located on the same device as `x` and
1005+ have the same USM allocation type as `x`.
1006+ """
1007+ if not isinstance (x , dpt .usm_ndarray ):
1008+ raise TypeError (f"Expected usm_ndarray type, got { type (x )} ." )
1009+
1010+ if not isinstance (repetitions , tuple ):
1011+ if isinstance (repetitions , int ):
1012+ repetitions = (repetitions ,)
1013+ else :
1014+ raise TypeError (
1015+ f"Expected tuple or integer type, got { type (repetitions )} ."
1016+ )
1017+
1018+ rep_dims = len (repetitions )
1019+ x_dims = x .ndim
1020+ if rep_dims < x_dims :
1021+ repetitions = (x_dims - rep_dims ) * (1 ,) + repetitions
1022+ elif x_dims < rep_dims :
1023+ x = dpt_ext .reshape (x , (rep_dims - x_dims ) * (1 ,) + x .shape )
1024+ res_shape = tuple (map (lambda sh , rep : sh * rep , x .shape , repetitions ))
1025+ # case of empty input
1026+ if x .size == 0 :
1027+ return dpt_ext .empty (
1028+ res_shape ,
1029+ dtype = x .dtype ,
1030+ usm_type = x .usm_type ,
1031+ sycl_queue = x .sycl_queue ,
1032+ )
1033+ in_sh = x .shape
1034+ if res_shape == in_sh :
1035+ return dpt_ext .copy (x )
1036+ expanded_sh = []
1037+ broadcast_sh = []
1038+ out_sz = 1
1039+ for i in range (len (res_shape )):
1040+ out_sz *= res_shape [i ]
1041+ reps , sh = repetitions [i ], in_sh [i ]
1042+ if reps == 1 :
1043+ # dimension will be unchanged
1044+ broadcast_sh .append (sh )
1045+ expanded_sh .append (sh )
1046+ elif sh == 1 :
1047+ # dimension will be broadcast
1048+ broadcast_sh .append (reps )
1049+ expanded_sh .append (sh )
1050+ else :
1051+ broadcast_sh .extend ([reps , sh ])
1052+ expanded_sh .extend ([1 , sh ])
1053+ exec_q = x .sycl_queue
1054+ xdt = x .dtype
1055+ xut = x .usm_type
1056+ res = dpt_ext .empty ((out_sz ,), dtype = xdt , usm_type = xut , sycl_queue = exec_q )
1057+ # no need to copy data for empty output
1058+ if out_sz > 0 :
1059+ x = dpt_ext .broadcast_to (
1060+ # this reshape should never copy
1061+ dpt_ext .reshape (x , expanded_sh ),
1062+ broadcast_sh ,
1063+ )
1064+ # copy broadcast input into flat array
1065+ _manager = dputils .SequentialOrderManager [exec_q ]
1066+ dep_evs = _manager .submitted_events
1067+ hev , cp_ev = ti ._copy_usm_ndarray_for_reshape (
1068+ src = x , dst = res , sycl_queue = exec_q , depends = dep_evs
1069+ )
1070+ _manager .add_event_pair (hev , cp_ev )
1071+ return dpt_ext .reshape (res , res_shape )
0 commit comments