Skip to content

Commit 121a9b0

Browse files
authored
tests: Add compare_traces script (#566)
1 parent 2f8fef8 commit 121a9b0

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed

tests/profiling/compare_traces.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Compare profiling traces between autojac_old and autojac_new."""
2+
3+
import json
4+
from pathlib import Path
5+
6+
7+
def find_event_duration(trace_data: dict, event_name: str) -> float | None:
8+
"""Find the duration of a specific event in the trace.
9+
10+
:param trace_data: The parsed JSON trace data
11+
:param event_name: The name of the event to find (e.g., "jac_to_grad")
12+
"""
13+
events = trace_data.get("traceEvents", [])
14+
for event in events:
15+
if "name" in event and "dur" in event and event["name"].endswith(f": {event_name}"):
16+
return event["dur"] / 1000.0 # Convert microseconds to milliseconds
17+
return None
18+
19+
20+
def parse_filename(filename: str) -> tuple[str, int, str]:
21+
"""Parse model name, batch size, and device from filename.
22+
23+
:param filename: The trace filename (e.g., "AlexNet()-bs4-cpu.json")
24+
"""
25+
# Remove .json extension
26+
name = filename.replace(".json", "")
27+
# Split by -bs and -
28+
parts = name.split("-bs")
29+
model = parts[0].replace("()", "")
30+
rest = parts[1].split("-")
31+
batch_size = int(rest[0])
32+
device = rest[1]
33+
return model, batch_size, device
34+
35+
36+
def compare_traces() -> None:
37+
"""Compare traces between autojac_old and autojac_new directories."""
38+
traces_dir = Path(__file__).parent.parent.parent / "traces"
39+
old_dir = traces_dir / "autojac_old"
40+
new_dir = traces_dir / "autojac_new"
41+
42+
# Collect data
43+
cpu_data = []
44+
cuda_data = []
45+
46+
for old_file in sorted(old_dir.glob("*.json")):
47+
model, batch_size, device = parse_filename(old_file.name)
48+
new_file = new_dir / old_file.name
49+
50+
if not new_file.exists():
51+
print(f"Warning: {new_file.name} not found in autojac_new")
52+
continue
53+
54+
# Load trace files
55+
with open(old_file) as f:
56+
old_trace = json.load(f)
57+
with open(new_file) as f:
58+
new_trace = json.load(f)
59+
60+
# Find durations for both events
61+
old_jac_to_grad = find_event_duration(old_trace, "jac_to_grad")
62+
new_jac_to_grad = find_event_duration(new_trace, "jac_to_grad")
63+
old_forward_backward = find_event_duration(old_trace, "autojac_forward_backward")
64+
new_forward_backward = find_event_duration(new_trace, "autojac_forward_backward")
65+
66+
if old_jac_to_grad is None or new_jac_to_grad is None:
67+
print(f"Warning: jac_to_grad not found in {old_file.name}")
68+
continue
69+
if old_forward_backward is None or new_forward_backward is None:
70+
print(f"Warning: autojac_forward_backward not found in {old_file.name}")
71+
continue
72+
73+
# Calculate differences
74+
diff_jac_to_grad = new_jac_to_grad - old_jac_to_grad
75+
diff_forward_backward = new_forward_backward - old_forward_backward
76+
pct_jac_to_grad = (diff_jac_to_grad / old_jac_to_grad) * 100
77+
pct_forward_backward = (diff_forward_backward / old_forward_backward) * 100
78+
79+
row = {
80+
"model": model,
81+
"batch_size": batch_size,
82+
"old_jac_to_grad": old_jac_to_grad,
83+
"new_jac_to_grad": new_jac_to_grad,
84+
"diff_jac_to_grad": diff_jac_to_grad,
85+
"pct_jac_to_grad": pct_jac_to_grad,
86+
"old_forward_backward": old_forward_backward,
87+
"new_forward_backward": new_forward_backward,
88+
"diff_forward_backward": diff_forward_backward,
89+
"pct_forward_backward": pct_forward_backward,
90+
}
91+
92+
if device == "cpu":
93+
cpu_data.append(row)
94+
else:
95+
cuda_data.append(row)
96+
97+
# Print tables
98+
print("CPU Traces Comparison")
99+
print_table(cpu_data)
100+
101+
print("\nCUDA Traces Comparison")
102+
print_table(cuda_data)
103+
104+
105+
def print_table(data: list[dict]) -> None:
106+
"""Print a formatted comparison table.
107+
108+
:param data: List of row dictionaries with timing data
109+
"""
110+
# Header
111+
header = (
112+
"|Model|Batch Size|Time before (jac_to_grad)|Time after (jac_to_grad)|"
113+
"Difference (jac_to_grad)|Time before (autojac_forward_backward)|"
114+
"Time after (autojac_forward_backward)|Difference (autojac_forward_backward)|"
115+
)
116+
separator = "|---|---|---|---|---|---|---|---|"
117+
118+
print(header)
119+
print(separator)
120+
121+
# Rows
122+
for row in data:
123+
# Format differences with + sign for positive values
124+
diff_jac = round(row["diff_jac_to_grad"])
125+
pct_jac = round(row["pct_jac_to_grad"])
126+
diff_fb = round(row["diff_forward_backward"])
127+
pct_fb = round(row["pct_forward_backward"])
128+
129+
diff_jac_str = f"+{diff_jac}" if diff_jac > 0 else str(diff_jac)
130+
pct_jac_str = f"+{pct_jac}" if pct_jac > 0 else str(pct_jac)
131+
diff_fb_str = f"+{diff_fb}" if diff_fb > 0 else str(diff_fb)
132+
pct_fb_str = f"+{pct_fb}" if pct_fb > 0 else str(pct_fb)
133+
134+
print(
135+
f"|{row['model']}|{row['batch_size']}|"
136+
f"{round(row['old_jac_to_grad'])} ms|"
137+
f"{round(row['new_jac_to_grad'])} ms|"
138+
f"{diff_jac_str} ms ({pct_jac_str}%)|"
139+
f"{round(row['old_forward_backward'])} ms|"
140+
f"{round(row['new_forward_backward'])} ms|"
141+
f"{diff_fb_str} ms ({pct_fb_str}%)|",
142+
)
143+
144+
145+
if __name__ == "__main__":
146+
compare_traces()

0 commit comments

Comments
 (0)