44
55import torch
66from torch import Tensor
7- from torch .nn import Linear , MSELoss , ReLU , Sequential , Module
7+ from torch .nn import Linear , MSELoss , ReLU , Sequential
88from torch .optim import SGD
9- from torchjd .aggregation import Mean , Aggregator , UPGrad
9+ from torchjd .aggregation import Mean
1010
11+ from arena .architectures import Cifar10Model
1112from arena .matrix_samplers import MatrixSampler , NonWeakSampler , NormalSampler , StrictlyWeakSampler , StrongSampler
1213
1314
@@ -58,19 +59,18 @@ def __str__(self) -> str:
5859
5960
6061class ForwardBackwardTime (Objective ):
61- def __init__ (self , ns : list [ int ], device : str , iterations : int ):
62- self . ns = ns
62+ def __init__ (self , device : str , iterations : int ):
63+ torch . cuda . empty_cache ()
6364 self .device = device
64- shapes = zip (ns [:- 1 ], ns [1 :])
65- layers = [Linear (n , m ) for n , m in shapes ]
66- self .model = Sequential (* layers ).to (device = device )
65+ self .model = Cifar10Model ().to (device = device )
6766 self .iterations = iterations
67+ self .input_shape = (16 , 3 , 32 , 32 )
6868
6969 def __call__ (self , forward_backward : Callable ):
70- aggregator = UPGrad ()
70+ aggregator = Mean ()
7171 total_time = 0.0
7272 for i in range (self .iterations + 1 ):
73- x = torch .randn (self .ns [ 0 ] , device = self .device )
73+ x = torch .randn (self .input_shape , device = self .device )
7474
7575 if self .device .startswith ("cuda" ):
7676 torch .cuda .synchronize ()
@@ -86,12 +86,6 @@ def __call__(self, forward_backward: Callable):
8686 average_runtime = total_time / self .iterations
8787 return average_runtime
8888
89- def __repr__ (self ) -> str :
90- return f"{ self .__class__ .__name__ } (ns={ self .ns } , device={ self .device } ," f" iterations={ self .iterations } )"
91-
92- def __str__ (self ) -> str :
93- return f"AT({ self .matrix_sampler } , { self .device } , x{ self .iterations } )"
94-
9589
9690class DualProjectionPrimalFeasibilityObjective (Objective ):
9791 def __init__ (self , matrix_sampler : MatrixSampler , device : str , iterations : int ):
@@ -283,6 +277,7 @@ def compute_kkt_conditions(
283277 ],
284278 "mtl_backward_runtime" : [MTLBackwardTime (n_tasks = 50 , device = device , iterations = 100 ) for device in ["cpu" , "cuda" ]],
285279 "gramian_runtime" : [GramianTime (100 , 1000000 , "cuda" , 1 )],
280+ "forward_backward_runtime" : [ForwardBackwardTime ("cuda" , 1 )],
286281 "project_weights" : [
287282 DualProjectionPrimalFeasibilityObjective (matrix_sampler = cls (m , m , m - 1 , torch .float32 ), device = device , iterations = 10 )
288283 for cls in [NormalSampler , StrongSampler , StrictlyWeakSampler , NonWeakSampler ]
0 commit comments