Skip to content

Commit 17a39c9

Browse files
committed
Address review: drop include_initial from cumsum/cumprod
Per maintainer request (#3731), remove the include_initial arg. It needed an extra concatenate to be correct, which is the inefficient pattern angeloskath flagged. cumsum/cumprod keep the efficient dtype arg only; cumulative_sum and cumulative_prod remain pure aliases of them.
1 parent 674d058 commit 17a39c9

4 files changed

Lines changed: 10 additions & 52 deletions

File tree

mlx/ops.cpp

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3942,7 +3942,6 @@ array cumsum(
39423942
bool reverse /* = false*/,
39433943
bool inclusive /* = true*/,
39443944
std::optional<Dtype> dtype /* = std::nullopt*/,
3945-
bool include_initial /* = false*/,
39463945
StreamOrDevice s /* = {}*/) {
39473946
int ndim = a.ndim();
39483947
if (axis >= ndim || axis < -ndim) {
@@ -3954,31 +3953,22 @@ array cumsum(
39543953
axis = (axis + a.ndim()) % a.ndim();
39553954
auto x = dtype ? astype(a, *dtype, s) : a;
39563955
auto out_type = x.dtype() == bool_ ? int32 : x.dtype();
3957-
auto out = array(
3956+
return array(
39583957
x.shape(),
39593958
out_type,
39603959
std::make_shared<Scan>(
39613960
to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive),
39623961
{x});
3963-
if (include_initial) {
3964-
Shape init_shape = out.shape();
3965-
init_shape[axis] = 1;
3966-
auto init = zeros(init_shape, out.dtype(), s);
3967-
out = reverse ? concatenate({out, init}, axis, s)
3968-
: concatenate({init, out}, axis, s);
3969-
}
3970-
return out;
39713962
}
39723963

39733964
array cumsum(
39743965
const array& a,
39753966
bool reverse /* = false*/,
39763967
bool inclusive /* = true*/,
39773968
std::optional<Dtype> dtype /* = std::nullopt*/,
3978-
bool include_initial /* = false*/,
39793969
StreamOrDevice s /* = {}*/) {
39803970
return cumsum(
3981-
flatten(a, to_stream(s)), 0, reverse, inclusive, dtype, include_initial, to_stream(s));
3971+
flatten(a, to_stream(s)), 0, reverse, inclusive, dtype, to_stream(s));
39823972
}
39833973

39843974
array cumprod(
@@ -3987,7 +3977,6 @@ array cumprod(
39873977
bool reverse /* = false*/,
39883978
bool inclusive /* = true*/,
39893979
std::optional<Dtype> dtype /* = std::nullopt*/,
3990-
bool include_initial /* = false*/,
39913980
StreamOrDevice s /* = {}*/) {
39923981
int ndim = a.ndim();
39933982
if (axis >= ndim || axis < -ndim) {
@@ -3998,30 +3987,21 @@ array cumprod(
39983987
}
39993988
axis = (axis + a.ndim()) % a.ndim();
40003989
auto x = dtype ? astype(a, *dtype, s) : a;
4001-
auto out = array(
3990+
return array(
40023991
x.shape(),
40033992
x.dtype(),
40043993
std::make_shared<Scan>(
40053994
to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive),
40063995
{x});
4007-
if (include_initial) {
4008-
Shape init_shape = out.shape();
4009-
init_shape[axis] = 1;
4010-
auto init = ones(init_shape, out.dtype(), s);
4011-
out = reverse ? concatenate({out, init}, axis, s)
4012-
: concatenate({init, out}, axis, s);
4013-
}
4014-
return out;
40153996
}
40163997

40173998
array cumprod(
40183999
const array& a,
40194000
bool reverse /* = false*/,
40204001
bool inclusive /* = true*/,
40214002
std::optional<Dtype> dtype /* = std::nullopt*/,
4022-
bool include_initial /* = false*/,
40234003
StreamOrDevice s /* = {}*/) {
4024-
return cumprod(flatten(a, s), 0, reverse, inclusive, dtype, include_initial, s);
4004+
return cumprod(flatten(a, s), 0, reverse, inclusive, dtype, s);
40254005
}
40264006

40274007
array cummax(

mlx/ops.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,7 +1334,6 @@ MLX_API array cumsum(
13341334
bool reverse = false,
13351335
bool inclusive = true,
13361336
std::optional<Dtype> dtype = std::nullopt,
1337-
bool include_initial = false,
13381337
StreamOrDevice s = {});
13391338

13401339
/** Cumulative sum of an array along the given axis. */
@@ -1344,7 +1343,6 @@ MLX_API array cumsum(
13441343
bool reverse = false,
13451344
bool inclusive = true,
13461345
std::optional<Dtype> dtype = std::nullopt,
1347-
bool include_initial = false,
13481346
StreamOrDevice s = {});
13491347

13501348
/** Cumulative product of an array. */
@@ -1353,7 +1351,6 @@ MLX_API array cumprod(
13531351
bool reverse = false,
13541352
bool inclusive = true,
13551353
std::optional<Dtype> dtype = std::nullopt,
1356-
bool include_initial = false,
13571354
StreamOrDevice s = {});
13581355

13591356
/** Cumulative product of an array along the given axis. */
@@ -1363,7 +1360,6 @@ MLX_API array cumprod(
13631360
bool reverse = false,
13641361
bool inclusive = true,
13651362
std::optional<Dtype> dtype = std::nullopt,
1366-
bool include_initial = false,
13671363
StreamOrDevice s = {});
13681364

13691365
/** Cumulative max of an array. */

python/src/ops.cpp

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3448,23 +3448,21 @@ void init_ops(nb::module_& m) {
34483448
bool reverse,
34493449
bool inclusive,
34503450
std::optional<mx::Dtype> dtype,
3451-
bool include_initial,
34523451
mx::StreamOrDevice s) {
34533452
if (axis) {
3454-
return mx::cumsum(a, *axis, reverse, inclusive, dtype, include_initial, s);
3453+
return mx::cumsum(a, *axis, reverse, inclusive, dtype, s);
34553454
}
3456-
return mx::cumsum(a, reverse, inclusive, dtype, include_initial, s);
3455+
return mx::cumsum(a, reverse, inclusive, dtype, s);
34573456
},
34583457
nb::arg(),
34593458
"axis"_a = nb::none(),
34603459
nb::kw_only(),
34613460
"reverse"_a = false,
34623461
"inclusive"_a = true,
34633462
"dtype"_a = nb::none(),
3464-
"include_initial"_a = false,
34653463
"stream"_a = nb::none(),
34663464
nb::sig(
3467-
"def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
3465+
"def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, stream: Union[None, Stream, Device] = None) -> array"),
34683466
R"pbdoc(
34693467
Return the cumulative sum of the elements along the given axis.
34703468
@@ -3477,8 +3475,6 @@ void init_ops(nb::module_& m) {
34773475
inclusive (bool): The i-th element of the output includes the i-th
34783476
element of the input.
34793477
dtype (Dtype, optional): Cast the input to this type before summing.
3480-
include_initial (bool): Prepend the identity element (0) so the
3481-
output has one extra element along the given axis.
34823478
34833479
Returns:
34843480
array: The output array.
@@ -3490,23 +3486,21 @@ void init_ops(nb::module_& m) {
34903486
bool reverse,
34913487
bool inclusive,
34923488
std::optional<mx::Dtype> dtype,
3493-
bool include_initial,
34943489
mx::StreamOrDevice s) {
34953490
if (axis) {
3496-
return mx::cumprod(a, *axis, reverse, inclusive, dtype, include_initial, s);
3491+
return mx::cumprod(a, *axis, reverse, inclusive, dtype, s);
34973492
}
3498-
return mx::cumprod(a, reverse, inclusive, dtype, include_initial, s);
3493+
return mx::cumprod(a, reverse, inclusive, dtype, s);
34993494
},
35003495
nb::arg(),
35013496
"axis"_a = nb::none(),
35023497
nb::kw_only(),
35033498
"reverse"_a = false,
35043499
"inclusive"_a = true,
35053500
"dtype"_a = nb::none(),
3506-
"include_initial"_a = false,
35073501
"stream"_a = nb::none(),
35083502
nb::sig(
3509-
"def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
3503+
"def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, stream: Union[None, Stream, Device] = None) -> array"),
35103504
R"pbdoc(
35113505
Return the cumulative product of the elements along the given axis.
35123506
@@ -3519,8 +3513,6 @@ void init_ops(nb::module_& m) {
35193513
inclusive (bool): The i-th element of the output includes the i-th
35203514
element of the input.
35213515
dtype (Dtype, optional): Cast the input to this type before multiplying.
3522-
include_initial (bool): Prepend the identity element (1) so the
3523-
output has one extra element along the given axis.
35243516
35253517
Returns:
35263518
array: The output array.

python/tests/test_ops.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3449,21 +3449,11 @@ def test_to_from_fp8(self):
34493449
def test_cumulative_sum_prod(self):
34503450
a = mx.array([1, 2, 3, 4])
34513451
self.assertEqual(mx.cumulative_sum(a).tolist(), [1, 3, 6, 10])
3452-
self.assertEqual(
3453-
mx.cumulative_sum(a, include_initial=True).tolist(), [0, 1, 3, 6, 10]
3454-
)
34553452
self.assertEqual(mx.cumulative_prod(a).tolist(), [1, 2, 6, 24])
3456-
self.assertEqual(
3457-
mx.cumulative_prod(a, include_initial=True).tolist(), [1, 1, 2, 6, 24]
3458-
)
34593453

34603454
m = mx.array([[1, 2], [3, 4]])
34613455
self.assertEqual(mx.cumulative_sum(m, axis=0).tolist(), [[1, 2], [4, 6]])
34623456
self.assertEqual(mx.cumulative_sum(m, axis=1).tolist(), [[1, 3], [3, 7]])
3463-
self.assertEqual(
3464-
mx.cumulative_sum(m, axis=1, include_initial=True).tolist(),
3465-
[[0, 1, 3], [0, 3, 7]],
3466-
)
34673457
# axis=None flattens.
34683458
self.assertEqual(mx.cumulative_sum(m).tolist(), [1, 3, 6, 10])
34693459
self.assertEqual(mx.cumulative_sum(a, dtype=mx.float32).dtype, mx.float32)

0 commit comments

Comments
 (0)