diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 10de200f6..b1eb17526 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -4,10 +4,11 @@ import shutil import tempfile import uuid +import warnings from datetime import datetime from enum import Enum from functools import lru_cache -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import networkx as nx @@ -271,25 +272,21 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap return dag - def _check_projected_mem(self, dag) -> None: - op_name = None - max_projected_mem_op = None + def _find_ops_exceeding_memory(self, dag) -> List[Tuple[str, "PrimitiveOperation"]]: + """Find all operations where projected memory exceeds allowed memory. + + Returns a list of (op_name, primitive_op) tuples for operations that + exceed memory limits, sorted by projected memory (highest first). + """ + ops_exceeding = [] for n, d in dag.nodes(data=True): if "primitive_op" in d: op = d["primitive_op"] - if ( - max_projected_mem_op is None - or op.projected_mem > max_projected_mem_op.projected_mem - ): - op_name = n - max_projected_mem_op = op - if max_projected_mem_op is not None: - op = max_projected_mem_op - if op.projected_mem > op.allowed_mem: - raise ValueError( - f"Projected blockwise memory ({memory_repr(op.projected_mem)}) exceeds allowed_mem ({memory_repr(op.allowed_mem)}), " - f"including reserved_mem ({memory_repr(op.reserved_mem)}) for {op_name}" - ) + if op.projected_mem > op.allowed_mem: + ops_exceeding.append((n, op)) + # Sort by projected_mem descending so worst offenders are first + ops_exceeding.sort(key=lambda x: x[1].projected_mem, reverse=True) + return ops_exceeding @lru_cache # noqa: B019 def _finalize( @@ -304,8 +301,10 @@ def _finalize( if callable(compile_function): dag = self._compile_blockwise(dag, compile_function) dag = self._create_lazy_zarr_arrays(dag) - self._check_projected_mem(dag) - return FinalizedPlan(nx.freeze(dag), self.array_names, optimize_graph) + ops_exceeding_memory = self._find_ops_exceeding_memory(dag) + return FinalizedPlan( + nx.freeze(dag), self.array_names, optimize_graph, ops_exceeding_memory + ) class ArrayRole(Enum): @@ -324,10 +323,11 @@ class FinalizedPlan: 4. freezing the final DAG so it can't be changed """ - def __init__(self, dag, array_names, optimized): + def __init__(self, dag, array_names, optimized, ops_exceeding_memory=None): self.dag = dag self.array_names = array_names self.optimized = optimized + self._ops_exceeding_memory = ops_exceeding_memory or [] self._calculate_stats() self.input_array_names = [] @@ -540,6 +540,34 @@ def total_nchunks(self) -> int: """The total number of chunks for all materialized arrays in this plan.""" return self._total_nchunks + @property + def exceeds_memory(self) -> bool: + """True if any operation in this plan exceeds the allowed memory.""" + return len(self._ops_exceeding_memory) > 0 + + @property + def ops_exceeding_memory(self) -> List[Tuple[str, "PrimitiveOperation"]]: + """List of (op_name, primitive_op) tuples for operations exceeding memory. + + Sorted by projected memory (highest first). + """ + return self._ops_exceeding_memory + + def validate(self) -> None: + """Validate that this plan can be executed. + + Raises + ------ + ValueError + If any operation's projected memory exceeds the allowed memory. + """ + if self._ops_exceeding_memory: + op_name, op = self._ops_exceeding_memory[0] # Report worst offender + raise ValueError( + f"Projected blockwise memory ({memory_repr(op.projected_mem)}) exceeds allowed_mem ({memory_repr(op.allowed_mem)}), " + f"including reserved_mem ({memory_repr(op.reserved_mem)}) for {op_name}" + ) + def execute( self, executor=None, @@ -548,6 +576,8 @@ def execute( spec=None, **kwargs, ): + self.validate() + dag = self.dag if resume: @@ -580,6 +610,15 @@ def visualize( rankdir="TB", show_hidden=False, ): + if self._ops_exceeding_memory: + op_names = [name for name, _ in self._ops_exceeding_memory] + warnings.warn( + f"Plan has {len(self._ops_exceeding_memory)} operation(s) that exceed allowed memory: {op_names}. " + "These are shown in red in the visualization.", + stacklevel=2, + ) + ops_exceeding_names = {name for name, _ in self._ops_exceeding_memory} + dag = self.dag.copy() # make a copy since we mutate the DAG below # remove edges from create-arrays output node to avoid cluttering the diagram @@ -590,19 +629,39 @@ def visualize( list(n for n, d in dag.nodes(data=True) if d.get("hidden", False)) ) + # Build the graph label - use HTML-like label for mixed colors if memory exceeded + stats_text = ( + f"num tasks: {self.num_tasks}
" + f"max projected memory: {memory_repr(self.max_projected_mem)}
" + f"total nbytes written: {memory_repr(self.total_nbytes_written)}
" + f"optimized: {self.optimized}
" + ) + + if self._ops_exceeding_memory: + # Build warning text in red + warning_lines = [ + "
!!! MEMORY EXCEEDED !!!
" + ] + for op_name, op in self._ops_exceeding_memory: + warning_lines.append( + f"{op_name}: requires {memory_repr(op.projected_mem)}, " + f"allowed {memory_repr(op.allowed_mem)}
" + ) + warning_text = "".join(warning_lines) + # HTML-like label with mixed colors + label = f"<{stats_text}{warning_text}>" + else: + # Simple HTML label (no warning) + label = f"<{stats_text}>" + dag.graph["graph"] = { "rankdir": rankdir, - "label": ( - # note that \l is used to left-justify each line (see https://www.graphviz.org/docs/attrs/nojustify/) - rf"num tasks: {self.num_tasks}\l" - rf"max projected memory: {memory_repr(self.max_projected_mem)}\l" - rf"total nbytes written: {memory_repr(self.total_nbytes_written)}\l" - rf"optimized: {self.optimized}\l" - ), + "label": label, "labelloc": "bottom", "labeljust": "left", "fontsize": "10", } + dag.graph["node"] = {"fontname": "helvetica", "shape": "box", "fontsize": "10"} # do an initial pass to extract array variable names from stack summaries @@ -627,7 +686,11 @@ def visualize( func_name = d["func_name"] label = f"{n}\n{func_name}".strip() op_name = d["op_name"] - if op_name == "blockwise": + if n in ops_exceeding_names: + # operation exceeds memory - show in red + d["style"] = '"rounded,filled"' + d["fillcolor"] = "#ff6b6b" + elif op_name == "blockwise": d["style"] = '"rounded,filled"' d["fillcolor"] = "#dcbeff" elif op_name == "rechunk": diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index 29baa60d9..e5599010f 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -493,11 +493,29 @@ def test_default_spec_allowed_mem_exceeded(): # default spec fails for large computations a = xp.ones((20000, 10000), chunks=(10000, 10000)) b = xp.negative(a) + # plan() succeeds but marks plan as exceeding memory + plan = b.plan() + assert plan.exceeds_memory + assert len(plan.ops_exceeding_memory) == 1 + # compute() raises the error with pytest.raises( ValueError, match=r"Projected blockwise memory \(.+\) exceeds allowed_mem \(.+\), including reserved_mem \(.+\) for op-\d+", ): - b.plan() + b.compute() + + +def test_default_spec_allowed_mem_exceeded_visualize(tmp_path): + # visualize works but warns when memory is exceeded + import warnings + + a = xp.ones((20000, 10000), chunks=(10000, 10000)) + b = xp.negative(a) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + b.visualize(filename=str(tmp_path / "cubed")) + assert len(w) == 1 + assert "exceed allowed memory" in str(w[0].message) def test_default_spec_config_override():