|
1 | 1 | import gc |
2 | | -import time |
3 | 2 |
|
4 | 3 | import torch |
5 | 4 | from device import DEVICE |
|
24 | 23 | ) |
25 | 24 | from utils.tensors import make_inputs_and_targets |
26 | 25 |
|
| 26 | +from tests.speed.utils import print_times, time_call |
27 | 27 | from torchjd.aggregation import Mean |
28 | 28 | from torchjd.autogram import Engine |
29 | 29 |
|
@@ -96,50 +96,20 @@ def post_fn(): |
96 | 96 | optionally_cuda_sync() |
97 | 97 |
|
98 | 98 | n_runs = 10 |
99 | | - autograd_times = torch.tensor(time_call(fn_autograd, init_fn_autograd, pre_fn, post_fn, n_runs)) |
100 | | - print(f"autograd times (avg = {autograd_times.mean():.5f}, std = {autograd_times.std():.5f}") |
101 | | - print(autograd_times) |
102 | | - print() |
| 99 | + autograd_times = time_call(fn_autograd, init_fn_autograd, pre_fn, post_fn, n_runs) |
| 100 | + print_times("autograd", autograd_times) |
103 | 101 |
|
104 | | - autograd_gramian_times = torch.tensor( |
105 | | - time_call(fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs) |
| 102 | + autograd_gramian_times = time_call( |
| 103 | + fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs |
106 | 104 | ) |
107 | | - print( |
108 | | - f"autograd gramian times (avg = {autograd_gramian_times.mean():.5f}, std = " |
109 | | - f"{autograd_gramian_times.std():.5f}" |
110 | | - ) |
111 | | - print(autograd_gramian_times) |
112 | | - print() |
| 105 | + print_times("autograd gramian", autograd_gramian_times) |
113 | 106 |
|
114 | | - autojac_times = torch.tensor(time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs)) |
115 | | - print(f"autojac times (avg = {autojac_times.mean():.5f}, std = {autojac_times.std():.5f}") |
116 | | - print(autojac_times) |
117 | | - print() |
| 107 | + autojac_times = time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs) |
| 108 | + print_times("autojac", autojac_times) |
118 | 109 |
|
119 | 110 | engine = Engine(model) |
120 | | - autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) |
121 | | - print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") |
122 | | - print(autogram_times) |
123 | | - print() |
124 | | - |
125 | | - |
126 | | -def noop(): |
127 | | - pass |
128 | | - |
129 | | - |
130 | | -def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> list[float]: |
131 | | - init_fn() |
132 | | - |
133 | | - times = [] |
134 | | - for _ in range(n_runs): |
135 | | - pre_fn() |
136 | | - start = time.perf_counter() |
137 | | - fn() |
138 | | - post_fn() |
139 | | - elapsed_time = time.perf_counter() - start |
140 | | - times.append(elapsed_time) |
141 | | - |
142 | | - return times |
| 111 | + autogram_times = time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs) |
| 112 | + print_times("autogram", autogram_times) |
143 | 113 |
|
144 | 114 |
|
145 | 115 | def main(): |
|
0 commit comments