Skip to content

Commit 7de9081

Browse files
dcheriantomwhite
andauthored
Add scan (#531)
* Add scan. Closes #277 * Add `cumulative_sum` and tests. Support arrays other than 2D in `scan`. Don't assume axis is last dimension. * Convert Cubed arrays to Zarr arrays for direct access in map_direct function. --------- Co-authored-by: Tom White <tom.e.white@gmail.com>
1 parent d7c44df commit 7de9081

7 files changed

Lines changed: 175 additions & 6 deletions

File tree

cubed/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,18 @@
330330

331331
__all__ += ["argmax", "argmin", "searchsorted", "where"]
332332

333-
from .array_api.statistical_functions import max, mean, min, prod, std, sum, var
333+
from .array_api.statistical_functions import (
334+
cumulative_sum,
335+
max,
336+
mean,
337+
min,
338+
prod,
339+
std,
340+
sum,
341+
var,
342+
)
334343

335-
__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]
344+
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
336345

337346
from .array_api.utility_functions import all, any
338347

cubed/array_api/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,9 +268,9 @@
268268

269269
__all__ += ["argmax", "argmin", "searchsorted", "where"]
270270

271-
from .statistical_functions import max, mean, min, prod, std, sum, var
271+
from .statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var
272272

273-
__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]
273+
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
274274

275275
from .utility_functions import all, any
276276

cubed/array_api/statistical_functions.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,23 @@
77
)
88
from cubed.array_api.elementwise_functions import sqrt
99
from cubed.backend_array_api import namespace as nxp
10-
from cubed.core import reduction
10+
from cubed.core import reduction, scan
11+
12+
13+
def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False, device=None):
14+
dtype = _upcast_integral_dtypes(
15+
x,
16+
dtype,
17+
allowed_dtypes=(
18+
"numeric",
19+
"boolean",
20+
),
21+
fname="cumulative_sum",
22+
device=device,
23+
)
24+
return scan(
25+
x, preop=nxp.sum, func=nxp.cumulative_sum, binop=nxp.add, identity=0, axis=axis
26+
)
1127

1228

1329
def max(x, /, *, axis=None, keepdims=False, split_every=None):

cubed/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
map_blocks,
1010
rechunk,
1111
reduction,
12+
scan,
1213
squeeze,
1314
store,
1415
to_zarr,

cubed/core/ops.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from itertools import product
77
from numbers import Integral, Number
88
from operator import add
9-
from typing import TYPE_CHECKING, Any, Sequence, Tuple, Union
9+
from typing import TYPE_CHECKING, Any, Callable, Sequence, Tuple, Union
1010
from warnings import warn
1111

1212
import ndindex
@@ -23,6 +23,7 @@
2323
from cubed.core.plan import Plan, new_temp_path
2424
from cubed.primitive.blockwise import blockwise as primitive_blockwise
2525
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
26+
from cubed.primitive.blockwise import key_to_slices
2627
from cubed.primitive.memory import get_buffer_copies
2728
from cubed.primitive.rechunk import rechunk as primitive_rechunk
2829
from cubed.spec import spec_from_config
@@ -1678,3 +1679,117 @@ def smallest_blockdim(blockdims):
16781679
m = ntd[0]
16791680
out = ntd
16801681
return out
1682+
1683+
1684+
def _scan_binop(
1685+
out: np.ndarray,
1686+
left: "Array",
1687+
right: "Array",
1688+
*,
1689+
binop: Callable,
1690+
block_id: tuple[int, ...],
1691+
axis: int,
1692+
identity: Any,
1693+
) -> "Array":
1694+
# Get the underlying Zarr arrays so we can access directly
1695+
left = left.zarray
1696+
right = right.zarray
1697+
1698+
left_slicer = key_to_slices(block_id, left)
1699+
right_slicer = list(left_slicer)
1700+
1701+
# For the first block, we add the identity element
1702+
# For all other blocks `k`, we add the `k-1` element along `axis`
1703+
right_slicer[axis] = slice(block_id[axis] - 1, block_id[axis])
1704+
right_slicer = tuple(right_slicer)
1705+
right_ = right[right_slicer] if block_id[axis] > 0 else identity
1706+
return binop(left[left_slicer], right_)
1707+
1708+
1709+
def scan(
1710+
array: "Array",
1711+
func: Callable,
1712+
*,
1713+
preop: Callable,
1714+
binop: Callable,
1715+
identity: Any,
1716+
axis: int,
1717+
dtype=None,
1718+
) -> "Array":
1719+
"""
1720+
Generic parallel scan.
1721+
1722+
Parameters
1723+
----------
1724+
x: Cubed Array
1725+
func: callable
1726+
Scan or cumulative function like np.cumsum or np.cumprod
1727+
preop: callable
1728+
Function applied blockwise that reduces each block to a single value
1729+
along ``axis``. For ``np.cumsum`` this is ``np.sum`` and for ``np.cumprod`` this is ``np.prod``.
1730+
binop: callable
1731+
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
1732+
identity: Any
1733+
Associated identity element more scan like 0 for ``np.cumsum`` and 1 for ``np.cumprod``.
1734+
axis: int
1735+
dtype: dtype
1736+
1737+
Notes
1738+
-----
1739+
This method uses a variant of the Blelloch (1989) alogrithm.
1740+
1741+
Returns
1742+
-------
1743+
Array
1744+
1745+
See also
1746+
--------
1747+
cumsum
1748+
cumprod
1749+
"""
1750+
axis = validate_axis(axis, array.ndim)
1751+
1752+
# Blelloch (1990) out-of-core algorithm.
1753+
# 1. First, scan blockwise
1754+
scanned = map_blocks(func, array, axis=axis)
1755+
# If there is only a single chunk, we can be done
1756+
if array.numblocks[axis] == 1:
1757+
return scanned
1758+
1759+
# 2. Calculate the blockwise reduction using `preop`
1760+
# TODO: could also merge(1,2) by returning {"scan": np.cumsum(array), "preop": np.sum(array)} in `scanned`
1761+
reduced_chunks = tuple(
1762+
(1,) * array.numblocks[i] if i == axis else c
1763+
for i, c in enumerate(array.chunks)
1764+
)
1765+
reduced = map_blocks(preop, array, chunks=reduced_chunks, axis=axis, keepdims=True)
1766+
1767+
# 3. Now scan `reduced` to generate the increments for each block of `scanned`.
1768+
# Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan.
1769+
# Instead we generalize recursively apply the scan to `reduced`.
1770+
# 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1
1771+
new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * 5)
1772+
new_chunks = (
1773+
reduced.chunksize[:axis] + (new_chunksize,) + reduced.chunksize[axis + 1 :]
1774+
)
1775+
1776+
merged = merge_chunks(reduced, new_chunks)
1777+
1778+
# 3b. Recursively scan this merged array to generate the increment for each block of `scanned`
1779+
increment = scan(
1780+
merged, func, preop=preop, binop=binop, identity=identity, axis=axis
1781+
)
1782+
1783+
# 4. Back to Blelloch. Now that we have the increment, add it to the blocks of `scanned`.
1784+
# Use map_direct since the chunks of increment and scanned aren't aligned anymore.
1785+
assert increment.shape[axis] == scanned.numblocks[axis]
1786+
# 5. Bada-bing, bada-boom.
1787+
return map_direct(
1788+
partial(_scan_binop, binop=binop, axis=axis, identity=identity),
1789+
scanned,
1790+
increment,
1791+
shape=scanned.shape,
1792+
dtype=scanned.dtype,
1793+
chunks=scanned.chunks,
1794+
extra_projected_mem=scanned.chunkmem * 2, # arbitrary
1795+
)

cubed/tests/test_array_api.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,25 @@ def test_where_scalars():
826826
# Statistical functions
827827

828828

829+
@pytest.mark.parametrize("axis", [0, 1])
830+
def test_cumulative_sum_2d(axis):
831+
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2))
832+
b = xp.cumulative_sum(a, axis=axis)
833+
assert_array_equal(
834+
b.compute(),
835+
np.cumulative_sum(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), axis=axis),
836+
)
837+
838+
839+
def test_cumulative_sum_1d():
840+
a = xp.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], chunks=(4,))
841+
b = xp.cumulative_sum(a, axis=0)
842+
assert_array_equal(
843+
b.compute(),
844+
np.cumulative_sum(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), axis=0),
845+
)
846+
847+
829848
def test_mean_axis_0(spec, executor):
830849
a = xp.asarray(
831850
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec

cubed/tests/test_mem_utilization.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,15 @@ def test_argmax(tmp_path, spec, executor):
341341
# Statistical Functions
342342

343343

344+
@pytest.mark.slow
345+
def test_cumulative_sum(tmp_path, spec, executor):
346+
a = cubed.random.random(
347+
(50000, 5000), chunks=(5000, 5000), spec=spec
348+
) # 200MB chunks
349+
b = xp.cumulative_sum(a, axis=0)
350+
run_operation(tmp_path, executor, "cumulative_sum", b)
351+
352+
344353
@pytest.mark.slow
345354
def test_max(tmp_path, spec, executor):
346355
a = cubed.random.random(

0 commit comments

Comments
 (0)