|
28 | 28 |
|
29 | 29 | import dpctl |
30 | 30 | import dpctl.tensor as dpt |
31 | | -from dpctl.tensor._type_utils import ( # _default_accumulation_dtype_fp_types, |
32 | | - _default_accumulation_dtype, |
33 | | - _to_device_supported_dtype, |
34 | | -) |
35 | 31 | from dpctl.utils import ExecutionPlacementError, SequentialOrderManager |
36 | 32 |
|
37 | 33 | # TODO: revert to `import dpctl.tensor...` |
38 | 34 | # when dpnp fully migrates dpctl/tensor |
39 | 35 | import dpctl_ext.tensor as dpt_ext |
40 | 36 | import dpctl_ext.tensor._tensor_accumulation_impl as tai |
41 | 37 | import dpctl_ext.tensor._tensor_impl as ti |
| 38 | +from dpctl_ext.tensor._type_utils import ( |
| 39 | + _default_accumulation_dtype, |
| 40 | + _default_accumulation_dtype_fp_types, |
| 41 | + _to_device_supported_dtype, |
| 42 | +) |
42 | 43 |
|
43 | 44 | from ._numpy_helper import normalize_axis_index |
44 | 45 |
|
@@ -389,3 +390,81 @@ def cumulative_prod( |
389 | 390 | tai._cumprod_dtype_supported, |
390 | 391 | _default_accumulation_dtype, |
391 | 392 | ) |
| 393 | + |
| 394 | + |
| 395 | +def cumulative_logsumexp( |
| 396 | + x, /, *, axis=None, dtype=None, include_initial=False, out=None |
| 397 | +): |
| 398 | + """ |
| 399 | + cumulative_logsumexp(x, /, *, axis=None, dtype=None, include_initial=False, |
| 400 | + out=None) |
| 401 | +
|
| 402 | + Calculates the cumulative logsmumexp of elements in the input array `x`. |
| 403 | +
|
| 404 | + Args: |
| 405 | + x (usm_ndarray): |
| 406 | + input array. |
| 407 | + axis (Optional[int]): |
| 408 | + axis along which cumulative logsumexp must be computed. |
| 409 | + If `None`, the logsumexp is computed over the entire array. |
| 410 | + If `x` is a one-dimensional array, providing an `axis` is optional; |
| 411 | + however, if `x` has more than one dimension, providing an `axis` |
| 412 | + is required. |
| 413 | + Default: `None`. |
| 414 | + dtype (Optional[dtype]): |
| 415 | + data type of the returned array. If `None`, the default data |
| 416 | + type is inferred from the "kind" of the input array data type. |
| 417 | +
|
| 418 | + * If `x` has a real- or complex-valued floating-point data |
| 419 | + type, the returned array will have the same data type as |
| 420 | + `x`. |
| 421 | + * If `x` has signed integral data type, the returned array |
| 422 | + will have the default signed integral type for the device |
| 423 | + where input array `x` is allocated. |
| 424 | + * If `x` has unsigned integral data type, the returned array |
| 425 | + will have the default unsigned integral type for the device |
| 426 | + where input array `x` is allocated. |
| 427 | + * If `x` has a boolean data type, the returned array will |
| 428 | + have the default signed integral type for the device |
| 429 | + where input array `x` is allocated. |
| 430 | +
|
| 431 | + If the data type (either specified or resolved) differs from the |
| 432 | + data type of `x`, the input array elements are cast to the |
| 433 | + specified data type before computing the cumulative logsumexp. |
| 434 | + Default: `None`. |
| 435 | + include_initial (bool): |
| 436 | + boolean indicating whether to include the initial value (i.e., the |
| 437 | + additive identity, zero) as the first value along the provided axis |
| 438 | + in the output. Default: `False`. |
| 439 | + out (Optional[usm_ndarray]): |
| 440 | + the array into which the result is written. |
| 441 | + The data type of `out` must match the expected shape and the |
| 442 | + expected data type of the result or (if provided) `dtype`. |
| 443 | + If `None` then a new array is returned. Default: `None`. |
| 444 | +
|
| 445 | + Returns: |
| 446 | + usm_ndarray: |
| 447 | + an array containing cumulative logsumexp results. The returned |
| 448 | + array has the data type as described in the `dtype` parameter |
| 449 | + description above. |
| 450 | +
|
| 451 | + The returned array shape is determined as follows: |
| 452 | +
|
| 453 | + * If `include_initial` is `False`, the returned array will |
| 454 | + have the same shape as `x` |
| 455 | + * If `include_initial` is `True`, the returned array will |
| 456 | + have the same shape as `x` except the axis along which the |
| 457 | + cumulative logsumexp is calculated, which will have size |
| 458 | + `N+1` |
| 459 | + """ |
| 460 | + return _accumulate_common( |
| 461 | + x, |
| 462 | + axis, |
| 463 | + dtype, |
| 464 | + include_initial, |
| 465 | + out, |
| 466 | + tai._cumlogsumexp_over_axis, |
| 467 | + tai._cumlogsumexp_final_axis_include_initial, |
| 468 | + tai._cumlogsumexp_dtype_supported, |
| 469 | + _default_accumulation_dtype_fp_types, |
| 470 | + ) |
0 commit comments