Skip to content

Commit a258254

Browse files
authored
test: Add profiler (#519)
* Add traces/ to .gitignore * Add test.profiling package * Move speed tests to tests.profiling * Add run_profiler in tests.profiling * Fix call to jac_to_grad in autojac_forward_backward to give it directly the list of model params (instead of a generator)
1 parent b0adff6 commit a258254

File tree

7 files changed

+169
-36
lines changed

7 files changed

+169
-36
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Profiling results
2+
traces/
3+
14
# uv
25
uv.lock
36

tests/profiling/run_profiler.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import gc
2+
from pathlib import Path
3+
from typing import Callable
4+
5+
import torch
6+
from settings import DEVICE
7+
from torch.profiler import ProfilerActivity, profile
8+
from utils.architectures import (
9+
AlexNet,
10+
Cifar10Model,
11+
GroupNormMobileNetV3Small,
12+
InstanceNormMobileNetV2,
13+
InstanceNormResNet18,
14+
ModuleFactory,
15+
SqueezeNet,
16+
WithTransformerLarge,
17+
)
18+
from utils.forward_backwards import (
19+
autogram_forward_backward,
20+
autojac_forward_backward,
21+
make_mse_loss_fn,
22+
)
23+
from utils.tensors import make_inputs_and_targets
24+
25+
from torchjd.aggregation import UPGrad, UPGradWeighting
26+
from torchjd.autogram import Engine
27+
28+
PARAMETRIZATIONS = [
29+
(ModuleFactory(WithTransformerLarge), 4),
30+
(ModuleFactory(Cifar10Model), 64),
31+
(ModuleFactory(AlexNet), 4),
32+
(ModuleFactory(InstanceNormResNet18), 4),
33+
(ModuleFactory(GroupNormMobileNetV3Small), 8),
34+
(ModuleFactory(SqueezeNet), 4),
35+
(ModuleFactory(InstanceNormMobileNetV2), 2),
36+
]
37+
38+
39+
def profile_method(
40+
method_name: str,
41+
forward_backward_fn: Callable,
42+
factory: ModuleFactory,
43+
batch_size: int,
44+
) -> None:
45+
"""
46+
Profiles memory and computation time of a forward and backward pass.
47+
48+
:param method_name: Name of the method being profiled (for output paths)
49+
:param forward_backward_fn: Function to execute the forward and backward pass.
50+
:param factory: A ModuleFactory that creates the model to profile.
51+
:param batch_size: The batch size to use for profiling.
52+
"""
53+
print(f"{method_name}: {factory} with batch_size={batch_size} on {DEVICE}:")
54+
55+
_clear_unused_memory()
56+
model = factory()
57+
inputs, targets = make_inputs_and_targets(model, batch_size)
58+
loss_fn = make_mse_loss_fn(targets)
59+
60+
activities = _get_profiler_activities()
61+
62+
# Warmup run
63+
forward_backward_fn(model, inputs, loss_fn)
64+
model.zero_grad()
65+
_clear_unused_memory()
66+
67+
# Profiled run
68+
with profile(
69+
activities=activities,
70+
profile_memory=True,
71+
record_shapes=False, # Otherwise some tensors may be referenced longer than normal
72+
with_stack=True,
73+
) as prof:
74+
forward_backward_fn(model, inputs, loss_fn)
75+
76+
_save_and_print_trace(prof, method_name, factory, batch_size)
77+
78+
79+
def _clear_unused_memory() -> None:
80+
gc.collect()
81+
if torch.cuda.is_available():
82+
torch.cuda.empty_cache()
83+
84+
85+
def _get_profiler_activities() -> list[ProfilerActivity]:
86+
activities = [ProfilerActivity.CPU]
87+
if DEVICE.type == "cuda":
88+
activities.append(ProfilerActivity.CUDA)
89+
return activities
90+
91+
92+
def _save_and_print_trace(
93+
prof: profile, method_name: str, factory: ModuleFactory, batch_size: int
94+
) -> None:
95+
filename = f"{factory}-bs{batch_size}-{DEVICE.type}.json"
96+
torchjd_dir = Path(__file__).parent.parent.parent
97+
traces_dir = torchjd_dir / "traces" / method_name
98+
traces_dir.mkdir(parents=True, exist_ok=True)
99+
trace_path = traces_dir / filename
100+
101+
prof.export_chrome_trace(str(trace_path))
102+
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=20))
103+
104+
105+
def profile_autojac(factory: ModuleFactory, batch_size: int) -> None:
106+
def forward_backward_fn(model, inputs, loss_fn):
107+
aggregator = UPGrad()
108+
autojac_forward_backward(model, inputs, loss_fn, aggregator)
109+
110+
profile_method("autojac", forward_backward_fn, factory, batch_size)
111+
112+
113+
def profile_autogram(factory: ModuleFactory, batch_size: int) -> None:
114+
def forward_backward_fn(model, inputs, loss_fn):
115+
engine = Engine(model, batch_dim=0)
116+
weighting = UPGradWeighting()
117+
autogram_forward_backward(model, inputs, loss_fn, engine, weighting)
118+
119+
profile_method("autogram", forward_backward_fn, factory, batch_size)
120+
121+
122+
def main():
123+
for factory, batch_size in PARAMETRIZATIONS:
124+
profile_autojac(factory, batch_size)
125+
print("\n" + "=" * 80 + "\n")
126+
profile_autogram(factory, batch_size)
127+
print("\n" + "=" * 80 + "\n")
128+
129+
130+
if __name__ == "__main__":
131+
# To test this on cuda, add the following environment variables when running this:
132+
# CUBLAS_WORKSPACE_CONFIG=:4096:8;PYTEST_TORCH_DEVICE=cuda:0
133+
main()

tests/speed/autogram/grad_vs_jac_vs_gram.py renamed to tests/profiling/speed_grad_vs_jac_vs_gram.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import gc
2+
import time
23

34
import torch
45
from settings import DEVICE
6+
from torch import Tensor
57
from utils.architectures import (
68
AlexNet,
79
Cifar10Model,
@@ -23,7 +25,6 @@
2325
)
2426
from utils.tensors import make_inputs_and_targets
2527

26-
from tests.speed.utils import print_times, time_call
2728
from torchjd.aggregation import Mean
2829
from torchjd.autogram import Engine
2930

@@ -40,6 +41,12 @@
4041
]
4142

4243

44+
def main():
45+
for factory, batch_size in PARAMETRIZATIONS:
46+
compare_autograd_autojac_and_autogram_speed(factory, batch_size)
47+
print("\n")
48+
49+
4350
def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int):
4451
model = factory()
4552
inputs, targets = make_inputs_and_targets(model, batch_size)
@@ -85,7 +92,7 @@ def init_fn_autogram():
8592
fn_autogram()
8693

8794
def optionally_cuda_sync():
88-
if str(DEVICE).startswith("cuda"):
95+
if DEVICE.type == "cuda":
8996
torch.cuda.synchronize()
9097

9198
def pre_fn():
@@ -112,10 +119,29 @@ def post_fn():
112119
print_times("autogram", autogram_times)
113120

114121

115-
def main():
116-
for factory, batch_size in PARAMETRIZATIONS:
117-
compare_autograd_autojac_and_autogram_speed(factory, batch_size)
118-
print("\n")
122+
def noop():
123+
pass
124+
125+
126+
def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> Tensor:
127+
init_fn()
128+
129+
times = []
130+
for _ in range(n_runs):
131+
pre_fn()
132+
start = time.perf_counter()
133+
fn()
134+
post_fn()
135+
elapsed_time = time.perf_counter() - start
136+
times.append(elapsed_time)
137+
138+
return torch.tensor(times)
139+
140+
141+
def print_times(name: str, times: Tensor) -> None:
142+
print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f})")
143+
print(times)
144+
print()
119145

120146

121147
if __name__ == "__main__":

tests/speed/autogram/__init__.py

Whitespace-only changes.

tests/speed/utils.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

tests/utils/forward_backwards.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def autojac_forward_backward(
3131
) -> None:
3232
losses = forward_pass(model, inputs, loss_fn, reduce_to_vector)
3333
backward(losses)
34-
jac_to_grad(model.parameters(), aggregator)
34+
jac_to_grad(list(model.parameters()), aggregator)
3535

3636

3737
def autograd_gramian_forward_backward(

0 commit comments

Comments
 (0)