2323from cubed .core .plan import Plan , new_temp_path
2424from cubed .primitive .blockwise import blockwise as primitive_blockwise
2525from cubed .primitive .blockwise import general_blockwise as primitive_general_blockwise
26- from cubed .primitive .blockwise import key_to_slices
2726from cubed .primitive .memory import get_buffer_copies
2827from cubed .primitive .rechunk import rechunk as primitive_rechunk
2928from cubed .spec import spec_from_config
@@ -1680,40 +1679,16 @@ def smallest_blockdim(blockdims):
16801679 return out
16811680
16821681
1683- def _scan_binop (
1684- out : np .ndarray ,
1685- left : "Array" ,
1686- right : "Array" ,
1687- * ,
1688- binop : Callable ,
1689- block_id : tuple [int , ...],
1690- axis : int ,
1691- identity : Any ,
1692- ) -> "Array" :
1693- # Get the underlying Zarr arrays so we can access directly
1694- left = left .zarray
1695- right = right .zarray
1696-
1697- left_slicer = key_to_slices (block_id , left )
1698- right_slicer = list (left_slicer )
1699-
1700- # For the first block, we add the identity element
1701- # For all other blocks `k`, we add the `k-1` element along `axis`
1702- right_slicer [axis ] = slice (block_id [axis ] - 1 , block_id [axis ])
1703- right_slicer = tuple (right_slicer )
1704- right_ = right [right_slicer ] if block_id [axis ] > 0 else identity
1705- return binop (left [left_slicer ], right_ )
1706-
1707-
17081682def scan (
17091683 array : "Array" ,
17101684 func : Callable ,
17111685 * ,
17121686 preop : Callable ,
17131687 binop : Callable ,
1714- identity : Any ,
17151688 axis : int ,
17161689 dtype = None ,
1690+ include_initial = False ,
1691+ split_every : int = 5 ,
17171692) -> "Array" :
17181693 """
17191694 Generic parallel scan.
@@ -1728,10 +1703,10 @@ def scan(
17281703 along ``axis``. For ``np.cumsum`` this is ``np.sum`` and for ``np.cumprod`` this is ``np.prod``.
17291704 binop: callable
17301705 Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
1731- identity: Any
1732- Associated identity element more scan like 0 for ``np.cumsum`` and 1 for ``np.cumprod``.
17331706 axis: int
17341707 dtype: dtype
1708+ include_initial: bool
1709+ Whether to include the identity value as the first value in the output.
17351710
17361711 Notes
17371712 -----
@@ -1746,49 +1721,82 @@ def scan(
17461721 cumsum
17471722 cumprod
17481723 """
1724+
1725+ # Note that if include_initial=True the final value is *not* included in the output.
1726+ # To include the final value is tricky with constant chunk sizes, since if the last
1727+ # chunk is full then a new chunk of size one needs to be added for the final value.
1728+ # TODO: add an include_final argument (default True)
1729+
17491730 axis = validate_axis (axis , array .ndim )
17501731
17511732 # Blelloch (1990) out-of-core algorithm.
17521733 # 1. First, scan blockwise
1753- scanned = map_blocks (func , array , axis = axis )
1734+ scanned = map_blocks (func , array , axis = axis , include_initial = include_initial )
17541735 # If there is only a single chunk, we can be done
17551736 if array .numblocks [axis ] == 1 :
17561737 return scanned
17571738
1758- # 2. Calculate the blockwise reduction using `preop`
1759- # TODO: could also merge(1,2) by returning {"scan": np.cumsum(array), "preop": np.sum(array)} in `scanned`
1760- reduced_chunks = tuple (
1761- (1 ,) * array .numblocks [i ] if i == axis else c
1762- for i , c in enumerate (array .chunks )
1739+ # 2. Calculate the reduction using `preop`
1740+ # Use `partial_reduce` to also merge to a decent intermediate chunksize
1741+ # since reduced.chunksize[axis] == 1
1742+
1743+ def identity_func (a , ** kwargs ):
1744+ return a
1745+
1746+ split_size = min (split_every , array .numblocks [axis ])
1747+ reduced = partial_reduce (
1748+ array ,
1749+ initial_func = partial (preop , axis = axis , keepdims = True ),
1750+ func = identity_func ,
1751+ split_every = {axis : split_size },
1752+ combine_sizes = {axis : split_size },
17631753 )
1764- reduced = map_blocks (preop , array , chunks = reduced_chunks , axis = axis , keepdims = True )
17651754
17661755 # 3. Now scan `reduced` to generate the increments for each block of `scanned`.
17671756 # Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan.
17681757 # Instead we generalize recursively apply the scan to `reduced`.
1769- # 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1
1770- new_chunksize = min (reduced .shape [axis ], reduced .chunksize [axis ] * 5 )
1771- new_chunks = (
1772- reduced .chunksize [:axis ] + (new_chunksize ,) + reduced .chunksize [axis + 1 :]
1773- )
1774-
1775- merged = merge_chunks (reduced , new_chunks )
1776-
1777- # 3b. Recursively scan this merged array to generate the increment for each block of `scanned`
1758+ # Note we always want to include the initial identity value (but not the final value)
1759+ # so blocks line up correctly.
17781760 increment = scan (
1779- merged , func , preop = preop , binop = binop , identity = identity , axis = axis
1761+ reduced ,
1762+ func ,
1763+ preop = preop ,
1764+ binop = binop ,
1765+ axis = axis ,
1766+ include_initial = True ,
17801767 )
17811768
17821769 # 4. Back to Blelloch. Now that we have the increment, add it to the blocks of `scanned`.
1783- # Use map_direct since the chunks of increment and scanned aren't aligned anymore.
1770+ # Use general_blockwise with a key function since the chunks of increment and scanned aren't aligned anymore.
17841771 assert increment .shape [axis ] == scanned .numblocks [axis ]
1772+
1773+ def key_function (out_key ):
1774+ out_coords = out_key [1 :]
1775+ inc_coords = tuple (
1776+ bi // split_every if i == axis else bi for i , bi in enumerate (out_coords )
1777+ )
1778+ return ((scanned .name ,) + out_coords , (increment .name ,) + inc_coords )
1779+
1780+ def _scan_binop (scn , inc , block_id = None , ** kwargs ):
1781+ bi = block_id [axis ] % split_every
1782+ ind = tuple (
1783+ slice (bi , bi + 1 ) if i == axis else slice (None ) for i in range (inc .ndim )
1784+ )
1785+ return binop (scn , inc [ind ])
1786+
17851787 # 5. Bada-bing, bada-boom.
1786- return map_direct (
1787- partial (_scan_binop , binop = binop , axis = axis , identity = identity ),
1788+ out = general_blockwise (
1789+ _scan_binop ,
1790+ key_function ,
17881791 scanned ,
17891792 increment ,
1790- shape = scanned .shape ,
1791- dtype = scanned .dtype ,
1792- chunks = scanned .chunks ,
1793+ shapes = [ scanned .shape ] ,
1794+ dtypes = [ scanned .dtype ] ,
1795+ chunkss = [ scanned .chunks ] ,
17931796 extra_projected_mem = scanned .chunkmem * 2 , # arbitrary
17941797 )
1798+
1799+ from cubed import Array
1800+
1801+ assert isinstance (out , Array ) # single output
1802+ return out
0 commit comments