Skip to content

Commit d8c3680

Browse files
Move ti.cumulative_logsumexp() and reuse it in dpnp
1 parent 72d2109 commit d8c3680

File tree

3 files changed

+87
-6
lines changed

3 files changed

+87
-6
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
)
8181
from dpctl_ext.tensor._reshape import reshape
8282

83-
from ._accumulation import cumulative_prod, cumulative_sum
83+
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
8484
from ._clip import clip
8585
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
8686

@@ -95,6 +95,7 @@
9595
"concat",
9696
"copy",
9797
"clip",
98+
"cumulative_logsumexp",
9899
"cumulative_prod",
99100
"cumulative_sum",
100101
"empty",

dpctl_ext/tensor/_accumulation.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,18 @@
2828

2929
import dpctl
3030
import dpctl.tensor as dpt
31-
from dpctl.tensor._type_utils import ( # _default_accumulation_dtype_fp_types,
32-
_default_accumulation_dtype,
33-
_to_device_supported_dtype,
34-
)
3531
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
3632

3733
# TODO: revert to `import dpctl.tensor...`
3834
# when dpnp fully migrates dpctl/tensor
3935
import dpctl_ext.tensor as dpt_ext
4036
import dpctl_ext.tensor._tensor_accumulation_impl as tai
4137
import dpctl_ext.tensor._tensor_impl as ti
38+
from dpctl_ext.tensor._type_utils import (
39+
_default_accumulation_dtype,
40+
_default_accumulation_dtype_fp_types,
41+
_to_device_supported_dtype,
42+
)
4243

4344
from ._numpy_helper import normalize_axis_index
4445

@@ -389,3 +390,81 @@ def cumulative_prod(
389390
tai._cumprod_dtype_supported,
390391
_default_accumulation_dtype,
391392
)
393+
394+
395+
def cumulative_logsumexp(
396+
x, /, *, axis=None, dtype=None, include_initial=False, out=None
397+
):
398+
"""
399+
cumulative_logsumexp(x, /, *, axis=None, dtype=None, include_initial=False,
400+
out=None)
401+
402+
Calculates the cumulative logsmumexp of elements in the input array `x`.
403+
404+
Args:
405+
x (usm_ndarray):
406+
input array.
407+
axis (Optional[int]):
408+
axis along which cumulative logsumexp must be computed.
409+
If `None`, the logsumexp is computed over the entire array.
410+
If `x` is a one-dimensional array, providing an `axis` is optional;
411+
however, if `x` has more than one dimension, providing an `axis`
412+
is required.
413+
Default: `None`.
414+
dtype (Optional[dtype]):
415+
data type of the returned array. If `None`, the default data
416+
type is inferred from the "kind" of the input array data type.
417+
418+
* If `x` has a real- or complex-valued floating-point data
419+
type, the returned array will have the same data type as
420+
`x`.
421+
* If `x` has signed integral data type, the returned array
422+
will have the default signed integral type for the device
423+
where input array `x` is allocated.
424+
* If `x` has unsigned integral data type, the returned array
425+
will have the default unsigned integral type for the device
426+
where input array `x` is allocated.
427+
* If `x` has a boolean data type, the returned array will
428+
have the default signed integral type for the device
429+
where input array `x` is allocated.
430+
431+
If the data type (either specified or resolved) differs from the
432+
data type of `x`, the input array elements are cast to the
433+
specified data type before computing the cumulative logsumexp.
434+
Default: `None`.
435+
include_initial (bool):
436+
boolean indicating whether to include the initial value (i.e., the
437+
additive identity, zero) as the first value along the provided axis
438+
in the output. Default: `False`.
439+
out (Optional[usm_ndarray]):
440+
the array into which the result is written.
441+
The data type of `out` must match the expected shape and the
442+
expected data type of the result or (if provided) `dtype`.
443+
If `None` then a new array is returned. Default: `None`.
444+
445+
Returns:
446+
usm_ndarray:
447+
an array containing cumulative logsumexp results. The returned
448+
array has the data type as described in the `dtype` parameter
449+
description above.
450+
451+
The returned array shape is determined as follows:
452+
453+
* If `include_initial` is `False`, the returned array will
454+
have the same shape as `x`
455+
* If `include_initial` is `True`, the returned array will
456+
have the same shape as `x` except the axis along which the
457+
cumulative logsumexp is calculated, which will have size
458+
`N+1`
459+
"""
460+
return _accumulate_common(
461+
x,
462+
axis,
463+
dtype,
464+
include_initial,
465+
out,
466+
tai._cumlogsumexp_over_axis,
467+
tai._cumlogsumexp_final_axis_include_initial,
468+
tai._cumlogsumexp_dtype_supported,
469+
_default_accumulation_dtype_fp_types,
470+
)

dpnp/dpnp_iface_trigonometric.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
# TODO: revert to `import dpctl.tensor...`
5050
# when dpnp fully migrates dpctl/tensor
51+
import dpctl_ext.tensor as dpt_ext
5152
import dpctl_ext.tensor._type_utils as dtu
5253
import dpnp
5354
import dpnp.backend.extensions.ufunc._ufunc_impl as ufi
@@ -934,7 +935,7 @@ def cumlogsumexp(
934935
return dpnp_wrap_reduction_call(
935936
usm_x,
936937
out,
937-
dpt.cumulative_logsumexp,
938+
dpt_ext.cumulative_logsumexp,
938939
_get_accumulation_res_dt(x, dtype),
939940
axis=axis,
940941
dtype=dtype,

0 commit comments

Comments
 (0)