Skip to content

Commit cab0b36

Browse files
Move ti.copysign()/remainder()/subtract() and reuse them
1 parent 2d5d2ea commit cab0b36

File tree

14 files changed

+2367
-59
lines changed

14 files changed

+2367
-59
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ set(_elementwise_sources
9191
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cbrt.cpp
9292
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/ceil.cpp
9393
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/conj.cpp
94-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/copysign.cpp
94+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/copysign.cpp
9595
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cos.cpp
9696
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cosh.cpp
9797
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/equal.cpp
@@ -129,7 +129,7 @@ set(_elementwise_sources
129129
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/proj.cpp
130130
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/real.cpp
131131
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/reciprocal.cpp
132-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/remainder.cpp
132+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/remainder.cpp
133133
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/round.cpp
134134
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/rsqrt.cpp
135135
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sign.cpp
@@ -138,7 +138,7 @@ set(_elementwise_sources
138138
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sinh.cpp
139139
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sqrt.cpp
140140
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/square.cpp
141-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/subtract.cpp
141+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/subtract.cpp
142142
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tan.cpp
143143
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tanh.cpp
144144
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/true_divide.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
cbrt,
110110
ceil,
111111
conj,
112+
copysign,
112113
cos,
113114
cosh,
114115
divide,
@@ -146,6 +147,7 @@
146147
proj,
147148
real,
148149
reciprocal,
150+
remainder,
149151
round,
150152
rsqrt,
151153
sign,
@@ -154,6 +156,7 @@
154156
sinh,
155157
sqrt,
156158
square,
159+
subtract,
157160
tan,
158161
tanh,
159162
trunc,
@@ -214,6 +217,7 @@
214217
"concat",
215218
"conj",
216219
"copy",
220+
"copysign",
217221
"cos",
218222
"cosh",
219223
"count_nonzero",
@@ -287,6 +291,7 @@
287291
"real",
288292
"reciprocal",
289293
"reduce_hypot",
294+
"remainder",
290295
"repeat",
291296
"reshape",
292297
"result_type",
@@ -303,6 +308,7 @@
303308
"square",
304309
"squeeze",
305310
"stack",
311+
"subtract",
306312
"sum",
307313
"swapaxes",
308314
"take",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_acceptance_fn_divide,
3636
_acceptance_fn_negative,
3737
_acceptance_fn_reciprocal,
38+
_acceptance_fn_subtract,
3839
_resolve_weak_types_all_py_ints,
3940
)
4041

@@ -1660,6 +1661,43 @@
16601661
)
16611662
del _real_docstring
16621663

1664+
# B22: ==== REMAINDER (x1, x2)
1665+
_remainder_docstring_ = r"""
1666+
remainder(x1, x2, /, \*, out=None, order='K')
1667+
1668+
Calculates the remainder of division for each element `x1_i` of the input array
1669+
`x1` with the respective element `x2_i` of the input array `x2`.
1670+
1671+
This function is equivalent to the Python modulus operator.
1672+
1673+
Args:
1674+
x1 (usm_ndarray):
1675+
First input array, expected to have a real-valued data type.
1676+
x2 (usm_ndarray):
1677+
Second input array, also expected to have a real-valued data type.
1678+
out (Union[usm_ndarray, None], optional):
1679+
Output array to populate.
1680+
Array must have the correct shape and the expected data type.
1681+
order ("C","F","A","K", optional):
1682+
Memory layout of the new output array, if parameter
1683+
`out` is ``None``.
1684+
Default: "K".
1685+
1686+
Returns:
1687+
usm_ndarray:
1688+
An array containing the element-wise remainders. Each remainder has the
1689+
same sign as respective element `x2_i`. The data type of the returned
1690+
array is determined by the Type Promotion Rules.
1691+
"""
1692+
remainder = BinaryElementwiseFunc(
1693+
"remainder",
1694+
ti._remainder_result_type,
1695+
ti._remainder,
1696+
_remainder_docstring_,
1697+
binary_inplace_fn=ti._remainder_inplace,
1698+
)
1699+
del _remainder_docstring_
1700+
16631701
# U28: ==== ROUND (x)
16641702
_round_docstring = r"""
16651703
round(x, /, \*, out=None, order='K')
@@ -1835,6 +1873,41 @@
18351873
)
18361874
del _sqrt_docstring_
18371875

1876+
# B23: ==== SUBTRACT (x1, x2)
1877+
_subtract_docstring_ = r"""
1878+
subtract(x1, x2, /, \*, out=None, order='K')
1879+
1880+
Calculates the difference between each element `x1_i` of the input
1881+
array `x1` and the respective element `x2_i` of the input array `x2`.
1882+
1883+
Args:
1884+
x1 (usm_ndarray):
1885+
First input array, expected to have a numeric data type.
1886+
x2 (usm_ndarray):
1887+
Second input array, also expected to have a numeric data type.
1888+
out (Union[usm_ndarray, None], optional):
1889+
Output array to populate.
1890+
Array must have the correct shape and the expected data type.
1891+
order ("C","F","A","K", optional):
1892+
Memory layout of the new output array, if parameter
1893+
`out` is ``None``.
1894+
Default: "K".
1895+
1896+
Returns:
1897+
usm_ndarray:
1898+
An array containing the element-wise differences. The data type
1899+
of the returned array is determined by the Type Promotion Rules.
1900+
"""
1901+
subtract = BinaryElementwiseFunc(
1902+
"subtract",
1903+
ti._subtract_result_type,
1904+
ti._subtract,
1905+
_subtract_docstring_,
1906+
binary_inplace_fn=ti._subtract_inplace,
1907+
acceptance_fn=_acceptance_fn_subtract,
1908+
)
1909+
del _subtract_docstring_
1910+
18381911
# U34: ==== TAN (x)
18391912
_tan_docstring = r"""
18401913
tan(x, /, \*, out=None, order='K')
@@ -2011,6 +2084,41 @@
20112084
)
20122085
del _exp2_docstring_
20132086

2087+
# B25: ==== COPYSIGN (x1, x2)
2088+
_copysign_docstring_ = r"""
2089+
copysign(x1, x2, /, \*, out=None, order='K')
2090+
2091+
Composes a floating-point value with the magnitude of `x1_i` and the sign of
2092+
`x2_i` for each element of input arrays `x1` and `x2`.
2093+
2094+
Args:
2095+
x1 (usm_ndarray):
2096+
First input array, expected to have a real-valued floating-point data
2097+
type.
2098+
x2 (usm_ndarray):
2099+
Second input array, also expected to have a real-valued floating-point
2100+
data type.
2101+
out (Union[usm_ndarray, None], optional):
2102+
Output array to populate.
2103+
Array have the correct shape and the expected data type.
2104+
order ("C","F","A","K", optional):
2105+
Memory layout of the new output array, if parameter
2106+
`out` is ``None``.
2107+
Default: "K".
2108+
2109+
Returns:
2110+
usm_ndarray:
2111+
An array containing the element-wise results. The data type
2112+
of the returned array is determined by the Type Promotion Rules.
2113+
"""
2114+
copysign = BinaryElementwiseFunc(
2115+
"copysign",
2116+
ti._copysign_result_type,
2117+
ti._copysign,
2118+
_copysign_docstring_,
2119+
)
2120+
del _copysign_docstring_
2121+
20142122
# U39: ==== RSQRT (x)
20152123
_rsqrt_docstring_ = r"""
20162124
rsqrt(x, /, \*, out=None, order='K')

0 commit comments

Comments
 (0)