@@ -270,19 +270,21 @@ def test_multiple_ops(spec, executor):
270270 ({}, ((2 , 1 ), (1 , 1 , 1 ))), # unchanged
271271 ],
272272)
273- def test_rechunk (spec , executor , new_chunks , expected_chunks ):
273+ @pytest .mark .parametrize ("use_new_impl" , [True , False ])
274+ def test_rechunk (spec , executor , new_chunks , expected_chunks , use_new_impl ):
274275 a = xp .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], chunks = (2 , 1 ), spec = spec )
275- b = a .rechunk (new_chunks )
276+ b = a .rechunk (new_chunks , use_new_impl = use_new_impl )
276277 assert b .chunks == expected_chunks
277278 assert_array_equal (
278279 b .compute (executor = executor ),
279280 np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]]),
280281 )
281282
282283
283- def test_rechunk_same_chunks (spec ):
284+ @pytest .mark .parametrize ("use_new_impl" , [True , False ])
285+ def test_rechunk_same_chunks (spec , use_new_impl ):
284286 a = xp .asarray ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]], chunks = (2 , 1 ), spec = spec )
285- b = a .rechunk ((2 , 1 ))
287+ b = a .rechunk ((2 , 1 ), use_new_impl = use_new_impl )
286288 assert b is a
287289 task_counter = TaskCounter ()
288290 res = b .compute (callbacks = [task_counter ])
@@ -293,11 +295,12 @@ def test_rechunk_same_chunks(spec):
293295
294296
295297# see also test_rechunk.py
296- def test_rechunk_intermediate (tmp_path ):
297- # factor of 4 is for chunks copies, extra 8 is for map_selection
298- spec = cubed .Spec (tmp_path , allowed_mem = 5 * 8 * 4 + 8 )
298+ @pytest .mark .parametrize (("use_new_impl" , "factor" ), [(True , 5 ), (False , 4 )])
299+ def test_rechunk_intermediate (tmp_path , use_new_impl , factor ):
300+ # factor is for chunks copies, extra 8 is for map_selection
301+ spec = cubed .Spec (tmp_path , allowed_mem = 5 * 8 * factor + 8 )
299302 a = xp .ones ((5 , 5 ), chunks = (1 , 5 ), spec = spec )
300- b = a .rechunk ((5 , 1 ))
303+ b = a .rechunk ((5 , 1 ), use_new_impl = use_new_impl )
301304 assert_array_equal (b .compute (), np .ones ((5 , 5 )))
302305 # intermediates = [n for (n, d) in b.plan.dag.nodes(data=True) if "-int" in d["name"]]
303306 # assert len(intermediates) == 1
@@ -310,11 +313,12 @@ def test_rechunk_intermediate(tmp_path):
310313
311314
312315def test_rechunk_merge_chunks_optimization ():
316+ # new impl doesn't use this optimization
313317 a = xp .asarray (
314318 [[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ], [9 , 10 , 11 , 12 ], [13 , 14 , 15 , 16 ]],
315319 chunks = (2 , 1 ),
316320 )
317- b = a .rechunk ((4 , 2 ))
321+ b = a .rechunk ((4 , 2 ), use_new_impl = False )
318322 assert b .chunks == ((4 ,), (2 , 2 ))
319323 assert_array_equal (
320324 b .compute (),
0 commit comments