Skip to content

Commit 0cdc3e4

Browse files
Move ti.cbrt()/exp2()/reciprocal()/rsqrt()/trunc() and reuse them
1 parent a1707b2 commit 0cdc3e4

File tree

21 files changed

+2191
-25
lines changed

21 files changed

+2191
-25
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ set(_elementwise_sources
8888
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_or.cpp
8989
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_right_shift.cpp
9090
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_xor.cpp
91-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cbrt.cpp
91+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cbrt.cpp
9292
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/ceil.cpp
9393
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/conj.cpp
9494
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/copysign.cpp
9595
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cos.cpp
9696
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cosh.cpp
9797
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/equal.cpp
9898
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp.cpp
99-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp2.cpp
99+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp2.cpp
100100
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/expm1.cpp
101101
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/floor_divide.cpp
102102
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/floor.cpp
@@ -128,10 +128,10 @@ set(_elementwise_sources
128128
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/pow.cpp
129129
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/proj.cpp
130130
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/real.cpp
131-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/reciprocal.cpp
131+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/reciprocal.cpp
132132
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/remainder.cpp
133133
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/round.cpp
134-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/rsqrt.cpp
134+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/rsqrt.cpp
135135
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sign.cpp
136136
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/signbit.cpp
137137
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sin.cpp
@@ -142,7 +142,7 @@ set(_elementwise_sources
142142
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tan.cpp
143143
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tanh.cpp
144144
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/true_divide.cpp
145-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/trunc.cpp
145+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/trunc.cpp
146146
)
147147
set(_reduction_sources
148148
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduction_common.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,13 @@
9393
atan,
9494
atanh,
9595
bitwise_invert,
96+
cbrt,
9697
ceil,
9798
conj,
9899
cos,
99100
cosh,
100101
exp,
102+
exp2,
101103
expm1,
102104
floor,
103105
imag,
@@ -113,7 +115,9 @@
113115
positive,
114116
proj,
115117
real,
118+
reciprocal,
116119
round,
120+
rsqrt,
117121
sign,
118122
signbit,
119123
sin,
@@ -122,6 +126,7 @@
122126
square,
123127
tan,
124128
tanh,
129+
trunc,
125130
)
126131
from ._reduction import (
127132
argmax,
@@ -167,6 +172,7 @@
167172
"broadcast_arrays",
168173
"broadcast_to",
169174
"can_cast",
175+
"cbrt",
170176
"ceil",
171177
"concat",
172178
"conj",
@@ -185,6 +191,7 @@
185191
"expand_dims",
186192
"eye",
187193
"exp",
194+
"exp2",
188195
"expm1",
189196
"finfo",
190197
"flip",
@@ -222,12 +229,14 @@
222229
"put",
223230
"put_along_axis",
224231
"real",
232+
"reciprocal",
225233
"reduce_hypot",
226234
"repeat",
227235
"reshape",
228236
"result_type",
229237
"roll",
230238
"round",
239+
"rsqrt",
231240
"searchsorted",
232241
"sign",
233242
"signbit",
@@ -249,6 +258,7 @@
249258
"to_numpy",
250259
"tril",
251260
"triu",
261+
"trunc",
252262
"unique_all",
253263
"unique_counts",
254264
"unique_inverse",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ._elementwise_common import UnaryElementwiseFunc
3434
from ._type_utils import (
3535
_acceptance_fn_negative,
36+
_acceptance_fn_reciprocal,
3637
)
3738

3839
# U01: ==== ABS (x)
@@ -1042,6 +1043,124 @@
10421043
)
10431044
del _tanh_docstring
10441045

1046+
# U36: ==== TRUNC (x)
1047+
_trunc_docstring = r"""
1048+
trunc(x, /, \*, out=None, order='K')
1049+
1050+
Returns the truncated value for each element `x_i` for input array `x`.
1051+
1052+
The truncated value of the scalar `x` is the nearest integer i which is
1053+
closer to zero than `x` is. In short, the fractional part of the
1054+
signed number `x` is discarded.
1055+
1056+
Args:
1057+
x (usm_ndarray):
1058+
Input array, expected to have a boolean or real-valued data type.
1059+
out (Union[usm_ndarray, None], optional):
1060+
Output array to populate.
1061+
Array must have the correct shape and the expected data type.
1062+
order ("C","F","A","K", optional):
1063+
Memory layout of the new output array, if parameter
1064+
`out` is ``None``.
1065+
Default: "K".
1066+
1067+
Returns:
1068+
usm_ndarray:
1069+
An array containing the result of element-wise division. The data type
1070+
of the returned array is determined by the Type Promotion Rules.
1071+
"""
1072+
trunc = UnaryElementwiseFunc(
1073+
"trunc", ti._trunc_result_type, ti._trunc, _trunc_docstring
1074+
)
1075+
del _trunc_docstring
1076+
1077+
# U37: ==== CBRT (x)
1078+
_cbrt_docstring_ = r"""
1079+
cbrt(x, /, \*, out=None, order='K')
1080+
1081+
Computes the cube-root for each element `x_i` for input array `x`.
1082+
1083+
Args:
1084+
x (usm_ndarray):
1085+
Input array, expected to have a real-valued floating-point data type.
1086+
out (Union[usm_ndarray, None], optional):
1087+
Output array to populate.
1088+
Array have the correct shape and the expected data type.
1089+
order ("C","F","A","K", optional):
1090+
Memory layout of the new output array, if parameter
1091+
`out` is ``None``.
1092+
Default: "K".
1093+
1094+
Returns:
1095+
usm_ndarray:
1096+
An array containing the element-wise cube-root.
1097+
The data type of the returned array is determined by
1098+
the Type Promotion Rules.
1099+
"""
1100+
1101+
cbrt = UnaryElementwiseFunc(
1102+
"cbrt", ti._cbrt_result_type, ti._cbrt, _cbrt_docstring_
1103+
)
1104+
del _cbrt_docstring_
1105+
1106+
# U38: ==== EXP2 (x)
1107+
_exp2_docstring_ = r"""
1108+
exp2(x, /, \*, out=None, order='K')
1109+
1110+
Computes the base-2 exponential for each element `x_i` for input array `x`.
1111+
1112+
Args:
1113+
x (usm_ndarray):
1114+
Input array, expected to have a floating-point data type.
1115+
out (Union[usm_ndarray, None], optional):
1116+
Output array to populate.
1117+
Array have the correct shape and the expected data type.
1118+
order ("C","F","A","K", optional):
1119+
Memory layout of the new output array, if parameter
1120+
`out` is ``None``.
1121+
Default: "K".
1122+
1123+
Returns:
1124+
usm_ndarray:
1125+
An array containing the element-wise base-2 exponentials.
1126+
The data type of the returned array is determined by
1127+
the Type Promotion Rules.
1128+
"""
1129+
1130+
exp2 = UnaryElementwiseFunc(
1131+
"exp2", ti._exp2_result_type, ti._exp2, _exp2_docstring_
1132+
)
1133+
del _exp2_docstring_
1134+
1135+
# U39: ==== RSQRT (x)
1136+
_rsqrt_docstring_ = r"""
1137+
rsqrt(x, /, \*, out=None, order='K')
1138+
1139+
Computes the reciprocal square-root for each element `x_i` for input array `x`.
1140+
1141+
Args:
1142+
x (usm_ndarray):
1143+
Input array, expected to have a real-valued floating-point data type.
1144+
out (Union[usm_ndarray, None], optional):
1145+
Output array to populate.
1146+
Array have the correct shape and the expected data type.
1147+
order ("C","F","A","K", optional):
1148+
Memory layout of the new output array, if parameter
1149+
`out` is ``None``.
1150+
Default: "K".
1151+
1152+
Returns:
1153+
usm_ndarray:
1154+
An array containing the element-wise reciprocal square-root.
1155+
The returned array has a floating-point data type determined by
1156+
the Type Promotion Rules.
1157+
"""
1158+
1159+
rsqrt = UnaryElementwiseFunc(
1160+
"rsqrt", ti._rsqrt_result_type, ti._rsqrt, _rsqrt_docstring_
1161+
)
1162+
del _rsqrt_docstring_
1163+
10451164
# U40: ==== PROJ (x)
10461165
_proj_docstring = r"""
10471166
proj(x, /, \*, out=None, order='K')
@@ -1098,6 +1217,39 @@
10981217
)
10991218
del _signbit_docstring
11001219

1220+
# U42: ==== RECIPROCAL (x)
1221+
_reciprocal_docstring = r"""
1222+
reciprocal(x, /, \*, out=None, order='K')
1223+
1224+
Computes the reciprocal of each element `x_i` for input array `x`.
1225+
1226+
Args:
1227+
x (usm_ndarray):
1228+
Input array, expected to have a floating-point data type.
1229+
out (Union[usm_ndarray, None], optional):
1230+
Output array to populate.
1231+
Array have the correct shape and the expected data type.
1232+
order ("C","F","A","K", optional):
1233+
Memory layout of the new output array, if parameter
1234+
`out` is ``None``.
1235+
Default: "K".
1236+
1237+
Returns:
1238+
usm_ndarray:
1239+
An array containing the element-wise reciprocals.
1240+
The returned array has a floating-point data type determined
1241+
by the Type Promotion Rules.
1242+
"""
1243+
1244+
reciprocal = UnaryElementwiseFunc(
1245+
"reciprocal",
1246+
ti._reciprocal_result_type,
1247+
ti._reciprocal,
1248+
_reciprocal_docstring,
1249+
acceptance_fn=_acceptance_fn_reciprocal,
1250+
)
1251+
del _reciprocal_docstring
1252+
11011253
# U43: ==== ANGLE (x)
11021254
_angle_docstring = r"""
11031255
angle(x, /, \*, out=None, order='K')

0 commit comments

Comments
 (0)