diff --git a/cubed/core/plan.py b/cubed/core/plan.py
index 10de200f6..b1eb17526 100644
--- a/cubed/core/plan.py
+++ b/cubed/core/plan.py
@@ -4,10 +4,11 @@
import shutil
import tempfile
import uuid
+import warnings
from datetime import datetime
from enum import Enum
from functools import lru_cache
-from typing import Any, Callable, Dict, Optional
+from typing import Any, Callable, Dict, List, Optional, Tuple
import networkx as nx
@@ -271,25 +272,21 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap
return dag
- def _check_projected_mem(self, dag) -> None:
- op_name = None
- max_projected_mem_op = None
+ def _find_ops_exceeding_memory(self, dag) -> List[Tuple[str, "PrimitiveOperation"]]:
+ """Find all operations where projected memory exceeds allowed memory.
+
+ Returns a list of (op_name, primitive_op) tuples for operations that
+ exceed memory limits, sorted by projected memory (highest first).
+ """
+ ops_exceeding = []
for n, d in dag.nodes(data=True):
if "primitive_op" in d:
op = d["primitive_op"]
- if (
- max_projected_mem_op is None
- or op.projected_mem > max_projected_mem_op.projected_mem
- ):
- op_name = n
- max_projected_mem_op = op
- if max_projected_mem_op is not None:
- op = max_projected_mem_op
- if op.projected_mem > op.allowed_mem:
- raise ValueError(
- f"Projected blockwise memory ({memory_repr(op.projected_mem)}) exceeds allowed_mem ({memory_repr(op.allowed_mem)}), "
- f"including reserved_mem ({memory_repr(op.reserved_mem)}) for {op_name}"
- )
+ if op.projected_mem > op.allowed_mem:
+ ops_exceeding.append((n, op))
+ # Sort by projected_mem descending so worst offenders are first
+ ops_exceeding.sort(key=lambda x: x[1].projected_mem, reverse=True)
+ return ops_exceeding
@lru_cache # noqa: B019
def _finalize(
@@ -304,8 +301,10 @@ def _finalize(
if callable(compile_function):
dag = self._compile_blockwise(dag, compile_function)
dag = self._create_lazy_zarr_arrays(dag)
- self._check_projected_mem(dag)
- return FinalizedPlan(nx.freeze(dag), self.array_names, optimize_graph)
+ ops_exceeding_memory = self._find_ops_exceeding_memory(dag)
+ return FinalizedPlan(
+ nx.freeze(dag), self.array_names, optimize_graph, ops_exceeding_memory
+ )
class ArrayRole(Enum):
@@ -324,10 +323,11 @@ class FinalizedPlan:
4. freezing the final DAG so it can't be changed
"""
- def __init__(self, dag, array_names, optimized):
+ def __init__(self, dag, array_names, optimized, ops_exceeding_memory=None):
self.dag = dag
self.array_names = array_names
self.optimized = optimized
+ self._ops_exceeding_memory = ops_exceeding_memory or []
self._calculate_stats()
self.input_array_names = []
@@ -540,6 +540,34 @@ def total_nchunks(self) -> int:
"""The total number of chunks for all materialized arrays in this plan."""
return self._total_nchunks
+ @property
+ def exceeds_memory(self) -> bool:
+ """True if any operation in this plan exceeds the allowed memory."""
+ return len(self._ops_exceeding_memory) > 0
+
+ @property
+ def ops_exceeding_memory(self) -> List[Tuple[str, "PrimitiveOperation"]]:
+ """List of (op_name, primitive_op) tuples for operations exceeding memory.
+
+ Sorted by projected memory (highest first).
+ """
+ return self._ops_exceeding_memory
+
+ def validate(self) -> None:
+ """Validate that this plan can be executed.
+
+ Raises
+ ------
+ ValueError
+ If any operation's projected memory exceeds the allowed memory.
+ """
+ if self._ops_exceeding_memory:
+ op_name, op = self._ops_exceeding_memory[0] # Report worst offender
+ raise ValueError(
+ f"Projected blockwise memory ({memory_repr(op.projected_mem)}) exceeds allowed_mem ({memory_repr(op.allowed_mem)}), "
+ f"including reserved_mem ({memory_repr(op.reserved_mem)}) for {op_name}"
+ )
+
def execute(
self,
executor=None,
@@ -548,6 +576,8 @@ def execute(
spec=None,
**kwargs,
):
+ self.validate()
+
dag = self.dag
if resume:
@@ -580,6 +610,15 @@ def visualize(
rankdir="TB",
show_hidden=False,
):
+ if self._ops_exceeding_memory:
+ op_names = [name for name, _ in self._ops_exceeding_memory]
+ warnings.warn(
+ f"Plan has {len(self._ops_exceeding_memory)} operation(s) that exceed allowed memory: {op_names}. "
+ "These are shown in red in the visualization.",
+ stacklevel=2,
+ )
+ ops_exceeding_names = {name for name, _ in self._ops_exceeding_memory}
+
dag = self.dag.copy() # make a copy since we mutate the DAG below
# remove edges from create-arrays output node to avoid cluttering the diagram
@@ -590,19 +629,39 @@ def visualize(
list(n for n, d in dag.nodes(data=True) if d.get("hidden", False))
)
+ # Build the graph label - use HTML-like label for mixed colors if memory exceeded
+ stats_text = (
+ f"num tasks: {self.num_tasks}
"
+ f"max projected memory: {memory_repr(self.max_projected_mem)}
"
+ f"total nbytes written: {memory_repr(self.total_nbytes_written)}
"
+ f"optimized: {self.optimized}
"
+ )
+
+ if self._ops_exceeding_memory:
+ # Build warning text in red
+ warning_lines = [
+ "
!!! MEMORY EXCEEDED !!!
"
+ ]
+ for op_name, op in self._ops_exceeding_memory:
+ warning_lines.append(
+ f"{op_name}: requires {memory_repr(op.projected_mem)}, "
+ f"allowed {memory_repr(op.allowed_mem)}
"
+ )
+ warning_text = "".join(warning_lines)
+ # HTML-like label with mixed colors
+ label = f"<{stats_text}{warning_text}>"
+ else:
+ # Simple HTML label (no warning)
+ label = f"<{stats_text}>"
+
dag.graph["graph"] = {
"rankdir": rankdir,
- "label": (
- # note that \l is used to left-justify each line (see https://www.graphviz.org/docs/attrs/nojustify/)
- rf"num tasks: {self.num_tasks}\l"
- rf"max projected memory: {memory_repr(self.max_projected_mem)}\l"
- rf"total nbytes written: {memory_repr(self.total_nbytes_written)}\l"
- rf"optimized: {self.optimized}\l"
- ),
+ "label": label,
"labelloc": "bottom",
"labeljust": "left",
"fontsize": "10",
}
+
dag.graph["node"] = {"fontname": "helvetica", "shape": "box", "fontsize": "10"}
# do an initial pass to extract array variable names from stack summaries
@@ -627,7 +686,11 @@ def visualize(
func_name = d["func_name"]
label = f"{n}\n{func_name}".strip()
op_name = d["op_name"]
- if op_name == "blockwise":
+ if n in ops_exceeding_names:
+ # operation exceeds memory - show in red
+ d["style"] = '"rounded,filled"'
+ d["fillcolor"] = "#ff6b6b"
+ elif op_name == "blockwise":
d["style"] = '"rounded,filled"'
d["fillcolor"] = "#dcbeff"
elif op_name == "rechunk":
diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py
index 29baa60d9..e5599010f 100644
--- a/cubed/tests/test_core.py
+++ b/cubed/tests/test_core.py
@@ -493,11 +493,29 @@ def test_default_spec_allowed_mem_exceeded():
# default spec fails for large computations
a = xp.ones((20000, 10000), chunks=(10000, 10000))
b = xp.negative(a)
+ # plan() succeeds but marks plan as exceeding memory
+ plan = b.plan()
+ assert plan.exceeds_memory
+ assert len(plan.ops_exceeding_memory) == 1
+ # compute() raises the error
with pytest.raises(
ValueError,
match=r"Projected blockwise memory \(.+\) exceeds allowed_mem \(.+\), including reserved_mem \(.+\) for op-\d+",
):
- b.plan()
+ b.compute()
+
+
+def test_default_spec_allowed_mem_exceeded_visualize(tmp_path):
+ # visualize works but warns when memory is exceeded
+ import warnings
+
+ a = xp.ones((20000, 10000), chunks=(10000, 10000))
+ b = xp.negative(a)
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ b.visualize(filename=str(tmp_path / "cubed"))
+ assert len(w) == 1
+ assert "exceed allowed memory" in str(w[0].message)
def test_default_spec_config_override():