Skip to content

Commit e94b415

Browse files
katlun-lgtmzcbenz
andauthored
array API: add positive, logical_xor, trunc, count_nonzero, diff, full_like (#3730)
Co-authored-by: katlun-lgtm <264247399+katlun-lgtm@users.noreply.github.com> Co-authored-by: Cheng <git@zcbenz.com>
1 parent c9ccaba commit e94b415

5 files changed

Lines changed: 275 additions & 0 deletions

File tree

docs/src/python/ops.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,13 @@ Operations
6565
cummin
6666
cumprod
6767
cumsum
68+
count_nonzero
6869
degrees
6970
depends
7071
dequantize
7172
diag
7273
diagonal
74+
diff
7375
divide
7476
divmod
7577
einsum
@@ -87,6 +89,7 @@ Operations
8789
floor_divide
8890
full
8991
from_dlpack
92+
full_like
9093
from_fp8
9194
gather_mm
9295
gather_qmm
@@ -121,6 +124,7 @@ Operations
121124
logical_not
122125
logical_and
123126
logical_or
127+
logical_xor
124128
logsumexp
125129
matmul
126130
max
@@ -141,6 +145,7 @@ Operations
141145
partition
142146
pad
143147
permute_dims
148+
positive
144149
power
145150
prod
146151
put_along_axis
@@ -195,6 +200,7 @@ Operations
195200
tri
196201
tril
197202
triu
203+
trunc
198204
unflatten
199205
unstack
200206
vecdot

mlx/ops.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,6 +2158,32 @@ array sum(
21582158
return sum(a, std::vector<int>{axis}, keepdims, s);
21592159
}
21602160

2161+
array count_nonzero(
2162+
const array& a,
2163+
bool keepdims /* = false */,
2164+
StreamOrDevice s /* = {} */) {
2165+
std::vector<int> axes(a.ndim());
2166+
std::iota(axes.begin(), axes.end(), 0);
2167+
return count_nonzero(a, axes, keepdims, s);
2168+
}
2169+
2170+
array count_nonzero(
2171+
const array& a,
2172+
int axis,
2173+
bool keepdims /* = false */,
2174+
StreamOrDevice s /* = {} */) {
2175+
return count_nonzero(a, std::vector<int>{axis}, keepdims, s);
2176+
}
2177+
2178+
array count_nonzero(
2179+
const array& a,
2180+
const std::vector<int>& axes,
2181+
bool keepdims /* = false */,
2182+
StreamOrDevice s /* = {} */) {
2183+
auto nz = astype(not_equal(a, array(0, a.dtype()), s), int32, s);
2184+
return sum(nz, axes, keepdims, s);
2185+
}
2186+
21612187
array mean(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
21622188
std::vector<int> axes(a.ndim());
21632189
std::iota(axes.begin(), axes.end(), 0);
@@ -2848,6 +2874,10 @@ array abs(const array& a, StreamOrDevice s /* = {} */) {
28482874
return out;
28492875
}
28502876

2877+
array positive(const array& a, StreamOrDevice s /* = {} */) {
2878+
return array(a);
2879+
}
2880+
28512881
array negative(const array& a, StreamOrDevice s /* = {} */) {
28522882
if (a.dtype() == bool_) {
28532883
auto msg = "[negative] Not supported for bool, use logical_not instead.";
@@ -2900,6 +2930,10 @@ array operator||(const array& a, const array& b) {
29002930
return logical_or(a, b);
29012931
}
29022932

2933+
array logical_xor(const array& a, const array& b, StreamOrDevice s /* = {} */) {
2934+
return not_equal(astype(a, bool_, s), astype(b, bool_, s), s);
2935+
}
2936+
29032937
array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
29042938
auto dtype = at_least_float(a.dtype());
29052939
return divide(array(1.0f, dtype), a, to_stream(s));
@@ -3052,6 +3086,17 @@ array ceil(const array& a, StreamOrDevice s /* = {} */) {
30523086
return array(a.shape(), a.dtype(), std::make_shared<Ceil>(to_stream(s)), {a});
30533087
}
30543088

3089+
array trunc(const array& a, StreamOrDevice s /* = {} */) {
3090+
if (a.dtype() == complex64) {
3091+
throw std::invalid_argument("[trunc] Not supported for complex64.");
3092+
}
3093+
if (issubdtype(a.dtype(), integer)) {
3094+
return array(a);
3095+
}
3096+
auto zero = array(0, a.dtype());
3097+
return where(less(a, zero, s), ceil(a, s), floor(a, s), s);
3098+
}
3099+
30553100
array square(const array& a, StreamOrDevice s /* = {} */) {
30563101
return array(
30573102
a.shape(), a.dtype(), std::make_shared<Square>(to_stream(s)), {a});
@@ -4063,6 +4108,33 @@ array cummin(
40634108
return cummin(flatten(a, s), 0, reverse, inclusive, s);
40644109
}
40654110

4111+
array diff(
4112+
const array& a,
4113+
int n /* = 1 */,
4114+
int axis /* = -1 */,
4115+
StreamOrDevice s /* = {} */) {
4116+
int ndim = static_cast<int>(a.ndim());
4117+
int ax = axis < 0 ? axis + ndim : axis;
4118+
if (ax < 0 || ax >= ndim) {
4119+
throw std::invalid_argument("[diff] Axis is out of bounds for the array.");
4120+
}
4121+
if (n < 0) {
4122+
throw std::invalid_argument("[diff] Order `n` must be non-negative.");
4123+
}
4124+
array x = a;
4125+
for (int i = 0; i < n; ++i) {
4126+
Shape upper_start(x.ndim(), 0);
4127+
Shape lower_stop = x.shape();
4128+
Shape strides(x.ndim(), 1);
4129+
upper_start[ax] = 1;
4130+
lower_stop[ax] = x.shape(ax) - 1;
4131+
auto upper = slice(x, upper_start, x.shape(), strides, s);
4132+
auto lower = slice(x, Shape(x.ndim(), 0), lower_stop, strides, s);
4133+
x = subtract(upper, lower, s);
4134+
}
4135+
return x;
4136+
}
4137+
40664138
array logcumsumexp(
40674139
const array& a,
40684140
int axis,

mlx/ops.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,20 @@ sum(const array& a,
622622
MLX_API array
623623
sum(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});
624624

625+
/** Count the number of non-zero elements in an array. */
626+
MLX_API array
627+
count_nonzero(const array& a, bool keepdims = false, StreamOrDevice s = {});
628+
MLX_API array count_nonzero(
629+
const array& a,
630+
int axis,
631+
bool keepdims = false,
632+
StreamOrDevice s = {});
633+
MLX_API array count_nonzero(
634+
const array& a,
635+
const std::vector<int>& axes,
636+
bool keepdims = false,
637+
StreamOrDevice s = {});
638+
625639
/** Computes the mean of the elements of an array. */
626640
MLX_API array mean(const array& a, bool keepdims, StreamOrDevice s = {});
627641
inline array mean(const array& a, StreamOrDevice s = {}) {
@@ -883,6 +897,9 @@ MLX_API array logsumexp(
883897
/** Absolute value of elements in an array. */
884898
MLX_API array abs(const array& a, StreamOrDevice s = {});
885899

900+
/** Unary plus — return a copy of the array unchanged. */
901+
MLX_API array positive(const array& a, StreamOrDevice s = {});
902+
886903
/** Negate an array. */
887904
MLX_API array negative(const array& a, StreamOrDevice s = {});
888905
MLX_API array operator-(const array& a);
@@ -902,6 +919,10 @@ MLX_API array operator&&(const array& a, const array& b);
902919
MLX_API array logical_or(const array& a, const array& b, StreamOrDevice s = {});
903920
MLX_API array operator||(const array& a, const array& b);
904921

922+
/** Logical exclusive or of two arrays */
923+
MLX_API array
924+
logical_xor(const array& a, const array& b, StreamOrDevice s = {});
925+
905926
/** The reciprocal (1/x) of the elements in an array. */
906927
MLX_API array reciprocal(const array& a, StreamOrDevice s = {});
907928

@@ -979,6 +1000,9 @@ MLX_API array floor(const array& a, StreamOrDevice s = {});
9791000
/** Ceil the element of an array. **/
9801001
MLX_API array ceil(const array& a, StreamOrDevice s = {});
9811002

1003+
/** Truncate the elements of an array towards zero. **/
1004+
MLX_API array trunc(const array& a, StreamOrDevice s = {});
1005+
9821006
/** Square the elements of an array. */
9831007
MLX_API array square(const array& a, StreamOrDevice s = {});
9841008

@@ -1388,6 +1412,10 @@ MLX_API array cummin(
13881412
bool inclusive = true,
13891413
StreamOrDevice s = {});
13901414

1415+
/** The n-th discrete difference along the given axis. */
1416+
MLX_API array
1417+
diff(const array& a, int n = 1, int axis = -1, StreamOrDevice s = {});
1418+
13911419
/** General convolution with a filter */
13921420
MLX_API array conv_general(
13931421
array input,

python/src/ops.cpp

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,23 @@ void init_ops(nb::module_& m) {
297297
Returns:
298298
array: The sign of ``a``.
299299
)pbdoc");
300+
m.def(
301+
"positive",
302+
&mx::positive,
303+
nb::arg(),
304+
nb::kw_only(),
305+
"stream"_a = nb::none(),
306+
nb::sig(
307+
"def positive(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
308+
R"pbdoc(
309+
Element-wise unary plus. Returns a copy of the input.
310+
311+
Args:
312+
a (array): Input array.
313+
314+
Returns:
315+
array: A copy of ``a``.
316+
)pbdoc");
300317
m.def(
301318
"negative",
302319
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
@@ -733,6 +750,23 @@ void init_ops(nb::module_& m) {
733750
Returns:
734751
array: The matrix product of ``a`` and ``b``.
735752
)pbdoc");
753+
m.def(
754+
"trunc",
755+
&mx::trunc,
756+
nb::arg(),
757+
nb::kw_only(),
758+
"stream"_a = nb::none(),
759+
nb::sig(
760+
"def trunc(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
761+
R"pbdoc(
762+
Element-wise truncation towards zero.
763+
764+
Args:
765+
a (array): Input array.
766+
767+
Returns:
768+
array: The truncated array.
769+
)pbdoc");
736770
m.def(
737771
"square",
738772
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
@@ -871,6 +905,27 @@ void init_ops(nb::module_& m) {
871905
Returns:
872906
array: The boolean array containing the logical or of ``a`` and ``b``.
873907
)pbdoc");
908+
m.def(
909+
"logical_xor",
910+
[](const ScalarOrArray& a, const ScalarOrArray& b, mx::StreamOrDevice s) {
911+
return mx::logical_xor(to_array(a), to_array(b), s);
912+
},
913+
nb::arg(),
914+
nb::arg(),
915+
nb::kw_only(),
916+
"stream"_a = nb::none(),
917+
nb::sig(
918+
"def logical_xor(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"),
919+
R"pbdoc(
920+
Element-wise logical exclusive or.
921+
922+
Args:
923+
a (array): First input array or scalar.
924+
b (array): Second input array or scalar.
925+
926+
Returns:
927+
array: The boolean array containing the logical xor of ``a`` and ``b``.
928+
)pbdoc");
874929
m.def(
875930
"logaddexp",
876931
[](const ScalarOrArray& a_,
@@ -1792,6 +1847,34 @@ void init_ops(nb::module_& m) {
17921847
Returns:
17931848
array: The output array with the specified shape and values.
17941849
)pbdoc");
1850+
m.def(
1851+
"full_like",
1852+
[](const mx::array& a,
1853+
const ScalarOrArray& vals,
1854+
std::optional<mx::Dtype> dtype,
1855+
mx::StreamOrDevice s) {
1856+
auto t = dtype.value_or(a.dtype());
1857+
return mx::full_like(a, to_array(vals, t), t, s);
1858+
},
1859+
nb::arg(),
1860+
"vals"_a,
1861+
"dtype"_a = nb::none(),
1862+
nb::kw_only(),
1863+
"stream"_a = nb::none(),
1864+
nb::sig(
1865+
"def full_like(a: array, vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
1866+
R"pbdoc(
1867+
An array filled with ``vals`` with the same shape as the input.
1868+
1869+
Args:
1870+
a (array): The input to take the shape from.
1871+
vals (float or int or array): Values to fill the array with.
1872+
dtype (Dtype, optional): Data type of the output array. If
1873+
unspecified the type of the input is used.
1874+
1875+
Returns:
1876+
array: The output array.
1877+
)pbdoc");
17951878
m.def(
17961879
"zeros",
17971880
[](const nb::object& shape,
@@ -2494,6 +2577,41 @@ void init_ops(nb::module_& m) {
24942577
Returns:
24952578
array: The output array with the corresponding axes reduced.
24962579
)pbdoc");
2580+
m.def(
2581+
"count_nonzero",
2582+
[](const mx::array& a,
2583+
const IntOrVec& axis,
2584+
bool keepdims,
2585+
mx::StreamOrDevice s) {
2586+
if (std::holds_alternative<std::monostate>(axis)) {
2587+
return mx::count_nonzero(a, keepdims, s);
2588+
} else if (auto pv = std::get_if<int>(&axis); pv) {
2589+
return mx::count_nonzero(a, *pv, keepdims, s);
2590+
} else {
2591+
return mx::count_nonzero(
2592+
a, std::get<std::vector<int>>(axis), keepdims, s);
2593+
}
2594+
},
2595+
nb::arg(),
2596+
"axis"_a = nb::none(),
2597+
nb::kw_only(),
2598+
"keepdims"_a = false,
2599+
"stream"_a = nb::none(),
2600+
nb::sig(
2601+
"def count_nonzero(a: array, /, *, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
2602+
R"pbdoc(
2603+
Count the number of non-zero elements along the given axis.
2604+
2605+
Args:
2606+
a (array): Input array.
2607+
axis (int or tuple(int), optional): Axis or axes to count over.
2608+
Defaults to ``None`` in which case the whole array is counted.
2609+
keepdims (bool, optional): Keep the reduced axes as size one.
2610+
Default: ``False``.
2611+
2612+
Returns:
2613+
array: The counts as an ``int32`` array.
2614+
)pbdoc");
24972615
m.def(
24982616
"prod",
24992617
[](const mx::array& a,
@@ -3585,6 +3703,28 @@ void init_ops(nb::module_& m) {
35853703
Returns:
35863704
array: The output array.
35873705
)pbdoc");
3706+
m.def(
3707+
"diff",
3708+
&mx::diff,
3709+
nb::arg(),
3710+
"n"_a = 1,
3711+
"axis"_a = -1,
3712+
nb::kw_only(),
3713+
"stream"_a = nb::none(),
3714+
nb::sig(
3715+
"def diff(a: array, /, n: int = 1, axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array"),
3716+
R"pbdoc(
3717+
The n-th discrete difference along the given axis.
3718+
3719+
Args:
3720+
a (array): Input array.
3721+
n (int, optional): The number of times to difference. Default: ``1``.
3722+
axis (int, optional): The axis along which to difference.
3723+
Default: ``-1``.
3724+
3725+
Returns:
3726+
array: The n-th differences.
3727+
)pbdoc");
35883728
m.def(
35893729
"conj",
35903730
[](const ScalarOrArray& a, mx::StreamOrDevice s) {

0 commit comments

Comments
 (0)