Skip to content

Commit 970f75b

Browse files
committed
Use general_blockwise for scan, rather than map_direct
1 parent f2c0ae3 commit 970f75b

3 files changed

Lines changed: 76 additions & 42 deletions

File tree

cubed/array_api/statistical_functions.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212

1313
def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False, device=None):
14+
if include_initial:
15+
raise NotImplementedError("include_initial is not supported in cumulative_sum")
1416
dtype = _upcast_integral_dtypes(
1517
x,
1618
dtype,
@@ -21,9 +23,18 @@ def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False, device
2123
fname="cumulative_sum",
2224
device=device,
2325
)
24-
return scan(
25-
x, preop=nxp.sum, func=nxp.cumulative_sum, binop=nxp.add, identity=0, axis=axis
26-
)
26+
return scan(x, preop=nxp.sum, func=nxp.cumulative_sum, binop=nxp.add, axis=axis)
27+
28+
29+
def _cumulative_sum_func(a, /, *, axis=None, dtype=None, include_initial=False):
30+
out = nxp.cumulative_sum(a, axis=axis, dtype=dtype, include_initial=include_initial)
31+
if include_initial:
32+
# we don't yet support including the final element as it complicates chunk sizing
33+
ind = tuple(
34+
slice(a.shape[i] - 1) if i == axis else slice(None) for i in range(a.ndim)
35+
)
36+
out = out[ind]
37+
return out
2738

2839

2940
def max(x, /, *, axis=None, keepdims=False, split_every=None):

cubed/core/ops.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from cubed.core.plan import Plan, new_temp_path
2424
from cubed.primitive.blockwise import blockwise as primitive_blockwise
2525
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
26-
from cubed.primitive.blockwise import key_to_slices
2726
from cubed.primitive.memory import get_buffer_copies
2827
from cubed.primitive.rechunk import rechunk as primitive_rechunk
2928
from cubed.spec import spec_from_config
@@ -1684,40 +1683,16 @@ def smallest_blockdim(blockdims):
16841683
return out
16851684

16861685

1687-
def _scan_binop(
1688-
out: np.ndarray,
1689-
left: "Array",
1690-
right: "Array",
1691-
*,
1692-
binop: Callable,
1693-
block_id: tuple[int, ...],
1694-
axis: int,
1695-
identity: Any,
1696-
) -> "Array":
1697-
# Get the underlying Zarr arrays so we can access directly
1698-
left = left.zarray
1699-
right = right.zarray
1700-
1701-
left_slicer = key_to_slices(block_id, left)
1702-
right_slicer = list(left_slicer)
1703-
1704-
# For the first block, we add the identity element
1705-
# For all other blocks `k`, we add the `k-1` element along `axis`
1706-
right_slicer[axis] = slice(block_id[axis] - 1, block_id[axis])
1707-
right_slicer = tuple(right_slicer)
1708-
right_ = right[right_slicer] if block_id[axis] > 0 else identity
1709-
return binop(left[left_slicer], right_)
1710-
1711-
17121686
def scan(
17131687
array: "Array",
17141688
func: Callable,
17151689
*,
17161690
preop: Callable,
17171691
binop: Callable,
1718-
identity: Any,
17191692
axis: int,
17201693
dtype=None,
1694+
include_initial=False,
1695+
split_every: int = 5,
17211696
) -> "Array":
17221697
"""
17231698
Generic parallel scan.
@@ -1732,10 +1707,10 @@ def scan(
17321707
along ``axis``. For ``np.cumsum`` this is ``np.sum`` and for ``np.cumprod`` this is ``np.prod``.
17331708
binop: callable
17341709
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
1735-
identity: Any
1736-
Associated identity element more scan like 0 for ``np.cumsum`` and 1 for ``np.cumprod``.
17371710
axis: int
17381711
dtype: dtype
1712+
include_initial: bool
1713+
Whether to include the identity value as the first value in the output.
17391714
17401715
Notes
17411716
-----
@@ -1750,17 +1725,22 @@ def scan(
17501725
cumsum
17511726
cumprod
17521727
"""
1728+
1729+
# Note that if include_initial=True the final value is *not* included in the output.
1730+
# To include the final value is tricky with constant chunk sizes, since if the last
1731+
# chunk is full then a new chunk of size one needs to be added for the final value.
1732+
# TODO: add an include_final argument (default True)
1733+
17531734
axis = validate_axis(axis, array.ndim)
17541735

17551736
# Blelloch (1990) out-of-core algorithm.
17561737
# 1. First, scan blockwise
1757-
scanned = map_blocks(func, array, axis=axis)
1738+
scanned = map_blocks(func, array, axis=axis, include_initial=include_initial)
17581739
# If there is only a single chunk, we can be done
17591740
if array.numblocks[axis] == 1:
17601741
return scanned
17611742

17621743
# 2. Calculate the blockwise reduction using `preop`
1763-
# TODO: could also merge(1,2) by returning {"scan": np.cumsum(array), "preop": np.sum(array)} in `scanned`
17641744
reduced_chunks = tuple(
17651745
(1,) * array.numblocks[i] if i == axis else c
17661746
for i, c in enumerate(array.chunks)
@@ -1771,28 +1751,56 @@ def scan(
17711751
# Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan.
17721752
# Instead we generalize recursively apply the scan to `reduced`.
17731753
# 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1
1774-
new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * 5)
1754+
new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * split_every)
17751755
new_chunks = (
17761756
reduced.chunksize[:axis] + (new_chunksize,) + reduced.chunksize[axis + 1 :]
17771757
)
17781758

17791759
merged = merge_chunks(reduced, new_chunks)
17801760

17811761
# 3b. Recursively scan this merged array to generate the increment for each block of `scanned`
1762+
# Note we always want to include the initial identity value (but not the final value)
1763+
# so blocks line up correctly.
17821764
increment = scan(
1783-
merged, func, preop=preop, binop=binop, identity=identity, axis=axis
1765+
merged,
1766+
func,
1767+
preop=preop,
1768+
binop=binop,
1769+
axis=axis,
1770+
include_initial=True,
17841771
)
17851772

17861773
# 4. Back to Blelloch. Now that we have the increment, add it to the blocks of `scanned`.
1787-
# Use map_direct since the chunks of increment and scanned aren't aligned anymore.
1774+
# Use general_blockwise with a key function since the chunks of increment and scanned aren't aligned anymore.
17881775
assert increment.shape[axis] == scanned.numblocks[axis]
1776+
1777+
def key_function(out_key):
1778+
out_coords = out_key[1:]
1779+
inc_coords = tuple(
1780+
bi // split_every if i == axis else bi for i, bi in enumerate(out_coords)
1781+
)
1782+
return ((scanned.name,) + out_coords, (increment.name,) + inc_coords)
1783+
1784+
def _scan_binop(scn, inc, block_id=None, **kwargs):
1785+
bi = block_id[axis] % split_every
1786+
ind = tuple(
1787+
slice(bi, bi + 1) if i == axis else slice(None) for i in range(inc.ndim)
1788+
)
1789+
return binop(scn, inc[ind])
1790+
17891791
# 5. Bada-bing, bada-boom.
1790-
return map_direct(
1791-
partial(_scan_binop, binop=binop, axis=axis, identity=identity),
1792+
out = general_blockwise(
1793+
_scan_binop,
1794+
key_function,
17921795
scanned,
17931796
increment,
1794-
shape=scanned.shape,
1795-
dtype=scanned.dtype,
1796-
chunks=scanned.chunks,
1797+
shapes=[scanned.shape],
1798+
dtypes=[scanned.dtype],
1799+
chunkss=[scanned.chunks],
17971800
extra_projected_mem=scanned.chunkmem * 2, # arbitrary
17981801
)
1802+
1803+
from cubed import Array
1804+
1805+
assert isinstance(out, Array) # single output
1806+
return out

cubed/tests/test_array_api.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,15 @@ def test_cumulative_sum_2d(axis):
836836
)
837837

838838

839+
def test_cumulative_sum_2d_recursive(executor):
840+
a = xp.ones((10, 100), chunks=(10, 10))
841+
b = xp.cumulative_sum(a, axis=1)
842+
assert_array_equal(
843+
b.compute(executor=executor),
844+
np.cumulative_sum(np.ones((10, 100)), axis=1),
845+
)
846+
847+
839848
def test_cumulative_sum_1d():
840849
a = xp.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], chunks=(4,))
841850
b = xp.cumulative_sum(a, axis=0)
@@ -845,6 +854,12 @@ def test_cumulative_sum_1d():
845854
)
846855

847856

857+
def test_cumulative_sum_unsupported_include_initial(spec):
858+
with pytest.raises(NotImplementedError):
859+
a = xp.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], chunks=(4,))
860+
xp.cumulative_sum(a, axis=0, include_initial=True)
861+
862+
848863
def test_mean_axis_0(spec, executor):
849864
a = xp.asarray(
850865
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec

0 commit comments

Comments
 (0)