Skip to content

Commit fe7778d

Browse files
Move ti.equal()/floor_divide()/divide() and reuse them
1 parent ccb4c67 commit fe7778d

File tree

15 files changed

+2659
-17
lines changed

15 files changed

+2659
-17
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,11 @@ set(_elementwise_sources
9494
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/copysign.cpp
9595
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cos.cpp
9696
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cosh.cpp
97-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/equal.cpp
97+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/equal.cpp
9898
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp.cpp
9999
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp2.cpp
100100
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/expm1.cpp
101-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/floor_divide.cpp
101+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/floor_divide.cpp
102102
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/floor.cpp
103103
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/greater_equal.cpp
104104
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/greater.cpp
@@ -141,7 +141,7 @@ set(_elementwise_sources
141141
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/subtract.cpp
142142
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tan.cpp
143143
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tanh.cpp
144-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/true_divide.cpp
144+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/true_divide.cpp
145145
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/trunc.cpp
146146
)
147147
set(_reduction_sources

dpctl_ext/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,13 @@
111111
conj,
112112
cos,
113113
cosh,
114+
divide,
115+
equal,
114116
exp,
115117
exp2,
116118
expm1,
117119
floor,
120+
floor_divide,
118121
imag,
119122
isfinite,
120123
isinf,
@@ -205,8 +208,10 @@
205208
"cumulative_prod",
206209
"cumulative_sum",
207210
"diff",
211+
"divide",
208212
"empty",
209213
"empty_like",
214+
"equal",
210215
"extract",
211216
"expand_dims",
212217
"eye",
@@ -216,6 +221,7 @@
216221
"finfo",
217222
"flip",
218223
"floor",
224+
"floor_divide",
219225
"from_numpy",
220226
"full",
221227
"full_like",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@
3232

3333
from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc
3434
from ._type_utils import (
35+
_acceptance_fn_divide,
3536
_acceptance_fn_negative,
3637
_acceptance_fn_reciprocal,
38+
_resolve_weak_types_all_py_ints,
3739
)
3840

3941
# U01: ==== ABS (x)
@@ -637,6 +639,78 @@
637639
)
638640
del _cosh_docstring
639641

642+
# B08: ==== DIVIDE (x1, x2)
643+
_divide_docstring_ = r"""
644+
divide(x1, x2, /, \*, out=None, order='K')
645+
646+
Calculates the ratio for each element `x1_i` of the input array `x1` with
647+
the respective element `x2_i` of the input array `x2`.
648+
649+
Args:
650+
x1 (usm_ndarray):
651+
First input array, expected to have a floating-point data type.
652+
x2 (usm_ndarray):
653+
Second input array, also expected to have a floating-point data type.
654+
out (Union[usm_ndarray, None], optional):
655+
Output array to populate.
656+
Array must have the correct shape and the expected data type.
657+
order ("C","F","A","K", optional):
658+
Memory layout of the new output array, if parameter
659+
`out` is ``None``.
660+
Default: "K".
661+
662+
Returns:
663+
usm_ndarray:
664+
An array containing the result of element-wise division. The data type
665+
of the returned array is determined by the Type Promotion Rules.
666+
"""
667+
668+
divide = BinaryElementwiseFunc(
669+
"divide",
670+
ti._divide_result_type,
671+
ti._divide,
672+
_divide_docstring_,
673+
binary_inplace_fn=ti._divide_inplace,
674+
acceptance_fn=_acceptance_fn_divide,
675+
weak_type_resolver=_resolve_weak_types_all_py_ints,
676+
)
677+
del _divide_docstring_
678+
679+
# B09: ==== EQUAL (x1, x2)
680+
_equal_docstring_ = r"""
681+
equal(x1, x2, /, \*, out=None, order='K')
682+
683+
Calculates equality test results for each element `x1_i` of the input array `x1`
684+
with the respective element `x2_i` of the input array `x2`.
685+
686+
Args:
687+
x1 (usm_ndarray):
688+
First input array. May have any data type.
689+
x2 (usm_ndarray):
690+
Second input array. May have any data type.
691+
out (Union[usm_ndarray, None], optional):
692+
Output array to populate.
693+
Array must have the correct shape and the expected data type.
694+
order ("C","F","A","K", optional):
695+
Memory layout of the new output array, if parameter
696+
`out` is ``None``.
697+
Default: "K".
698+
699+
Returns:
700+
usm_ndarray:
701+
An array containing the result of element-wise equality comparison.
702+
The returned array has a data type of `bool`.
703+
"""
704+
705+
equal = BinaryElementwiseFunc(
706+
"equal",
707+
ti._equal_result_type,
708+
ti._equal,
709+
_equal_docstring_,
710+
weak_type_resolver=_resolve_weak_types_all_py_ints,
711+
)
712+
del _equal_docstring_
713+
640714
# U13: ==== EXP (x)
641715
_exp_docstring = r"""
642716
exp(x, /, \*, out=None, order='K')
@@ -664,6 +738,43 @@
664738
exp = UnaryElementwiseFunc("exp", ti._exp_result_type, ti._exp, _exp_docstring)
665739
del _exp_docstring
666740

741+
# B10: ==== FLOOR_DIVIDE (x1, x2)
742+
_floor_divide_docstring_ = r"""
743+
floor_divide(x1, x2, /, \*, out=None, order='K')
744+
745+
Calculates the ratio for each element `x1_i` of the input array `x1` with
746+
the respective element `x2_i` of the input array `x2` to the greatest
747+
integer-value number that is not greater than the division result.
748+
749+
Args:
750+
x1 (usm_ndarray):
751+
First input array, expected to have a real-valued data type.
752+
x2 (usm_ndarray):
753+
Second input array, also expected to have a real-valued data type.
754+
out (Union[usm_ndarray, None], optional):
755+
Output array to populate.
756+
Array must have the correct shape and the expected data type.
757+
order ("C","F","A","K", optional):
758+
Memory layout of the new output array, if parameter
759+
`out` is ``None``.
760+
Default: "K".
761+
762+
Returns:
763+
usm_ndarray:
764+
An array containing the result of element-wise floor of division.
765+
The data type of the returned array is determined by the Type
766+
Promotion Rules.
767+
"""
768+
769+
floor_divide = BinaryElementwiseFunc(
770+
"floor_divide",
771+
ti._floor_divide_result_type,
772+
ti._floor_divide,
773+
_floor_divide_docstring_,
774+
binary_inplace_fn=ti._floor_divide_inplace,
775+
)
776+
del _floor_divide_docstring_
777+
667778
# U14: ==== EXPM1 (x)
668779
_expm1_docstring = r"""
669780
expm1(x, /, \*, out=None, order='K')

0 commit comments

Comments
 (0)