@@ -288,12 +288,7 @@ def roll(X, shift, axes=None):
288288 return res
289289
290290
291- def concat (arrays , axis = 0 ):
292- """
293- concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
294-
295- Joins a sequence of arrays along an existing axis.
296- """
291+ def _arrays_validation (arrays ):
297292 n = len (arrays )
298293 if n == 0 :
299294 raise TypeError ("Missing 1 required positional argument: 'arrays'" )
@@ -324,11 +319,23 @@ def concat(arrays, axis=0):
324319 for i in range (1 , n ):
325320 if X0 .ndim != arrays [i ].ndim :
326321 raise ValueError (
327- "All the input arrays must have same number of "
328- "dimensions, but the array at index 0 has "
329- f"{ X0 .ndim } dimension(s) and the array at index "
330- f"{ i } has { arrays [i ].ndim } dimension(s)"
322+ "All the input arrays must have same number of dimensions, "
323+ f"but the array at index 0 has { X0 .ndim } dimension(s) and the "
324+ f"array at index { i } has { arrays [i ].ndim } dimension(s)"
331325 )
326+ return res_dtype , res_usm_type , exec_q
327+
328+
329+ def concat (arrays , axis = 0 ):
330+ """
331+ concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
332+
333+ Joins a sequence of arrays along an existing axis.
334+ """
335+ res_dtype , res_usm_type , exec_q = _arrays_validation (arrays )
336+
337+ n = len (arrays )
338+ X0 = arrays [0 ]
332339
333340 axis = normalize_axis_index (axis , X0 .ndim )
334341 X0_shape = X0 .shape
@@ -337,11 +344,10 @@ def concat(arrays, axis=0):
337344 for j in range (X0 .ndim ):
338345 if X0_shape [j ] != Xi_shape [j ] and j != axis :
339346 raise ValueError (
340- "All the input array dimensions for the "
341- "concatenation axis must match exactly, but "
342- f"along dimension { j } , the array at index 0 "
343- f"has size { X0_shape [j ]} and the array at "
344- f"index { i } has size { Xi_shape [j ]} "
347+ "All the input array dimensions for the concatenation "
348+ f"axis must match exactly, but along dimension { j } , the "
349+ f"array at index 0 has size { X0_shape [j ]} and the array "
350+ f"at index { i } has size { Xi_shape [j ]} "
345351 )
346352
347353 res_shape_axis = 0
@@ -373,3 +379,45 @@ def concat(arrays, axis=0):
373379 dpctl .SyclEvent .wait_for (hev_list )
374380
375381 return res
382+
383+
384+ def stack (arrays , axis = 0 ):
385+ """
386+ stack(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
387+
388+ Joins a sequence of arrays along a new axis.
389+ """
390+ res_dtype , res_usm_type , exec_q = _arrays_validation (arrays )
391+
392+ n = len (arrays )
393+ X0 = arrays [0 ]
394+ res_ndim = X0 .ndim + 1
395+ axis = normalize_axis_index (axis , res_ndim )
396+ X0_shape = X0 .shape
397+
398+ for i in range (1 , n ):
399+ if X0_shape != arrays [i ].shape :
400+ raise ValueError ("All input arrays must have the same shape" )
401+
402+ res_shape = tuple (
403+ X0_shape [i - 1 * (i >= axis )] if i != axis else n
404+ for i in range (res_ndim )
405+ )
406+
407+ res = dpt .empty (
408+ res_shape , dtype = res_dtype , usm_type = res_usm_type , sycl_queue = exec_q
409+ )
410+
411+ hev_list = []
412+ for i in range (n ):
413+ c_shapes_copy = tuple (
414+ i if j == axis else np .s_ [:] for j in range (res_ndim )
415+ )
416+ hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
417+ src = arrays [i ], dst = res [c_shapes_copy ], sycl_queue = exec_q
418+ )
419+ hev_list .append (hev )
420+
421+ dpctl .SyclEvent .wait_for (hev_list )
422+
423+ return res
0 commit comments