3434
3535# TODO: revert to `import dpctl.tensor...`
3636# when dpnp fully migrates dpctl/tensor
37+ import dpctl_ext .tensor as dpt_ext
3738import dpctl_ext .tensor ._tensor_elementwise_impl as tei
3839import dpctl_ext .tensor ._tensor_impl as ti
3940import dpctl_ext .tensor ._tensor_linalg_impl as tli
@@ -180,8 +181,8 @@ def tensordot(x1, x2, axes=2):
180181 axes2 = normalize_axis_tuple (axes2 , x2_nd )
181182 perm1 = [i for i in range (x1_nd ) if i not in axes1 ] + list (axes1 )
182183 perm2 = list (axes2 ) + [i for i in range (x2_nd ) if i not in axes2 ]
183- arr1 = dpt .permute_dims (x1 , perm1 )
184- arr2 = dpt .permute_dims (x2 , perm2 )
184+ arr1 = dpt_ext .permute_dims (x1 , perm1 )
185+ arr2 = dpt_ext .permute_dims (x2 , perm2 )
185186 arr1_outer_nd = arr1 .ndim - n_axes1
186187 arr2_outer_nd = arr2 .ndim - n_axes2
187188 res_shape = arr1 .shape [:arr1_outer_nd ] + arr2 .shape [n_axes2 :]
@@ -206,7 +207,7 @@ def tensordot(x1, x2, axes=2):
206207
207208 _manager = SequentialOrderManager [exec_q ]
208209 if buf1_dt is None and buf2_dt is None :
209- out = dpt .empty (
210+ out = dpt_ext .empty (
210211 res_shape ,
211212 dtype = res_dt ,
212213 usm_type = res_usm_type ,
@@ -237,7 +238,7 @@ def tensordot(x1, x2, axes=2):
237238 src = arr2 , dst = buf2 , sycl_queue = exec_q , depends = dep_evs
238239 )
239240 _manager .add_event_pair (ht_copy_ev , copy_ev )
240- out = dpt .empty (
241+ out = dpt_ext .empty (
241242 res_shape ,
242243 dtype = res_dt ,
243244 usm_type = res_usm_type ,
@@ -266,7 +267,7 @@ def tensordot(x1, x2, axes=2):
266267 src = arr1 , dst = buf1 , sycl_queue = exec_q , depends = dep_evs
267268 )
268269 _manager .add_event_pair (ht_copy_ev , copy_ev )
269- out = dpt .empty (
270+ out = dpt_ext .empty (
270271 res_shape ,
271272 dtype = res_dt ,
272273 usm_type = res_usm_type ,
@@ -299,7 +300,7 @@ def tensordot(x1, x2, axes=2):
299300 src = arr2 , dst = buf2 , sycl_queue = exec_q , depends = deps_ev
300301 )
301302 _manager .add_event_pair (ht_copy2_ev , copy2_ev )
302- out = dpt .empty (
303+ out = dpt_ext .empty (
303304 res_shape ,
304305 dtype = res_dt ,
305306 usm_type = res_usm_type ,
@@ -434,12 +435,12 @@ def vecdot(x1, x2, axis=-1):
434435 _manager .add_event_pair (ht_conj_ev , conj_ev )
435436 x1 = x1_tmp
436437 if x1 .shape != broadcast_sh :
437- x1 = dpt .broadcast_to (x1 , broadcast_sh )
438+ x1 = dpt_ext .broadcast_to (x1 , broadcast_sh )
438439 if x2 .shape != broadcast_sh :
439- x2 = dpt .broadcast_to (x2 , broadcast_sh )
440- x1 = dpt .moveaxis (x1 , contracted_axis , - 1 )
441- x2 = dpt .moveaxis (x2 , contracted_axis , - 1 )
442- out = dpt .empty (
440+ x2 = dpt_ext .broadcast_to (x2 , broadcast_sh )
441+ x1 = dpt_ext .moveaxis (x1 , contracted_axis , - 1 )
442+ x2 = dpt_ext .moveaxis (x2 , contracted_axis , - 1 )
443+ out = dpt_ext .empty (
443444 res_sh ,
444445 dtype = res_dt ,
445446 usm_type = res_usm_type ,
@@ -459,7 +460,7 @@ def vecdot(x1, x2, axis=-1):
459460 depends = dep_evs ,
460461 )
461462 _manager .add_event_pair (ht_dot_ev , dot_ev )
462- return dpt .reshape (out , res_sh )
463+ return dpt_ext .reshape (out , res_sh )
463464
464465 elif buf1_dt is None :
465466 if x1 .dtype .kind == "c" :
@@ -477,12 +478,12 @@ def vecdot(x1, x2, axis=-1):
477478 )
478479 _manager .add_event_pair (ht_copy_ev , copy_ev )
479480 if x1 .shape != broadcast_sh :
480- x1 = dpt .broadcast_to (x1 , broadcast_sh )
481+ x1 = dpt_ext .broadcast_to (x1 , broadcast_sh )
481482 if buf2 .shape != broadcast_sh :
482- buf2 = dpt .broadcast_to (buf2 , broadcast_sh )
483- x1 = dpt .moveaxis (x1 , contracted_axis , - 1 )
484- buf2 = dpt .moveaxis (buf2 , contracted_axis , - 1 )
485- out = dpt .empty (
483+ buf2 = dpt_ext .broadcast_to (buf2 , broadcast_sh )
484+ x1 = dpt_ext .moveaxis (x1 , contracted_axis , - 1 )
485+ buf2 = dpt_ext .moveaxis (buf2 , contracted_axis , - 1 )
486+ out = dpt_ext .empty (
486487 res_sh ,
487488 dtype = res_dt ,
488489 usm_type = res_usm_type ,
@@ -501,7 +502,7 @@ def vecdot(x1, x2, axis=-1):
501502 depends = [copy_ev ],
502503 )
503504 _manager .add_event_pair (ht_dot_ev , dot_ev )
504- return dpt .reshape (out , res_sh )
505+ return dpt_ext .reshape (out , res_sh )
505506
506507 elif buf2_dt is None :
507508 buf1 = _empty_like_orderK (x1 , buf1_dt )
@@ -516,12 +517,12 @@ def vecdot(x1, x2, axis=-1):
516517 )
517518 _manager .add_event_pair (ht_conj_ev , conj_ev )
518519 if buf1 .shape != broadcast_sh :
519- buf1 = dpt .broadcast_to (buf1 , broadcast_sh )
520+ buf1 = dpt_ext .broadcast_to (buf1 , broadcast_sh )
520521 if x2 .shape != broadcast_sh :
521- x2 = dpt .broadcast_to (x2 , broadcast_sh )
522- buf1 = dpt .moveaxis (buf1 , contracted_axis , - 1 )
523- x2 = dpt .moveaxis (x2 , contracted_axis , - 1 )
524- out = dpt .empty (
522+ x2 = dpt_ext .broadcast_to (x2 , broadcast_sh )
523+ buf1 = dpt_ext .moveaxis (buf1 , contracted_axis , - 1 )
524+ x2 = dpt_ext .moveaxis (x2 , contracted_axis , - 1 )
525+ out = dpt_ext .empty (
525526 res_sh ,
526527 dtype = res_dt ,
527528 usm_type = res_usm_type ,
@@ -541,7 +542,7 @@ def vecdot(x1, x2, axis=-1):
541542 depends = deps_ev ,
542543 )
543544 _manager .add_event_pair (ht_dot_ev , dot_ev )
544- return dpt .reshape (out , res_sh )
545+ return dpt_ext .reshape (out , res_sh )
545546
546547 buf1 = _empty_like_orderK (x1 , buf1_dt )
547548 deps_ev = _manager .submitted_events
@@ -560,12 +561,12 @@ def vecdot(x1, x2, axis=-1):
560561 )
561562 _manager .add_event_pair (ht_copy2_ev , copy2_ev )
562563 if buf1 .shape != broadcast_sh :
563- buf1 = dpt .broadcast_to (buf1 , broadcast_sh )
564+ buf1 = dpt_ext .broadcast_to (buf1 , broadcast_sh )
564565 if buf2 .shape != broadcast_sh :
565- buf2 = dpt .broadcast_to (buf2 , broadcast_sh )
566- buf1 = dpt .moveaxis (buf1 , contracted_axis , - 1 )
567- buf2 = dpt .moveaxis (buf2 , contracted_axis , - 1 )
568- out = dpt .empty (
566+ buf2 = dpt_ext .broadcast_to (buf2 , broadcast_sh )
567+ buf1 = dpt_ext .moveaxis (buf1 , contracted_axis , - 1 )
568+ buf2 = dpt_ext .moveaxis (buf2 , contracted_axis , - 1 )
569+ out = dpt_ext .empty (
569570 res_sh ,
570571 dtype = res_dt ,
571572 usm_type = res_usm_type ,
@@ -732,7 +733,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
732733 res_dt = _to_device_supported_dtype (res_dt , sycl_dev )
733734 buf1_dt , buf2_dt = None , None
734735 if x1_dtype != res_dt :
735- if dpt .can_cast (x1_dtype , res_dt , casting = "same_kind" ):
736+ if dpt_ext .can_cast (x1_dtype , res_dt , casting = "same_kind" ):
736737 buf1_dt = res_dt
737738 else :
738739 raise ValueError (
@@ -742,7 +743,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
742743 "''same_kind''."
743744 )
744745 if x2_dtype != res_dt :
745- if dpt .can_cast (x2_dtype , res_dt , casting = "same_kind" ):
746+ if dpt_ext .can_cast (x2_dtype , res_dt , casting = "same_kind" ):
746747 buf2_dt = res_dt
747748 else :
748749 raise ValueError (
@@ -774,7 +775,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
774775 )
775776
776777 if appended_axes :
777- out = dpt .expand_dims (out , axis = appended_axes )
778+ out = dpt_ext .expand_dims (out , axis = appended_axes )
778779 orig_out = out
779780
780781 if res_dt != out .dtype :
@@ -788,12 +789,12 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
788789 )
789790
790791 if ti ._array_overlap (x1 , out ) and buf1_dt is None :
791- out = dpt .empty_like (out )
792+ out = dpt_ext .empty_like (out )
792793
793794 if ti ._array_overlap (x2 , out ) and buf2_dt is None :
794795 # should not reach if out is reallocated
795796 # after being checked against x1
796- out = dpt .empty_like (out )
797+ out = dpt_ext .empty_like (out )
797798
798799 if order == "A" :
799800 order = (
@@ -816,17 +817,17 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
816817 x1 , x2 , res_dt , res_shape , res_usm_type , exec_q
817818 )
818819 else :
819- out = dpt .empty (
820+ out = dpt_ext .empty (
820821 res_shape ,
821822 dtype = res_dt ,
822823 usm_type = res_usm_type ,
823824 sycl_queue = exec_q ,
824825 order = order ,
825826 )
826827 if x1 .shape != x1_broadcast_shape :
827- x1 = dpt .broadcast_to (x1 , x1_broadcast_shape )
828+ x1 = dpt_ext .broadcast_to (x1 , x1_broadcast_shape )
828829 if x2 .shape != x2_broadcast_shape :
829- x2 = dpt .broadcast_to (x2 , x2_broadcast_shape )
830+ x2 = dpt_ext .broadcast_to (x2 , x2_broadcast_shape )
830831 deps_evs = _manager .submitted_events
831832 ht_dot_ev , dot_ev = tli ._dot (
832833 x1 = x1 ,
@@ -851,13 +852,13 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
851852 _manager .add_event_pair (ht_copy_out_ev , cpy_ev )
852853 out = orig_out
853854 if appended_axes :
854- out = dpt .squeeze (out , tuple (appended_axes ))
855+ out = dpt_ext .squeeze (out , tuple (appended_axes ))
855856 return out
856857 elif buf1_dt is None :
857858 if order == "K" :
858859 buf2 = _empty_like_orderK (x2 , buf2_dt )
859860 else :
860- buf2 = dpt .empty_like (x2 , dtype = buf2_dt , order = order )
861+ buf2 = dpt_ext .empty_like (x2 , dtype = buf2_dt , order = order )
861862 deps_evs = _manager .submitted_events
862863 ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
863864 src = x2 , dst = buf2 , sycl_queue = exec_q , depends = deps_evs
@@ -869,7 +870,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
869870 x1 , buf2 , res_dt , res_shape , res_usm_type , exec_q
870871 )
871872 else :
872- out = dpt .empty (
873+ out = dpt_ext .empty (
873874 res_shape ,
874875 dtype = res_dt ,
875876 usm_type = res_usm_type ,
@@ -878,9 +879,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
878879 )
879880
880881 if x1 .shape != x1_broadcast_shape :
881- x1 = dpt .broadcast_to (x1 , x1_broadcast_shape )
882+ x1 = dpt_ext .broadcast_to (x1 , x1_broadcast_shape )
882883 if buf2 .shape != x2_broadcast_shape :
883- buf2 = dpt .broadcast_to (buf2 , x2_broadcast_shape )
884+ buf2 = dpt_ext .broadcast_to (buf2 , x2_broadcast_shape )
884885 ht_dot_ev , dot_ev = tli ._dot (
885886 x1 = x1 ,
886887 x2 = buf2 ,
@@ -904,14 +905,14 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
904905 _manager .add_event_pair (ht_copy_out_ev , cpy_ev )
905906 out = orig_out
906907 if appended_axes :
907- out = dpt .squeeze (out , tuple (appended_axes ))
908+ out = dpt_ext .squeeze (out , tuple (appended_axes ))
908909 return out
909910
910911 elif buf2_dt is None :
911912 if order == "K" :
912913 buf1 = _empty_like_orderK (x1 , buf1_dt )
913914 else :
914- buf1 = dpt .empty_like (x1 , dtype = buf1_dt , order = order )
915+ buf1 = dpt_ext .empty_like (x1 , dtype = buf1_dt , order = order )
915916 deps_ev = _manager .submitted_events
916917 ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
917918 src = x1 , dst = buf1 , sycl_queue = exec_q , depends = deps_ev
@@ -923,7 +924,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
923924 buf1 , x2 , res_dt , res_shape , res_usm_type , exec_q
924925 )
925926 else :
926- out = dpt .empty (
927+ out = dpt_ext .empty (
927928 res_shape ,
928929 dtype = res_dt ,
929930 usm_type = res_usm_type ,
@@ -932,9 +933,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
932933 )
933934
934935 if buf1 .shape != x1_broadcast_shape :
935- buf1 = dpt .broadcast_to (buf1 , x1_broadcast_shape )
936+ buf1 = dpt_ext .broadcast_to (buf1 , x1_broadcast_shape )
936937 if x2 .shape != x2_broadcast_shape :
937- x2 = dpt .broadcast_to (x2 , x2_broadcast_shape )
938+ x2 = dpt_ext .broadcast_to (x2 , x2_broadcast_shape )
938939 ht_dot_ev , dot_ev = tli ._dot (
939940 x1 = buf1 ,
940941 x2 = x2 ,
@@ -958,7 +959,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
958959 _manager .add_event_pair (ht_copy_out_ev , cpy_ev )
959960 out = orig_out
960961 if appended_axes :
961- out = dpt .squeeze (out , tuple (appended_axes ))
962+ out = dpt_ext .squeeze (out , tuple (appended_axes ))
962963 return out
963964
964965 if order == "K" :
@@ -969,7 +970,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
969970 if order == "K" :
970971 buf1 = _empty_like_orderK (x1 , buf1_dt )
971972 else :
972- buf1 = dpt .empty_like (x1 , dtype = buf1_dt , order = order )
973+ buf1 = dpt_ext .empty_like (x1 , dtype = buf1_dt , order = order )
973974 deps_ev = _manager .submitted_events
974975 ht_copy1_ev , copy1_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
975976 src = x1 , dst = buf1 , sycl_queue = exec_q , depends = deps_ev
@@ -978,7 +979,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
978979 if order == "K" :
979980 buf2 = _empty_like_orderK (x2 , buf2_dt )
980981 else :
981- buf2 = dpt .empty_like (x2 , dtype = buf2_dt , order = order )
982+ buf2 = dpt_ext .empty_like (x2 , dtype = buf2_dt , order = order )
982983 ht_copy2_ev , copy2_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
983984 src = x2 , dst = buf2 , sycl_queue = exec_q , depends = deps_ev
984985 )
@@ -989,7 +990,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
989990 buf1 , buf2 , res_dt , res_shape , res_usm_type , exec_q
990991 )
991992 else :
992- out = dpt .empty (
993+ out = dpt_ext .empty (
993994 res_shape ,
994995 dtype = res_dt ,
995996 usm_type = res_usm_type ,
@@ -998,9 +999,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
998999 )
9991000
10001001 if buf1 .shape != x1_broadcast_shape :
1001- buf1 = dpt .broadcast_to (buf1 , x1_broadcast_shape )
1002+ buf1 = dpt_ext .broadcast_to (buf1 , x1_broadcast_shape )
10021003 if buf2 .shape != x2_broadcast_shape :
1003- buf2 = dpt .broadcast_to (buf2 , x2_broadcast_shape )
1004+ buf2 = dpt_ext .broadcast_to (buf2 , x2_broadcast_shape )
10041005 ht_ , dot_ev = tli ._dot (
10051006 x1 = buf1 ,
10061007 x2 = buf2 ,
@@ -1014,5 +1015,5 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
10141015 )
10151016 _manager .add_event_pair (ht_ , dot_ev )
10161017 if appended_axes :
1017- out = dpt .squeeze (out , tuple (appended_axes ))
1018+ out = dpt_ext .squeeze (out , tuple (appended_axes ))
10181019 return out
0 commit comments