Skip to content

Commit 4872bb6

Browse files
Move ti.count_nonzero()/logsumexp()/reduce_hypot() to dpctl_ext.tensor and reuse them
1 parent 3041d7d commit 4872bb6

File tree

4 files changed

+178
-6
lines changed

4 files changed

+178
-6
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,12 @@
8686
from ._reduction import (
8787
argmax,
8888
argmin,
89+
count_nonzero,
90+
logsumexp,
8991
max,
9092
min,
9193
prod,
94+
reduce_hypot,
9295
sum,
9396
)
9497
from ._searchsorted import searchsorted
@@ -117,6 +120,7 @@
117120
"can_cast",
118121
"concat",
119122
"copy",
123+
"count_nonzero",
120124
"clip",
121125
"cumulative_logsumexp",
122126
"cumulative_prod",
@@ -136,6 +140,7 @@
136140
"isdtype",
137141
"isin",
138142
"linspace",
143+
"logsumexp",
139144
"max",
140145
"meshgrid",
141146
"min",
@@ -148,6 +153,7 @@
148153
"prod",
149154
"put",
150155
"put_along_axis",
156+
"reduce_hypot",
151157
"repeat",
152158
"reshape",
153159
"result_type",

dpctl_ext/tensor/_reduction.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from ._numpy_helper import normalize_axis_tuple
4040
from ._type_utils import (
4141
_default_accumulation_dtype,
42+
_default_accumulation_dtype_fp_types,
4243
_to_device_supported_dtype,
4344
)
4445

@@ -472,6 +473,111 @@ def argmin(x, /, *, axis=None, keepdims=False, out=None):
472473
return _search_over_axis(x, axis, keepdims, out, tri._argmin_over_axis)
473474

474475

476+
def count_nonzero(x, /, *, axis=None, keepdims=False, out=None):
477+
"""
478+
Counts the number of elements in the input array ``x`` which are non-zero.
479+
480+
Args:
481+
x (usm_ndarray):
482+
input array.
483+
axis (Optional[int, Tuple[int, ...]]):
484+
axis or axes along which to count. If a tuple of unique integers,
485+
the number of non-zero values are computed over multiple axes.
486+
If ``None``, the number of non-zero values is computed over the
487+
entire array.
488+
Default: ``None``.
489+
keepdims (Optional[bool]):
490+
if ``True``, the reduced axes (dimensions) are included in the
491+
result as singleton dimensions, so that the returned array remains
492+
compatible with the input arrays according to Array Broadcasting
493+
rules. Otherwise, if ``False``, the reduced axes are not included
494+
in the returned array. Default: ``False``.
495+
out (Optional[usm_ndarray]):
496+
the array into which the result is written.
497+
The data type of ``out`` must match the expected shape and data
498+
type.
499+
If ``None`` then a new array is returned. Default: ``None``.
500+
501+
Returns:
502+
usm_ndarray:
503+
an array containing the count of non-zero values. If the sum was
504+
computed over the entire array, a zero-dimensional array is
505+
returned. The returned array will have the default array index data
506+
type.
507+
"""
508+
if x.dtype != dpt.bool:
509+
x = dpt.astype(x, dpt.bool, copy=False)
510+
return sum(
511+
x,
512+
axis=axis,
513+
dtype=ti.default_device_index_type(x.sycl_device),
514+
keepdims=keepdims,
515+
out=out,
516+
)
517+
518+
519+
def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
520+
"""
521+
Calculates the logarithm of the sum of exponentials of elements in the
522+
input array ``x``.
523+
524+
Args:
525+
x (usm_ndarray):
526+
input array.
527+
axis (Optional[int, Tuple[int, ...]]):
528+
axis or axes along which values must be computed. If a tuple
529+
of unique integers, values are computed over multiple axes.
530+
If ``None``, the result is computed over the entire array.
531+
Default: ``None``.
532+
dtype (Optional[dtype]):
533+
data type of the returned array. If ``None``, the default data
534+
type is inferred from the "kind" of the input array data type.
535+
536+
* If ``x`` has a real-valued floating-point data type, the
537+
returned array will have the same data type as ``x``.
538+
* If ``x`` has a boolean or integral data type, the returned array
539+
will have the default floating point data type for the device
540+
where input array ``x`` is allocated.
541+
* If ``x`` has a complex-valued floating-point data type,
542+
an error is raised.
543+
544+
If the data type (either specified or resolved) differs from the
545+
data type of ``x``, the input array elements are cast to the
546+
specified data type before computing the result.
547+
Default: ``None``.
548+
keepdims (Optional[bool]):
549+
if ``True``, the reduced axes (dimensions) are included in the
550+
result as singleton dimensions, so that the returned array remains
551+
compatible with the input arrays according to Array Broadcasting
552+
rules. Otherwise, if ``False``, the reduced axes are not included
553+
in the returned array. Default: ``False``.
554+
out (Optional[usm_ndarray]):
555+
the array into which the result is written.
556+
The data type of ``out`` must match the expected shape and the
557+
expected data type of the result or (if provided) ``dtype``.
558+
If ``None`` then a new array is returned. Default: ``None``.
559+
560+
Returns:
561+
usm_ndarray:
562+
an array containing the results. If the result was computed over
563+
the entire array, a zero-dimensional array is returned.
564+
The returned array has the data type as described in the
565+
``dtype`` parameter description above.
566+
"""
567+
return _reduction_over_axis(
568+
x,
569+
axis,
570+
dtype,
571+
keepdims,
572+
out,
573+
tri._logsumexp_over_axis,
574+
lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported(
575+
inp_dt, res_dt
576+
),
577+
_default_accumulation_dtype_fp_types,
578+
)
579+
580+
475581
def max(x, /, *, axis=None, keepdims=False, out=None):
476582
"""
477583
Calculates the maximum value of the input array ``x``.
@@ -602,6 +708,67 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
602708
)
603709

604710

711+
def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
712+
"""
713+
Calculates the square root of the sum of squares of elements in the input
714+
array ``x``.
715+
716+
Args:
717+
x (usm_ndarray):
718+
input array.
719+
axis (Optional[int, Tuple[int, ...]]):
720+
axis or axes along which values must be computed. If a tuple
721+
of unique integers, values are computed over multiple axes.
722+
If ``None``, the result is computed over the entire array.
723+
Default: ``None``.
724+
dtype (Optional[dtype]):
725+
data type of the returned array. If ``None``, the default data
726+
type is inferred from the "kind" of the input array data type.
727+
728+
* If ``x`` has a real-valued floating-point data type, the
729+
returned array will have the same data type as ``x``.
730+
* If ``x`` has a boolean or integral data type, the returned array
731+
will have the default floating point data type for the device
732+
where input array ``x`` is allocated.
733+
* If ``x`` has a complex-valued floating-point data type,
734+
an error is raised.
735+
736+
If the data type (either specified or resolved) differs from the
737+
data type of ``x``, the input array elements are cast to the
738+
specified data type before computing the result. Default: ``None``.
739+
keepdims (Optional[bool]):
740+
if ``True``, the reduced axes (dimensions) are included in the
741+
result as singleton dimensions, so that the returned array remains
742+
compatible with the input arrays according to Array Broadcasting
743+
rules. Otherwise, if ``False``, the reduced axes are not included
744+
in the returned array. Default: ``False``.
745+
out (Optional[usm_ndarray]):
746+
the array into which the result is written.
747+
The data type of ``out`` must match the expected shape and the
748+
expected data type of the result or (if provided) ``dtype``.
749+
If ``None`` then a new array is returned. Default: ``None``.
750+
751+
Returns:
752+
usm_ndarray:
753+
an array containing the results. If the result was computed over
754+
the entire array, a zero-dimensional array is returned. The
755+
returned array has the data type as described in the ``dtype``
756+
parameter description above.
757+
"""
758+
return _reduction_over_axis(
759+
x,
760+
axis,
761+
dtype,
762+
keepdims,
763+
out,
764+
tri._hypot_over_axis,
765+
lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported(
766+
inp_dt, res_dt
767+
),
768+
_default_accumulation_dtype_fp_types,
769+
)
770+
771+
605772
def sum(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
606773
"""
607774
Calculates the sum of elements in the input array ``x``.

dpnp/dpnp_iface_counting.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@
3939
4040
"""
4141

42-
import dpctl.tensor as dpt
43-
42+
# TODO: revert to `import dpctl.tensor...`
43+
# when dpnp fully migrates dpctl/tensor
44+
import dpctl_ext.tensor as dpt
4445
import dpnp
4546

4647

dpnp/dpnp_iface_trigonometric.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,11 @@
4242
# pylint: disable=protected-access
4343
# pylint: disable=no-name-in-module
4444

45-
46-
import dpctl.tensor as dpt
4745
import dpctl.tensor._tensor_elementwise_impl as ti
4846

4947
# TODO: revert to `import dpctl.tensor...`
5048
# when dpnp fully migrates dpctl/tensor
51-
import dpctl_ext.tensor as dpt_ext
49+
import dpctl_ext.tensor as dpt
5250
import dpctl_ext.tensor._type_utils as dtu
5351
import dpnp
5452
import dpnp.backend.extensions.ufunc._ufunc_impl as ufi
@@ -935,7 +933,7 @@ def cumlogsumexp(
935933
return dpnp_wrap_reduction_call(
936934
usm_x,
937935
out,
938-
dpt_ext.cumulative_logsumexp,
936+
dpt.cumulative_logsumexp,
939937
_get_accumulation_res_dt(x, dtype),
940938
axis=axis,
941939
dtype=dtype,

0 commit comments

Comments
 (0)