Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
4fb71b5
Refactor memory check method to find operations exceeding memory limits
neilSchroeder Dec 1, 2025
8beb9eb
Refactor memory check to identify operations exceeding allowed memory…
neilSchroeder Dec 1, 2025
f28a23b
Add operation memory check to finalize plan creation
neilSchroeder Dec 1, 2025
40ea158
Add optional parameter for operations exceeding memory in FinalizedPl…
neilSchroeder Dec 1, 2025
ade9dbb
Add handling for operations exceeding memory in FinalizedPlan
neilSchroeder Dec 1, 2025
d18ab14
Add memory validation and reporting in FinalizedPlan
neilSchroeder Dec 1, 2025
916b1c1
Add validation call in FinalizedPlan constructor
neilSchroeder Dec 1, 2025
900c8d9
Add warning for operations exceeding memory in FinalizedPlan visualiz…
neilSchroeder Dec 1, 2025
f6dc5ab
Add warnings import for enhanced memory management in FinalizedPlan
neilSchroeder Dec 1, 2025
ac2c403
Refactor imports in plan.py for improved organization
neilSchroeder Dec 1, 2025
f15bea8
Add HTML warning for memory exceeded in FinalizedPlan visualization
neilSchroeder Dec 1, 2025
eadb8dc
Refactor FinalizedPlan graph label to use a predefined variable
neilSchroeder Dec 1, 2025
bd9ca4f
Add missing line break for improved readability in FinalizedPlan class
neilSchroeder Dec 1, 2025
6031855
Highlight operations exceeding memory in red within FinalizedPlan vis…
neilSchroeder Dec 1, 2025
99c754c
Add test for plan exceeding memory in default spec
neilSchroeder Dec 1, 2025
a57491c
Add visualization test for memory exceeded warning in default spec
neilSchroeder Dec 1, 2025
4987167
Merge branch 'cubed-dev:main' into visualize-plans-that-exceed-memory
neilSchroeder Dec 1, 2025
bc76787
lint
neilSchroeder Dec 1, 2025
0d45b82
Merge branch 'visualize-plans-that-exceed-memory' of https://github.c…
neilSchroeder Dec 1, 2025
a778403
Update memory exceeded warning text format in FinalizedPlan class
neilSchroeder Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 91 additions & 28 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import shutil
import tempfile
import uuid
import warnings
from datetime import datetime
from enum import Enum
from functools import lru_cache
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple

import networkx as nx

Expand Down Expand Up @@ -271,25 +272,21 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap

return dag

def _check_projected_mem(self, dag) -> None:
op_name = None
max_projected_mem_op = None
def _find_ops_exceeding_memory(self, dag) -> List[Tuple[str, "PrimitiveOperation"]]:
"""Find all operations where projected memory exceeds allowed memory.

Returns a list of (op_name, primitive_op) tuples for operations that
exceed memory limits, sorted by projected memory (highest first).
"""
ops_exceeding = []
for n, d in dag.nodes(data=True):
if "primitive_op" in d:
op = d["primitive_op"]
if (
max_projected_mem_op is None
or op.projected_mem > max_projected_mem_op.projected_mem
):
op_name = n
max_projected_mem_op = op
if max_projected_mem_op is not None:
op = max_projected_mem_op
if op.projected_mem > op.allowed_mem:
raise ValueError(
f"Projected blockwise memory ({memory_repr(op.projected_mem)}) exceeds allowed_mem ({memory_repr(op.allowed_mem)}), "
f"including reserved_mem ({memory_repr(op.reserved_mem)}) for {op_name}"
)
if op.projected_mem > op.allowed_mem:
ops_exceeding.append((n, op))
# Sort by projected_mem descending so worst offenders are first
ops_exceeding.sort(key=lambda x: x[1].projected_mem, reverse=True)
return ops_exceeding

@lru_cache # noqa: B019
def _finalize(
Expand All @@ -304,8 +301,10 @@ def _finalize(
if callable(compile_function):
dag = self._compile_blockwise(dag, compile_function)
dag = self._create_lazy_zarr_arrays(dag)
self._check_projected_mem(dag)
return FinalizedPlan(nx.freeze(dag), self.array_names, optimize_graph)
ops_exceeding_memory = self._find_ops_exceeding_memory(dag)
return FinalizedPlan(
nx.freeze(dag), self.array_names, optimize_graph, ops_exceeding_memory
)


class ArrayRole(Enum):
Expand All @@ -324,10 +323,11 @@ class FinalizedPlan:
4. freezing the final DAG so it can't be changed
"""

def __init__(self, dag, array_names, optimized):
def __init__(self, dag, array_names, optimized, ops_exceeding_memory=None):
self.dag = dag
self.array_names = array_names
self.optimized = optimized
self._ops_exceeding_memory = ops_exceeding_memory or []
self._calculate_stats()

self.input_array_names = []
Expand Down Expand Up @@ -540,6 +540,34 @@ def total_nchunks(self) -> int:
"""The total number of chunks for all materialized arrays in this plan."""
return self._total_nchunks

@property
def exceeds_memory(self) -> bool:
"""True if any operation in this plan exceeds the allowed memory."""
return len(self._ops_exceeding_memory) > 0

@property
def ops_exceeding_memory(self) -> List[Tuple[str, "PrimitiveOperation"]]:
"""List of (op_name, primitive_op) tuples for operations exceeding memory.

Sorted by projected memory (highest first).
"""
return self._ops_exceeding_memory

def validate(self) -> None:
"""Validate that this plan can be executed.

Raises
------
ValueError
If any operation's projected memory exceeds the allowed memory.
"""
if self._ops_exceeding_memory:
op_name, op = self._ops_exceeding_memory[0] # Report worst offender
raise ValueError(
f"Projected blockwise memory ({memory_repr(op.projected_mem)}) exceeds allowed_mem ({memory_repr(op.allowed_mem)}), "
f"including reserved_mem ({memory_repr(op.reserved_mem)}) for {op_name}"
)

def execute(
self,
executor=None,
Expand All @@ -548,6 +576,8 @@ def execute(
spec=None,
**kwargs,
):
self.validate()

dag = self.dag

if resume:
Expand Down Expand Up @@ -580,6 +610,15 @@ def visualize(
rankdir="TB",
show_hidden=False,
):
if self._ops_exceeding_memory:
op_names = [name for name, _ in self._ops_exceeding_memory]
warnings.warn(
f"Plan has {len(self._ops_exceeding_memory)} operation(s) that exceed allowed memory: {op_names}. "
"These are shown in red in the visualization.",
stacklevel=2,
)
ops_exceeding_names = {name for name, _ in self._ops_exceeding_memory}

dag = self.dag.copy() # make a copy since we mutate the DAG below

# remove edges from create-arrays output node to avoid cluttering the diagram
Expand All @@ -590,19 +629,39 @@ def visualize(
list(n for n, d in dag.nodes(data=True) if d.get("hidden", False))
)

# Build the graph label - use HTML-like label for mixed colors if memory exceeded
stats_text = (
f"num tasks: {self.num_tasks}<BR ALIGN='LEFT'/>"
f"max projected memory: {memory_repr(self.max_projected_mem)}<BR ALIGN='LEFT'/>"
f"total nbytes written: {memory_repr(self.total_nbytes_written)}<BR ALIGN='LEFT'/>"
f"optimized: {self.optimized}<BR ALIGN='LEFT'/>"
)

if self._ops_exceeding_memory:
# Build warning text in red
warning_lines = [
"<BR ALIGN='LEFT'/>!!! MEMORY EXCEEDED !!!<BR ALIGN='LEFT'/>"
]
for op_name, op in self._ops_exceeding_memory:
warning_lines.append(
f"{op_name}: requires {memory_repr(op.projected_mem)}, "
f"allowed {memory_repr(op.allowed_mem)}<BR ALIGN='LEFT'/>"
)
warning_text = "".join(warning_lines)
# HTML-like label with mixed colors
label = f"<<FONT>{stats_text}</FONT><FONT COLOR='#cc0000'>{warning_text}</FONT>>"
else:
# Simple HTML label (no warning)
label = f"<{stats_text}>"

dag.graph["graph"] = {
"rankdir": rankdir,
"label": (
# note that \l is used to left-justify each line (see https://www.graphviz.org/docs/attrs/nojustify/)
rf"num tasks: {self.num_tasks}\l"
rf"max projected memory: {memory_repr(self.max_projected_mem)}\l"
rf"total nbytes written: {memory_repr(self.total_nbytes_written)}\l"
rf"optimized: {self.optimized}\l"
),
"label": label,
"labelloc": "bottom",
"labeljust": "left",
"fontsize": "10",
}

dag.graph["node"] = {"fontname": "helvetica", "shape": "box", "fontsize": "10"}

# do an initial pass to extract array variable names from stack summaries
Expand All @@ -627,7 +686,11 @@ def visualize(
func_name = d["func_name"]
label = f"{n}\n{func_name}".strip()
op_name = d["op_name"]
if op_name == "blockwise":
if n in ops_exceeding_names:
# operation exceeds memory - show in red
d["style"] = '"rounded,filled"'
d["fillcolor"] = "#ff6b6b"
elif op_name == "blockwise":
d["style"] = '"rounded,filled"'
d["fillcolor"] = "#dcbeff"
elif op_name == "rechunk":
Expand Down
20 changes: 19 additions & 1 deletion cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,11 +493,29 @@ def test_default_spec_allowed_mem_exceeded():
# default spec fails for large computations
a = xp.ones((20000, 10000), chunks=(10000, 10000))
b = xp.negative(a)
# plan() succeeds but marks plan as exceeding memory
plan = b.plan()
assert plan.exceeds_memory
assert len(plan.ops_exceeding_memory) == 1
# compute() raises the error
with pytest.raises(
ValueError,
match=r"Projected blockwise memory \(.+\) exceeds allowed_mem \(.+\), including reserved_mem \(.+\) for op-\d+",
):
b.plan()
b.compute()


def test_default_spec_allowed_mem_exceeded_visualize(tmp_path):
# visualize works but warns when memory is exceeded
import warnings

a = xp.ones((20000, 10000), chunks=(10000, 10000))
b = xp.negative(a)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
b.visualize(filename=str(tmp_path / "cubed"))
assert len(w) == 1
assert "exceed allowed memory" in str(w[0].message)


def test_default_spec_config_override():
Expand Down