Skip to content

Commit 6b27b1b

Browse files
Move ti.sum()/prod() to dpctl_ext.tensor and reuse them in dpnp
1 parent 2ec3cc7 commit 6b27b1b

File tree

4 files changed

+305
-11
lines changed

4 files changed

+305
-11
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@
8888
argmin,
8989
max,
9090
min,
91+
prod,
92+
sum,
9193
)
9294
from ._searchsorted import searchsorted
9395
from ._set_functions import (
@@ -143,6 +145,7 @@
143145
"ones",
144146
"ones_like",
145147
"place",
148+
"prod",
146149
"put",
147150
"put_along_axis",
148151
"repeat",
@@ -153,6 +156,7 @@
153156
"sort",
154157
"squeeze",
155158
"stack",
159+
"sum",
156160
"swapaxes",
157161
"take",
158162
"take_along_axis",

dpctl_ext/tensor/_reduction.py

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
import dpctl_ext.tensor._tensor_reductions_impl as tri
3838

3939
from ._numpy_helper import normalize_axis_tuple
40+
from ._type_utils import (
41+
_default_accumulation_dtype,
42+
_to_device_supported_dtype,
43+
)
4044

4145

4246
def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn):
@@ -137,6 +141,164 @@ def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn):
137141
return out
138142

139143

144+
def _reduction_over_axis(
145+
x,
146+
axis,
147+
dtype,
148+
keepdims,
149+
out,
150+
_reduction_fn,
151+
_dtype_supported,
152+
_default_reduction_type_fn,
153+
):
154+
if not isinstance(x, dpt.usm_ndarray):
155+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
156+
nd = x.ndim
157+
if axis is None:
158+
axis = tuple(range(nd))
159+
perm = list(axis)
160+
arr = x
161+
else:
162+
if not isinstance(axis, (tuple, list)):
163+
axis = (axis,)
164+
axis = normalize_axis_tuple(axis, nd, "axis")
165+
perm = [i for i in range(nd) if i not in axis] + list(axis)
166+
arr = dpt_ext.permute_dims(x, perm)
167+
red_nd = len(axis)
168+
res_shape = arr.shape[: nd - red_nd]
169+
q = x.sycl_queue
170+
inp_dt = x.dtype
171+
if dtype is None:
172+
res_dt = _default_reduction_type_fn(inp_dt, q)
173+
else:
174+
res_dt = dpt.dtype(dtype)
175+
res_dt = _to_device_supported_dtype(res_dt, q.sycl_device)
176+
177+
res_usm_type = x.usm_type
178+
179+
implemented_types = _dtype_supported(inp_dt, res_dt, res_usm_type, q)
180+
if dtype is None and not implemented_types:
181+
raise RuntimeError(
182+
"Automatically determined reduction data type does not "
183+
"have direct implementation"
184+
)
185+
orig_out = out
186+
if out is not None:
187+
if not isinstance(out, dpt.usm_ndarray):
188+
raise TypeError(
189+
f"output array must be of usm_ndarray type, got {type(out)}"
190+
)
191+
if not out.flags.writable:
192+
raise ValueError("provided `out` array is read-only")
193+
if not keepdims:
194+
final_res_shape = res_shape
195+
else:
196+
inp_shape = x.shape
197+
final_res_shape = tuple(
198+
inp_shape[i] if i not in axis else 1 for i in range(nd)
199+
)
200+
if not out.shape == final_res_shape:
201+
raise ValueError(
202+
"The shape of input and output arrays are inconsistent. "
203+
f"Expected output shape is {final_res_shape}, got {out.shape}"
204+
)
205+
if res_dt != out.dtype:
206+
raise ValueError(
207+
f"Output array of type {res_dt} is needed, got {out.dtype}"
208+
)
209+
if dpctl.utils.get_execution_queue((q, out.sycl_queue)) is None:
210+
raise ExecutionPlacementError(
211+
"Input and output allocation queues are not compatible"
212+
)
213+
if keepdims:
214+
out = dpt_ext.squeeze(out, axis=axis)
215+
orig_out = out
216+
if ti._array_overlap(x, out) and implemented_types:
217+
out = dpt_ext.empty_like(out)
218+
else:
219+
out = dpt_ext.empty(
220+
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
221+
)
222+
223+
_manager = SequentialOrderManager[q]
224+
dep_evs = _manager.submitted_events
225+
if red_nd == 0:
226+
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
227+
src=arr, dst=out, sycl_queue=q, depends=dep_evs
228+
)
229+
_manager.add_event_pair(ht_e_cpy, cpy_e)
230+
if not (orig_out is None or orig_out is out):
231+
ht_e_cpy2, cpy2_e = ti._copy_usm_ndarray_into_usm_ndarray(
232+
src=out, dst=orig_out, sycl_queue=q, depends=[cpy_e]
233+
)
234+
_manager.add_event_pair(ht_e_cpy2, cpy2_e)
235+
out = orig_out
236+
return out
237+
238+
if implemented_types:
239+
ht_e, red_e = _reduction_fn(
240+
src=arr,
241+
trailing_dims_to_reduce=red_nd,
242+
dst=out,
243+
sycl_queue=q,
244+
depends=dep_evs,
245+
)
246+
_manager.add_event_pair(ht_e, red_e)
247+
if not (orig_out is None or orig_out is out):
248+
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
249+
src=out, dst=orig_out, sycl_queue=q, depends=[red_e]
250+
)
251+
_manager.add_event_pair(ht_e_cpy, cpy_e)
252+
out = orig_out
253+
else:
254+
if _dtype_supported(res_dt, res_dt, res_usm_type, q):
255+
tmp = dpt_ext.empty(
256+
arr.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
257+
)
258+
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
259+
src=arr, dst=tmp, sycl_queue=q, depends=dep_evs
260+
)
261+
_manager.add_event_pair(ht_e_cpy, cpy_e)
262+
ht_e_red, red_ev = _reduction_fn(
263+
src=tmp,
264+
trailing_dims_to_reduce=red_nd,
265+
dst=out,
266+
sycl_queue=q,
267+
depends=[cpy_e],
268+
)
269+
_manager.add_event_pair(ht_e_red, red_ev)
270+
else:
271+
buf_dt = _default_reduction_type_fn(inp_dt, q)
272+
tmp = dpt_ext.empty(
273+
arr.shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q
274+
)
275+
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
276+
src=arr, dst=tmp, sycl_queue=q, depends=dep_evs
277+
)
278+
_manager.add_event_pair(ht_e_cpy, cpy_e)
279+
tmp_res = dpt_ext.empty(
280+
res_shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q
281+
)
282+
ht_e_red, r_e = _reduction_fn(
283+
src=tmp,
284+
trailing_dims_to_reduce=red_nd,
285+
dst=tmp_res,
286+
sycl_queue=q,
287+
depends=[cpy_e],
288+
)
289+
_manager.add_event_pair(ht_e_red, r_e)
290+
ht_e_cpy2, cpy2_e = ti._copy_usm_ndarray_into_usm_ndarray(
291+
src=tmp_res, dst=out, sycl_queue=q, depends=[r_e]
292+
)
293+
_manager.add_event_pair(ht_e_cpy2, cpy2_e)
294+
295+
if keepdims:
296+
res_shape = res_shape + (1,) * red_nd
297+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
298+
out = dpt_ext.permute_dims(dpt_ext.reshape(out, res_shape), inv_perm)
299+
return out
300+
301+
140302
def _search_over_axis(x, axis, keepdims, out, _reduction_fn):
141303
if not isinstance(x, dpt.usm_ndarray):
142304
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
@@ -374,3 +536,132 @@ def min(x, /, *, axis=None, keepdims=False, out=None):
374536
array has the same data type as ``x``.
375537
"""
376538
return _comparison_over_axis(x, axis, keepdims, out, tri._min_over_axis)
539+
540+
541+
def prod(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
542+
"""
543+
Calculates the product of elements in the input array ``x``.
544+
545+
Args:
546+
x (usm_ndarray):
547+
input array.
548+
axis (Optional[int, Tuple[int, ...]]):
549+
axis or axes along which products must be computed. If a tuple
550+
of unique integers, products are computed over multiple axes.
551+
If ``None``, the product is computed over the entire array.
552+
Default: ``None``.
553+
dtype (Optional[dtype]):
554+
data type of the returned array. If ``None``, the default data
555+
type is inferred from the "kind" of the input array data type.
556+
557+
* If ``x`` has a real- or complex-valued floating-point data
558+
type, the returned array will have the same data type as
559+
``x``.
560+
* If ``x`` has signed integral data type, the returned array
561+
will have the default signed integral type for the device
562+
where input array ``x`` is allocated.
563+
* If ``x`` has unsigned integral data type, the returned array
564+
will have the default unsigned integral type for the device
565+
where input array ``x`` is allocated.
566+
* If ``x`` has a boolean data type, the returned array will
567+
have the default signed integral type for the device
568+
where input array ``x`` is allocated.
569+
570+
If the data type (either specified or resolved) differs from the
571+
data type of ``x``, the input array elements are cast to the
572+
specified data type before computing the product.
573+
Default: ``None``.
574+
keepdims (Optional[bool]):
575+
if ``True``, the reduced axes (dimensions) are included in the
576+
result as singleton dimensions, so that the returned array remains
577+
compatible with the input arrays according to Array Broadcasting
578+
rules. Otherwise, if ``False``, the reduced axes are not included
579+
in the returned array. Default: ``False``.
580+
out (Optional[usm_ndarray]):
581+
the array into which the result is written.
582+
The data type of ``out`` must match the expected shape and the
583+
expected data type of the result or (if provided) ``dtype``.
584+
If ``None`` then a new array is returned. Default: ``None``.
585+
586+
Returns:
587+
usm_ndarray:
588+
an array containing the products. If the product was computed over
589+
the entire array, a zero-dimensional array is returned. The
590+
returned array has the data type as described in the ``dtype``
591+
parameter description above.
592+
"""
593+
return _reduction_over_axis(
594+
x,
595+
axis,
596+
dtype,
597+
keepdims,
598+
out,
599+
tri._prod_over_axis,
600+
tri._prod_over_axis_dtype_supported,
601+
_default_accumulation_dtype,
602+
)
603+
604+
605+
def sum(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
606+
"""
607+
Calculates the sum of elements in the input array ``x``.
608+
609+
Args:
610+
x (usm_ndarray):
611+
input array.
612+
axis (Optional[int, Tuple[int, ...]]):
613+
axis or axes along which sums must be computed. If a tuple
614+
of unique integers, sums are computed over multiple axes.
615+
If ``None``, the sum is computed over the entire array.
616+
Default: ``None``.
617+
dtype (Optional[dtype]):
618+
data type of the returned array. If ``None``, the default data
619+
type is inferred from the "kind" of the input array data type.
620+
621+
* If ``x`` has a real- or complex-valued floating-point data
622+
type, the returned array will have the same data type as
623+
``x``.
624+
* If ``x`` has signed integral data type, the returned array
625+
will have the default signed integral type for the device
626+
where input array ``x`` is allocated.
627+
* If ``x`` has unsigned integral data type, the returned array
628+
will have the default unsigned integral type for the device
629+
where input array ``x`` is allocated.
630+
array ``x`` is allocated.
631+
* If ``x`` has a boolean data type, the returned array will
632+
have the default signed integral type for the device
633+
where input array ``x`` is allocated.
634+
635+
If the data type (either specified or resolved) differs from the
636+
data type of ``x``, the input array elements are cast to the
637+
specified data type before computing the sum.
638+
Default: ``None``.
639+
keepdims (Optional[bool]):
640+
if ``True``, the reduced axes (dimensions) are included in the
641+
result as singleton dimensions, so that the returned array remains
642+
compatible with the input arrays according to Array Broadcasting
643+
rules. Otherwise, if ``False``, the reduced axes are not included
644+
in the returned array. Default: ``False``.
645+
out (Optional[usm_ndarray]):
646+
the array into which the result is written.
647+
The data type of ``out`` must match the expected shape and the
648+
expected data type of the result or (if provided) ``dtype``.
649+
If ``None`` then a new array is returned. Default: ``None``.
650+
651+
Returns:
652+
usm_ndarray:
653+
an array containing the sums. If the sum was computed over the
654+
entire array, a zero-dimensional array is returned. The returned
655+
array has the data type as described in the ``dtype`` parameter
656+
description above.
657+
"""
658+
return _reduction_over_axis(
659+
x,
660+
axis,
661+
dtype,
662+
keepdims,
663+
out,
664+
tri._sum_over_axis,
665+
tri._sum_over_axis_dtype_supported,
666+
_default_accumulation_dtype,
667+
)

dpnp/dpnp_iface_manipulation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,9 @@ def _get_first_nan_index(usm_a):
428428
if first_nan is not None:
429429
# all NaNs are collapsed, so need to put a count of all NaNs
430430
# at the last index
431-
dpt.sum(usm_res.counts[first_nan:], out=usm_res.counts[first_nan])
431+
dpt_ext.sum(
432+
usm_res.counts[first_nan:], out=usm_res.counts[first_nan]
433+
)
432434
result += (usm_res.counts[: first_nan + 1],)
433435
else:
434436
result += (usm_res.counts,)

0 commit comments

Comments
 (0)