Skip to content

Commit 2c720d3

Browse files
authored
num_tasks is wrong when region is specified (#854)
1 parent 40aa3df commit 2c720d3

3 files changed

Lines changed: 16 additions & 2 deletions

File tree

cubed/core/ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def key_function(out_key):
243243
chunkss=[chunks],
244244
target_stores=[target],
245245
output_blocks=output_blocks,
246+
num_tasks=source.npartitions,
246247
**blockwise_kwargs,
247248
)
248249
from cubed import Array

cubed/primitive/blockwise.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def general_blockwise(
296296
target_chunks_: Optional[T_RegularChunks] = None,
297297
return_writes_stores: bool = False,
298298
output_blocks: Optional[Iterator[List[int]]] = None,
299+
num_tasks=None,
299300
**kwargs,
300301
) -> PrimitiveOperation:
301302
"""A more general form of ``blockwise`` that uses a function to specify the block
@@ -418,7 +419,8 @@ def general_blockwise(
418419
output_blocks = map(
419420
list, itertools.product(*[range(len(c)) for c in chunks_normal])
420421
)
421-
num_tasks = math.prod(len(c) for c in chunks_normal)
422+
if num_tasks is None:
423+
num_tasks = math.prod(len(c) for c in chunks_normal)
422424

423425
pipeline = CubedPipeline(
424426
apply_blockwise,

cubed/tests/test_core.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
from cubed._testing import assert_array_equal
1414
from cubed.array_api.dtypes import _floating_dtypes
1515
from cubed.backend_array_api import namespace as nxp
16-
from cubed.core.ops import general_blockwise, merge_chunks, partial_reduce, tree_reduce
16+
from cubed.core.ops import (
17+
_store_array,
18+
general_blockwise,
19+
merge_chunks,
20+
partial_reduce,
21+
tree_reduce,
22+
)
1723
from cubed.core.optimization import fuse_all_optimize_dag, multiple_inputs_optimize_dag
1824
from cubed.core.plan import ArrayRole
1925
from cubed.storage.store import open_storage_array
@@ -197,6 +203,11 @@ def test_to_zarr_region(tmp_path, spec, executor):
197203
# note that the same zarr store is overwritten in the following tests
198204

199205
region = (slice(0, 2), slice(0, 2))
206+
207+
# need to use internal _store_array function to access plan
208+
# since to_zarr is eager
209+
assert _store_array(a[:2, :2], z, region=region).plan().num_tasks == 1
210+
200211
cubed.to_zarr(a[:2, :2], z, region=region, executor=executor)
201212
res = open_storage_array(store, mode="r")
202213
assert_array_equal(res[region], np.array([[1, 2], [5, 6]]))

0 commit comments

Comments
 (0)