Skip to content

Commit d9c5aff

Browse files
committed
Improve ForwardBackwardTime
1 parent bc217c5 commit d9c5aff

File tree

3 files changed

+38
-16
lines changed

3 files changed

+38
-16
lines changed

src/arena/architectures.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from torch import Tensor, nn
2+
3+
4+
class Cifar10Model(nn.Module):
5+
N_CLASSES = 10
6+
7+
def __init__(self, activation_function: type[nn.Module] = nn.ELU):
8+
super().__init__()
9+
10+
self.seq = nn.Sequential(
11+
nn.Conv2d(3, 32, 3),
12+
activation_function(),
13+
nn.Conv2d(32, 64, 3, groups=32),
14+
nn.MaxPool2d(2),
15+
activation_function(),
16+
nn.Conv2d(64, 64, 3, groups=64),
17+
nn.MaxPool2d(3),
18+
activation_function(),
19+
nn.Flatten(),
20+
nn.Linear(1024, 128),
21+
activation_function(),
22+
nn.Linear(128, self.N_CLASSES),
23+
nn.Flatten(start_dim=0),
24+
)
25+
26+
def forward(self, x: Tensor) -> Tensor:
27+
return self.seq(x)

src/arena/interfaces.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __call__(self, _: str):
117117
def forward_backward(model: Module, input: Tensor, aggregator: GramianWeightedAggregator) -> None:
118118
output, vgp_fn = vgp_from_module_2(model, input)
119119
gramian = get_gramian(vgp_fn, output)
120-
weights = aggregator.weighting.weighting(gramian)
120+
weights = aggregator.weighting(gramian)
121121
output.backward(weights)
122122

123123
return forward_backward

src/arena/objectives.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
import torch
66
from torch import Tensor
7-
from torch.nn import Linear, MSELoss, ReLU, Sequential, Module
7+
from torch.nn import Linear, MSELoss, ReLU, Sequential
88
from torch.optim import SGD
9-
from torchjd.aggregation import Mean, Aggregator, UPGrad
9+
from torchjd.aggregation import Mean
1010

11+
from arena.architectures import Cifar10Model
1112
from arena.matrix_samplers import MatrixSampler, NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler
1213

1314

@@ -58,19 +59,18 @@ def __str__(self) -> str:
5859

5960

6061
class ForwardBackwardTime(Objective):
61-
def __init__(self, ns: list[int], device: str, iterations: int):
62-
self.ns = ns
62+
def __init__(self, device: str, iterations: int):
63+
torch.cuda.empty_cache()
6364
self.device = device
64-
shapes = zip(ns[:-1], ns[1:])
65-
layers = [Linear(n, m) for n, m in shapes]
66-
self.model = Sequential(*layers).to(device=device)
65+
self.model = Cifar10Model().to(device=device)
6766
self.iterations = iterations
67+
self.input_shape = (16, 3, 32, 32)
6868

6969
def __call__(self, forward_backward: Callable):
70-
aggregator = UPGrad()
70+
aggregator = Mean()
7171
total_time = 0.0
7272
for i in range(self.iterations + 1):
73-
x = torch.randn(self.ns[0], device=self.device)
73+
x = torch.randn(self.input_shape, device=self.device)
7474

7575
if self.device.startswith("cuda"):
7676
torch.cuda.synchronize()
@@ -86,12 +86,6 @@ def __call__(self, forward_backward: Callable):
8686
average_runtime = total_time / self.iterations
8787
return average_runtime
8888

89-
def __repr__(self) -> str:
90-
return f"{self.__class__.__name__}(ns={self.ns}, device={self.device}," f" iterations={self.iterations})"
91-
92-
def __str__(self) -> str:
93-
return f"AT({self.matrix_sampler}, {self.device}, x{self.iterations})"
94-
9589

9690
class DualProjectionPrimalFeasibilityObjective(Objective):
9791
def __init__(self, matrix_sampler: MatrixSampler, device: str, iterations: int):
@@ -283,6 +277,7 @@ def compute_kkt_conditions(
283277
],
284278
"mtl_backward_runtime": [MTLBackwardTime(n_tasks=50, device=device, iterations=100) for device in ["cpu", "cuda"]],
285279
"gramian_runtime": [GramianTime(100, 1000000, "cuda", 1)],
280+
"forward_backward_runtime": [ForwardBackwardTime("cuda", 1)],
286281
"project_weights": [
287282
DualProjectionPrimalFeasibilityObjective(matrix_sampler=cls(m, m, m - 1, torch.float32), device=device, iterations=10)
288283
for cls in [NormalSampler, StrongSampler, StrictlyWeakSampler, NonWeakSampler]

0 commit comments

Comments
 (0)