@@ -63,8 +63,9 @@ class BlockwiseSpec:
6363 Whether the input blocks read from each input array are supplied as an iterable or not.
6464 reads_map : Dict[str, CubedArrayProxy]
6565 Read proxy dictionary keyed by array name.
66- writes_list : List[CubedArrayProxy]
67- Write proxy list where entries have an ``array`` attribute that supports ``__setitem__``.
66+ writes_map : Dict[str, CubedArrayProxy]
67+ Write proxy dictionary keyed by array name where entries have an ``array`` attribute that supports ``__setitem__``.
68+ Note that the order of the entries must correspond to the order of the outputs created by the function.
6869 """
6970
7071 key_function : Callable [..., Any ]
@@ -74,7 +75,7 @@ class BlockwiseSpec:
7475 num_output_blocks : Tuple [int , ...]
7576 iterable_input_blocks : Tuple [bool , ...]
7677 reads_map : Dict [str , CubedArrayProxy ]
77- writes_list : List [ CubedArrayProxy ]
78+ writes_map : Dict [ str , CubedArrayProxy ]
7879 return_writes_stores : bool = False
7980
8081
@@ -92,22 +93,20 @@ def apply_blockwise(
9293 results , tuple
9394 ):
9495 results = (results ,)
95- for i , result in enumerate (results ):
96+ for result , write_proxy in zip (results , config . writes_map . values () ):
9697 out_chunk_key = key_to_slices (
97- out_coords_tuple , config . writes_list [ i ]. array , config . writes_list [ i ] .chunks
98+ out_coords_tuple , write_proxy . array , write_proxy .chunks
9899 )
99100 if isinstance (result , dict ): # group of arrays with named fields
100101 for k , v in result .items ():
101102 v = backend_array_to_numpy_array (v )
102- config .writes_list [i ].open ().set_basic_selection (
103- out_chunk_key , v , fields = k
104- )
103+ write_proxy .open ().set_basic_selection (out_chunk_key , v , fields = k )
105104 else :
106105 result = backend_array_to_numpy_array (result )
107- config . writes_list [ i ] .open ()[out_chunk_key ] = result
106+ write_proxy .open ()[out_chunk_key ] = result
108107
109108 if config .return_writes_stores :
110- return [write_proxy .open ().store for write_proxy in config .writes_list ]
109+ return [write_proxy .open ().store for write_proxy in config .writes_map . values () ]
111110 return None
112111
113112
@@ -152,6 +151,7 @@ def blockwise(
152151 allowed_mem : int ,
153152 reserved_mem : int ,
154153 target_store : T_Store ,
154+ target_name : str ,
155155 target_path : Optional [str ] = None ,
156156 storage_options : Optional [Dict [str , Any ]] = None ,
157157 compressor : Union [dict , str , None ] = "default" ,
@@ -160,7 +160,6 @@ def blockwise(
160160 chunks : T_Chunks ,
161161 new_axes : Optional [Dict [int , int ]] = None ,
162162 in_names : Optional [List [str ]] = None ,
163- out_name : Optional [str ] = None ,
164163 extra_projected_mem : int = 0 ,
165164 buffer_copies : Optional [BufferCopies ] = None ,
166165 extra_func_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -231,7 +230,7 @@ def blockwise(
231230
232231 key_function = make_blockwise_key_function_flattened (
233232 func ,
234- out_name or "out" ,
233+ target_name ,
235234 out_ind ,
236235 * argindsstr ,
237236 numblocks = numblocks ,
@@ -245,6 +244,7 @@ def blockwise(
245244 allowed_mem = allowed_mem ,
246245 reserved_mem = reserved_mem ,
247246 target_stores = [target_store ],
247+ target_names = [target_name ],
248248 target_paths = [target_path ] if target_path is not None else None ,
249249 storage_options = storage_options ,
250250 compressor = compressor ,
@@ -270,6 +270,7 @@ def general_blockwise(
270270 allowed_mem : int ,
271271 reserved_mem : int ,
272272 target_stores : List [T_Store ],
273+ target_names : List [str ],
273274 target_paths : Optional [List [str ]] = None ,
274275 storage_options : Optional [Dict [str , Any ]] = None ,
275276 compressor : Union [dict , str , None ] = "default" ,
@@ -341,7 +342,7 @@ def general_blockwise(
341342 name : CubedArrayProxy (array , array .chunks ) for name , array in array_map .items ()
342343 }
343344
344- write_proxies = []
345+ write_proxies = {}
345346 output_chunk_memory = 0
346347 target_arrays = []
347348
@@ -372,7 +373,7 @@ def general_blockwise(
372373 )
373374 target_arrays .append (ta )
374375
375- write_proxies . append ( CubedArrayProxy (ta , chunksize ) )
376+ write_proxies [ target_names [ i ]] = CubedArrayProxy (ta , chunksize )
376377
377378 # only one output chunk is read into memory at a time, so we find the largest
378379 output_chunk_memory = max (
@@ -559,7 +560,7 @@ def fused_func(*args):
559560
560561 function_nargs = pipeline1 .config .function_nargs
561562 read_proxies = pipeline1 .config .reads_map
562- write_proxies = pipeline2 .config .writes_list
563+ write_proxies = pipeline2 .config .writes_map
563564 return_writes_stores = pipeline2 .config .return_writes_stores
564565 num_input_blocks = tuple (
565566 n * pipeline2 .config .num_input_blocks [0 ]
@@ -621,7 +622,7 @@ def fuse_multiple(
621622 num_output_blocks = (1 ,),
622623 iterable_input_blocks = (False ,),
623624 reads_map = {},
624- writes_list = [] ,
625+ writes_map = {} ,
625626 )
626627 predecessor_bw_specs = [
627628 primitive_op .pipeline .config
@@ -710,7 +711,7 @@ def fuse_blockwise_specs(
710711 read_proxies = dict (bw_spec .reads_map )
711712 for bws in predecessor_bw_specs :
712713 read_proxies .update (bws .reads_map )
713- write_proxies = bw_spec .writes_list
714+ write_proxies = bw_spec .writes_map
714715 return_writes_stores = bw_spec .return_writes_stores
715716 return BlockwiseSpec (
716717 fused_key_func ,
0 commit comments