Skip to content

Commit 5d4e76d

Browse files
authored
Add MemoryVisualizationCallback (#734)
* Add MemoryVisualizationCallback based on Tom Nicholas' work in https://gist.github.com/TomNicholas/c6a28f7c22c6981f75bce280d3e28283 Simplify TimelineVisualizationCallback * Rename 'max peak measured' to 'max actual usage'
1 parent a1c19b9 commit 5d4e76d

5 files changed

Lines changed: 138 additions & 63 deletions

File tree

cubed/diagnostics/history.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def on_compute_start(self, event):
1717
name=name,
1818
op_name=node["op_name"],
1919
projected_mem=primitive_op.projected_mem,
20+
allowed_mem=primitive_op.allowed_mem,
2021
reserved_mem=primitive_op.reserved_mem,
2122
num_tasks=primitive_op.num_tasks,
2223
)

cubed/diagnostics/mem_usage.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from dataclasses import asdict
2+
from pathlib import Path
3+
from typing import Optional
4+
5+
import matplotlib
6+
import matplotlib.pyplot as plt
7+
import pandas as pd
8+
9+
from cubed.runtime.pipeline import visit_nodes
10+
from cubed.runtime.types import Callback
11+
12+
matplotlib.use("Agg")
13+
14+
15+
class MemoryVisualizationCallback(Callback):
16+
def __init__(self, format: Optional[str] = "svg") -> None:
17+
self.format = format
18+
19+
def on_compute_start(self, event):
20+
plan = []
21+
for name, node in visit_nodes(event.dag):
22+
primitive_op = node["primitive_op"]
23+
plan.append(
24+
dict(
25+
name=name,
26+
op_name=node["op_name"],
27+
projected_mem=primitive_op.projected_mem,
28+
allowed_mem=primitive_op.allowed_mem,
29+
reserved_mem=primitive_op.reserved_mem,
30+
num_tasks=primitive_op.num_tasks,
31+
)
32+
)
33+
34+
self.plan = plan
35+
self.events = []
36+
37+
def on_task_end(self, event):
38+
self.events.append(asdict(event))
39+
40+
def on_compute_end(self, event):
41+
events_df = pd.DataFrame(self.events)
42+
plan_df = pd.DataFrame(self.plan)
43+
fig = generate_mem_usage(events_df, plan_df)
44+
45+
self.dst = Path(f"history/{event.compute_id}")
46+
self.dst.mkdir(parents=True, exist_ok=True)
47+
self.dst = self.dst / f"memory.{self.format}"
48+
49+
fig.savefig(self.dst)
50+
51+
52+
def generate_mem_usage(events_df, plan_df):
53+
# colours match those in https://cubed-dev.github.io/cubed/user-guide/memory.html
54+
55+
events_df = events_df.sort_values(by=["task_create_tstamp", "name"], ascending=True)
56+
projected_mem_map = plan_df.set_index("name")["projected_mem"].to_dict()
57+
58+
tstamp = events_df["task_result_tstamp"].astype("timedelta64[s]")
59+
events_df["time"] = (tstamp - tstamp.min()).astype(int)
60+
events_df["actual usage"] = events_df["peak_measured_mem_end"] / 1_000_000
61+
events_df["projected_mem"] = events_df.name.map(projected_mem_map) / 1_000_000
62+
63+
fig, ax = plt.subplots(figsize=(8, 6))
64+
65+
events_df.plot(
66+
kind="area", y="actual usage", ax=ax, use_index=True, color="#9fc5e8"
67+
)
68+
69+
allowed_mem = plan_df["allowed_mem"].max() / 1_000_000
70+
ax.axhline(allowed_mem, label="allowed", color="#e06666", linestyle="--")
71+
72+
reserved_mem = plan_df["reserved_mem"].max() / 1_000_000
73+
ax.axhline(
74+
reserved_mem,
75+
label="reserved",
76+
color="#f6b26b",
77+
linestyle="--",
78+
)
79+
80+
peak_measured_mem = events_df["peak_measured_mem_end"].max() / 1_000_000
81+
ax.axhline(peak_measured_mem, label="max actual usage", color="#6fa8dc")
82+
83+
events_df.plot(
84+
kind="line",
85+
y="projected_mem",
86+
ax=ax,
87+
use_index=True,
88+
label="projected",
89+
color="#93c47d",
90+
linestyle="--",
91+
)
92+
93+
ax.set_xlabel("Task number")
94+
ax.set_ylim(top=allowed_mem + 100)
95+
ax.set_ylabel("Task memory (MB)")
96+
ax.legend()
97+
98+
return fig

cubed/diagnostics/timeline.py

Lines changed: 35 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,63 @@
1-
import time
21
from dataclasses import asdict
32
from pathlib import Path
43
from typing import Optional
54

6-
import matplotlib.patches as mpatches
5+
import matplotlib
6+
import matplotlib.pyplot as plt
77
import numpy as np
88
import pandas as pd
9-
import pylab
109
import seaborn as sns
1110

1211
from cubed.runtime.types import Callback
1312

1413
sns.set_style("whitegrid")
15-
pylab.switch_backend("Agg")
14+
matplotlib.use("Agg")
1615

1716

1817
class TimelineVisualizationCallback(Callback):
1918
def __init__(self, format: Optional[str] = "svg") -> None:
2019
self.format = format
2120

2221
def on_compute_start(self, event):
23-
self.start_tstamp = time.time()
24-
self.stats = []
22+
self.events = []
2523

2624
def on_task_end(self, event):
27-
self.stats.append(asdict(event))
25+
self.events.append(asdict(event))
2826

2927
def on_compute_end(self, event):
30-
self.end_tstamp = time.time()
31-
32-
stats_df = pd.DataFrame(self.stats)
33-
stats_df = stats_df.sort_values(
34-
by=["task_create_tstamp", "name"], ascending=True
35-
)
36-
total_calls = len(stats_df)
37-
palette = sns.color_palette("deep", 6)
38-
39-
fig = pylab.figure(figsize=(10, 6))
40-
ax = fig.add_subplot(1, 1, 1)
41-
42-
y = np.arange(total_calls)
43-
point_size = 10
44-
45-
fields = [
46-
("task create", stats_df.task_create_tstamp - self.start_tstamp),
47-
("function start", stats_df.function_start_tstamp - self.start_tstamp),
48-
("function end", stats_df.function_end_tstamp - self.start_tstamp),
49-
("task result", stats_df.task_result_tstamp - self.start_tstamp),
50-
]
51-
52-
patches = []
53-
for f_i, (field_name, val) in enumerate(fields):
54-
ax.scatter(
55-
val, y, c=[palette[f_i]], edgecolor="none", s=point_size, alpha=0.8
56-
)
57-
patches.append(mpatches.Patch(color=palette[f_i], label=field_name))
58-
59-
ax.set_xlabel("Execution Time (sec)")
60-
ax.set_ylabel("Function Call")
61-
62-
legend = pylab.legend(handles=patches, loc="upper right", frameon=True)
63-
legend.get_frame().set_facecolor("#FFFFFF")
64-
65-
yplot_step = int(np.max([1, total_calls / 20]))
66-
y_ticks = np.arange(total_calls // yplot_step + 2) * yplot_step
67-
ax.set_yticks(y_ticks)
68-
ax.set_ylim(-0.02 * total_calls, total_calls * 1.02)
69-
for y in y_ticks:
70-
ax.axhline(y, c="k", alpha=0.1, linewidth=1)
71-
72-
max_seconds = np.max(self.end_tstamp - self.start_tstamp) * 1.25
73-
xplot_step = max(int(max_seconds / 8), 1)
74-
x_ticks = np.arange(max_seconds // xplot_step + 2) * xplot_step
75-
ax.set_xlim(0, max_seconds)
76-
77-
ax.set_xticks(x_ticks)
78-
for x in x_ticks:
79-
ax.axvline(x, c="k", alpha=0.2, linewidth=0.8)
80-
81-
ax.grid(False)
82-
fig.tight_layout()
28+
events_df = pd.DataFrame(self.events)
29+
fig = generate_timeline(events_df)
8330

8431
self.dst = Path(f"history/{event.compute_id}")
8532
self.dst.mkdir(parents=True, exist_ok=True)
8633
self.dst = self.dst / f"timeline.{self.format}"
8734

8835
fig.savefig(self.dst)
36+
37+
38+
def generate_timeline(events_df):
39+
events_df = events_df.sort_values(by=["task_create_tstamp", "name"], ascending=True)
40+
start_tstamp = events_df["task_create_tstamp"].min()
41+
total_calls = len(events_df)
42+
43+
fig, ax = plt.subplots(figsize=(10, 8))
44+
45+
y = np.arange(total_calls)
46+
point_size = 7
47+
48+
fields = [
49+
("task create", events_df.task_create_tstamp - start_tstamp),
50+
("function start", events_df.function_start_tstamp - start_tstamp),
51+
("function end", events_df.function_end_tstamp - start_tstamp),
52+
("task result", events_df.task_result_tstamp - start_tstamp),
53+
]
54+
55+
for f_i, (field_name, val) in enumerate(fields):
56+
ax.scatter(val, y, label=field_name, edgecolor="none", s=point_size, alpha=0.8)
57+
58+
ax.set_xlabel("Execution time (sec)")
59+
ax.set_ylabel("Task number")
60+
61+
ax.legend()
62+
63+
return fig

cubed/tests/test_executor_features.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import cubed.random
1515
from cubed.diagnostics import ProgressBar
1616
from cubed.diagnostics.history import HistoryCallback
17+
from cubed.diagnostics.mem_usage import MemoryVisualizationCallback
1718
from cubed.diagnostics.mem_warn import MemoryWarningCallback
1819
from cubed.diagnostics.rich import RichProgressBar
1920
from cubed.diagnostics.timeline import TimelineVisualizationCallback
@@ -101,13 +102,15 @@ def test_callbacks(spec, executor):
101102
progress = TqdmProgressBar()
102103
hist = HistoryCallback()
103104
timeline_viz = TimelineVisualizationCallback()
105+
memory_viz = MemoryVisualizationCallback()
104106

105107
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
106108
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
107109
c = xp.add(a, b)
108110
assert_array_equal(
109111
c.compute(
110-
executor=executor, callbacks=[task_counter, progress, hist, timeline_viz]
112+
executor=executor,
113+
callbacks=[task_counter, progress, hist, timeline_viz, memory_viz],
111114
),
112115
np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]),
113116
)

setup.cfg

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ ignore_missing_imports = True
6464
ignore_missing_imports = True
6565
[mypy-psutil.*]
6666
ignore_missing_imports = True
67-
[mypy-pylab.*]
68-
ignore_missing_imports = True
6967
[mypy-pytest.*]
7068
ignore_missing_imports = True
7169
[mypy-ray.*]

0 commit comments

Comments
 (0)