Skip to content

Commit ff67ca2

Browse files
authored
Use general_blockwise for scan, rather than map_direct (#726)
* Use general_blockwise for scan, rather than map_direct * Use `partial_reduce` in `scan` implementation for better memory management. * Update array api test skips file, and API status page
1 parent 8b42b71 commit ff67ca2

5 files changed

Lines changed: 93 additions & 59 deletions

File tree

.github/workflows/array-api-tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ jobs:
9191
array_api_tests/test_has_names.py
9292
9393
# signatures of items not implemented
94-
array_api_tests/test_signatures.py::test_func_signature[cumulative_sum]
9594
array_api_tests/test_signatures.py::test_func_signature[unique_all]
9695
array_api_tests/test_signatures.py::test_func_signature[unique_counts]
9796
array_api_tests/test_signatures.py::test_func_signature[unique_inverse]
@@ -112,13 +111,14 @@ jobs:
112111
array_api_tests/test_array_object.py::test_getitem
113112
# test_searchsorted depends on sort which is not implemented
114113
array_api_tests/test_searching_functions.py::test_searchsorted
114+
# cumulative_sum with include_initial=True is not implemented
115+
array_api_tests/test_statistical_functions.py::test_cumulative_sum
115116
116117
# not implemented
117118
array_api_tests/test_array_object.py::test_setitem
118119
array_api_tests/test_array_object.py::test_setitem_masking
119120
array_api_tests/test_manipulation_functions.py::test_repeat
120121
array_api_tests/test_sorting_functions.py
121-
array_api_tests/test_statistical_functions.py::test_cumulative_sum
122122
123123
# finfo(float32).eps returns float32 but should return float
124124
array_api_tests/test_data_type_functions.py::test_finfo[float32]

api_status.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
7979
| | `unique_values` | :x: | | Shape is data dependent |
8080
| Sorting Functions | `argsort` | :x: | | |
8181
| | `sort` | :x: | | |
82-
| Statistical Functions | `cumulative_prod` | :x: | 2024.12 | WIP [#531](https://github.com/cubed-dev/cubed/pull/531) |
83-
| | `cumulative_sum` | :x: | 2023.12 | WIP [#531](https://github.com/cubed-dev/cubed/pull/531) |
82+
| Statistical Functions | `cumulative_prod` | :x: | 2024.12 | |
83+
| | `cumulative_sum` | :white_check_mark: | 2023.12 | |
8484
| | `max` | :white_check_mark: | | |
8585
| | `mean` | :white_check_mark: | | |
8686
| | `min` | :white_check_mark: | | |

cubed/array_api/statistical_functions.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212

1313
def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False, device=None):
14+
if include_initial:
15+
raise NotImplementedError("include_initial is not supported in cumulative_sum")
1416
dtype = _upcast_integral_dtypes(
1517
x,
1618
dtype,
@@ -21,9 +23,18 @@ def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False, device
2123
fname="cumulative_sum",
2224
device=device,
2325
)
24-
return scan(
25-
x, preop=nxp.sum, func=nxp.cumulative_sum, binop=nxp.add, identity=0, axis=axis
26-
)
26+
return scan(x, preop=nxp.sum, func=_cumulative_sum_func, binop=nxp.add, axis=axis)
27+
28+
29+
def _cumulative_sum_func(a, /, *, axis=None, dtype=None, include_initial=False):
30+
out = nxp.cumulative_sum(a, axis=axis, dtype=dtype, include_initial=include_initial)
31+
if include_initial:
32+
# we don't yet support including the final element as it complicates chunk sizing
33+
ind = tuple(
34+
slice(a.shape[i]) if i == axis else slice(None) for i in range(a.ndim)
35+
)
36+
out = out[ind]
37+
return out
2738

2839

2940
def max(x, /, *, axis=None, keepdims=False, split_every=None):

cubed/core/ops.py

Lines changed: 60 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from cubed.core.plan import Plan, new_temp_path
2424
from cubed.primitive.blockwise import blockwise as primitive_blockwise
2525
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
26-
from cubed.primitive.blockwise import key_to_slices
2726
from cubed.primitive.memory import get_buffer_copies
2827
from cubed.primitive.rechunk import rechunk as primitive_rechunk
2928
from cubed.spec import spec_from_config
@@ -1680,40 +1679,16 @@ def smallest_blockdim(blockdims):
16801679
return out
16811680

16821681

1683-
def _scan_binop(
1684-
out: np.ndarray,
1685-
left: "Array",
1686-
right: "Array",
1687-
*,
1688-
binop: Callable,
1689-
block_id: tuple[int, ...],
1690-
axis: int,
1691-
identity: Any,
1692-
) -> "Array":
1693-
# Get the underlying Zarr arrays so we can access directly
1694-
left = left.zarray
1695-
right = right.zarray
1696-
1697-
left_slicer = key_to_slices(block_id, left)
1698-
right_slicer = list(left_slicer)
1699-
1700-
# For the first block, we add the identity element
1701-
# For all other blocks `k`, we add the `k-1` element along `axis`
1702-
right_slicer[axis] = slice(block_id[axis] - 1, block_id[axis])
1703-
right_slicer = tuple(right_slicer)
1704-
right_ = right[right_slicer] if block_id[axis] > 0 else identity
1705-
return binop(left[left_slicer], right_)
1706-
1707-
17081682
def scan(
17091683
array: "Array",
17101684
func: Callable,
17111685
*,
17121686
preop: Callable,
17131687
binop: Callable,
1714-
identity: Any,
17151688
axis: int,
17161689
dtype=None,
1690+
include_initial=False,
1691+
split_every: int = 5,
17171692
) -> "Array":
17181693
"""
17191694
Generic parallel scan.
@@ -1728,10 +1703,10 @@ def scan(
17281703
along ``axis``. For ``np.cumsum`` this is ``np.sum`` and for ``np.cumprod`` this is ``np.prod``.
17291704
binop: callable
17301705
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
1731-
identity: Any
1732-
Associated identity element more scan like 0 for ``np.cumsum`` and 1 for ``np.cumprod``.
17331706
axis: int
17341707
dtype: dtype
1708+
include_initial: bool
1709+
Whether to include the identity value as the first value in the output.
17351710
17361711
Notes
17371712
-----
@@ -1746,49 +1721,82 @@ def scan(
17461721
cumsum
17471722
cumprod
17481723
"""
1724+
1725+
# Note that if include_initial=True the final value is *not* included in the output.
1726+
# To include the final value is tricky with constant chunk sizes, since if the last
1727+
# chunk is full then a new chunk of size one needs to be added for the final value.
1728+
# TODO: add an include_final argument (default True)
1729+
17491730
axis = validate_axis(axis, array.ndim)
17501731

17511732
# Blelloch (1990) out-of-core algorithm.
17521733
# 1. First, scan blockwise
1753-
scanned = map_blocks(func, array, axis=axis)
1734+
scanned = map_blocks(func, array, axis=axis, include_initial=include_initial)
17541735
# If there is only a single chunk, we can be done
17551736
if array.numblocks[axis] == 1:
17561737
return scanned
17571738

1758-
# 2. Calculate the blockwise reduction using `preop`
1759-
# TODO: could also merge(1,2) by returning {"scan": np.cumsum(array), "preop": np.sum(array)} in `scanned`
1760-
reduced_chunks = tuple(
1761-
(1,) * array.numblocks[i] if i == axis else c
1762-
for i, c in enumerate(array.chunks)
1739+
# 2. Calculate the reduction using `preop`
1740+
# Use `partial_reduce` to also merge to a decent intermediate chunksize
1741+
# since reduced.chunksize[axis] == 1
1742+
1743+
def identity_func(a, **kwargs):
1744+
return a
1745+
1746+
split_size = min(split_every, array.numblocks[axis])
1747+
reduced = partial_reduce(
1748+
array,
1749+
initial_func=partial(preop, axis=axis, keepdims=True),
1750+
func=identity_func,
1751+
split_every={axis: split_size},
1752+
combine_sizes={axis: split_size},
17631753
)
1764-
reduced = map_blocks(preop, array, chunks=reduced_chunks, axis=axis, keepdims=True)
17651754

17661755
# 3. Now scan `reduced` to generate the increments for each block of `scanned`.
17671756
# Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan.
17681757
# Instead we generalize recursively apply the scan to `reduced`.
1769-
# 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1
1770-
new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * 5)
1771-
new_chunks = (
1772-
reduced.chunksize[:axis] + (new_chunksize,) + reduced.chunksize[axis + 1 :]
1773-
)
1774-
1775-
merged = merge_chunks(reduced, new_chunks)
1776-
1777-
# 3b. Recursively scan this merged array to generate the increment for each block of `scanned`
1758+
# Note we always want to include the initial identity value (but not the final value)
1759+
# so blocks line up correctly.
17781760
increment = scan(
1779-
merged, func, preop=preop, binop=binop, identity=identity, axis=axis
1761+
reduced,
1762+
func,
1763+
preop=preop,
1764+
binop=binop,
1765+
axis=axis,
1766+
include_initial=True,
17801767
)
17811768

17821769
# 4. Back to Blelloch. Now that we have the increment, add it to the blocks of `scanned`.
1783-
# Use map_direct since the chunks of increment and scanned aren't aligned anymore.
1770+
# Use general_blockwise with a key function since the chunks of increment and scanned aren't aligned anymore.
17841771
assert increment.shape[axis] == scanned.numblocks[axis]
1772+
1773+
def key_function(out_key):
1774+
out_coords = out_key[1:]
1775+
inc_coords = tuple(
1776+
bi // split_every if i == axis else bi for i, bi in enumerate(out_coords)
1777+
)
1778+
return ((scanned.name,) + out_coords, (increment.name,) + inc_coords)
1779+
1780+
def _scan_binop(scn, inc, block_id=None, **kwargs):
1781+
bi = block_id[axis] % split_every
1782+
ind = tuple(
1783+
slice(bi, bi + 1) if i == axis else slice(None) for i in range(inc.ndim)
1784+
)
1785+
return binop(scn, inc[ind])
1786+
17851787
# 5. Bada-bing, bada-boom.
1786-
return map_direct(
1787-
partial(_scan_binop, binop=binop, axis=axis, identity=identity),
1788+
out = general_blockwise(
1789+
_scan_binop,
1790+
key_function,
17881791
scanned,
17891792
increment,
1790-
shape=scanned.shape,
1791-
dtype=scanned.dtype,
1792-
chunks=scanned.chunks,
1793+
shapes=[scanned.shape],
1794+
dtypes=[scanned.dtype],
1795+
chunkss=[scanned.chunks],
17931796
extra_projected_mem=scanned.chunkmem * 2, # arbitrary
17941797
)
1798+
1799+
from cubed import Array
1800+
1801+
assert isinstance(out, Array) # single output
1802+
return out

cubed/tests/test_array_api.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,15 @@ def test_cumulative_sum_2d(axis):
836836
)
837837

838838

839+
def test_cumulative_sum_2d_recursive(executor):
840+
a = xp.ones((10, 100), chunks=(10, 10))
841+
b = xp.cumulative_sum(a, axis=1)
842+
assert_array_equal(
843+
b.compute(executor=executor),
844+
np.cumulative_sum(np.ones((10, 100)), axis=1),
845+
)
846+
847+
839848
def test_cumulative_sum_1d():
840849
a = xp.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], chunks=(4,))
841850
b = xp.cumulative_sum(a, axis=0)
@@ -845,6 +854,12 @@ def test_cumulative_sum_1d():
845854
)
846855

847856

857+
def test_cumulative_sum_unsupported_include_initial(spec):
858+
with pytest.raises(NotImplementedError):
859+
a = xp.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], chunks=(4,))
860+
xp.cumulative_sum(a, axis=0, include_initial=True)
861+
862+
848863
def test_mean_axis_0(spec, executor):
849864
a = xp.asarray(
850865
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec

0 commit comments

Comments
 (0)