Skip to content

Commit 2d5d2ea

Browse files
Move ti.nextafter()/not_equal()/pow() and reuse them
1 parent 87f5529 commit 2d5d2ea

File tree

15 files changed

+1934
-41
lines changed

15 files changed

+1934
-41
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ set(_elementwise_sources
122122
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/minimum.cpp
123123
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/multiply.cpp
124124
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/negative.cpp
125-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/nextafter.cpp
126-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/not_equal.cpp
125+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/nextafter.cpp
126+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/not_equal.cpp
127127
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/positive.cpp
128-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/pow.cpp
128+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/pow.cpp
129129
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/proj.cpp
130130
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/real.cpp
131131
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/reciprocal.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@
139139
minimum,
140140
multiply,
141141
negative,
142+
nextafter,
143+
not_equal,
142144
positive,
145+
pow,
143146
proj,
144147
real,
145148
reciprocal,
@@ -269,11 +272,14 @@
269272
"matmul",
270273
"matrix_transpose",
271274
"negative",
275+
"nextafter",
272276
"nonzero",
277+
"not_equal",
273278
"ones",
274279
"ones_like",
275280
"place",
276281
"positive",
282+
"pow",
277283
"prod",
278284
"proj",
279285
"put",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,77 @@
14981498
)
14991499
del _negative_docstring_
15001500

1501+
# B28: ==== NEXTAFTER (x1, x2)
1502+
_nextafter_docstring_ = r"""
1503+
nextafter(x1, x2, /, \*, out=None, order='K')
1504+
1505+
Calculates the next floating-point value after element `x1_i` of the input
1506+
array `x1` toward the respective element `x2_i` of the input array `x2`.
1507+
1508+
Args:
1509+
x1 (usm_ndarray):
1510+
First input array, expected to have a real-valued floating-point data
1511+
type.
1512+
x2 (usm_ndarray):
1513+
Second input array, expected to have a real-valued floating-point data
1514+
type.
1515+
out (Union[usm_ndarray, None], optional):
1516+
Output array to populate.
1517+
Array must have the correct shape and the expected data type.
1518+
order ("C","F","A","K", optional):
1519+
Memory layout of the new output array, if parameter
1520+
`out` is ``None``.
1521+
Default: "K".
1522+
1523+
Returns:
1524+
usm_ndarray:
1525+
An array containing the element-wise next representable values of `x1`
1526+
in the direction of `x2`. The data type of the returned array is
1527+
determined by the Type Promotion Rules.
1528+
"""
1529+
nextafter = BinaryElementwiseFunc(
1530+
"nextafter",
1531+
ti._nextafter_result_type,
1532+
ti._nextafter,
1533+
_nextafter_docstring_,
1534+
)
1535+
del _nextafter_docstring_
1536+
1537+
# B20: ==== NOT_EQUAL (x1, x2)
1538+
_not_equal_docstring_ = r"""
1539+
not_equal(x1, x2, /, \*, out=None, order='K')
1540+
1541+
Calculates inequality test results for each element `x1_i` of the
1542+
input array `x1` with the respective element `x2_i` of the input array `x2`.
1543+
1544+
Args:
1545+
x1 (usm_ndarray):
1546+
First input array.
1547+
x2 (usm_ndarray):
1548+
Second input array.
1549+
out (Union[usm_ndarray, None], optional):
1550+
Output array to populate.
1551+
Array must have the correct shape and the expected data type.
1552+
order ("C","F","A","K", optional):
1553+
Memory layout of the new output array, if parameter
1554+
`out` is ``None``.
1555+
Default: "K".
1556+
1557+
Returns:
1558+
usm_ndarray:
1559+
An array containing the result of element-wise inequality comparison.
1560+
The returned array has a data type of `bool`.
1561+
"""
1562+
1563+
not_equal = BinaryElementwiseFunc(
1564+
"not_equal",
1565+
ti._not_equal_result_type,
1566+
ti._not_equal,
1567+
_not_equal_docstring_,
1568+
weak_type_resolver=_resolve_weak_types_all_py_ints,
1569+
)
1570+
del _not_equal_docstring_
1571+
15011572
# U26: ==== POSITIVE (x)
15021573
_positive_docstring_ = r"""
15031574
positive(x, /, \*, out=None, order='K')
@@ -1524,6 +1595,40 @@
15241595
)
15251596
del _positive_docstring_
15261597

1598+
# B21: ==== POW (x1, x2)
1599+
_pow_docstring_ = r"""
1600+
pow(x1, x2, /, \*, out=None, order='K')
1601+
1602+
Calculates `x1_i` raised to `x2_i` for each element `x1_i` of the input array
1603+
`x1` with the respective element `x2_i` of the input array `x2`.
1604+
1605+
Args:
1606+
x1 (usm_ndarray):
1607+
First input array, expected to have a numeric data type.
1608+
x2 (usm_ndarray):
1609+
Second input array, also expected to have a numeric data type.
1610+
out (usm_ndarray):
1611+
Output array to populate. Array must have the correct
1612+
shape and the expected data type.
1613+
order ("C","F","A","K", optional): memory layout of the new
1614+
output array, if parameter `out` is ``None``.
1615+
Default: "K".
1616+
1617+
Returns:
1618+
usm_ndarray:
1619+
An array containing the bases in `x1` raised to the exponents in `x2`
1620+
element-wise. The data type of the returned array is determined by the
1621+
Type Promotion Rules.
1622+
"""
1623+
pow = BinaryElementwiseFunc(
1624+
"pow",
1625+
ti._pow_result_type,
1626+
ti._pow,
1627+
_pow_docstring_,
1628+
binary_inplace_fn=ti._pow_inplace,
1629+
)
1630+
del _pow_docstring_
1631+
15271632
# U27: ==== REAL (x)
15281633
_real_docstring = r"""
15291634
real(x, /, \*, out=None, order='K')

0 commit comments

Comments
 (0)