diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 2dba8e819..89bb0e1f9 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -91,7 +91,6 @@ jobs: array_api_tests/test_has_names.py # signatures of items not implemented - array_api_tests/test_signatures.py::test_func_signature[cumulative_sum] array_api_tests/test_signatures.py::test_func_signature[unique_all] array_api_tests/test_signatures.py::test_func_signature[unique_counts] array_api_tests/test_signatures.py::test_func_signature[unique_inverse] @@ -112,13 +111,14 @@ jobs: array_api_tests/test_array_object.py::test_getitem # test_searchsorted depends on sort which is not implemented array_api_tests/test_searching_functions.py::test_searchsorted + # cumulative_sum with include_initial=True is not implemented + array_api_tests/test_statistical_functions.py::test_cumulative_sum # not implemented array_api_tests/test_array_object.py::test_setitem array_api_tests/test_array_object.py::test_setitem_masking array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_sorting_functions.py - array_api_tests/test_statistical_functions.py::test_cumulative_sum # finfo(float32).eps returns float32 but should return float array_api_tests/test_data_type_functions.py::test_finfo[float32] diff --git a/api_status.md b/api_status.md index 692da9bc7..caf18255c 100644 --- a/api_status.md +++ b/api_status.md @@ -79,8 +79,8 @@ This table shows which parts of the the [Array API](https://data-apis.org/array- | | `unique_values` | :x: | | Shape is data dependent | | Sorting Functions | `argsort` | :x: | | | | | `sort` | :x: | | | -| Statistical Functions | `cumulative_prod` | :x: | 2024.12 | WIP [#531](https://github.com/cubed-dev/cubed/pull/531) | -| | `cumulative_sum` | :x: | 2023.12 | WIP [#531](https://github.com/cubed-dev/cubed/pull/531) | +| Statistical Functions | `cumulative_prod` | :x: | 2024.12 | | +| | `cumulative_sum` | :white_check_mark: | 2023.12 | | | | `max` | :white_check_mark: | | | | | `mean` | :white_check_mark: | | | | | `min` | :white_check_mark: | | | diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index fd6651ad5..95106268d 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -11,6 +11,8 @@ def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False, device=None): + if include_initial: + raise NotImplementedError("include_initial is not supported in cumulative_sum") dtype = _upcast_integral_dtypes( x, dtype, @@ -21,9 +23,18 @@ def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False, device fname="cumulative_sum", device=device, ) - return scan( - x, preop=nxp.sum, func=nxp.cumulative_sum, binop=nxp.add, identity=0, axis=axis - ) + return scan(x, preop=nxp.sum, func=_cumulative_sum_func, binop=nxp.add, axis=axis) + + +def _cumulative_sum_func(a, /, *, axis=None, dtype=None, include_initial=False): + out = nxp.cumulative_sum(a, axis=axis, dtype=dtype, include_initial=include_initial) + if include_initial: + # we don't yet support including the final element as it complicates chunk sizing + ind = tuple( + slice(a.shape[i]) if i == axis else slice(None) for i in range(a.ndim) + ) + out = out[ind] + return out def max(x, /, *, axis=None, keepdims=False, split_every=None): diff --git a/cubed/core/ops.py b/cubed/core/ops.py index e9393a671..21d15d2cb 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -23,7 +23,6 @@ from cubed.core.plan import Plan, new_temp_path from cubed.primitive.blockwise import blockwise as primitive_blockwise from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise -from cubed.primitive.blockwise import key_to_slices from cubed.primitive.memory import get_buffer_copies from cubed.primitive.rechunk import rechunk as primitive_rechunk from cubed.spec import spec_from_config @@ -1684,40 +1683,16 @@ def smallest_blockdim(blockdims): return out -def _scan_binop( - out: np.ndarray, - left: "Array", - right: "Array", - *, - binop: Callable, - block_id: tuple[int, ...], - axis: int, - identity: Any, -) -> "Array": - # Get the underlying Zarr arrays so we can access directly - left = left.zarray - right = right.zarray - - left_slicer = key_to_slices(block_id, left) - right_slicer = list(left_slicer) - - # For the first block, we add the identity element - # For all other blocks `k`, we add the `k-1` element along `axis` - right_slicer[axis] = slice(block_id[axis] - 1, block_id[axis]) - right_slicer = tuple(right_slicer) - right_ = right[right_slicer] if block_id[axis] > 0 else identity - return binop(left[left_slicer], right_) - - def scan( array: "Array", func: Callable, *, preop: Callable, binop: Callable, - identity: Any, axis: int, dtype=None, + include_initial=False, + split_every: int = 5, ) -> "Array": """ Generic parallel scan. @@ -1732,10 +1707,10 @@ def scan( along ``axis``. For ``np.cumsum`` this is ``np.sum`` and for ``np.cumprod`` this is ``np.prod``. binop: callable Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul`` - identity: Any - Associated identity element more scan like 0 for ``np.cumsum`` and 1 for ``np.cumprod``. axis: int dtype: dtype + include_initial: bool + Whether to include the identity value as the first value in the output. Notes ----- @@ -1750,49 +1725,82 @@ def scan( cumsum cumprod """ + + # Note that if include_initial=True the final value is *not* included in the output. + # To include the final value is tricky with constant chunk sizes, since if the last + # chunk is full then a new chunk of size one needs to be added for the final value. + # TODO: add an include_final argument (default True) + axis = validate_axis(axis, array.ndim) # Blelloch (1990) out-of-core algorithm. # 1. First, scan blockwise - scanned = map_blocks(func, array, axis=axis) + scanned = map_blocks(func, array, axis=axis, include_initial=include_initial) # If there is only a single chunk, we can be done if array.numblocks[axis] == 1: return scanned - # 2. Calculate the blockwise reduction using `preop` - # TODO: could also merge(1,2) by returning {"scan": np.cumsum(array), "preop": np.sum(array)} in `scanned` - reduced_chunks = tuple( - (1,) * array.numblocks[i] if i == axis else c - for i, c in enumerate(array.chunks) + # 2. Calculate the reduction using `preop` + # Use `partial_reduce` to also merge to a decent intermediate chunksize + # since reduced.chunksize[axis] == 1 + + def identity_func(a, **kwargs): + return a + + split_size = min(split_every, array.numblocks[axis]) + reduced = partial_reduce( + array, + initial_func=partial(preop, axis=axis, keepdims=True), + func=identity_func, + split_every={axis: split_size}, + combine_sizes={axis: split_size}, ) - reduced = map_blocks(preop, array, chunks=reduced_chunks, axis=axis, keepdims=True) # 3. Now scan `reduced` to generate the increments for each block of `scanned`. # Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan. # Instead we generalize recursively apply the scan to `reduced`. - # 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1 - new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * 5) - new_chunks = ( - reduced.chunksize[:axis] + (new_chunksize,) + reduced.chunksize[axis + 1 :] - ) - - merged = merge_chunks(reduced, new_chunks) - - # 3b. Recursively scan this merged array to generate the increment for each block of `scanned` + # Note we always want to include the initial identity value (but not the final value) + # so blocks line up correctly. increment = scan( - merged, func, preop=preop, binop=binop, identity=identity, axis=axis + reduced, + func, + preop=preop, + binop=binop, + axis=axis, + include_initial=True, ) # 4. Back to Blelloch. Now that we have the increment, add it to the blocks of `scanned`. - # Use map_direct since the chunks of increment and scanned aren't aligned anymore. + # Use general_blockwise with a key function since the chunks of increment and scanned aren't aligned anymore. assert increment.shape[axis] == scanned.numblocks[axis] + + def key_function(out_key): + out_coords = out_key[1:] + inc_coords = tuple( + bi // split_every if i == axis else bi for i, bi in enumerate(out_coords) + ) + return ((scanned.name,) + out_coords, (increment.name,) + inc_coords) + + def _scan_binop(scn, inc, block_id=None, **kwargs): + bi = block_id[axis] % split_every + ind = tuple( + slice(bi, bi + 1) if i == axis else slice(None) for i in range(inc.ndim) + ) + return binop(scn, inc[ind]) + # 5. Bada-bing, bada-boom. - return map_direct( - partial(_scan_binop, binop=binop, axis=axis, identity=identity), + out = general_blockwise( + _scan_binop, + key_function, scanned, increment, - shape=scanned.shape, - dtype=scanned.dtype, - chunks=scanned.chunks, + shapes=[scanned.shape], + dtypes=[scanned.dtype], + chunkss=[scanned.chunks], extra_projected_mem=scanned.chunkmem * 2, # arbitrary ) + + from cubed import Array + + assert isinstance(out, Array) # single output + return out diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index c62cd81ba..91a95d767 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -836,6 +836,15 @@ def test_cumulative_sum_2d(axis): ) +def test_cumulative_sum_2d_recursive(executor): + a = xp.ones((10, 100), chunks=(10, 10)) + b = xp.cumulative_sum(a, axis=1) + assert_array_equal( + b.compute(executor=executor), + np.cumulative_sum(np.ones((10, 100)), axis=1), + ) + + def test_cumulative_sum_1d(): a = xp.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], chunks=(4,)) b = xp.cumulative_sum(a, axis=0) @@ -845,6 +854,12 @@ def test_cumulative_sum_1d(): ) +def test_cumulative_sum_unsupported_include_initial(spec): + with pytest.raises(NotImplementedError): + a = xp.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], chunks=(4,)) + xp.cumulative_sum(a, axis=0, include_initial=True) + + def test_mean_axis_0(spec, executor): a = xp.asarray( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec