Skip to content

Commit 6afcc92

Browse files
committed
Make dependence on gramian explicit in IMTL-G
1 parent f63525b commit 6afcc92

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/torchjd/aggregation/imtl_g.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,13 @@ class _IMTLGWeighting(_Weighting):
3838
"""
3939

4040
def forward(self, matrix: Tensor) -> Tensor:
41-
d = torch.linalg.norm(matrix, dim=1)
42-
v = torch.linalg.pinv(matrix @ matrix.T) @ d
41+
gramian = matrix @ matrix.T
42+
return self._compute_from_gramian(gramian)
43+
44+
@staticmethod
45+
def _compute_from_gramian(gramian: Tensor) -> Tensor:
46+
d = torch.sqrt(torch.diagonal(gramian))
47+
v = torch.linalg.pinv(gramian) @ d
4348
v_sum = v.sum()
4449

4550
if v_sum.abs() < 1e-12:

0 commit comments

Comments
 (0)