Skip to content

Commit 6f8f3c8

Browse files
committed
Move normalization of gramian into _from_gramian
1 parent b923410 commit 6f8f3c8

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/torchjd/aggregation/cagrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ def __init__(self, c: float, norm_eps: float):
7272
self.norm_eps = norm_eps
7373

7474
def forward(self, matrix: Tensor) -> Tensor:
75-
gramian = normalize(compute_gramian(matrix), self.norm_eps)
75+
gramian = compute_gramian(matrix)
7676
return self._compute_from_gramian(gramian)
7777

7878
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
79-
U, S, _ = torch.svd(gramian)
79+
U, S, _ = torch.svd(normalize(gramian, self.norm_eps))
8080

8181
reduced_matrix = U @ S.sqrt().diag()
8282
reduced_array = reduced_matrix.cpu().detach().numpy().astype(np.float64)

0 commit comments

Comments
 (0)