Skip to content

Commit db2228f

Browse files
committed
Merge branch 'main' into dev-new-engine
# Conflicts: # tests/speed/autogram/grad_vs_jac_vs_gram.py
2 parents 131df9a + 0b3836a commit db2228f

2 files changed

Lines changed: 39 additions & 40 deletions

File tree

tests/speed/autogram/grad_vs_jac_vs_gram.py

Lines changed: 10 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import gc
2-
import time
32

43
import torch
54
from device import DEVICE
@@ -24,6 +23,7 @@
2423
)
2524
from utils.tensors import make_inputs_and_targets
2625

26+
from tests.speed.utils import print_times, time_call
2727
from torchjd.aggregation import Mean
2828
from torchjd.autogram import Engine
2929

@@ -96,50 +96,20 @@ def post_fn():
9696
optionally_cuda_sync()
9797

9898
n_runs = 10
99-
autograd_times = torch.tensor(time_call(fn_autograd, init_fn_autograd, pre_fn, post_fn, n_runs))
100-
print(f"autograd times (avg = {autograd_times.mean():.5f}, std = {autograd_times.std():.5f}")
101-
print(autograd_times)
102-
print()
99+
autograd_times = time_call(fn_autograd, init_fn_autograd, pre_fn, post_fn, n_runs)
100+
print_times("autograd", autograd_times)
103101

104-
autograd_gramian_times = torch.tensor(
105-
time_call(fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs)
102+
autograd_gramian_times = time_call(
103+
fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs
106104
)
107-
print(
108-
f"autograd gramian times (avg = {autograd_gramian_times.mean():.5f}, std = "
109-
f"{autograd_gramian_times.std():.5f}"
110-
)
111-
print(autograd_gramian_times)
112-
print()
105+
print_times("autograd gramian", autograd_gramian_times)
113106

114-
autojac_times = torch.tensor(time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs))
115-
print(f"autojac times (avg = {autojac_times.mean():.5f}, std = {autojac_times.std():.5f}")
116-
print(autojac_times)
117-
print()
107+
autojac_times = time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs)
108+
print_times("autojac", autojac_times)
118109

119110
engine = Engine(model)
120-
autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs))
121-
print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}")
122-
print(autogram_times)
123-
print()
124-
125-
126-
def noop():
127-
pass
128-
129-
130-
def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> list[float]:
131-
init_fn()
132-
133-
times = []
134-
for _ in range(n_runs):
135-
pre_fn()
136-
start = time.perf_counter()
137-
fn()
138-
post_fn()
139-
elapsed_time = time.perf_counter() - start
140-
times.append(elapsed_time)
141-
142-
return times
111+
autogram_times = time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)
112+
print_times("autogram", autogram_times)
143113

144114

145115
def main():

tests/speed/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import time
2+
3+
import torch
4+
from torch import Tensor
5+
6+
7+
def noop():
8+
pass
9+
10+
11+
def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> Tensor:
12+
init_fn()
13+
14+
times = []
15+
for _ in range(n_runs):
16+
pre_fn()
17+
start = time.perf_counter()
18+
fn()
19+
post_fn()
20+
elapsed_time = time.perf_counter() - start
21+
times.append(elapsed_time)
22+
23+
return torch.tensor(times)
24+
25+
26+
def print_times(name: str, times: Tensor) -> None:
27+
print(f"{name} times (avg = {times.mean():.5f}, std = {times.std():.5f}")
28+
print(times)
29+
print()

0 commit comments

Comments
 (0)