Skip to content

Commit b18af43

Browse files
authored
Don't count virtual arrays towards max_total_source_arrays in optimization (#838)
1 parent d3862ed commit b18af43

3 files changed

Lines changed: 41 additions & 28 deletions

File tree

cubed/core/optimization.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
fuse,
99
fuse_multiple,
1010
)
11+
from cubed.storage.virtual import VirtualArray
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -152,15 +153,15 @@ def is_fusable_with_predecessors(node_dict):
152153

153154

154155
def num_source_arrays(dag, name):
155-
"""Return the number of (non-hidden) arrays that are inputs to an op.
156+
"""Return the number of (non-virtual) arrays that are inputs to an op.
156157
157-
Hidden arrays are used for internal bookkeeping, are very small virtual arrays
158-
(empty, or offsets for example), and are not shown on the plan visualization.
159-
For these reasons they shouldn't count towards ``max_total_source_arrays``.
158+
Virtual arrays are very small arrays that are held in memory and not read from storage,
159+
so they shouldn't count towards ``max_total_source_arrays``.
160160
"""
161161
nodes = dict(dag.nodes(data=True))
162162
return sum(
163-
not nodes[array]["hidden"] for array in predecessors_unordered(dag, name)
163+
not isinstance(nodes[array]["target"], VirtualArray)
164+
for array in predecessors_unordered(dag, name)
164165
)
165166

166167

cubed/storage/virtual.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
from cubed.utils import array_memory, broadcast_trick, memory_repr
1212

1313

14-
class VirtualEmptyArray(ArrayMetadata):
14+
class VirtualArray(ArrayMetadata):
15+
pass
16+
17+
18+
class VirtualEmptyArray(VirtualArray):
1519
"""An array that is never materialized (in memory or on disk) and contains empty values."""
1620

1721
def __init__(
@@ -34,7 +38,7 @@ def chunkmem(self):
3438
return array_memory(self.dtype, (1,))
3539

3640

37-
class VirtualFullArray(ArrayMetadata):
41+
class VirtualFullArray(VirtualArray):
3842
"""An array that is never materialized (in memory or on disk) and contains a single fill value."""
3943

4044
def __init__(
@@ -61,7 +65,7 @@ def chunkmem(self):
6165
return array_memory(self.dtype, (1,))
6266

6367

64-
class VirtualOffsetsArray(ArrayMetadata):
68+
class VirtualOffsetsArray(VirtualArray):
6569
"""An array that is never materialized (in memory or on disk) and contains sequentially incrementing integers."""
6670

6771
def __init__(self, shape: T_Shape):
@@ -77,7 +81,7 @@ def __getitem__(self, key):
7781
)
7882

7983

80-
class VirtualInMemoryArray(ArrayMetadata):
84+
class VirtualInMemoryArray(VirtualArray):
8185
"""A small array that is held in memory but never materialized on disk."""
8286

8387
def __init__(

cubed/tests/test_optimization.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from cubed.core.plan import arrays_to_plan
2424
from cubed.tests.test_core import sqrts
25-
from cubed.tests.utils import TaskCounter
25+
from cubed.tests.utils import TaskCounter, create_zarr
2626

2727

2828
@pytest.fixture
@@ -679,15 +679,19 @@ def stack_add(*a):
679679
# \ /
680680
# p
681681
#
682-
def test_fuse_large_fan_in_default(spec):
683-
a = xp.ones((2, 2), chunks=(2, 2), spec=spec)
684-
b = xp.ones((2, 2), chunks=(2, 2), spec=spec)
685-
c = xp.ones((2, 2), chunks=(2, 2), spec=spec)
686-
d = xp.ones((2, 2), chunks=(2, 2), spec=spec)
687-
e = xp.ones((2, 2), chunks=(2, 2), spec=spec)
688-
f = xp.ones((2, 2), chunks=(2, 2), spec=spec)
689-
g = xp.ones((2, 2), chunks=(2, 2), spec=spec)
690-
h = xp.ones((2, 2), chunks=(2, 2), spec=spec)
682+
def test_fuse_large_fan_in_default(tmp_path, spec):
683+
# use zarr input so that they count towards max_total_source_arrays
684+
store = tmp_path / "source.zarr"
685+
create_zarr(np.ones((2, 2)), chunks=(2, 2), store=store)
686+
687+
a = cubed.from_zarr(store, spec=spec)
688+
b = cubed.from_zarr(store, spec=spec)
689+
c = cubed.from_zarr(store, spec=spec)
690+
d = cubed.from_zarr(store, spec=spec)
691+
e = cubed.from_zarr(store, spec=spec)
692+
f = cubed.from_zarr(store, spec=spec)
693+
g = cubed.from_zarr(store, spec=spec)
694+
h = cubed.from_zarr(store, spec=spec)
691695

692696
i = xp.add(a, b)
693697
j = xp.add(c, d)
@@ -738,15 +742,19 @@ def test_fuse_large_fan_in_default(spec):
738742
# \ /
739743
# p
740744
#
741-
def test_fuse_large_fan_in_override(spec):
742-
a = xp.ones((2, 2), chunks=(2, 2), spec=spec)
743-
b = xp.ones((2, 2), chunks=(2, 2), spec=spec)
744-
c = xp.ones((2, 2), chunks=(2, 2), spec=spec)
745-
d = xp.ones((2, 2), chunks=(2, 2), spec=spec)
746-
e = xp.ones((2, 2), chunks=(2, 2), spec=spec)
747-
f = xp.ones((2, 2), chunks=(2, 2), spec=spec)
748-
g = xp.ones((2, 2), chunks=(2, 2), spec=spec)
749-
h = xp.ones((2, 2), chunks=(2, 2), spec=spec)
745+
def test_fuse_large_fan_in_override(tmp_path, spec):
746+
# use zarr input so that they count towards max_total_source_arrays
747+
store = tmp_path / "source.zarr"
748+
create_zarr(np.ones((2, 2)), chunks=(2, 2), store=store)
749+
750+
a = cubed.from_zarr(store, spec=spec)
751+
b = cubed.from_zarr(store, spec=spec)
752+
c = cubed.from_zarr(store, spec=spec)
753+
d = cubed.from_zarr(store, spec=spec)
754+
e = cubed.from_zarr(store, spec=spec)
755+
f = cubed.from_zarr(store, spec=spec)
756+
g = cubed.from_zarr(store, spec=spec)
757+
h = cubed.from_zarr(store, spec=spec)
750758

751759
i = xp.add(a, b)
752760
j = xp.add(c, d)

0 commit comments

Comments
 (0)