Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions cubed/tests/test_rechunk_hypothesis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from math import prod

from hypothesis import Phase, assume, given, settings
from hypothesis import strategies as st
from hypothesis.extra.array_api import make_strategies_namespace

import cubed
import cubed.array_api as xp
from cubed._testing import assert_array_equal
from cubed.backend_array_api import namespace as nxp

xps = make_strategies_namespace(nxp)


@st.composite
def rechunk_shapes(draw):
shape = draw(xps.array_shapes(min_dims=2, max_dims=2, min_side=1001))
source_chunks = tuple(draw(st.integers(min_value=5, max_value=s)) for s in shape)
target_chunks = tuple(draw(st.integers(min_value=5, max_value=s)) for s in shape)
return (shape, source_chunks, target_chunks)


@given(rechunk_shapes())
@settings(
deadline=None,
max_examples=100,
phases=[Phase.explicit, Phase.reuse, Phase.generate, Phase.target],
)
def test_rechunk(rechunk_shapes):
shape, source_chunks, target_chunks = rechunk_shapes

size = prod(shape)
source_chunks_size = prod(source_chunks)
target_chunks_size = prod(target_chunks)

assume(size / source_chunks_size < 1000)
assume(size / target_chunks_size < 1000)
if source_chunks_size > target_chunks_size:
assume(source_chunks_size / target_chunks_size < 100)
else:
assume(target_chunks_size / source_chunks_size < 100)

print(rechunk_shapes)

itemsize = 8
allowed_mem = int(max(source_chunks_size, target_chunks_size) * itemsize * 1.1) * 5

spec = cubed.Spec(allowed_mem=allowed_mem)
a = xp.ones(shape, chunks=source_chunks, spec=spec)
b = a.rechunk(target_chunks)
plan = b.plan()
assume(plan.num_tasks < 500) # don't want to run too many tasks locally
print(
f"plan: num_stages: {plan.num_stages}, num_tasks: {plan.num_tasks}, max_projected_mem: {plan.max_projected_mem}"
)

assert_array_equal(b.compute(), nxp.ones(shape))


def test_rechunk_example():
rechunk_shapes = (tuple([1001, 1001]), (38, 376), (5, 146))
shape, source_chunks, target_chunks = rechunk_shapes

spec = cubed.Spec(allowed_mem=8000000 / 10)
a = xp.ones(shape, chunks=source_chunks, spec=spec)
b = a.rechunk(target_chunks)
plan = b.plan()
print(
f"plan: num_stages: {plan.num_stages}, num_tasks: {plan.num_tasks}, max_projected_mem: {plan.max_projected_mem}"
)

assert_array_equal(b.compute(), nxp.ones(shape))
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ coiled = [
test = [
"cubed[diagnostics]",
"dill",
"hypothesis",
"numpy_groupies",
"obstore",
"pytest",
Expand Down
Loading