Skip to content

Commit e962c99

Browse files
committed
Make dependence on gramian explicit in MGDA
1 parent 0a9fd42 commit e962c99

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/torchjd/aggregation/mgda.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,14 @@ def __init__(self, epsilon: float, max_iters: int):
5656
self.epsilon = epsilon
5757
self.max_iters = max_iters
5858

59-
def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor:
60-
gramian = compute_gramian(matrix)
61-
device = matrix.device
62-
dtype = matrix.dtype
59+
def _frank_wolfe_solver(self, gramian: Tensor) -> Tensor:
60+
device = gramian.device
61+
dtype = gramian.dtype
6362

64-
alpha = torch.ones(matrix.shape[0], device=device, dtype=dtype) / matrix.shape[0]
63+
alpha = torch.ones(gramian.shape[0], device=device, dtype=dtype) / gramian.shape[0]
6564
for i in range(self.max_iters):
6665
t = torch.argmin(gramian @ alpha)
67-
e_t = torch.zeros(matrix.shape[0], device=device, dtype=dtype)
66+
e_t = torch.zeros(gramian.shape[0], device=device, dtype=dtype)
6867
e_t[t] = 1.0
6968
a = alpha @ (gramian @ e_t)
7069
b = alpha @ (gramian @ alpha)
@@ -81,5 +80,6 @@ def _frank_wolfe_solver(self, matrix: Tensor) -> Tensor:
8180
return alpha
8281

8382
def forward(self, matrix: Tensor) -> Tensor:
84-
weights = self._frank_wolfe_solver(matrix)
83+
gramian = compute_gramian(matrix)
84+
weights = self._frank_wolfe_solver(gramian)
8585
return weights

0 commit comments

Comments
 (0)