Skip to content

Commit e988918

Browse files
committed
Allow store to be a lazy operation by passing compute=False
1 parent 662f04e commit e988918

3 files changed

Lines changed: 55 additions & 6 deletions

File tree

cubed/core/ops.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from cubed import config
1717
from cubed.backend_array_api import IS_IMMUTABLE_ARRAY, numpy_array_to_backend_array
1818
from cubed.backend_array_api import namespace as nxp
19-
from cubed.core.array import CoreArray, check_array_specs, compute, gensym
19+
from cubed.core.array import CoreArray, check_array_specs, gensym
20+
from cubed.core.array import compute as compute_arrays
2021
from cubed.core.plan import Plan, intermediate_store
2122
from cubed.core.rechunk import multistage_regular_rechunking_plan
2223
from cubed.primitive.blockwise import blockwise as primitive_blockwise
@@ -128,16 +129,15 @@ def store(
128129
sources: Union["Array", Sequence["Array"]],
129130
targets,
130131
regions: tuple[slice, ...] | list[tuple[slice, ...]] | None = None,
132+
compute: bool = True,
133+
*,
131134
executor=None,
132135
**kwargs,
133136
):
134137
"""Save source arrays to array-like objects.
135138
136139
In the current implementation ``targets`` must be Zarr arrays.
137140
138-
Note that this operation is eager, and will run the computation
139-
immediately.
140-
141141
Parameters
142142
----------
143143
sources : cubed.Array or collection of cubed.Array
@@ -146,6 +146,8 @@ def store(
146146
Zarr arrays to write to
147147
regions : tuple of slices or list of tuple of slices, optional
148148
The regions of data that should be written to in targets.
149+
compute : boolean, optional
150+
If True compute immediately, return tuple of arrays otherwise.
149151
executor : cubed.runtime.types.Executor, optional
150152
The executor to use to run the computation.
151153
Defaults to using the in-process Python executor.
@@ -176,7 +178,12 @@ def store(
176178
for source, target, region in zip(sources, targets, regions_list):
177179
array = _store_array(source, target, region=region)
178180
arrays.append(array)
179-
compute(*arrays, executor=executor, _return_in_memory_array=False, **kwargs)
181+
if compute:
182+
compute_arrays(
183+
*arrays, executor=executor, _return_in_memory_array=False, **kwargs
184+
)
185+
else:
186+
return tuple(arrays)
180187

181188

182189
def _store_array(
@@ -207,6 +214,7 @@ def _store_array(
207214
dtype=source.dtype,
208215
align_arrays=False,
209216
target_store=target,
217+
fusable_with_successors=False,
210218
**blockwise_kwargs,
211219
)
212220
else:
@@ -260,6 +268,7 @@ def __iter__(self):
260268
target_stores=[target],
261269
output_blocks=output_blocks,
262270
num_tasks=source.npartitions,
271+
fusable_with_successors=False,
263272
**blockwise_kwargs,
264273
)
265274
from cubed import Array

cubed/core/optimization.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,12 @@ def predecessor_ops_and_arrays(dag, name):
121121
pre_list = list(predecessors_unordered(dag, input))
122122
assert len(pre_list) == 1 # each array is produced by a single op
123123
pre = pre_list[0]
124-
can_fuse = is_primitive_op(nodes[pre]) and out_degree_unique(dag, input) == 1
124+
node_dict = nodes[pre]
125+
can_fuse = (
126+
is_primitive_op(node_dict)
127+
and node_dict["primitive_op"].fusable_with_successors
128+
and out_degree_unique(dag, input) == 1
129+
)
125130
yield pre, input, can_fuse
126131

127132

cubed/tests/test_core.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,41 @@ def test_store(tmp_path, spec, executor):
132132
assert_array_equal(target[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
133133

134134

135+
def test_store_lazy_compute(tmp_path, spec):
136+
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
137+
138+
store = tmp_path / "source.zarr"
139+
target = open_storage_array(
140+
store, mode="w", shape=a.shape, dtype=a.dtype, chunks=a.chunksize
141+
)
142+
143+
(b,) = cubed.store(a, target, compute=False)
144+
145+
# target has not been computed yet
146+
with pytest.raises(AssertionError):
147+
assert_array_equal(target[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
148+
149+
b.compute()
150+
assert_array_equal(target[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
151+
152+
153+
def test_store_lazy_compute_more(tmp_path, spec):
154+
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
155+
156+
store = tmp_path / "source.zarr"
157+
target = open_storage_array(
158+
store, mode="w", shape=a.shape, dtype=a.dtype, chunks=a.chunksize
159+
)
160+
161+
(b,) = cubed.store(a, target, compute=False)
162+
163+
# do a further computation and check that store has not been optimized away
164+
c = b + 1
165+
res = c.compute()
166+
assert_array_equal(target[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
167+
assert_array_equal(res, np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + 1)
168+
169+
135170
def test_store_multiple(tmp_path, spec, executor):
136171
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
137172
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)

0 commit comments

Comments
 (0)