@@ -56,15 +56,14 @@ def __init__(self, epsilon: float, max_iters: int):
5656 self .epsilon = epsilon
5757 self .max_iters = max_iters
5858
59- def _frank_wolfe_solver (self , matrix : Tensor ) -> Tensor :
60- gramian = compute_gramian (matrix )
61- device = matrix .device
62- dtype = matrix .dtype
59+ def _frank_wolfe_solver (self , gramian : Tensor ) -> Tensor :
60+ device = gramian .device
61+ dtype = gramian .dtype
6362
64- alpha = torch .ones (matrix .shape [0 ], device = device , dtype = dtype ) / matrix .shape [0 ]
63+ alpha = torch .ones (gramian .shape [0 ], device = device , dtype = dtype ) / gramian .shape [0 ]
6564 for i in range (self .max_iters ):
6665 t = torch .argmin (gramian @ alpha )
67- e_t = torch .zeros (matrix .shape [0 ], device = device , dtype = dtype )
66+ e_t = torch .zeros (gramian .shape [0 ], device = device , dtype = dtype )
6867 e_t [t ] = 1.0
6968 a = alpha @ (gramian @ e_t )
7069 b = alpha @ (gramian @ alpha )
@@ -81,5 +80,6 @@ def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor:
8180 return alpha
8281
8382 def forward (self , matrix : Tensor ) -> Tensor :
84- weights = self ._frank_wolfe_solver (matrix )
83+ gramian = compute_gramian (matrix )
84+ weights = self ._frank_wolfe_solver (gramian )
8585 return weights
0 commit comments