Skip to content

Commit 8b42b71

Browse files
authored
Minor improvements to rechunk_new (#732)
* Fix errors in new rechunk implementation * Change rechunk unit tests to test both implementations
1 parent 79d8616 commit 8b42b71

2 files changed

Lines changed: 14 additions & 14 deletions

File tree

cubed/core/ops.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,14 +1154,10 @@ def _rechunk_plan(x, chunks, *, min_mem=None):
11541154

11551155
normalized_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)
11561156
if x.chunks == normalized_chunks:
1157-
return x
1157+
return
11581158
# normalizing takes care of dict args for chunks
11591159
target_chunks = to_chunksize(normalized_chunks)
11601160

1161-
# merge chunks special case
1162-
if all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunks)):
1163-
return merge_chunks(x, target_chunks)
1164-
11651161
spec = x.spec
11661162
source_chunks = to_chunksize(normalize_chunks(x.chunks, x.shape, dtype=x.dtype))
11671163

cubed/tests/test_core.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -270,19 +270,21 @@ def test_multiple_ops(spec, executor):
270270
({}, ((2, 1), (1, 1, 1))), # unchanged
271271
],
272272
)
273-
def test_rechunk(spec, executor, new_chunks, expected_chunks):
273+
@pytest.mark.parametrize("use_new_impl", [True, False])
274+
def test_rechunk(spec, executor, new_chunks, expected_chunks, use_new_impl):
274275
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 1), spec=spec)
275-
b = a.rechunk(new_chunks)
276+
b = a.rechunk(new_chunks, use_new_impl=use_new_impl)
276277
assert b.chunks == expected_chunks
277278
assert_array_equal(
278279
b.compute(executor=executor),
279280
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
280281
)
281282

282283

283-
def test_rechunk_same_chunks(spec):
284+
@pytest.mark.parametrize("use_new_impl", [True, False])
285+
def test_rechunk_same_chunks(spec, use_new_impl):
284286
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 1), spec=spec)
285-
b = a.rechunk((2, 1))
287+
b = a.rechunk((2, 1), use_new_impl=use_new_impl)
286288
assert b is a
287289
task_counter = TaskCounter()
288290
res = b.compute(callbacks=[task_counter])
@@ -293,11 +295,12 @@ def test_rechunk_same_chunks(spec):
293295

294296

295297
# see also test_rechunk.py
296-
def test_rechunk_intermediate(tmp_path):
297-
# factor of 4 is for chunks copies, extra 8 is for map_selection
298-
spec = cubed.Spec(tmp_path, allowed_mem=5 * 8 * 4 + 8)
298+
@pytest.mark.parametrize(("use_new_impl", "factor"), [(True, 5), (False, 4)])
299+
def test_rechunk_intermediate(tmp_path, use_new_impl, factor):
300+
# factor is for chunks copies, extra 8 is for map_selection
301+
spec = cubed.Spec(tmp_path, allowed_mem=5 * 8 * factor + 8)
299302
a = xp.ones((5, 5), chunks=(1, 5), spec=spec)
300-
b = a.rechunk((5, 1))
303+
b = a.rechunk((5, 1), use_new_impl=use_new_impl)
301304
assert_array_equal(b.compute(), np.ones((5, 5)))
302305
# intermediates = [n for (n, d) in b.plan.dag.nodes(data=True) if "-int" in d["name"]]
303306
# assert len(intermediates) == 1
@@ -310,11 +313,12 @@ def test_rechunk_intermediate(tmp_path):
310313

311314

312315
def test_rechunk_merge_chunks_optimization():
316+
# new impl doesn't use this optimization
313317
a = xp.asarray(
314318
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
315319
chunks=(2, 1),
316320
)
317-
b = a.rechunk((4, 2))
321+
b = a.rechunk((4, 2), use_new_impl=False)
318322
assert b.chunks == ((4,), (2, 2))
319323
assert_array_equal(
320324
b.compute(),

0 commit comments

Comments
 (0)