Skip to content

Commit 0edd3b1

Browse files
Extend ._tensor_impl with where(), clip() and type utils functions (#2778)
This PR extends `_tensor_impl` in `dpctl_ext.tensor` with the `_where, _clip` and repeat functions (`_repeat_by_sequence, _repeat_by_scalar`) It also adds `repeat(), where(), clip()` and `can_cast, finfo, iinfo, isdtype, result_type` from `_type_utils.py` `to dpctl_ext.tensor and updates the corresponding dpnp functions to use these implementations internally
1 parent ecd4991 commit 0edd3b1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+5517
-142
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ set(_tensor_impl_sources
5858
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
5959
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/zeros_ctor.cpp
6060
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
61-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
61+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
6262
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
63-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
64-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
63+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
64+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
6565
)
6666

6767
set(_static_lib_trgt simplify_iteration_space)
@@ -92,10 +92,10 @@ endif()
9292

9393
set(_no_fast_math_sources
9494
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
95-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
95+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
9696
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
97-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
98-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
97+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
98+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
9999
)
100100
#list(
101101
#APPEND _no_fast_math_sources

dpctl_ext/tensor/__init__.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,21 @@
2727
# *****************************************************************************
2828

2929

30-
from dpctl_ext.tensor._copy_utils import (
30+
from ._clip import clip
31+
from ._copy_utils import (
3132
asnumpy,
3233
astype,
3334
copy,
3435
from_numpy,
3536
to_numpy,
3637
)
37-
from dpctl_ext.tensor._ctors import (
38+
from ._ctors import (
3839
eye,
3940
full,
4041
tril,
4142
triu,
4243
)
43-
from dpctl_ext.tensor._indexing_functions import (
44+
from ._indexing_functions import (
4445
extract,
4546
nonzero,
4647
place,
@@ -49,28 +50,39 @@
4950
take,
5051
take_along_axis,
5152
)
52-
from dpctl_ext.tensor._manipulation_functions import (
53+
from ._manipulation_functions import (
54+
repeat,
5355
roll,
5456
)
55-
from dpctl_ext.tensor._reshape import reshape
57+
from ._reshape import reshape
58+
from ._search_functions import where
59+
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
5660

5761
__all__ = [
5862
"asnumpy",
5963
"astype",
64+
"can_cast",
6065
"copy",
66+
"clip",
6167
"extract",
6268
"eye",
69+
"finfo",
6370
"from_numpy",
6471
"full",
72+
"iinfo",
73+
"isdtype",
6574
"nonzero",
6675
"place",
6776
"put",
6877
"put_along_axis",
78+
"repeat",
6979
"reshape",
80+
"result_type",
7081
"roll",
7182
"take",
7283
"take_along_axis",
7384
"to_numpy",
7485
"tril",
7586
"triu",
87+
"where",
7688
]

0 commit comments

Comments
 (0)