Skip to content

Commit 2f53b64

Browse files
committed
Add type annotations in run_profiler.py
1 parent 4c2b730 commit 2f53b64

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

tests/profiling/run_profiler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
import torch
55
from settings import DEVICE
6+
from torch import Tensor, nn
67
from torch.profiler import ProfilerActivity, profile
8+
from torch.utils._pytree import PyTree
79
from utils.architectures import (
810
AlexNet,
911
Cifar10Model,
@@ -105,15 +107,19 @@ def _save_and_print_trace(
105107

106108

107109
def profile_autojac(factory: ModuleFactory, batch_size: int) -> None:
108-
def forward_backward_fn(model, inputs, loss_fn) -> None:
110+
def forward_backward_fn(
111+
model: nn.Module, inputs: PyTree, loss_fn: Callable[[PyTree], list[Tensor]]
112+
) -> None:
109113
aggregator = UPGrad()
110114
autojac_forward_backward(model, inputs, loss_fn, aggregator)
111115

112116
profile_method("autojac", forward_backward_fn, factory, batch_size)
113117

114118

115119
def profile_autogram(factory: ModuleFactory, batch_size: int) -> None:
116-
def forward_backward_fn(model, inputs, loss_fn) -> None:
120+
def forward_backward_fn(
121+
model: nn.Module, inputs: PyTree, loss_fn: Callable[[PyTree], list[Tensor]]
122+
) -> None:
117123
engine = Engine(model, batch_dim=0)
118124
weighting = UPGradWeighting()
119125
autogram_forward_backward(model, inputs, loss_fn, engine, weighting)

0 commit comments

Comments
 (0)