|
39 | 39 | from ._numpy_helper import normalize_axis_tuple |
40 | 40 | from ._type_utils import ( |
41 | 41 | _default_accumulation_dtype, |
| 42 | + _default_accumulation_dtype_fp_types, |
42 | 43 | _to_device_supported_dtype, |
43 | 44 | ) |
44 | 45 |
|
@@ -472,6 +473,111 @@ def argmin(x, /, *, axis=None, keepdims=False, out=None): |
472 | 473 | return _search_over_axis(x, axis, keepdims, out, tri._argmin_over_axis) |
473 | 474 |
|
474 | 475 |
|
| 476 | +def count_nonzero(x, /, *, axis=None, keepdims=False, out=None): |
| 477 | + """ |
| 478 | + Counts the number of elements in the input array ``x`` which are non-zero. |
| 479 | +
|
| 480 | + Args: |
| 481 | + x (usm_ndarray): |
| 482 | + input array. |
| 483 | + axis (Optional[int, Tuple[int, ...]]): |
| 484 | + axis or axes along which to count. If a tuple of unique integers, |
| 485 | + the number of non-zero values are computed over multiple axes. |
| 486 | + If ``None``, the number of non-zero values is computed over the |
| 487 | + entire array. |
| 488 | + Default: ``None``. |
| 489 | + keepdims (Optional[bool]): |
| 490 | + if ``True``, the reduced axes (dimensions) are included in the |
| 491 | + result as singleton dimensions, so that the returned array remains |
| 492 | + compatible with the input arrays according to Array Broadcasting |
| 493 | + rules. Otherwise, if ``False``, the reduced axes are not included |
| 494 | + in the returned array. Default: ``False``. |
| 495 | + out (Optional[usm_ndarray]): |
| 496 | + the array into which the result is written. |
| 497 | + The data type of ``out`` must match the expected shape and data |
| 498 | + type. |
| 499 | + If ``None`` then a new array is returned. Default: ``None``. |
| 500 | +
|
| 501 | + Returns: |
| 502 | + usm_ndarray: |
| 503 | + an array containing the count of non-zero values. If the sum was |
| 504 | + computed over the entire array, a zero-dimensional array is |
| 505 | + returned. The returned array will have the default array index data |
| 506 | + type. |
| 507 | + """ |
| 508 | + if x.dtype != dpt.bool: |
| 509 | + x = dpt.astype(x, dpt.bool, copy=False) |
| 510 | + return sum( |
| 511 | + x, |
| 512 | + axis=axis, |
| 513 | + dtype=ti.default_device_index_type(x.sycl_device), |
| 514 | + keepdims=keepdims, |
| 515 | + out=out, |
| 516 | + ) |
| 517 | + |
| 518 | + |
| 519 | +def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False, out=None): |
| 520 | + """ |
| 521 | + Calculates the logarithm of the sum of exponentials of elements in the |
| 522 | + input array ``x``. |
| 523 | +
|
| 524 | + Args: |
| 525 | + x (usm_ndarray): |
| 526 | + input array. |
| 527 | + axis (Optional[int, Tuple[int, ...]]): |
| 528 | + axis or axes along which values must be computed. If a tuple |
| 529 | + of unique integers, values are computed over multiple axes. |
| 530 | + If ``None``, the result is computed over the entire array. |
| 531 | + Default: ``None``. |
| 532 | + dtype (Optional[dtype]): |
| 533 | + data type of the returned array. If ``None``, the default data |
| 534 | + type is inferred from the "kind" of the input array data type. |
| 535 | +
|
| 536 | + * If ``x`` has a real-valued floating-point data type, the |
| 537 | + returned array will have the same data type as ``x``. |
| 538 | + * If ``x`` has a boolean or integral data type, the returned array |
| 539 | + will have the default floating point data type for the device |
| 540 | + where input array ``x`` is allocated. |
| 541 | + * If ``x`` has a complex-valued floating-point data type, |
| 542 | + an error is raised. |
| 543 | +
|
| 544 | + If the data type (either specified or resolved) differs from the |
| 545 | + data type of ``x``, the input array elements are cast to the |
| 546 | + specified data type before computing the result. |
| 547 | + Default: ``None``. |
| 548 | + keepdims (Optional[bool]): |
| 549 | + if ``True``, the reduced axes (dimensions) are included in the |
| 550 | + result as singleton dimensions, so that the returned array remains |
| 551 | + compatible with the input arrays according to Array Broadcasting |
| 552 | + rules. Otherwise, if ``False``, the reduced axes are not included |
| 553 | + in the returned array. Default: ``False``. |
| 554 | + out (Optional[usm_ndarray]): |
| 555 | + the array into which the result is written. |
| 556 | + The data type of ``out`` must match the expected shape and the |
| 557 | + expected data type of the result or (if provided) ``dtype``. |
| 558 | + If ``None`` then a new array is returned. Default: ``None``. |
| 559 | +
|
| 560 | + Returns: |
| 561 | + usm_ndarray: |
| 562 | + an array containing the results. If the result was computed over |
| 563 | + the entire array, a zero-dimensional array is returned. |
| 564 | + The returned array has the data type as described in the |
| 565 | + ``dtype`` parameter description above. |
| 566 | + """ |
| 567 | + return _reduction_over_axis( |
| 568 | + x, |
| 569 | + axis, |
| 570 | + dtype, |
| 571 | + keepdims, |
| 572 | + out, |
| 573 | + tri._logsumexp_over_axis, |
| 574 | + lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported( |
| 575 | + inp_dt, res_dt |
| 576 | + ), |
| 577 | + _default_accumulation_dtype_fp_types, |
| 578 | + ) |
| 579 | + |
| 580 | + |
475 | 581 | def max(x, /, *, axis=None, keepdims=False, out=None): |
476 | 582 | """ |
477 | 583 | Calculates the maximum value of the input array ``x``. |
@@ -602,6 +708,67 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False, out=None): |
602 | 708 | ) |
603 | 709 |
|
604 | 710 |
|
| 711 | +def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False, out=None): |
| 712 | + """ |
| 713 | + Calculates the square root of the sum of squares of elements in the input |
| 714 | + array ``x``. |
| 715 | +
|
| 716 | + Args: |
| 717 | + x (usm_ndarray): |
| 718 | + input array. |
| 719 | + axis (Optional[int, Tuple[int, ...]]): |
| 720 | + axis or axes along which values must be computed. If a tuple |
| 721 | + of unique integers, values are computed over multiple axes. |
| 722 | + If ``None``, the result is computed over the entire array. |
| 723 | + Default: ``None``. |
| 724 | + dtype (Optional[dtype]): |
| 725 | + data type of the returned array. If ``None``, the default data |
| 726 | + type is inferred from the "kind" of the input array data type. |
| 727 | +
|
| 728 | + * If ``x`` has a real-valued floating-point data type, the |
| 729 | + returned array will have the same data type as ``x``. |
| 730 | + * If ``x`` has a boolean or integral data type, the returned array |
| 731 | + will have the default floating point data type for the device |
| 732 | + where input array ``x`` is allocated. |
| 733 | + * If ``x`` has a complex-valued floating-point data type, |
| 734 | + an error is raised. |
| 735 | +
|
| 736 | + If the data type (either specified or resolved) differs from the |
| 737 | + data type of ``x``, the input array elements are cast to the |
| 738 | + specified data type before computing the result. Default: ``None``. |
| 739 | + keepdims (Optional[bool]): |
| 740 | + if ``True``, the reduced axes (dimensions) are included in the |
| 741 | + result as singleton dimensions, so that the returned array remains |
| 742 | + compatible with the input arrays according to Array Broadcasting |
| 743 | + rules. Otherwise, if ``False``, the reduced axes are not included |
| 744 | + in the returned array. Default: ``False``. |
| 745 | + out (Optional[usm_ndarray]): |
| 746 | + the array into which the result is written. |
| 747 | + The data type of ``out`` must match the expected shape and the |
| 748 | + expected data type of the result or (if provided) ``dtype``. |
| 749 | + If ``None`` then a new array is returned. Default: ``None``. |
| 750 | +
|
| 751 | + Returns: |
| 752 | + usm_ndarray: |
| 753 | + an array containing the results. If the result was computed over |
| 754 | + the entire array, a zero-dimensional array is returned. The |
| 755 | + returned array has the data type as described in the ``dtype`` |
| 756 | + parameter description above. |
| 757 | + """ |
| 758 | + return _reduction_over_axis( |
| 759 | + x, |
| 760 | + axis, |
| 761 | + dtype, |
| 762 | + keepdims, |
| 763 | + out, |
| 764 | + tri._hypot_over_axis, |
| 765 | + lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported( |
| 766 | + inp_dt, res_dt |
| 767 | + ), |
| 768 | + _default_accumulation_dtype_fp_types, |
| 769 | + ) |
| 770 | + |
| 771 | + |
605 | 772 | def sum(x, /, *, axis=None, dtype=None, keepdims=False, out=None): |
606 | 773 | """ |
607 | 774 | Calculates the sum of elements in the input array ``x``. |
|
0 commit comments