@@ -297,6 +297,23 @@ void init_ops(nb::module_& m) {
297297 Returns:
298298 array: The sign of ``a``.
299299 )pbdoc" );
300+ m.def (
301+ " positive" ,
302+ &mx::positive,
303+ nb::arg (),
304+ nb::kw_only (),
305+ " stream" _a = nb::none (),
306+ nb::sig (
307+ " def positive(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array" ),
308+ R"pbdoc(
309+ Element-wise unary plus. Returns a copy of the input.
310+
311+ Args:
312+ a (array): Input array.
313+
314+ Returns:
315+ array: A copy of ``a``.
316+ )pbdoc" );
300317 m.def (
301318 " negative" ,
302319 [](const ScalarOrArray& a, mx::StreamOrDevice s) {
@@ -733,6 +750,23 @@ void init_ops(nb::module_& m) {
733750 Returns:
734751 array: The matrix product of ``a`` and ``b``.
735752 )pbdoc" );
753+ m.def (
754+ " trunc" ,
755+ &mx::trunc,
756+ nb::arg (),
757+ nb::kw_only (),
758+ " stream" _a = nb::none (),
759+ nb::sig (
760+ " def trunc(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array" ),
761+ R"pbdoc(
762+ Element-wise truncation towards zero.
763+
764+ Args:
765+ a (array): Input array.
766+
767+ Returns:
768+ array: The truncated array.
769+ )pbdoc" );
736770 m.def (
737771 " square" ,
738772 [](const ScalarOrArray& a, mx::StreamOrDevice s) {
@@ -871,6 +905,27 @@ void init_ops(nb::module_& m) {
871905 Returns:
872906 array: The boolean array containing the logical or of ``a`` and ``b``.
873907 )pbdoc" );
908+ m.def (
909+ " logical_xor" ,
910+ [](const ScalarOrArray& a, const ScalarOrArray& b, mx::StreamOrDevice s) {
911+ return mx::logical_xor (to_array (a), to_array (b), s);
912+ },
913+ nb::arg (),
914+ nb::arg (),
915+ nb::kw_only (),
916+ " stream" _a = nb::none (),
917+ nb::sig (
918+ " def logical_xor(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array" ),
919+ R"pbdoc(
920+ Element-wise logical exclusive or.
921+
922+ Args:
923+ a (array): First input array or scalar.
924+ b (array): Second input array or scalar.
925+
926+ Returns:
927+ array: The boolean array containing the logical xor of ``a`` and ``b``.
928+ )pbdoc" );
874929 m.def (
875930 " logaddexp" ,
876931 [](const ScalarOrArray& a_,
@@ -1792,6 +1847,34 @@ void init_ops(nb::module_& m) {
17921847 Returns:
17931848 array: The output array with the specified shape and values.
17941849 )pbdoc" );
1850+ m.def (
1851+ " full_like" ,
1852+ [](const mx::array& a,
1853+ const ScalarOrArray& vals,
1854+ std::optional<mx::Dtype> dtype,
1855+ mx::StreamOrDevice s) {
1856+ auto t = dtype.value_or (a.dtype ());
1857+ return mx::full_like (a, to_array (vals, t), t, s);
1858+ },
1859+ nb::arg (),
1860+ " vals" _a,
1861+ " dtype" _a = nb::none (),
1862+ nb::kw_only (),
1863+ " stream" _a = nb::none (),
1864+ nb::sig (
1865+ " def full_like(a: array, vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array" ),
1866+ R"pbdoc(
1867+ An array filled with ``vals`` with the same shape as the input.
1868+
1869+ Args:
1870+ a (array): The input to take the shape from.
1871+ vals (float or int or array): Values to fill the array with.
1872+ dtype (Dtype, optional): Data type of the output array. If
1873+ unspecified the type of the input is used.
1874+
1875+ Returns:
1876+ array: The output array.
1877+ )pbdoc" );
17951878 m.def (
17961879 " zeros" ,
17971880 [](const nb::object& shape,
@@ -2494,6 +2577,41 @@ void init_ops(nb::module_& m) {
24942577 Returns:
24952578 array: The output array with the corresponding axes reduced.
24962579 )pbdoc" );
2580+ m.def (
2581+ " count_nonzero" ,
2582+ [](const mx::array& a,
2583+ const IntOrVec& axis,
2584+ bool keepdims,
2585+ mx::StreamOrDevice s) {
2586+ if (std::holds_alternative<std::monostate>(axis)) {
2587+ return mx::count_nonzero (a, keepdims, s);
2588+ } else if (auto pv = std::get_if<int >(&axis); pv) {
2589+ return mx::count_nonzero (a, *pv, keepdims, s);
2590+ } else {
2591+ return mx::count_nonzero (
2592+ a, std::get<std::vector<int >>(axis), keepdims, s);
2593+ }
2594+ },
2595+ nb::arg (),
2596+ " axis" _a = nb::none (),
2597+ nb::kw_only (),
2598+ " keepdims" _a = false ,
2599+ " stream" _a = nb::none (),
2600+ nb::sig (
2601+ " def count_nonzero(a: array, /, *, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, stream: Union[None, Stream, Device] = None) -> array" ),
2602+ R"pbdoc(
2603+ Count the number of non-zero elements along the given axis.
2604+
2605+ Args:
2606+ a (array): Input array.
2607+ axis (int or tuple(int), optional): Axis or axes to count over.
2608+ Defaults to ``None`` in which case the whole array is counted.
2609+ keepdims (bool, optional): Keep the reduced axes as size one.
2610+ Default: ``False``.
2611+
2612+ Returns:
2613+ array: The counts as an ``int32`` array.
2614+ )pbdoc" );
24972615 m.def (
24982616 " prod" ,
24992617 [](const mx::array& a,
@@ -3585,6 +3703,28 @@ void init_ops(nb::module_& m) {
35853703 Returns:
35863704 array: The output array.
35873705 )pbdoc" );
3706+ m.def (
3707+ " diff" ,
3708+ &mx::diff,
3709+ nb::arg (),
3710+ " n" _a = 1 ,
3711+ " axis" _a = -1 ,
3712+ nb::kw_only (),
3713+ " stream" _a = nb::none (),
3714+ nb::sig (
3715+ " def diff(a: array, /, n: int = 1, axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array" ),
3716+ R"pbdoc(
3717+ The n-th discrete difference along the given axis.
3718+
3719+ Args:
3720+ a (array): Input array.
3721+ n (int, optional): The number of times to difference. Default: ``1``.
3722+ axis (int, optional): The axis along which to difference.
3723+ Default: ``-1``.
3724+
3725+ Returns:
3726+ array: The n-th differences.
3727+ )pbdoc" );
35883728 m.def (
35893729 " conj" ,
35903730 [](const ScalarOrArray& a, mx::StreamOrDevice s) {
0 commit comments