We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f63525b commit 6afcc92Copy full SHA for 6afcc92
src/torchjd/aggregation/imtl_g.py
@@ -38,8 +38,13 @@ class _IMTLGWeighting(_Weighting):
38
"""
39
40
def forward(self, matrix: Tensor) -> Tensor:
41
- d = torch.linalg.norm(matrix, dim=1)
42
- v = torch.linalg.pinv(matrix @ matrix.T) @ d
+ gramian = matrix @ matrix.T
+ return self._compute_from_gramian(gramian)
43
+
44
+ @staticmethod
45
+ def _compute_from_gramian(gramian: Tensor) -> Tensor:
46
+ d = torch.sqrt(torch.diagonal(gramian))
47
+ v = torch.linalg.pinv(gramian) @ d
48
v_sum = v.sum()
49
50
if v_sum.abs() < 1e-12:
0 commit comments