diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 46d29b4d..186afa70 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -16,7 +16,8 @@ from cubed import config from cubed.backend_array_api import IS_IMMUTABLE_ARRAY, numpy_array_to_backend_array from cubed.backend_array_api import namespace as nxp -from cubed.core.array import CoreArray, check_array_specs, compute, gensym +from cubed.core.array import CoreArray, check_array_specs, gensym +from cubed.core.array import compute as compute_arrays from cubed.core.plan import Plan, intermediate_store from cubed.core.rechunk import multistage_regular_rechunking_plan from cubed.primitive.blockwise import blockwise as primitive_blockwise @@ -128,6 +129,8 @@ def store( sources: Union["Array", Sequence["Array"]], targets, regions: tuple[slice, ...] | list[tuple[slice, ...]] | None = None, + compute: bool = True, + *, executor=None, **kwargs, ): @@ -135,9 +138,6 @@ def store( In the current implementation ``targets`` must be Zarr arrays. - Note that this operation is eager, and will run the computation - immediately. - Parameters ---------- sources : cubed.Array or collection of cubed.Array @@ -146,6 +146,8 @@ def store( Zarr arrays to write to regions : tuple of slices or list of tuple of slices, optional The regions of data that should be written to in targets. + compute : boolean, optional + If True compute immediately, return tuple of arrays otherwise. executor : cubed.runtime.types.Executor, optional The executor to use to run the computation. Defaults to using the in-process Python executor. @@ -176,7 +178,12 @@ def store( for source, target, region in zip(sources, targets, regions_list): array = _store_array(source, target, region=region) arrays.append(array) - compute(*arrays, executor=executor, _return_in_memory_array=False, **kwargs) + if compute: + compute_arrays( + *arrays, executor=executor, _return_in_memory_array=False, **kwargs + ) + else: + return tuple(arrays) def _store_array( @@ -207,6 +214,7 @@ def _store_array( dtype=source.dtype, align_arrays=False, target_store=target, + fusable_with_successors=False, **blockwise_kwargs, ) else: @@ -260,6 +268,7 @@ def __iter__(self): target_stores=[target], output_blocks=output_blocks, num_tasks=source.npartitions, + fusable_with_successors=False, **blockwise_kwargs, ) from cubed import Array diff --git a/cubed/core/optimization.py b/cubed/core/optimization.py index ff2bb951..021632f2 100644 --- a/cubed/core/optimization.py +++ b/cubed/core/optimization.py @@ -121,7 +121,12 @@ def predecessor_ops_and_arrays(dag, name): pre_list = list(predecessors_unordered(dag, input)) assert len(pre_list) == 1 # each array is produced by a single op pre = pre_list[0] - can_fuse = is_primitive_op(nodes[pre]) and out_degree_unique(dag, input) == 1 + node_dict = nodes[pre] + can_fuse = ( + is_primitive_op(node_dict) + and node_dict["primitive_op"].fusable_with_successors + and out_degree_unique(dag, input) == 1 + ) yield pre, input, can_fuse diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index dd8bd756..e00fcffa 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -132,6 +132,41 @@ def test_store(tmp_path, spec, executor): assert_array_equal(target[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) +def test_store_lazy_compute(tmp_path, spec): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + + store = tmp_path / "source.zarr" + target = open_storage_array( + store, mode="w", shape=a.shape, dtype=a.dtype, chunks=a.chunksize + ) + + (b,) = cubed.store(a, target, compute=False) + + # target has not been computed yet + with pytest.raises(AssertionError): + assert_array_equal(target[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + + b.compute() + assert_array_equal(target[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + + +def test_store_lazy_compute_more(tmp_path, spec): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + + store = tmp_path / "source.zarr" + target = open_storage_array( + store, mode="w", shape=a.shape, dtype=a.dtype, chunks=a.chunksize + ) + + (b,) = cubed.store(a, target, compute=False) + + # do a further computation and check that store has not been optimized away + c = b + 1 + res = c.compute() + assert_array_equal(target[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + assert_array_equal(res, np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + 1) + + def test_store_multiple(tmp_path, spec, executor): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)