Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ jobs:
array_api_tests/test_has_names.py

# signatures of items not implemented
array_api_tests/test_signatures.py::test_func_signature[cumulative_sum]
array_api_tests/test_signatures.py::test_func_signature[unique_all]
array_api_tests/test_signatures.py::test_func_signature[unique_counts]
array_api_tests/test_signatures.py::test_func_signature[unique_inverse]
Expand All @@ -112,13 +111,14 @@ jobs:
array_api_tests/test_array_object.py::test_getitem
# test_searchsorted depends on sort which is not implemented
array_api_tests/test_searching_functions.py::test_searchsorted
# cumulative_sum with include_initial=True is not implemented
array_api_tests/test_statistical_functions.py::test_cumulative_sum

# not implemented
array_api_tests/test_array_object.py::test_setitem
array_api_tests/test_array_object.py::test_setitem_masking
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_sorting_functions.py
array_api_tests/test_statistical_functions.py::test_cumulative_sum

# finfo(float32).eps returns float32 but should return float
array_api_tests/test_data_type_functions.py::test_finfo[float32]
Expand Down
4 changes: 2 additions & 2 deletions api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `unique_values` | :x: | | Shape is data dependent |
| Sorting Functions | `argsort` | :x: | | |
| | `sort` | :x: | | |
| Statistical Functions | `cumulative_prod` | :x: | 2024.12 | WIP [#531](https://github.com/cubed-dev/cubed/pull/531) |
| | `cumulative_sum` | :x: | 2023.12 | WIP [#531](https://github.com/cubed-dev/cubed/pull/531) |
| Statistical Functions | `cumulative_prod` | :x: | 2024.12 | |
| | `cumulative_sum` | :white_check_mark: | 2023.12 | |
| | `max` | :white_check_mark: | | |
| | `mean` | :white_check_mark: | | |
| | `min` | :white_check_mark: | | |
Expand Down
17 changes: 14 additions & 3 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@


def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False, device=None):
if include_initial:
raise NotImplementedError("include_initial is not supported in cumulative_sum")
dtype = _upcast_integral_dtypes(
x,
dtype,
Expand All @@ -21,9 +23,18 @@ def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False, device
fname="cumulative_sum",
device=device,
)
return scan(
x, preop=nxp.sum, func=nxp.cumulative_sum, binop=nxp.add, identity=0, axis=axis
)
return scan(x, preop=nxp.sum, func=_cumulative_sum_func, binop=nxp.add, axis=axis)


def _cumulative_sum_func(a, /, *, axis=None, dtype=None, include_initial=False):
out = nxp.cumulative_sum(a, axis=axis, dtype=dtype, include_initial=include_initial)
if include_initial:
# we don't yet support including the final element as it complicates chunk sizing
ind = tuple(
slice(a.shape[i]) if i == axis else slice(None) for i in range(a.ndim)
)
out = out[ind]
return out


def max(x, /, *, axis=None, keepdims=False, split_every=None):
Expand Down
112 changes: 60 additions & 52 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from cubed.core.plan import Plan, new_temp_path
from cubed.primitive.blockwise import blockwise as primitive_blockwise
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
from cubed.primitive.blockwise import key_to_slices
from cubed.primitive.memory import get_buffer_copies
from cubed.primitive.rechunk import rechunk as primitive_rechunk
from cubed.spec import spec_from_config
Expand Down Expand Up @@ -1684,40 +1683,16 @@ def smallest_blockdim(blockdims):
return out


def _scan_binop(
out: np.ndarray,
left: "Array",
right: "Array",
*,
binop: Callable,
block_id: tuple[int, ...],
axis: int,
identity: Any,
) -> "Array":
# Get the underlying Zarr arrays so we can access directly
left = left.zarray
right = right.zarray

left_slicer = key_to_slices(block_id, left)
right_slicer = list(left_slicer)

# For the first block, we add the identity element
# For all other blocks `k`, we add the `k-1` element along `axis`
right_slicer[axis] = slice(block_id[axis] - 1, block_id[axis])
right_slicer = tuple(right_slicer)
right_ = right[right_slicer] if block_id[axis] > 0 else identity
return binop(left[left_slicer], right_)


def scan(
array: "Array",
func: Callable,
*,
preop: Callable,
binop: Callable,
identity: Any,
axis: int,
dtype=None,
include_initial=False,
split_every: int = 5,
) -> "Array":
"""
Generic parallel scan.
Expand All @@ -1732,10 +1707,10 @@ def scan(
along ``axis``. For ``np.cumsum`` this is ``np.sum`` and for ``np.cumprod`` this is ``np.prod``.
binop: callable
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
identity: Any
Associated identity element more scan like 0 for ``np.cumsum`` and 1 for ``np.cumprod``.
axis: int
dtype: dtype
include_initial: bool
Whether to include the identity value as the first value in the output.

Notes
-----
Expand All @@ -1750,49 +1725,82 @@ def scan(
cumsum
cumprod
"""

# Note that if include_initial=True the final value is *not* included in the output.
# To include the final value is tricky with constant chunk sizes, since if the last
# chunk is full then a new chunk of size one needs to be added for the final value.
# TODO: add an include_final argument (default True)

axis = validate_axis(axis, array.ndim)

# Blelloch (1990) out-of-core algorithm.
# 1. First, scan blockwise
scanned = map_blocks(func, array, axis=axis)
scanned = map_blocks(func, array, axis=axis, include_initial=include_initial)
# If there is only a single chunk, we can be done
if array.numblocks[axis] == 1:
return scanned

# 2. Calculate the blockwise reduction using `preop`
# TODO: could also merge(1,2) by returning {"scan": np.cumsum(array), "preop": np.sum(array)} in `scanned`
reduced_chunks = tuple(
(1,) * array.numblocks[i] if i == axis else c
for i, c in enumerate(array.chunks)
# 2. Calculate the reduction using `preop`
# Use `partial_reduce` to also merge to a decent intermediate chunksize
# since reduced.chunksize[axis] == 1

def identity_func(a, **kwargs):
return a

split_size = min(split_every, array.numblocks[axis])
reduced = partial_reduce(
array,
initial_func=partial(preop, axis=axis, keepdims=True),
func=identity_func,
split_every={axis: split_size},
combine_sizes={axis: split_size},
)
reduced = map_blocks(preop, array, chunks=reduced_chunks, axis=axis, keepdims=True)

# 3. Now scan `reduced` to generate the increments for each block of `scanned`.
# Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan.
# Instead we generalize recursively apply the scan to `reduced`.
# 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1
new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * 5)
new_chunks = (
reduced.chunksize[:axis] + (new_chunksize,) + reduced.chunksize[axis + 1 :]
)

merged = merge_chunks(reduced, new_chunks)

# 3b. Recursively scan this merged array to generate the increment for each block of `scanned`
# Note we always want to include the initial identity value (but not the final value)
# so blocks line up correctly.
increment = scan(
merged, func, preop=preop, binop=binop, identity=identity, axis=axis
reduced,
func,
preop=preop,
binop=binop,
axis=axis,
include_initial=True,
)

# 4. Back to Blelloch. Now that we have the increment, add it to the blocks of `scanned`.
# Use map_direct since the chunks of increment and scanned aren't aligned anymore.
# Use general_blockwise with a key function since the chunks of increment and scanned aren't aligned anymore.
assert increment.shape[axis] == scanned.numblocks[axis]

def key_function(out_key):
out_coords = out_key[1:]
inc_coords = tuple(
bi // split_every if i == axis else bi for i, bi in enumerate(out_coords)
)
return ((scanned.name,) + out_coords, (increment.name,) + inc_coords)

def _scan_binop(scn, inc, block_id=None, **kwargs):
bi = block_id[axis] % split_every
ind = tuple(
slice(bi, bi + 1) if i == axis else slice(None) for i in range(inc.ndim)
)
return binop(scn, inc[ind])

# 5. Bada-bing, bada-boom.
return map_direct(
partial(_scan_binop, binop=binop, axis=axis, identity=identity),
out = general_blockwise(
_scan_binop,
key_function,
scanned,
increment,
shape=scanned.shape,
dtype=scanned.dtype,
chunks=scanned.chunks,
shapes=[scanned.shape],
dtypes=[scanned.dtype],
chunkss=[scanned.chunks],
extra_projected_mem=scanned.chunkmem * 2, # arbitrary
)

from cubed import Array

assert isinstance(out, Array) # single output
return out
15 changes: 15 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,15 @@ def test_cumulative_sum_2d(axis):
)


def test_cumulative_sum_2d_recursive(executor):
a = xp.ones((10, 100), chunks=(10, 10))
b = xp.cumulative_sum(a, axis=1)
assert_array_equal(
b.compute(executor=executor),
np.cumulative_sum(np.ones((10, 100)), axis=1),
)


def test_cumulative_sum_1d():
a = xp.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], chunks=(4,))
b = xp.cumulative_sum(a, axis=0)
Expand All @@ -845,6 +854,12 @@ def test_cumulative_sum_1d():
)


def test_cumulative_sum_unsupported_include_initial(spec):
with pytest.raises(NotImplementedError):
a = xp.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], chunks=(4,))
xp.cumulative_sum(a, axis=0, include_initial=True)


def test_mean_axis_0(spec, executor):
a = xp.asarray(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec
Expand Down