Skip to content

Commit 668f3fb

Browse files
Move ti.cumulative_prod() and reuse it in dpnp
1 parent 91547cc commit 668f3fb

File tree

3 files changed

+84
-3
lines changed

3 files changed

+84
-3
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
)
8181
from dpctl_ext.tensor._reshape import reshape
8282

83-
from ._accumulation import cumulative_sum
83+
from ._accumulation import cumulative_prod, cumulative_sum
8484
from ._clip import clip
8585
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
8686

@@ -95,6 +95,7 @@
9595
"concat",
9696
"copy",
9797
"clip",
98+
"cumulative_prod",
9899
"cumulative_sum",
99100
"empty",
100101
"empty_like",

dpctl_ext/tensor/_accumulation.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,83 @@ def cumulative_sum(
309309
tai._cumsum_dtype_supported,
310310
_default_accumulation_dtype,
311311
)
312+
313+
314+
def cumulative_prod(
315+
x, /, *, axis=None, dtype=None, include_initial=False, out=None
316+
):
317+
"""
318+
cumulative_prod(x, /, *, axis=None, dtype=None, include_initial=False,
319+
out=None)
320+
321+
Calculates the cumulative product of elements in the input array `x`.
322+
323+
Args:
324+
x (usm_ndarray):
325+
input array.
326+
axis (Optional[int]):
327+
axis along which cumulative product must be computed.
328+
If `None`, the product is computed over the entire array.
329+
If `x` is a one-dimensional array, providing an `axis` is optional;
330+
however, if `x` has more than one dimension, providing an `axis`
331+
is required.
332+
Default: `None`.
333+
dtype (Optional[dtype]):
334+
data type of the returned array. If `None`, the default data
335+
type is inferred from the "kind" of the input array data type.
336+
337+
* If `x` has a real- or complex-valued floating-point data
338+
type, the returned array will have the same data type as
339+
`x`.
340+
* If `x` has signed integral data type, the returned array
341+
will have the default signed integral type for the device
342+
where input array `x` is allocated.
343+
* If `x` has unsigned integral data type, the returned array
344+
will have the default unsigned integral type for the device
345+
where input array `x` is allocated.
346+
* If `x` has a boolean data type, the returned array will
347+
have the default signed integral type for the device
348+
where input array `x` is allocated.
349+
350+
If the data type (either specified or resolved) differs from the
351+
data type of `x`, the input array elements are cast to the
352+
specified data type before computing the cumulative product.
353+
Default: `None`.
354+
include_initial (bool):
355+
boolean indicating whether to include the initial value (i.e., the
356+
additive identity, zero) as the first value along the provided
357+
axis in the output. Default: `False`.
358+
out (Optional[usm_ndarray]):
359+
the array into which the result is written.
360+
The data type of `out` must match the expected shape and the
361+
expected data type of the result or (if provided) `dtype`.
362+
If `None` then a new array is returned. Default: `None`.
363+
364+
Returns:
365+
usm_ndarray:
366+
an array containing cumulative products. The returned array has
367+
the data type as described in the `dtype` parameter description
368+
above.
369+
370+
The returned array shape is determined as follows:
371+
372+
* If `include_initial` is `False`, the returned array will
373+
have the same shape as `x`
374+
* If `include_initial` is `True`, the returned array will
375+
have the same shape as `x` except the axis along which the
376+
cumulative product is calculated, which will have size `N+1`
377+
378+
where `N` is the size of the axis the cumulative products are
379+
computed along.
380+
"""
381+
return _accumulate_common(
382+
x,
383+
axis,
384+
dtype,
385+
include_initial,
386+
out,
387+
tai._cumprod_over_axis,
388+
tai._cumprod_final_axis_include_initial,
389+
tai._cumprod_dtype_supported,
390+
_default_accumulation_dtype,
391+
)

dpnp/dpnp_iface_mathematical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,7 @@ def cumprod(a, axis=None, dtype=None, out=None):
11261126
return dpnp_wrap_reduction_call(
11271127
usm_a,
11281128
out,
1129-
dpt.cumulative_prod,
1129+
dpt_ext.cumulative_prod,
11301130
_get_reduction_res_dt(a, dtype),
11311131
axis=axis,
11321132
dtype=dtype,
@@ -1307,7 +1307,7 @@ def cumulative_prod(
13071307
return dpnp_wrap_reduction_call(
13081308
dpnp.get_usm_ndarray(x),
13091309
out,
1310-
dpt.cumulative_prod,
1310+
dpt_ext.cumulative_prod,
13111311
_get_reduction_res_dt(x, dtype),
13121312
axis=axis,
13131313
dtype=dtype,

0 commit comments

Comments
 (0)