Skip to content

Commit f63525b

Browse files
committed
Make dependence on gramian explicit in CAGrad
1 parent 58c4866 commit f63525b

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/torchjd/aggregation/cagrad.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,15 @@ def __init__(self, c: float, norm_eps: float):
7373

7474
def forward(self, matrix: Tensor) -> Tensor:
7575
gramian = normalize(compute_gramian(matrix), self.norm_eps)
76+
return self._compute_from_gramian(gramian)
77+
78+
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
7679
U, S, _ = torch.svd(gramian)
7780

7881
reduced_matrix = U @ S.sqrt().diag()
7982
reduced_array = reduced_matrix.cpu().detach().numpy().astype(np.float64)
8083

81-
dimension = matrix.shape[0]
84+
dimension = gramian.shape[0]
8285
reduced_g_0 = reduced_array.T @ np.ones(dimension) / dimension
8386
sqrt_phi = self.c * np.linalg.norm(reduced_g_0, 2)
8487

@@ -97,6 +100,6 @@ def forward(self, matrix: Tensor) -> Tensor:
97100
# We are approximately on the pareto front
98101
weight_array = np.zeros(dimension)
99102

100-
weights = torch.from_numpy(weight_array).to(device=matrix.device, dtype=matrix.dtype)
103+
weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)
101104

102105
return weights

0 commit comments

Comments
 (0)