Skip to content

Commit 58c4866

Browse files
committed
Make dependence on gramian explicit in AlignedMTL
1 parent a781602 commit 58c4866

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/torchjd/aggregation/aligned_mtl.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,14 @@ def __init__(self, weighting: _Weighting):
9191
def forward(self, matrix: Tensor) -> Tensor:
9292
w = self.weighting(matrix)
9393

94-
G = matrix.T
95-
B = self._compute_balance_transformation(G)
94+
M = matrix @ matrix.T
95+
B = self._compute_balance_transformation(M)
9696
alpha = B @ w
9797

9898
return alpha
9999

100100
@staticmethod
101-
def _compute_balance_transformation(G: Tensor) -> Tensor:
102-
M = G.T @ G
103-
101+
def _compute_balance_transformation(M: Tensor) -> Tensor:
104102
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
105103
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
106104
rank = sum(lambda_ > tol)

0 commit comments

Comments
 (0)