Skip to content

Commit 4c2b730

Browse files
committed
Add type annotations to time_call
1 parent aacf671 commit 4c2b730

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

tests/profiling/speed_grad_vs_jac_vs_gram.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import gc
22
import time
3+
from collections.abc import Callable
34

45
import torch
56
from settings import DEVICE
@@ -125,7 +126,13 @@ def noop() -> None:
125126
pass
126127

127128

128-
def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> Tensor:
129+
def time_call(
130+
fn: Callable[[], None],
131+
init_fn: Callable[[], None] = noop,
132+
pre_fn: Callable[[], None] = noop,
133+
post_fn: Callable[[], None] = noop,
134+
n_runs: int = 10,
135+
) -> Tensor:
129136
init_fn()
130137

131138
times = []

0 commit comments

Comments
 (0)