Skip to content

Commit 34b658a

Browse files
authored
Change list of output arrays in BlockwiseSpec to a map (#724)
1 parent 7de9081 commit 34b658a

4 files changed

Lines changed: 31 additions & 20 deletions

File tree

cubed/core/ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def blockwise(
306306
reserved_mem=spec.reserved_mem,
307307
extra_projected_mem=extra_projected_mem,
308308
target_store=target_store,
309+
target_name=name,
309310
target_path=target_path,
310311
storage_options=spec.storage_options,
311312
compressor=spec.zarr_compressor,
@@ -314,7 +315,6 @@ def blockwise(
314315
chunks=_chunks,
315316
new_axes=new_axes,
316317
in_names=in_names,
317-
out_name=name,
318318
buffer_copies=buffer_copies,
319319
extra_func_kwargs=extra_func_kwargs,
320320
fusable_with_predecessors=fusable_with_predecessors,
@@ -454,10 +454,12 @@ def _general_blockwise(
454454
ts if ts is not None else new_temp_path(name=n, spec=spec)
455455
for n, ts in zip(name, target_stores)
456456
]
457+
target_names = name
457458
else: # single output
458459
name = gensym()
459460
if target_stores is None:
460461
target_stores = [new_temp_path(name=name, spec=spec)]
462+
target_names = [name]
461463

462464
op = primitive_general_blockwise(
463465
func,
@@ -468,6 +470,7 @@ def _general_blockwise(
468470
extra_projected_mem=extra_projected_mem,
469471
buffer_copies=buffer_copies,
470472
target_stores=target_stores,
473+
target_names=target_names,
471474
target_paths=target_paths,
472475
storage_options=spec.storage_options,
473476
compressor=spec.zarr_compressor,

cubed/primitive/blockwise.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ class BlockwiseSpec:
6363
Whether the input blocks read from each input array are supplied as an iterable or not.
6464
reads_map : Dict[str, CubedArrayProxy]
6565
Read proxy dictionary keyed by array name.
66-
writes_list : List[CubedArrayProxy]
67-
Write proxy list where entries have an ``array`` attribute that supports ``__setitem__``.
66+
writes_map : Dict[str, CubedArrayProxy]
67+
Write proxy dictionary keyed by array name where entries have an ``array`` attribute that supports ``__setitem__``.
68+
Note that the order of the entries must correspond to the order of the outputs created by the function.
6869
"""
6970

7071
key_function: Callable[..., Any]
@@ -74,7 +75,7 @@ class BlockwiseSpec:
7475
num_output_blocks: Tuple[int, ...]
7576
iterable_input_blocks: Tuple[bool, ...]
7677
reads_map: Dict[str, CubedArrayProxy]
77-
writes_list: List[CubedArrayProxy]
78+
writes_map: Dict[str, CubedArrayProxy]
7879
return_writes_stores: bool = False
7980

8081

@@ -92,22 +93,20 @@ def apply_blockwise(
9293
results, tuple
9394
):
9495
results = (results,)
95-
for i, result in enumerate(results):
96+
for result, write_proxy in zip(results, config.writes_map.values()):
9697
out_chunk_key = key_to_slices(
97-
out_coords_tuple, config.writes_list[i].array, config.writes_list[i].chunks
98+
out_coords_tuple, write_proxy.array, write_proxy.chunks
9899
)
99100
if isinstance(result, dict): # group of arrays with named fields
100101
for k, v in result.items():
101102
v = backend_array_to_numpy_array(v)
102-
config.writes_list[i].open().set_basic_selection(
103-
out_chunk_key, v, fields=k
104-
)
103+
write_proxy.open().set_basic_selection(out_chunk_key, v, fields=k)
105104
else:
106105
result = backend_array_to_numpy_array(result)
107-
config.writes_list[i].open()[out_chunk_key] = result
106+
write_proxy.open()[out_chunk_key] = result
108107

109108
if config.return_writes_stores:
110-
return [write_proxy.open().store for write_proxy in config.writes_list]
109+
return [write_proxy.open().store for write_proxy in config.writes_map.values()]
111110
return None
112111

113112

@@ -152,6 +151,7 @@ def blockwise(
152151
allowed_mem: int,
153152
reserved_mem: int,
154153
target_store: T_Store,
154+
target_name: str,
155155
target_path: Optional[str] = None,
156156
storage_options: Optional[Dict[str, Any]] = None,
157157
compressor: Union[dict, str, None] = "default",
@@ -160,7 +160,6 @@ def blockwise(
160160
chunks: T_Chunks,
161161
new_axes: Optional[Dict[int, int]] = None,
162162
in_names: Optional[List[str]] = None,
163-
out_name: Optional[str] = None,
164163
extra_projected_mem: int = 0,
165164
buffer_copies: Optional[BufferCopies] = None,
166165
extra_func_kwargs: Optional[Dict[str, Any]] = None,
@@ -231,7 +230,7 @@ def blockwise(
231230

232231
key_function = make_blockwise_key_function_flattened(
233232
func,
234-
out_name or "out",
233+
target_name,
235234
out_ind,
236235
*argindsstr,
237236
numblocks=numblocks,
@@ -245,6 +244,7 @@ def blockwise(
245244
allowed_mem=allowed_mem,
246245
reserved_mem=reserved_mem,
247246
target_stores=[target_store],
247+
target_names=[target_name],
248248
target_paths=[target_path] if target_path is not None else None,
249249
storage_options=storage_options,
250250
compressor=compressor,
@@ -270,6 +270,7 @@ def general_blockwise(
270270
allowed_mem: int,
271271
reserved_mem: int,
272272
target_stores: List[T_Store],
273+
target_names: List[str],
273274
target_paths: Optional[List[str]] = None,
274275
storage_options: Optional[Dict[str, Any]] = None,
275276
compressor: Union[dict, str, None] = "default",
@@ -341,7 +342,7 @@ def general_blockwise(
341342
name: CubedArrayProxy(array, array.chunks) for name, array in array_map.items()
342343
}
343344

344-
write_proxies = []
345+
write_proxies = {}
345346
output_chunk_memory = 0
346347
target_arrays = []
347348

@@ -372,7 +373,7 @@ def general_blockwise(
372373
)
373374
target_arrays.append(ta)
374375

375-
write_proxies.append(CubedArrayProxy(ta, chunksize))
376+
write_proxies[target_names[i]] = CubedArrayProxy(ta, chunksize)
376377

377378
# only one output chunk is read into memory at a time, so we find the largest
378379
output_chunk_memory = max(
@@ -559,7 +560,7 @@ def fused_func(*args):
559560

560561
function_nargs = pipeline1.config.function_nargs
561562
read_proxies = pipeline1.config.reads_map
562-
write_proxies = pipeline2.config.writes_list
563+
write_proxies = pipeline2.config.writes_map
563564
return_writes_stores = pipeline2.config.return_writes_stores
564565
num_input_blocks = tuple(
565566
n * pipeline2.config.num_input_blocks[0]
@@ -621,7 +622,7 @@ def fuse_multiple(
621622
num_output_blocks=(1,),
622623
iterable_input_blocks=(False,),
623624
reads_map={},
624-
writes_list=[],
625+
writes_map={},
625626
)
626627
predecessor_bw_specs = [
627628
primitive_op.pipeline.config
@@ -710,7 +711,7 @@ def fuse_blockwise_specs(
710711
read_proxies = dict(bw_spec.reads_map)
711712
for bws in predecessor_bw_specs:
712713
read_proxies.update(bws.reads_map)
713-
write_proxies = bw_spec.writes_list
714+
write_proxies = bw_spec.writes_map
714715
return_writes_stores = bw_spec.return_writes_stores
715716
return BlockwiseSpec(
716717
fused_key_func,

cubed/tests/primitive/test_blockwise.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def test_blockwise(tmp_path, executor, reserved_mem):
4040
allowed_mem=allowed_mem,
4141
reserved_mem=reserved_mem,
4242
target_store=target_store,
43+
target_name="target",
4344
shape=(3, 3),
4445
dtype=int,
4546
chunks=(2, 2),
@@ -70,7 +71,7 @@ def test_blockwise(tmp_path, executor, reserved_mem):
7071
assert_array_equal(res[:], np.outer([0, 1, 2], [10, 50, 100]))
7172

7273

73-
def _permute_dims(x, /, axes, allowed_mem, reserved_mem, target_store):
74+
def _permute_dims(x, /, axes, allowed_mem, reserved_mem, target_store, target_name):
7475
# From dask transpose
7576
if axes:
7677
if len(axes) != x.ndim:
@@ -86,6 +87,7 @@ def _permute_dims(x, /, axes, allowed_mem, reserved_mem, target_store):
8687
allowed_mem=allowed_mem,
8788
reserved_mem=reserved_mem,
8889
target_store=target_store,
90+
target_name=target_name,
8991
shape=x.shape,
9092
dtype=x.dtype,
9193
chunks=x.chunks,
@@ -109,6 +111,7 @@ def test_blockwise_with_args(tmp_path, executor):
109111
allowed_mem=allowed_mem,
110112
reserved_mem=0,
111113
target_store=target_store,
114+
target_name="target",
112115
)
113116

114117
assert op.target_array.shape == (3, 3)
@@ -160,6 +163,7 @@ def test_blockwise_allowed_mem_exceeded(tmp_path, reserved_mem):
160163
allowed_mem=allowed_mem,
161164
reserved_mem=reserved_mem,
162165
target_store=target_store,
166+
target_name="target",
163167
shape=(3, 3),
164168
dtype=np.int64,
165169
chunks=(2, 2),
@@ -204,6 +208,7 @@ def key_function(out_key):
204208
allowed_mem=allowed_mem,
205209
reserved_mem=0,
206210
target_stores=[target_store],
211+
target_names=["target"],
207212
shapes=[(20,)],
208213
dtypes=[int],
209214
chunkss=[(6,)],
@@ -252,6 +257,7 @@ def block_function(out_key):
252257
allowed_mem=allowed_mem,
253258
reserved_mem=0,
254259
target_stores=[target_store1, target_store2],
260+
target_names=["target1", "target2"],
255261
shapes=[(3, 3), (3, 3)],
256262
dtypes=[float, float],
257263
chunkss=[(2, 2), (2, 2)],
@@ -317,6 +323,7 @@ def block_function(out_key):
317323
allowed_mem=allowed_mem,
318324
reserved_mem=0,
319325
target_stores=[target_store1, target_store2],
326+
target_names=["target1", "target2"],
320327
shapes=[(3, 3), (3, 3)],
321328
dtypes=[float, float],
322329
chunkss=[(2, 2), (4, 2)], # numblocks differ

cubed/tests/primitive/test_blockwise_fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def make_blockwise_spec(
121121
num_output_blocks=num_output_blocks,
122122
iterable_input_blocks=iterable_input_blocks,
123123
reads_map={}, # unused
124-
writes_list=[], # unused
124+
writes_map={}, # unused
125125
)
126126

127127

0 commit comments

Comments
 (0)