|
3 | 3 |
|
4 | 4 | import torch |
5 | 5 | from settings import DEVICE |
| 6 | +from torch import Tensor, nn |
6 | 7 | from torch.profiler import ProfilerActivity, profile |
| 8 | +from torch.utils._pytree import PyTree |
7 | 9 | from utils.architectures import ( |
8 | 10 | AlexNet, |
9 | 11 | Cifar10Model, |
@@ -105,15 +107,19 @@ def _save_and_print_trace( |
105 | 107 |
|
106 | 108 |
|
107 | 109 | def profile_autojac(factory: ModuleFactory, batch_size: int) -> None: |
108 | | - def forward_backward_fn(model, inputs, loss_fn) -> None: |
| 110 | + def forward_backward_fn( |
| 111 | + model: nn.Module, inputs: PyTree, loss_fn: Callable[[PyTree], list[Tensor]] |
| 112 | + ) -> None: |
109 | 113 | aggregator = UPGrad() |
110 | 114 | autojac_forward_backward(model, inputs, loss_fn, aggregator) |
111 | 115 |
|
112 | 116 | profile_method("autojac", forward_backward_fn, factory, batch_size) |
113 | 117 |
|
114 | 118 |
|
115 | 119 | def profile_autogram(factory: ModuleFactory, batch_size: int) -> None: |
116 | | - def forward_backward_fn(model, inputs, loss_fn) -> None: |
| 120 | + def forward_backward_fn( |
| 121 | + model: nn.Module, inputs: PyTree, loss_fn: Callable[[PyTree], list[Tensor]] |
| 122 | + ) -> None: |
117 | 123 | engine = Engine(model, batch_dim=0) |
118 | 124 | weighting = UPGradWeighting() |
119 | 125 | autogram_forward_backward(model, inputs, loss_fn, engine, weighting) |
|
0 commit comments