Skip to content

Commit ed289fc

Browse files
Merge move_elementwise_binary_impl into move_elementwise_binary_impl_part_2
2 parents d80285e + e6e179e commit ed289fc

File tree

109 files changed

+214
-293
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+214
-293
lines changed

dpctl_ext/tensor/_linear_algebra_functions.py

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
# TODO: revert to `import dpctl.tensor...`
3636
# when dpnp fully migrates dpctl/tensor
37+
import dpctl_ext.tensor as dpt_ext
3738
import dpctl_ext.tensor._tensor_elementwise_impl as tei
3839
import dpctl_ext.tensor._tensor_impl as ti
3940
import dpctl_ext.tensor._tensor_linalg_impl as tli
@@ -180,8 +181,8 @@ def tensordot(x1, x2, axes=2):
180181
axes2 = normalize_axis_tuple(axes2, x2_nd)
181182
perm1 = [i for i in range(x1_nd) if i not in axes1] + list(axes1)
182183
perm2 = list(axes2) + [i for i in range(x2_nd) if i not in axes2]
183-
arr1 = dpt.permute_dims(x1, perm1)
184-
arr2 = dpt.permute_dims(x2, perm2)
184+
arr1 = dpt_ext.permute_dims(x1, perm1)
185+
arr2 = dpt_ext.permute_dims(x2, perm2)
185186
arr1_outer_nd = arr1.ndim - n_axes1
186187
arr2_outer_nd = arr2.ndim - n_axes2
187188
res_shape = arr1.shape[:arr1_outer_nd] + arr2.shape[n_axes2:]
@@ -206,7 +207,7 @@ def tensordot(x1, x2, axes=2):
206207

207208
_manager = SequentialOrderManager[exec_q]
208209
if buf1_dt is None and buf2_dt is None:
209-
out = dpt.empty(
210+
out = dpt_ext.empty(
210211
res_shape,
211212
dtype=res_dt,
212213
usm_type=res_usm_type,
@@ -237,7 +238,7 @@ def tensordot(x1, x2, axes=2):
237238
src=arr2, dst=buf2, sycl_queue=exec_q, depends=dep_evs
238239
)
239240
_manager.add_event_pair(ht_copy_ev, copy_ev)
240-
out = dpt.empty(
241+
out = dpt_ext.empty(
241242
res_shape,
242243
dtype=res_dt,
243244
usm_type=res_usm_type,
@@ -266,7 +267,7 @@ def tensordot(x1, x2, axes=2):
266267
src=arr1, dst=buf1, sycl_queue=exec_q, depends=dep_evs
267268
)
268269
_manager.add_event_pair(ht_copy_ev, copy_ev)
269-
out = dpt.empty(
270+
out = dpt_ext.empty(
270271
res_shape,
271272
dtype=res_dt,
272273
usm_type=res_usm_type,
@@ -299,7 +300,7 @@ def tensordot(x1, x2, axes=2):
299300
src=arr2, dst=buf2, sycl_queue=exec_q, depends=deps_ev
300301
)
301302
_manager.add_event_pair(ht_copy2_ev, copy2_ev)
302-
out = dpt.empty(
303+
out = dpt_ext.empty(
303304
res_shape,
304305
dtype=res_dt,
305306
usm_type=res_usm_type,
@@ -434,12 +435,12 @@ def vecdot(x1, x2, axis=-1):
434435
_manager.add_event_pair(ht_conj_ev, conj_ev)
435436
x1 = x1_tmp
436437
if x1.shape != broadcast_sh:
437-
x1 = dpt.broadcast_to(x1, broadcast_sh)
438+
x1 = dpt_ext.broadcast_to(x1, broadcast_sh)
438439
if x2.shape != broadcast_sh:
439-
x2 = dpt.broadcast_to(x2, broadcast_sh)
440-
x1 = dpt.moveaxis(x1, contracted_axis, -1)
441-
x2 = dpt.moveaxis(x2, contracted_axis, -1)
442-
out = dpt.empty(
440+
x2 = dpt_ext.broadcast_to(x2, broadcast_sh)
441+
x1 = dpt_ext.moveaxis(x1, contracted_axis, -1)
442+
x2 = dpt_ext.moveaxis(x2, contracted_axis, -1)
443+
out = dpt_ext.empty(
443444
res_sh,
444445
dtype=res_dt,
445446
usm_type=res_usm_type,
@@ -459,7 +460,7 @@ def vecdot(x1, x2, axis=-1):
459460
depends=dep_evs,
460461
)
461462
_manager.add_event_pair(ht_dot_ev, dot_ev)
462-
return dpt.reshape(out, res_sh)
463+
return dpt_ext.reshape(out, res_sh)
463464

464465
elif buf1_dt is None:
465466
if x1.dtype.kind == "c":
@@ -477,12 +478,12 @@ def vecdot(x1, x2, axis=-1):
477478
)
478479
_manager.add_event_pair(ht_copy_ev, copy_ev)
479480
if x1.shape != broadcast_sh:
480-
x1 = dpt.broadcast_to(x1, broadcast_sh)
481+
x1 = dpt_ext.broadcast_to(x1, broadcast_sh)
481482
if buf2.shape != broadcast_sh:
482-
buf2 = dpt.broadcast_to(buf2, broadcast_sh)
483-
x1 = dpt.moveaxis(x1, contracted_axis, -1)
484-
buf2 = dpt.moveaxis(buf2, contracted_axis, -1)
485-
out = dpt.empty(
483+
buf2 = dpt_ext.broadcast_to(buf2, broadcast_sh)
484+
x1 = dpt_ext.moveaxis(x1, contracted_axis, -1)
485+
buf2 = dpt_ext.moveaxis(buf2, contracted_axis, -1)
486+
out = dpt_ext.empty(
486487
res_sh,
487488
dtype=res_dt,
488489
usm_type=res_usm_type,
@@ -501,7 +502,7 @@ def vecdot(x1, x2, axis=-1):
501502
depends=[copy_ev],
502503
)
503504
_manager.add_event_pair(ht_dot_ev, dot_ev)
504-
return dpt.reshape(out, res_sh)
505+
return dpt_ext.reshape(out, res_sh)
505506

506507
elif buf2_dt is None:
507508
buf1 = _empty_like_orderK(x1, buf1_dt)
@@ -516,12 +517,12 @@ def vecdot(x1, x2, axis=-1):
516517
)
517518
_manager.add_event_pair(ht_conj_ev, conj_ev)
518519
if buf1.shape != broadcast_sh:
519-
buf1 = dpt.broadcast_to(buf1, broadcast_sh)
520+
buf1 = dpt_ext.broadcast_to(buf1, broadcast_sh)
520521
if x2.shape != broadcast_sh:
521-
x2 = dpt.broadcast_to(x2, broadcast_sh)
522-
buf1 = dpt.moveaxis(buf1, contracted_axis, -1)
523-
x2 = dpt.moveaxis(x2, contracted_axis, -1)
524-
out = dpt.empty(
522+
x2 = dpt_ext.broadcast_to(x2, broadcast_sh)
523+
buf1 = dpt_ext.moveaxis(buf1, contracted_axis, -1)
524+
x2 = dpt_ext.moveaxis(x2, contracted_axis, -1)
525+
out = dpt_ext.empty(
525526
res_sh,
526527
dtype=res_dt,
527528
usm_type=res_usm_type,
@@ -541,7 +542,7 @@ def vecdot(x1, x2, axis=-1):
541542
depends=deps_ev,
542543
)
543544
_manager.add_event_pair(ht_dot_ev, dot_ev)
544-
return dpt.reshape(out, res_sh)
545+
return dpt_ext.reshape(out, res_sh)
545546

546547
buf1 = _empty_like_orderK(x1, buf1_dt)
547548
deps_ev = _manager.submitted_events
@@ -560,12 +561,12 @@ def vecdot(x1, x2, axis=-1):
560561
)
561562
_manager.add_event_pair(ht_copy2_ev, copy2_ev)
562563
if buf1.shape != broadcast_sh:
563-
buf1 = dpt.broadcast_to(buf1, broadcast_sh)
564+
buf1 = dpt_ext.broadcast_to(buf1, broadcast_sh)
564565
if buf2.shape != broadcast_sh:
565-
buf2 = dpt.broadcast_to(buf2, broadcast_sh)
566-
buf1 = dpt.moveaxis(buf1, contracted_axis, -1)
567-
buf2 = dpt.moveaxis(buf2, contracted_axis, -1)
568-
out = dpt.empty(
566+
buf2 = dpt_ext.broadcast_to(buf2, broadcast_sh)
567+
buf1 = dpt_ext.moveaxis(buf1, contracted_axis, -1)
568+
buf2 = dpt_ext.moveaxis(buf2, contracted_axis, -1)
569+
out = dpt_ext.empty(
569570
res_sh,
570571
dtype=res_dt,
571572
usm_type=res_usm_type,
@@ -732,7 +733,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
732733
res_dt = _to_device_supported_dtype(res_dt, sycl_dev)
733734
buf1_dt, buf2_dt = None, None
734735
if x1_dtype != res_dt:
735-
if dpt.can_cast(x1_dtype, res_dt, casting="same_kind"):
736+
if dpt_ext.can_cast(x1_dtype, res_dt, casting="same_kind"):
736737
buf1_dt = res_dt
737738
else:
738739
raise ValueError(
@@ -742,7 +743,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
742743
"''same_kind''."
743744
)
744745
if x2_dtype != res_dt:
745-
if dpt.can_cast(x2_dtype, res_dt, casting="same_kind"):
746+
if dpt_ext.can_cast(x2_dtype, res_dt, casting="same_kind"):
746747
buf2_dt = res_dt
747748
else:
748749
raise ValueError(
@@ -774,7 +775,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
774775
)
775776

776777
if appended_axes:
777-
out = dpt.expand_dims(out, axis=appended_axes)
778+
out = dpt_ext.expand_dims(out, axis=appended_axes)
778779
orig_out = out
779780

780781
if res_dt != out.dtype:
@@ -788,12 +789,12 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
788789
)
789790

790791
if ti._array_overlap(x1, out) and buf1_dt is None:
791-
out = dpt.empty_like(out)
792+
out = dpt_ext.empty_like(out)
792793

793794
if ti._array_overlap(x2, out) and buf2_dt is None:
794795
# should not reach if out is reallocated
795796
# after being checked against x1
796-
out = dpt.empty_like(out)
797+
out = dpt_ext.empty_like(out)
797798

798799
if order == "A":
799800
order = (
@@ -816,17 +817,17 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
816817
x1, x2, res_dt, res_shape, res_usm_type, exec_q
817818
)
818819
else:
819-
out = dpt.empty(
820+
out = dpt_ext.empty(
820821
res_shape,
821822
dtype=res_dt,
822823
usm_type=res_usm_type,
823824
sycl_queue=exec_q,
824825
order=order,
825826
)
826827
if x1.shape != x1_broadcast_shape:
827-
x1 = dpt.broadcast_to(x1, x1_broadcast_shape)
828+
x1 = dpt_ext.broadcast_to(x1, x1_broadcast_shape)
828829
if x2.shape != x2_broadcast_shape:
829-
x2 = dpt.broadcast_to(x2, x2_broadcast_shape)
830+
x2 = dpt_ext.broadcast_to(x2, x2_broadcast_shape)
830831
deps_evs = _manager.submitted_events
831832
ht_dot_ev, dot_ev = tli._dot(
832833
x1=x1,
@@ -851,13 +852,13 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
851852
_manager.add_event_pair(ht_copy_out_ev, cpy_ev)
852853
out = orig_out
853854
if appended_axes:
854-
out = dpt.squeeze(out, tuple(appended_axes))
855+
out = dpt_ext.squeeze(out, tuple(appended_axes))
855856
return out
856857
elif buf1_dt is None:
857858
if order == "K":
858859
buf2 = _empty_like_orderK(x2, buf2_dt)
859860
else:
860-
buf2 = dpt.empty_like(x2, dtype=buf2_dt, order=order)
861+
buf2 = dpt_ext.empty_like(x2, dtype=buf2_dt, order=order)
861862
deps_evs = _manager.submitted_events
862863
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
863864
src=x2, dst=buf2, sycl_queue=exec_q, depends=deps_evs
@@ -869,7 +870,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
869870
x1, buf2, res_dt, res_shape, res_usm_type, exec_q
870871
)
871872
else:
872-
out = dpt.empty(
873+
out = dpt_ext.empty(
873874
res_shape,
874875
dtype=res_dt,
875876
usm_type=res_usm_type,
@@ -878,9 +879,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
878879
)
879880

880881
if x1.shape != x1_broadcast_shape:
881-
x1 = dpt.broadcast_to(x1, x1_broadcast_shape)
882+
x1 = dpt_ext.broadcast_to(x1, x1_broadcast_shape)
882883
if buf2.shape != x2_broadcast_shape:
883-
buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape)
884+
buf2 = dpt_ext.broadcast_to(buf2, x2_broadcast_shape)
884885
ht_dot_ev, dot_ev = tli._dot(
885886
x1=x1,
886887
x2=buf2,
@@ -904,14 +905,14 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
904905
_manager.add_event_pair(ht_copy_out_ev, cpy_ev)
905906
out = orig_out
906907
if appended_axes:
907-
out = dpt.squeeze(out, tuple(appended_axes))
908+
out = dpt_ext.squeeze(out, tuple(appended_axes))
908909
return out
909910

910911
elif buf2_dt is None:
911912
if order == "K":
912913
buf1 = _empty_like_orderK(x1, buf1_dt)
913914
else:
914-
buf1 = dpt.empty_like(x1, dtype=buf1_dt, order=order)
915+
buf1 = dpt_ext.empty_like(x1, dtype=buf1_dt, order=order)
915916
deps_ev = _manager.submitted_events
916917
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
917918
src=x1, dst=buf1, sycl_queue=exec_q, depends=deps_ev
@@ -923,7 +924,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
923924
buf1, x2, res_dt, res_shape, res_usm_type, exec_q
924925
)
925926
else:
926-
out = dpt.empty(
927+
out = dpt_ext.empty(
927928
res_shape,
928929
dtype=res_dt,
929930
usm_type=res_usm_type,
@@ -932,9 +933,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
932933
)
933934

934935
if buf1.shape != x1_broadcast_shape:
935-
buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape)
936+
buf1 = dpt_ext.broadcast_to(buf1, x1_broadcast_shape)
936937
if x2.shape != x2_broadcast_shape:
937-
x2 = dpt.broadcast_to(x2, x2_broadcast_shape)
938+
x2 = dpt_ext.broadcast_to(x2, x2_broadcast_shape)
938939
ht_dot_ev, dot_ev = tli._dot(
939940
x1=buf1,
940941
x2=x2,
@@ -958,7 +959,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
958959
_manager.add_event_pair(ht_copy_out_ev, cpy_ev)
959960
out = orig_out
960961
if appended_axes:
961-
out = dpt.squeeze(out, tuple(appended_axes))
962+
out = dpt_ext.squeeze(out, tuple(appended_axes))
962963
return out
963964

964965
if order == "K":
@@ -969,7 +970,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
969970
if order == "K":
970971
buf1 = _empty_like_orderK(x1, buf1_dt)
971972
else:
972-
buf1 = dpt.empty_like(x1, dtype=buf1_dt, order=order)
973+
buf1 = dpt_ext.empty_like(x1, dtype=buf1_dt, order=order)
973974
deps_ev = _manager.submitted_events
974975
ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
975976
src=x1, dst=buf1, sycl_queue=exec_q, depends=deps_ev
@@ -978,7 +979,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
978979
if order == "K":
979980
buf2 = _empty_like_orderK(x2, buf2_dt)
980981
else:
981-
buf2 = dpt.empty_like(x2, dtype=buf2_dt, order=order)
982+
buf2 = dpt_ext.empty_like(x2, dtype=buf2_dt, order=order)
982983
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
983984
src=x2, dst=buf2, sycl_queue=exec_q, depends=deps_ev
984985
)
@@ -989,7 +990,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
989990
buf1, buf2, res_dt, res_shape, res_usm_type, exec_q
990991
)
991992
else:
992-
out = dpt.empty(
993+
out = dpt_ext.empty(
993994
res_shape,
994995
dtype=res_dt,
995996
usm_type=res_usm_type,
@@ -998,9 +999,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
998999
)
9991000

10001001
if buf1.shape != x1_broadcast_shape:
1001-
buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape)
1002+
buf1 = dpt_ext.broadcast_to(buf1, x1_broadcast_shape)
10021003
if buf2.shape != x2_broadcast_shape:
1003-
buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape)
1004+
buf2 = dpt_ext.broadcast_to(buf2, x2_broadcast_shape)
10041005
ht_, dot_ev = tli._dot(
10051006
x1=buf1,
10061007
x2=buf2,
@@ -1014,5 +1015,5 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
10141015
)
10151016
_manager.add_event_pair(ht_, dot_ev)
10161017
if appended_axes:
1017-
out = dpt.squeeze(out, tuple(appended_axes))
1018+
out = dpt_ext.squeeze(out, tuple(appended_axes))
10181019
return out

dpctl_ext/tensor/libtensor/include/kernels/accumulators.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
//===---------------------------------------------------------------------===//
3434

3535
#pragma once
36+
37+
#include <algorithm>
3638
#include <array>
3739
#include <cstddef>
3840
#include <cstdint>

dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
#include "kernels/dpctl_tensor_types.hpp"
4848
#include "kernels/elementwise_functions/common.hpp"
4949

50-
#include "utils/offset_utils.hpp"
5150
#include "utils/type_dispatch_building.hpp"
5251
#include "utils/type_utils.hpp"
5352

dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
#include "kernels/dpctl_tensor_types.hpp"
4848
#include "kernels/elementwise_functions/common.hpp"
4949

50-
#include "utils/offset_utils.hpp"
5150
#include "utils/type_dispatch_building.hpp"
5251
#include "utils/type_utils.hpp"
5352

dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
#include "kernels/dpctl_tensor_types.hpp"
5050
#include "kernels/elementwise_functions/common.hpp"
5151

52-
#include "utils/offset_utils.hpp"
5352
#include "utils/type_dispatch_building.hpp"
5453
#include "utils/type_utils.hpp"
5554

dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
/// This file defines kernels for elementwise evaluation of ANGLE(x) function.
3333
//===---------------------------------------------------------------------===//
3434

35-
#include <cmath>
3635
#include <complex>
3736
#include <cstddef>
3837
#include <cstdint>
@@ -47,7 +46,6 @@
4746
#include "kernels/dpctl_tensor_types.hpp"
4847
#include "kernels/elementwise_functions/common.hpp"
4948

50-
#include "utils/offset_utils.hpp"
5149
#include "utils/type_dispatch_building.hpp"
5250
#include "utils/type_utils.hpp"
5351

0 commit comments

Comments
 (0)