Skip to content

Commit fefaa17

Browse files
Move ti.less()/less_equal() and reuse them
1 parent c73df9c commit fefaa17

File tree

11 files changed

+1098
-10
lines changed

11 files changed

+1098
-10
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ set(_elementwise_sources
107107
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isfinite.cpp
108108
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isinf.cpp
109109
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isnan.cpp
110-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/less_equal.cpp
111-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/less.cpp
110+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/less_equal.cpp
111+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/less.cpp
112112
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log.cpp
113113
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log1p.cpp
114114
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log2.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@
125125
isfinite,
126126
isinf,
127127
isnan,
128+
less,
129+
less_equal,
128130
log,
129131
log1p,
130132
log2,
@@ -238,6 +240,8 @@
238240
"isdtype",
239241
"isin",
240242
"isnan",
243+
"less",
244+
"less_equal",
241245
"linspace",
242246
"log",
243247
"logical_not",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,77 @@
10211021
)
10221022
del _isnan_docstring_
10231023

1024+
# B13: ==== LESS (x1, x2)
1025+
_less_docstring_ = r"""
1026+
less(x1, x2, /, \*, out=None, order='K')
1027+
1028+
Computes the less-than test results for each element `x1_i` of
1029+
the input array `x1` with the respective element `x2_i` of the input array `x2`.
1030+
1031+
Args:
1032+
x1 (usm_ndarray):
1033+
First input array. May have any data type.
1034+
x2 (usm_ndarray):
1035+
Second input array. May have any data type.
1036+
out (Union[usm_ndarray, None], optional):
1037+
Output array to populate.
1038+
Array must have the correct shape and the expected data type.
1039+
order ("C","F","A","K", optional):
1040+
Memory layout of the new output array, if parameter
1041+
`out` is ``None``.
1042+
Default: "K".
1043+
1044+
Returns:
1045+
usm_ndarray:
1046+
An array containing the result of element-wise less-than comparison.
1047+
The returned array has a data type of `bool`.
1048+
"""
1049+
1050+
less = BinaryElementwiseFunc(
1051+
"less",
1052+
ti._less_result_type,
1053+
ti._less,
1054+
_less_docstring_,
1055+
weak_type_resolver=_resolve_weak_types_all_py_ints,
1056+
)
1057+
del _less_docstring_
1058+
1059+
1060+
# B14: ==== LESS_EQUAL (x1, x2)
1061+
_less_equal_docstring_ = r"""
1062+
less_equal(x1, x2, /, \*, out=None, order='K')
1063+
1064+
Computes the less-than or equal-to test results for each element `x1_i` of
1065+
the input array `x1` with the respective element `x2_i` of the input array `x2`.
1066+
1067+
Args:
1068+
x1 (usm_ndarray):
1069+
First input array. May have any data type.
1070+
x2 (usm_ndarray):
1071+
Second input array. May have any data type.
1072+
out (Union[usm_ndarray, None], optional):
1073+
Output array to populate.
1074+
Array must have the correct shape and the expected data type.
1075+
order ("C","F","A","K", optional):
1076+
Memory layout of the new output array, if parameter
1077+
`out` is ``None``.
1078+
Default: "K".
1079+
1080+
Returns:
1081+
usm_ndarray:
1082+
An array containing the result of element-wise less-than or equal-to
1083+
comparison. The returned array has a data type of `bool`.
1084+
"""
1085+
1086+
less_equal = BinaryElementwiseFunc(
1087+
"less_equal",
1088+
ti._less_equal_result_type,
1089+
ti._less_equal,
1090+
_less_equal_docstring_,
1091+
weak_type_resolver=_resolve_weak_types_all_py_ints,
1092+
)
1093+
del _less_equal_docstring_
1094+
10241095
# U20: ==== LOG (x)
10251096
_log_docstring = r"""
10261097
log(x, /, \*, out=None, order='K')

0 commit comments

Comments
 (0)