Skip to content

Commit e17e42d

Browse files
committed
BUG: fix cumsum/cumprod with arrays on a non-default device
1 parent 60688c7 commit e17e42d

2 files changed

Lines changed: 24 additions & 2 deletions

File tree

array_api_strict/_statistical_functions.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ def cumulative_sum(
3939
if include_initial:
4040
if axis < 0:
4141
axis += x.ndim
42-
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
42+
x = concat(
43+
[zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype, device=x.device), x],
44+
axis=axis
45+
)
4346
return Array._new(np.cumsum(x._array, axis=axis, dtype=_np_dtype(dtype)), device=x.device)
4447

4548

@@ -66,7 +69,10 @@ def cumulative_prod(
6669
if include_initial:
6770
if axis < 0:
6871
axis += x.ndim
69-
x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
72+
x = concat(
73+
[ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype, device=x.device), x],
74+
axis=axis
75+
)
7076
return Array._new(np.cumprod(x._array, axis=axis, dtype=_np_dtype(dtype)), device=x.device)
7177

7278

array_api_strict/tests/test_statistical_functions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,19 @@ def test_mean_complex():
5555
with pytest.raises(TypeError):
5656
xp.mean(xp.arange(3))
5757

58+
59+
def test_cumsum_device():
60+
x = xp.arange(3, device=xp.Device('device1'))
61+
y = xp.cumulative_sum(x, include_initial=True)
62+
expected = xp.asarray([0, 0, 1, 3], device=x.device)
63+
assert y.device == expected.device
64+
assert xp.all(y == expected)
65+
66+
67+
def test_cumprod_device():
68+
x = xp.arange(1, 4, device=xp.Device('device1'))
69+
y = xp.cumulative_prod(x, include_initial=True)
70+
expected = xp.asarray([1, 1, 2, 6], device=x.device)
71+
assert y.device == expected.device
72+
assert xp.all(y == expected)
73+

0 commit comments

Comments
 (0)