Skip to content

Commit 1e03d34

Browse files
authored
test: Add plot_memory_timeline.py (#548)
* Add matplotlib as a plot dependency * Add paths.py file to store paths * Add plot_memory_timeline.py
1 parent c5bd2eb commit 1e03d34

File tree

4 files changed

+106
-5
lines changed

4 files changed

+106
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ plot = [
9090
"plotly>=5.19.0", # Recent version to avoid problems, could be relaxed
9191
"dash>=2.16.0", # Recent version to avoid problems, could be relaxed
9292
"kaleido==0.2.1", # Only works with locked version
93+
"matplotlib>=3.10.0", # Recent version to avoid problems, could be relaxed
9394
]
9495

9596
[project.optional-dependencies]

tests/paths.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from pathlib import Path
2+
3+
TORCHJD_DIR = Path(__file__).parent.parent
4+
TRACES_DIR = TORCHJD_DIR / "traces"
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
Script to plot memory timeline evolution from profiling traces.
3+
Reads memory traces from json files and plots them on a single graph.
4+
"""
5+
6+
import argparse
7+
import json
8+
from dataclasses import dataclass
9+
from pathlib import Path
10+
11+
import matplotlib.pyplot as plt
12+
import numpy as np
13+
from paths import TRACES_DIR
14+
15+
16+
@dataclass
17+
class MemoryFrame:
18+
timestamp: int
19+
total_allocated: int # in bytes
20+
21+
@staticmethod
22+
def from_event(event: dict):
23+
args = event["args"]
24+
return MemoryFrame(
25+
timestamp=event["ts"],
26+
total_allocated=args.get("Total Allocated"),
27+
)
28+
29+
30+
def extract_memory_timeline(path: Path) -> np.ndarray:
31+
with open(path, "r") as f:
32+
data = json.load(f)
33+
34+
events = data["traceEvents"]
35+
print(f"Total events in trace: {len(events):,}")
36+
print("Extracting memory frames...")
37+
38+
frames = [MemoryFrame.from_event(e) for e in events if e["name"] == "[memory]"]
39+
frames.sort(key=lambda frame: frame.timestamp)
40+
41+
print(f"Found {len(frames):,} memory frames")
42+
43+
timestamp_list = [frame.timestamp for frame in frames]
44+
total_allocated_list = [frame.total_allocated for frame in frames]
45+
46+
return np.array([timestamp_list, total_allocated_list]).T
47+
48+
49+
def plot_memory_timelines(experiment: str, folders: list[str]) -> None:
50+
timelines = list[np.ndarray]()
51+
for folder in folders:
52+
path = TRACES_DIR / folder / f"{experiment}.json"
53+
timelines.append(extract_memory_timeline(path))
54+
55+
fig, ax = plt.subplots(figsize=(12, 6))
56+
for folder, timeline in zip(folders, timelines):
57+
time = (timeline[:, 0] - timeline[0, 0]) // 1000 # Make time start at 0 and convert to ms.
58+
memory = timeline[:, 1]
59+
ax.plot(time, memory, label=folder, linewidth=1.5)
60+
61+
ax.set_xlabel("Time (ms)", fontsize=12)
62+
ax.set_ylabel("Total Allocated (bytes)", fontsize=12)
63+
ax.set_title(f"Memory Timeline: {experiment}", fontsize=14, fontweight="bold")
64+
ax.legend(loc="best", fontsize=11)
65+
ax.grid(True, alpha=0.3)
66+
ax.set_ylim(bottom=0)
67+
fig.tight_layout()
68+
69+
output_dir = Path(TRACES_DIR / "memory_timelines")
70+
output_dir.mkdir(parents=True, exist_ok=True)
71+
output_path = output_dir / f"{experiment}.png"
72+
print(f"\nSaving plot to: {output_path}")
73+
fig.savefig(output_path, dpi=300, bbox_inches="tight")
74+
print("Plot saved successfully!")
75+
76+
77+
def main():
78+
parser = argparse.ArgumentParser(description="Plot memory timeline from profiling traces.")
79+
parser.add_argument(
80+
"experiment",
81+
type=str,
82+
help="Name of the experiment under profiling (e.g., 'WithTransformerLarge()-bs4-cpu')",
83+
)
84+
parser.add_argument(
85+
"folders",
86+
nargs="+",
87+
type=str,
88+
help="Folder names containing the traces (e.g., autojac_old autojac_new)",
89+
)
90+
91+
args = parser.parse_args()
92+
93+
return plot_memory_timelines(args.experiment, args.folders)
94+
95+
96+
if __name__ == "__main__":
97+
main()

tests/profiling/run_profiler.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import gc
2-
from pathlib import Path
32
from typing import Callable
43

54
import torch
@@ -22,6 +21,7 @@
2221
)
2322
from utils.tensors import make_inputs_and_targets
2423

24+
from tests.paths import TRACES_DIR
2525
from torchjd.aggregation import UPGrad, UPGradWeighting
2626
from torchjd.autogram import Engine
2727

@@ -93,10 +93,9 @@ def _save_and_print_trace(
9393
prof: profile, method_name: str, factory: ModuleFactory, batch_size: int
9494
) -> None:
9595
filename = f"{factory}-bs{batch_size}-{DEVICE.type}.json"
96-
torchjd_dir = Path(__file__).parent.parent.parent
97-
traces_dir = torchjd_dir / "traces" / method_name
98-
traces_dir.mkdir(parents=True, exist_ok=True)
99-
trace_path = traces_dir / filename
96+
output_dir = TRACES_DIR / method_name
97+
output_dir.mkdir(parents=True, exist_ok=True)
98+
trace_path = output_dir / filename
10099

101100
prof.export_chrome_trace(str(trace_path))
102101
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=20))

0 commit comments

Comments
 (0)