2323from cubed .core .plan import Plan , new_temp_path
2424from cubed .primitive .blockwise import blockwise as primitive_blockwise
2525from cubed .primitive .blockwise import general_blockwise as primitive_general_blockwise
26- from cubed .primitive .blockwise import key_to_slices
2726from cubed .primitive .memory import get_buffer_copies
2827from cubed .primitive .rechunk import rechunk as primitive_rechunk
2928from cubed .spec import spec_from_config
@@ -1684,40 +1683,16 @@ def smallest_blockdim(blockdims):
16841683 return out
16851684
16861685
1687- def _scan_binop (
1688- out : np .ndarray ,
1689- left : "Array" ,
1690- right : "Array" ,
1691- * ,
1692- binop : Callable ,
1693- block_id : tuple [int , ...],
1694- axis : int ,
1695- identity : Any ,
1696- ) -> "Array" :
1697- # Get the underlying Zarr arrays so we can access directly
1698- left = left .zarray
1699- right = right .zarray
1700-
1701- left_slicer = key_to_slices (block_id , left )
1702- right_slicer = list (left_slicer )
1703-
1704- # For the first block, we add the identity element
1705- # For all other blocks `k`, we add the `k-1` element along `axis`
1706- right_slicer [axis ] = slice (block_id [axis ] - 1 , block_id [axis ])
1707- right_slicer = tuple (right_slicer )
1708- right_ = right [right_slicer ] if block_id [axis ] > 0 else identity
1709- return binop (left [left_slicer ], right_ )
1710-
1711-
17121686def scan (
17131687 array : "Array" ,
17141688 func : Callable ,
17151689 * ,
17161690 preop : Callable ,
17171691 binop : Callable ,
1718- identity : Any ,
17191692 axis : int ,
17201693 dtype = None ,
1694+ include_initial = False ,
1695+ split_every : int = 5 ,
17211696) -> "Array" :
17221697 """
17231698 Generic parallel scan.
@@ -1732,10 +1707,10 @@ def scan(
17321707 along ``axis``. For ``np.cumsum`` this is ``np.sum`` and for ``np.cumprod`` this is ``np.prod``.
17331708 binop: callable
17341709 Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
1735- identity: Any
1736- Associated identity element more scan like 0 for ``np.cumsum`` and 1 for ``np.cumprod``.
17371710 axis: int
17381711 dtype: dtype
1712+ include_initial: bool
1713+ Whether to include the identity value as the first value in the output.
17391714
17401715 Notes
17411716 -----
@@ -1750,17 +1725,22 @@ def scan(
17501725 cumsum
17511726 cumprod
17521727 """
1728+
1729+ # Note that if include_initial=True the final value is *not* included in the output.
1730+ # To include the final value is tricky with constant chunk sizes, since if the last
1731+ # chunk is full then a new chunk of size one needs to be added for the final value.
1732+ # TODO: add an include_final argument (default True)
1733+
17531734 axis = validate_axis (axis , array .ndim )
17541735
17551736 # Blelloch (1990) out-of-core algorithm.
17561737 # 1. First, scan blockwise
1757- scanned = map_blocks (func , array , axis = axis )
1738+ scanned = map_blocks (func , array , axis = axis , include_initial = include_initial )
17581739 # If there is only a single chunk, we can be done
17591740 if array .numblocks [axis ] == 1 :
17601741 return scanned
17611742
17621743 # 2. Calculate the blockwise reduction using `preop`
1763- # TODO: could also merge(1,2) by returning {"scan": np.cumsum(array), "preop": np.sum(array)} in `scanned`
17641744 reduced_chunks = tuple (
17651745 (1 ,) * array .numblocks [i ] if i == axis else c
17661746 for i , c in enumerate (array .chunks )
@@ -1771,28 +1751,56 @@ def scan(
17711751 # Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan.
17721752 # Instead we generalize recursively apply the scan to `reduced`.
17731753 # 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1
1774- new_chunksize = min (reduced .shape [axis ], reduced .chunksize [axis ] * 5 )
1754+ new_chunksize = min (reduced .shape [axis ], reduced .chunksize [axis ] * split_every )
17751755 new_chunks = (
17761756 reduced .chunksize [:axis ] + (new_chunksize ,) + reduced .chunksize [axis + 1 :]
17771757 )
17781758
17791759 merged = merge_chunks (reduced , new_chunks )
17801760
17811761 # 3b. Recursively scan this merged array to generate the increment for each block of `scanned`
1762+ # Note we always want to include the initial identity value (but not the final value)
1763+ # so blocks line up correctly.
17821764 increment = scan (
1783- merged , func , preop = preop , binop = binop , identity = identity , axis = axis
1765+ merged ,
1766+ func ,
1767+ preop = preop ,
1768+ binop = binop ,
1769+ axis = axis ,
1770+ include_initial = True ,
17841771 )
17851772
17861773 # 4. Back to Blelloch. Now that we have the increment, add it to the blocks of `scanned`.
1787- # Use map_direct since the chunks of increment and scanned aren't aligned anymore.
1774+ # Use general_blockwise with a key function since the chunks of increment and scanned aren't aligned anymore.
17881775 assert increment .shape [axis ] == scanned .numblocks [axis ]
1776+
1777+ def key_function (out_key ):
1778+ out_coords = out_key [1 :]
1779+ inc_coords = tuple (
1780+ bi // split_every if i == axis else bi for i , bi in enumerate (out_coords )
1781+ )
1782+ return ((scanned .name ,) + out_coords , (increment .name ,) + inc_coords )
1783+
1784+ def _scan_binop (scn , inc , block_id = None , ** kwargs ):
1785+ bi = block_id [axis ] % split_every
1786+ ind = tuple (
1787+ slice (bi , bi + 1 ) if i == axis else slice (None ) for i in range (inc .ndim )
1788+ )
1789+ return binop (scn , inc [ind ])
1790+
17891791 # 5. Bada-bing, bada-boom.
1790- return map_direct (
1791- partial (_scan_binop , binop = binop , axis = axis , identity = identity ),
1792+ out = general_blockwise (
1793+ _scan_binop ,
1794+ key_function ,
17921795 scanned ,
17931796 increment ,
1794- shape = scanned .shape ,
1795- dtype = scanned .dtype ,
1796- chunks = scanned .chunks ,
1797+ shapes = [ scanned .shape ] ,
1798+ dtypes = [ scanned .dtype ] ,
1799+ chunkss = [ scanned .chunks ] ,
17971800 extra_projected_mem = scanned .chunkmem * 2 , # arbitrary
17981801 )
1802+
1803+ from cubed import Array
1804+
1805+ assert isinstance (out , Array ) # single output
1806+ return out
0 commit comments