We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a781602 commit 58c4866Copy full SHA for 58c4866
src/torchjd/aggregation/aligned_mtl.py
@@ -91,16 +91,14 @@ def __init__(self, weighting: _Weighting):
91
def forward(self, matrix: Tensor) -> Tensor:
92
w = self.weighting(matrix)
93
94
- G = matrix.T
95
- B = self._compute_balance_transformation(G)
+ M = matrix @ matrix.T
+ B = self._compute_balance_transformation(M)
96
alpha = B @ w
97
98
return alpha
99
100
@staticmethod
101
- def _compute_balance_transformation(G: Tensor) -> Tensor:
102
- M = G.T @ G
103
-
+ def _compute_balance_transformation(M: Tensor) -> Tensor:
104
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
105
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
106
rank = sum(lambda_ > tol)
0 commit comments