Skip to content

Commit 75b8bf8

Browse files
fix(aggregation): Fix CAGrad norm typing (#355)
* Call .item() on norm arrays to turn them into float explicitly * Make explicit the order used for the norm
1 parent 6c09b99 commit 75b8bf8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/torchjd/aggregation/_cagrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def forward(self, gramian: Tensor) -> Tensor:
9494

9595
dimension = gramian.shape[0]
9696
reduced_g_0 = reduced_array.T @ np.ones(dimension) / dimension
97-
sqrt_phi = self.c * np.linalg.norm(reduced_g_0, 2)
97+
sqrt_phi = self.c * np.linalg.norm(reduced_g_0, 2).item()
9898

9999
w = cp.Variable(shape=dimension)
100100
cost = (reduced_array @ reduced_g_0).T @ w + sqrt_phi * cp.norm(reduced_array.T @ w, 2)
@@ -103,7 +103,7 @@ def forward(self, gramian: Tensor) -> Tensor:
103103
problem.solve(cp.CLARABEL)
104104
w_opt = w.value
105105

106-
g_w_norm = np.linalg.norm(reduced_array.T @ w_opt)
106+
g_w_norm = np.linalg.norm(reduced_array.T @ w_opt, 2).item()
107107
if g_w_norm >= self.norm_eps:
108108
weight_array = np.ones(dimension) / dimension
109109
weight_array += (sqrt_phi / g_w_norm) * w_opt

0 commit comments

Comments
 (0)