|
6 | 6 | from itertools import product |
7 | 7 | from numbers import Integral, Number |
8 | 8 | from operator import add |
9 | | -from typing import TYPE_CHECKING, Any, Sequence, Tuple, Union |
| 9 | +from typing import TYPE_CHECKING, Any, Callable, Sequence, Tuple, Union |
10 | 10 | from warnings import warn |
11 | 11 |
|
12 | 12 | import ndindex |
|
23 | 23 | from cubed.core.plan import Plan, new_temp_path |
24 | 24 | from cubed.primitive.blockwise import blockwise as primitive_blockwise |
25 | 25 | from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise |
| 26 | +from cubed.primitive.blockwise import key_to_slices |
26 | 27 | from cubed.primitive.memory import get_buffer_copies |
27 | 28 | from cubed.primitive.rechunk import rechunk as primitive_rechunk |
28 | 29 | from cubed.spec import spec_from_config |
@@ -1678,3 +1679,117 @@ def smallest_blockdim(blockdims): |
1678 | 1679 | m = ntd[0] |
1679 | 1680 | out = ntd |
1680 | 1681 | return out |
| 1682 | + |
| 1683 | + |
| 1684 | +def _scan_binop( |
| 1685 | + out: np.ndarray, |
| 1686 | + left: "Array", |
| 1687 | + right: "Array", |
| 1688 | + *, |
| 1689 | + binop: Callable, |
| 1690 | + block_id: tuple[int, ...], |
| 1691 | + axis: int, |
| 1692 | + identity: Any, |
| 1693 | +) -> "Array": |
| 1694 | + # Get the underlying Zarr arrays so we can access directly |
| 1695 | + left = left.zarray |
| 1696 | + right = right.zarray |
| 1697 | + |
| 1698 | + left_slicer = key_to_slices(block_id, left) |
| 1699 | + right_slicer = list(left_slicer) |
| 1700 | + |
| 1701 | + # For the first block, we add the identity element |
| 1702 | + # For all other blocks `k`, we add the `k-1` element along `axis` |
| 1703 | + right_slicer[axis] = slice(block_id[axis] - 1, block_id[axis]) |
| 1704 | + right_slicer = tuple(right_slicer) |
| 1705 | + right_ = right[right_slicer] if block_id[axis] > 0 else identity |
| 1706 | + return binop(left[left_slicer], right_) |
| 1707 | + |
| 1708 | + |
| 1709 | +def scan( |
| 1710 | + array: "Array", |
| 1711 | + func: Callable, |
| 1712 | + *, |
| 1713 | + preop: Callable, |
| 1714 | + binop: Callable, |
| 1715 | + identity: Any, |
| 1716 | + axis: int, |
| 1717 | + dtype=None, |
| 1718 | +) -> "Array": |
| 1719 | + """ |
| 1720 | + Generic parallel scan. |
| 1721 | +
|
| 1722 | + Parameters |
| 1723 | + ---------- |
| 1724 | + x: Cubed Array |
| 1725 | + func: callable |
| 1726 | + Scan or cumulative function like np.cumsum or np.cumprod |
| 1727 | + preop: callable |
| 1728 | + Function applied blockwise that reduces each block to a single value |
| 1729 | + along ``axis``. For ``np.cumsum`` this is ``np.sum`` and for ``np.cumprod`` this is ``np.prod``. |
| 1730 | + binop: callable |
| 1731 | + Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul`` |
| 1732 | + identity: Any |
| 1733 | + Associated identity element more scan like 0 for ``np.cumsum`` and 1 for ``np.cumprod``. |
| 1734 | + axis: int |
| 1735 | + dtype: dtype |
| 1736 | +
|
| 1737 | + Notes |
| 1738 | + ----- |
| 1739 | + This method uses a variant of the Blelloch (1989) alogrithm. |
| 1740 | +
|
| 1741 | + Returns |
| 1742 | + ------- |
| 1743 | + Array |
| 1744 | +
|
| 1745 | + See also |
| 1746 | + -------- |
| 1747 | + cumsum |
| 1748 | + cumprod |
| 1749 | + """ |
| 1750 | + axis = validate_axis(axis, array.ndim) |
| 1751 | + |
| 1752 | + # Blelloch (1990) out-of-core algorithm. |
| 1753 | + # 1. First, scan blockwise |
| 1754 | + scanned = map_blocks(func, array, axis=axis) |
| 1755 | + # If there is only a single chunk, we can be done |
| 1756 | + if array.numblocks[axis] == 1: |
| 1757 | + return scanned |
| 1758 | + |
| 1759 | + # 2. Calculate the blockwise reduction using `preop` |
| 1760 | + # TODO: could also merge(1,2) by returning {"scan": np.cumsum(array), "preop": np.sum(array)} in `scanned` |
| 1761 | + reduced_chunks = tuple( |
| 1762 | + (1,) * array.numblocks[i] if i == axis else c |
| 1763 | + for i, c in enumerate(array.chunks) |
| 1764 | + ) |
| 1765 | + reduced = map_blocks(preop, array, chunks=reduced_chunks, axis=axis, keepdims=True) |
| 1766 | + |
| 1767 | + # 3. Now scan `reduced` to generate the increments for each block of `scanned`. |
| 1768 | + # Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan. |
| 1769 | + # Instead we generalize recursively apply the scan to `reduced`. |
| 1770 | + # 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1 |
| 1771 | + new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * 5) |
| 1772 | + new_chunks = ( |
| 1773 | + reduced.chunksize[:axis] + (new_chunksize,) + reduced.chunksize[axis + 1 :] |
| 1774 | + ) |
| 1775 | + |
| 1776 | + merged = merge_chunks(reduced, new_chunks) |
| 1777 | + |
| 1778 | + # 3b. Recursively scan this merged array to generate the increment for each block of `scanned` |
| 1779 | + increment = scan( |
| 1780 | + merged, func, preop=preop, binop=binop, identity=identity, axis=axis |
| 1781 | + ) |
| 1782 | + |
| 1783 | + # 4. Back to Blelloch. Now that we have the increment, add it to the blocks of `scanned`. |
| 1784 | + # Use map_direct since the chunks of increment and scanned aren't aligned anymore. |
| 1785 | + assert increment.shape[axis] == scanned.numblocks[axis] |
| 1786 | + # 5. Bada-bing, bada-boom. |
| 1787 | + return map_direct( |
| 1788 | + partial(_scan_binop, binop=binop, axis=axis, identity=identity), |
| 1789 | + scanned, |
| 1790 | + increment, |
| 1791 | + shape=scanned.shape, |
| 1792 | + dtype=scanned.dtype, |
| 1793 | + chunks=scanned.chunks, |
| 1794 | + extra_projected_mem=scanned.chunkmem * 2, # arbitrary |
| 1795 | + ) |
0 commit comments