Skip to content

Commit 3a3c459

Browse files
refactor(aggregation): Explicitly cast arrays to float64 (#236)
When casting a Tensor to a numpy array, the obtained array is typically typed as float32 while numpy typically uses float64. This makes the cast explicit.
1 parent 95bb00d commit 3a3c459

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

src/torchjd/aggregation/cagrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def forward(self, matrix: Tensor) -> Tensor:
7676
U, S, _ = torch.svd(gramian)
7777

7878
reduced_matrix = U @ S.sqrt().diag()
79-
reduced_array = reduced_matrix.cpu().detach().numpy()
79+
reduced_array = reduced_matrix.cpu().detach().numpy().astype(np.float64)
8080

8181
dimension = matrix.shape[0]
8282
reduced_g_0 = reduced_array.T @ np.ones(dimension) / dimension

src/torchjd/aggregation/dualproj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ def __init__(
107107

108108
def forward(self, matrix: Tensor) -> Tensor:
109109
weights = self.weighting(matrix)
110-
weights_array = weights.cpu().detach().numpy()
110+
weights_array = weights.cpu().detach().numpy().astype(np.float64)
111111

112112
gramian = _compute_normalized_gramian(matrix, self.norm_eps)
113-
gramian_array = gramian.cpu().detach().numpy()
113+
gramian_array = gramian.cpu().detach().numpy().astype(np.float64)
114114
dimension = gramian.shape[0]
115115

116116
# Because of numerical errors, `gramian_array` might have slightly negative eigenvalue(s),

src/torchjd/aggregation/upgrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def forward(self, matrix: Tensor) -> Tensor:
109109

110110
def _compute_lagrangian(self, matrix: Tensor, weights: Tensor) -> Tensor:
111111
gramian = _compute_normalized_gramian(matrix, self.norm_eps)
112-
gramian_array = gramian.cpu().detach().numpy()
112+
gramian_array = gramian.cpu().detach().numpy().astype(np.float64)
113113
dimension = gramian.shape[0]
114114

115115
regularization_array = self.reg_eps * np.eye(dimension)

0 commit comments

Comments
 (0)