Skip to content

Commit 19a2a4d

Browse files
committed
Rename Frank-Wolfe solver into _from_gramian
1 parent 6f8f3c8 commit 19a2a4d

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

src/torchjd/aggregation/mgda.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@ def __init__(self, epsilon: float, max_iters: int):
5656
self.epsilon = epsilon
5757
self.max_iters = max_iters
5858

59-
def _frank_wolfe_solver(self, gramian: Tensor) -> Tensor:
59+
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
60+
"""
61+
This is the Frank-Wolfe solver in Algorithm 2 of `Multi-Task Learning as Multi-Objective
62+
Optimization
63+
<https://proceedings.neurips.cc/paper_files/paper/2018/file/432aca3a1e345e339f35a30c8f65edce-Paper.pdf>`_.
64+
"""
6065
device = gramian.device
6166
dtype = gramian.dtype
6267

@@ -81,5 +86,5 @@ def _frank_wolfe_solver(self, gramian: Tensor) -> Tensor:
8186

8287
def forward(self, matrix: Tensor) -> Tensor:
8388
gramian = compute_gramian(matrix)
84-
weights = self._frank_wolfe_solver(gramian)
89+
weights = self._compute_from_gramian(gramian)
8590
return weights

0 commit comments

Comments
 (0)