|
22 | 22 | ) |
23 | 23 | from cubed.core.plan import arrays_to_plan |
24 | 24 | from cubed.tests.test_core import sqrts |
25 | | -from cubed.tests.utils import TaskCounter |
| 25 | +from cubed.tests.utils import TaskCounter, create_zarr |
26 | 26 |
|
27 | 27 |
|
28 | 28 | @pytest.fixture |
@@ -679,15 +679,19 @@ def stack_add(*a): |
679 | 679 | # \ / |
680 | 680 | # p |
681 | 681 | # |
682 | | -def test_fuse_large_fan_in_default(spec): |
683 | | - a = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
684 | | - b = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
685 | | - c = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
686 | | - d = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
687 | | - e = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
688 | | - f = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
689 | | - g = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
690 | | - h = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
| 682 | +def test_fuse_large_fan_in_default(tmp_path, spec): |
| 683 | + # use zarr input so that they count towards max_total_source_arrays |
| 684 | + store = tmp_path / "source.zarr" |
| 685 | + create_zarr(np.ones((2, 2)), chunks=(2, 2), store=store) |
| 686 | + |
| 687 | + a = cubed.from_zarr(store, spec=spec) |
| 688 | + b = cubed.from_zarr(store, spec=spec) |
| 689 | + c = cubed.from_zarr(store, spec=spec) |
| 690 | + d = cubed.from_zarr(store, spec=spec) |
| 691 | + e = cubed.from_zarr(store, spec=spec) |
| 692 | + f = cubed.from_zarr(store, spec=spec) |
| 693 | + g = cubed.from_zarr(store, spec=spec) |
| 694 | + h = cubed.from_zarr(store, spec=spec) |
691 | 695 |
|
692 | 696 | i = xp.add(a, b) |
693 | 697 | j = xp.add(c, d) |
@@ -738,15 +742,19 @@ def test_fuse_large_fan_in_default(spec): |
738 | 742 | # \ / |
739 | 743 | # p |
740 | 744 | # |
741 | | -def test_fuse_large_fan_in_override(spec): |
742 | | - a = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
743 | | - b = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
744 | | - c = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
745 | | - d = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
746 | | - e = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
747 | | - f = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
748 | | - g = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
749 | | - h = xp.ones((2, 2), chunks=(2, 2), spec=spec) |
| 745 | +def test_fuse_large_fan_in_override(tmp_path, spec): |
| 746 | + # use zarr input so that they count towards max_total_source_arrays |
| 747 | + store = tmp_path / "source.zarr" |
| 748 | + create_zarr(np.ones((2, 2)), chunks=(2, 2), store=store) |
| 749 | + |
| 750 | + a = cubed.from_zarr(store, spec=spec) |
| 751 | + b = cubed.from_zarr(store, spec=spec) |
| 752 | + c = cubed.from_zarr(store, spec=spec) |
| 753 | + d = cubed.from_zarr(store, spec=spec) |
| 754 | + e = cubed.from_zarr(store, spec=spec) |
| 755 | + f = cubed.from_zarr(store, spec=spec) |
| 756 | + g = cubed.from_zarr(store, spec=spec) |
| 757 | + h = cubed.from_zarr(store, spec=spec) |
750 | 758 |
|
751 | 759 | i = xp.add(a, b) |
752 | 760 | j = xp.add(c, d) |
|
0 commit comments